summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/kotlin/MailServer.kt30
-rw-r--r--src/main/kotlin/Main.kt24
-rw-r--r--src/main/kotlin/Protocol.kt157
-rw-r--r--src/main/kotlin/RFC822Parser.kt58
-rw-r--r--src/main/kotlin/SMTPProtocol.kt144
-rw-r--r--src/test/kotlin/moe/nea89/mail/util/ProtocolSpec.kt41
6 files changed, 300 insertions, 154 deletions
diff --git a/src/main/kotlin/MailServer.kt b/src/main/kotlin/MailServer.kt
new file mode 100644
index 0000000..94a53c9
--- /dev/null
+++ b/src/main/kotlin/MailServer.kt
@@ -0,0 +1,30 @@
+import kotlinx.coroutines.*
+import java.net.ServerSocket
+import java.net.Socket
+import kotlin.coroutines.EmptyCoroutineContext
+
+class MailServer(
+ val localhostName: String,
+ val scope: CoroutineScope = CoroutineScope(EmptyCoroutineContext)
+) {
+
+ fun createAndLaunchHandlerFor(socket: Socket): Job {
+ val protocol = SMTPReceiveProtocol(localhostName, socket.inetAddress)
+ return protocol.executeAsync(scope + CoroutineName("connection handler from ${socket.inetAddress}"), Protocol.IO.FromSocket(socket))
+ }
+
+ suspend fun createServer(port: Int) {
+ listenToServerSocket(ServerSocket(port))
+ }
+
+ suspend fun listenToServerSocket(serverSocket: ServerSocket) {
+ withContext(Dispatchers.Unconfined) {
+ while (true) {
+ val newIncomingConnection =
+ withContext(Dispatchers.IO) { serverSocket.accept() }
+ createAndLaunchHandlerFor(newIncomingConnection)
+ }
+ }
+ }
+
+}
diff --git a/src/main/kotlin/Main.kt b/src/main/kotlin/Main.kt
index 577c3fb..7f19f90 100644
--- a/src/main/kotlin/Main.kt
+++ b/src/main/kotlin/Main.kt
@@ -1,31 +1,21 @@
-import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.Dispatchers
-import kotlinx.coroutines.Job
import kotlinx.coroutines.runBlocking
-import java.net.ServerSocket
+import kotlin.system.exitProcess
object Main {
@JvmStatic
fun main(args: Array<String>) {
if (args.size != 1) {
- System.err.println("Use ./javamailteste run/install")
+ System.err.println("Use ./javamailteste run")
+ exitProcess(1)
}
when (args[0]) {
- "run" -> runServer(2500)
+ "run" -> runServer(args.getOrElse(1) { "2500" }.toInt())
}
}
- fun runServer(port: Int) = runBlocking(Dispatchers.Default) {
- val ss = ServerSocket(port)
- val jobs = mutableListOf<Job>()
- println("Starting SMTP socket on port $port")
- while (true) {
- val scope = CoroutineScope(Dispatchers.Default)
- val socket = with(Dispatchers.IO) { ss.accept() }
- val prot = SMTPReceiveProtocol("nea89.moe", socket.inetAddress)
- jobs.add(prot.executeAsync(scope, Protocol.IO.FromSocket(socket)))
- println("jobs: $jobs")
- }
+ fun runServer(port: Int) {
+ val mailServer = MailServer("nea89.moe")
+ runBlocking { mailServer.createServer(port) }
}
}
diff --git a/src/main/kotlin/Protocol.kt b/src/main/kotlin/Protocol.kt
new file mode 100644
index 0000000..019e172
--- /dev/null
+++ b/src/main/kotlin/Protocol.kt
@@ -0,0 +1,157 @@
+import kotlinx.coroutines.*
+import java.io.InputStream
+import java.io.OutputStream
+import java.net.Socket
+
+class Invalidatable {
+ var isInvalid: Boolean = false
+ fun checkValid() {
+ if (isInvalid)
+ throw IllegalStateException("Accessed invalid object")
+ }
+
+ fun invalidate() {
+ isInvalid = true
+ }
+}
+
+abstract class Protocol {
+ interface IO {
+ fun isOpen(): Boolean
+ suspend fun pushBack(data: ByteArray)
+ suspend fun readBytes(into: ByteArray): Int
+ suspend fun send(bytes: ByteArray)
+ suspend fun close()
+
+ class FromSocket(val socket: Socket) : FromStreams(socket.getInputStream(), socket.getOutputStream()) {
+ override suspend fun close() {
+ super.close()
+ withContext(Dispatchers.IO) {
+ socket.close()
+ }
+ }
+ }
+
+ open class FromStreams(val inputStream: InputStream, val outputStream: OutputStream) : IO {
+ private val i = Invalidatable()
+ override fun isOpen(): Boolean =
+ !i.isInvalid
+
+
+ val readBuffer = mutableListOf<ByteArray>()
+ override suspend fun pushBack(data: ByteArray) {
+ i.checkValid()
+ if (data.isEmpty()) return
+ readBuffer.add(0, data)
+ }
+
+ override suspend fun send(bytes: ByteArray) {
+ i.checkValid()
+ withContext(Dispatchers.IO) {
+ outputStream.write(bytes)
+ outputStream.flush()
+ }
+ }
+
+ override suspend fun close() {
+ i.checkValid()
+ i.invalidate()
+ withContext(Dispatchers.IO) {
+ inputStream.close()
+ outputStream.close()
+ }
+ }
+
+ override suspend fun readBytes(into: ByteArray): Int {
+ i.checkValid()
+ val rb = readBuffer.removeFirstOrNull()
+ if (rb != null) {
+ val w = minOf(rb.size, into.size)
+ rb.copyInto(into, 0, 0, w)
+ return w
+ }
+ return withContext(Dispatchers.IO) {
+ inputStream.read(into)
+ }
+ }
+ }
+ }
+
+ protected abstract suspend fun IO.execute()
+
+ fun executeAsync(scope: CoroutineScope, io: Protocol.IO): Job {
+ return scope.launch {
+ io.execute()
+ }
+ }
+}
+
+suspend fun Protocol.IO.readAll(): ByteArray {
+ var ret = ByteArray(0)
+ val buffer = ByteArray(4096)
+ while (true) {
+ val read = readBytes(buffer)
+ if (read == -1) {
+ return ret
+ }
+ val oldSize = ret.size
+ ret = ret.copyOf(oldSize + read)
+ buffer.copyInto(ret, oldSize, endIndex = read)
+ }
+}
+
+suspend fun Protocol.IO.send(string: String) = send(string.encodeToByteArray())
+suspend fun Protocol.IO.readLine(): String = readUntil(CRLF).decodeToString()
+suspend fun Protocol.IO.readUntil(search: ByteArray, errorOnEOF: Boolean = true): ByteArray {
+ var ret = ByteArray(0)
+ val buffer = ByteArray(4096)
+ while (true) {
+ val read = readBytes(buffer)
+ if (read == -1) {
+ if (errorOnEOF) {
+ throw IllegalStateException("End of Protocol.IO")
+ } else {
+ return ret
+ }
+ }
+ val oldSize = ret.size
+ ret = ret.copyOf(oldSize + read)
+ buffer.copyInto(ret, oldSize, endIndex = read)
+ val firstFoundIndex = ret.findSubarray(search, startIndex = (oldSize - search.size - 1).coerceAtLeast(0))
+ if (firstFoundIndex != null) {
+ pushBack(ret.copyOfRange(firstFoundIndex + search.size, ret.size))
+ return ret.copyOf(firstFoundIndex)
+ }
+ }
+}
+
+val CRLF = "\r\n".encodeToByteArray()
+
+fun ByteArray.findSubarray(subarray: ByteArray, startIndex: Int = 0): Int? {
+ if (subarray.size > size - startIndex) return null
+ for (i in startIndex..size - subarray.size) {
+ var isEqual = true
+ for (j in subarray.indices) {
+ if (this[i + j] != subarray[j]) {
+ isEqual = false
+ break
+ }
+ }
+ if (isEqual) {
+ return i
+ }
+ }
+ return null
+}
+
+suspend fun Protocol.IO.pushBack(string: String) = pushBack(string.encodeToByteArray())
+suspend fun Protocol.IO.lookahead(string: String): Boolean = lookahead(string.encodeToByteArray())
+suspend fun Protocol.IO.lookahead(bytes: ByteArray): Boolean {
+ val buffer = ByteArray(bytes.size)
+ val read = readBytes(buffer)
+ if (read != bytes.size || !buffer.contentEquals(bytes)) {
+ pushBack(buffer.copyOf(read))
+ return false
+ }
+ return true
+}
diff --git a/src/main/kotlin/RFC822Parser.kt b/src/main/kotlin/RFC822Parser.kt
new file mode 100644
index 0000000..4a6ce74
--- /dev/null
+++ b/src/main/kotlin/RFC822Parser.kt
@@ -0,0 +1,58 @@
+import kotlinx.coroutines.MainScope
+import java.io.ByteArrayOutputStream
+import java.io.File
+
+suspend fun main() {
+ val io = Protocol.IO.FromStreams(File("samplemail.txt").inputStream(), ByteArrayOutputStream())
+ val rfc = RFC822Parser()
+ rfc.executeAsync(MainScope(), io).join()
+ println(rfc.headers)
+ println("Content: ${rfc.content.decodeToString()}")
+}
+
+class RFC822Parser() : Protocol() {
+
+ class Header(val field: String, val value: ByteArray)
+
+ private val _headers = mutableListOf<Header>()
+ val headers: List<Header> get() = _headers
+ lateinit var content: ByteArray
+ private set
+
+ override suspend fun IO.execute() {
+ while (parseField()) Unit
+ content = readAll()
+ }
+
+ private suspend fun IO.parseField(): Boolean {
+ val read = readUntil(CRLF)
+ if (read.contentEquals(CRLF)) {
+ return false
+ }
+ val indexOfColon = read.indexOf(':'.code.toByte())
+ if (indexOfColon == -1) {
+ throw IllegalStateException("Expected : in MIME header")
+ }
+ val headerField = read.sliceArray(0 until indexOfColon).decodeToString().trim()
+ var data = read.sliceArray(indexOfColon + 1 until read.size)
+ while (true) {
+ val nextLine = readUntil(CRLF)
+ if (nextLine.isNotEmpty() && isWhitespaceCharacter(nextLine[0])) {
+ val oldSize = data.size
+ data = data.copyOf(oldSize + nextLine.size)
+ nextLine.copyInto(data, oldSize)
+ } else {
+ pushBack(CRLF)
+ pushBack(nextLine)
+ break
+ }
+ }
+ _headers.add(Header(headerField, data))
+ return true
+ }
+
+ fun isWhitespaceCharacter(char: Byte): Boolean {
+ val char = char.toInt().toChar()
+ return char == ' ' || char == '\t'
+ }
+}
diff --git a/src/main/kotlin/SMTPProtocol.kt b/src/main/kotlin/SMTPProtocol.kt
index 6f01bb2..e5da601 100644
--- a/src/main/kotlin/SMTPProtocol.kt
+++ b/src/main/kotlin/SMTPProtocol.kt
@@ -1,133 +1,8 @@
-import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.Dispatchers
-import kotlinx.coroutines.Job
-import kotlinx.coroutines.launch
-import java.io.InputStream
-import java.io.OutputStream
import java.net.InetAddress
-import java.net.Socket
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
-class Invalidatable {
- var isInvalid: Boolean = false
- fun checkValid() {
- if (isInvalid)
- throw IllegalStateException("Accessed invalid object")
- }
-
- fun invalidate() {
- isInvalid = true
- }
-}
-
-abstract class Protocol {
- interface IO {
- fun isOpen(): Boolean
- suspend fun pushBack(data: ByteArray)
- suspend fun readBytes(into: ByteArray): Int
- suspend fun send(bytes: ByteArray)
- suspend fun close()
- class FromSocket(val socket: Socket) : FromStreams(socket.getInputStream(), socket.getOutputStream()) {
- override suspend fun close() {
- super.close()
- with(Dispatchers.IO) {
- socket.close()
- }
- }
- }
-
- open class FromStreams(val inputStream: InputStream, val outputStream: OutputStream) : IO {
- private val i = Invalidatable()
- override fun isOpen(): Boolean =
- !i.isInvalid
-
-
- val readBuffer = mutableListOf<ByteArray>()
- override suspend fun pushBack(data: ByteArray) {
- i.checkValid()
- if (data.isEmpty()) return
- readBuffer.add(0, data)
- }
-
- override suspend fun send(bytes: ByteArray) {
- i.checkValid()
- with(Dispatchers.IO) {
- outputStream.write(bytes)
- outputStream.flush()
- }
- }
-
- override suspend fun close() {
- i.checkValid()
- i.invalidate()
- with(Dispatchers.IO) {
- inputStream.close()
- outputStream.close()
- }
- }
-
- override suspend fun readBytes(into: ByteArray): Int {
- i.checkValid()
- val rb = readBuffer.removeFirstOrNull()
- if (rb != null) {
- val w = minOf(rb.size, into.size)
- rb.copyInto(into, 0, 0, w)
- return w
- }
- return with(Dispatchers.IO) {
- inputStream.read(into)
- }
- }
- }
- }
-
- protected abstract suspend fun IO.execute()
-
- fun executeAsync(scope: CoroutineScope, io: Protocol.IO): Job {
- return scope.launch {
- io.execute()
- }
- }
-}
-
-suspend fun Protocol.IO.send(string: String) = send(string.encodeToByteArray())
-suspend fun Protocol.IO.readLine(): String {
- val y = mutableListOf<String>()
- while (true) {
- val buffer = ByteArray(4096)
- val read = readBytes(buffer)
- val i = buffer.findCRLF()
- if (i in 0 until read) {
- y.add(buffer.copyOfRange(0, i).decodeToString())
- pushBack(buffer.copyOfRange(i + 2, read))
- break
- } else {
- y.add(buffer.copyOfRange(0, read).decodeToString())
- }
- }
- return y.joinToString("")
-}
-
-private fun ByteArray.findCRLF(): Int {
- return this.asSequence().zipWithNext().withIndex().firstOrNull { (idx, v) ->
- (v.first == '\r'.code.toByte()) and (v.second == '\n'.code.toByte())
- }?.index ?: -1
-}
-
-suspend fun Protocol.IO.pushBack(string: String) = pushBack(string.encodeToByteArray())
-suspend fun Protocol.IO.lookahead(string: String): Boolean = lookahead(string.encodeToByteArray())
-suspend fun Protocol.IO.lookahead(bytes: ByteArray): Boolean {
- val buffer = ByteArray(bytes.size)
- val read = readBytes(buffer)
- if (read != bytes.size || !buffer.contentEquals(bytes)) {
- pushBack(buffer.copyOf(read))
- return false
- }
- return true
-}
-
@OptIn(ExperimentalContracts::class)
class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) : Protocol() {
@@ -136,20 +11,20 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) :
var matched = false
- suspend inline fun command(vararg name: String, block: IO.(String) -> Unit) {
+ suspend inline fun command(vararg name: String, block: suspend IO.(String) -> Unit) {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
for (n in name) commandOnce(n, block)
}
- suspend inline fun commandOnce(name: String, block: IO.(String) -> Unit) {
+ suspend inline fun commandOnce(name: String, block: suspend IO.(String) -> Unit) {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
if (matched) return
- if (!line.startsWith(name)) return
+ if (!line.startsWith(name, ignoreCase = true)) return
matched = true
block(line.substring(name.length).trimStart())
}
- suspend inline fun otherwise(block: IO.(String) -> Unit) {
+ suspend inline fun otherwise(block: suspend IO.(String) -> Unit) {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
if (matched) return
matched = true
@@ -157,7 +32,7 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) :
}
}
- suspend inline fun IO.commands(line: String, block: Commands.() -> Unit) {
+ suspend inline fun IO.commands(line: String, block: suspend Commands.() -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
Commands(line, this).block()
}
@@ -202,13 +77,8 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) :
}
command("DATA") {
send("354 Enter mail, end with \".\" on a line by itself\r\n")
- var text = ""
- while (true) {
- val tmp = readLine()
- if (tmp == ".") break
- text += tmp + "\n"
- }
- messages.add(Mail(trans.sender!!, trans.recipients.toList(), text))
+ var text = readUntil("\r\n.\r\n".encodeToByteArray())
+ messages.add(Mail(trans.sender!!, trans.recipients.toList(), text.decodeToString()))
trans.reset()
send("250 Message accepted for delivery\r\n")
}
diff --git a/src/test/kotlin/moe/nea89/mail/util/ProtocolSpec.kt b/src/test/kotlin/moe/nea89/mail/util/ProtocolSpec.kt
new file mode 100644
index 0000000..9a4bcbf
--- /dev/null
+++ b/src/test/kotlin/moe/nea89/mail/util/ProtocolSpec.kt
@@ -0,0 +1,41 @@
+package moe.nea89.mail.util
+
+import Protocol
+import findSubarray
+import io.kotest.core.spec.style.FreeSpec
+import io.kotest.datatest.withData
+import readUntil
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+
+
+data class ByteArrayString(val hay: String, val needle: String, val position: Int?, val offset: Int = 0)
+class ProtocolSpec : FreeSpec({
+ "ByteArray.findSubarray" - {
+ "should work correctly" - {
+ withData(
+ listOf(
+ ByteArrayString("abc", "a", 0),
+ ByteArrayString("abca", "a", 0),
+ ByteArrayString("abca", "a", 3, 2),
+ ByteArrayString("abca", "a", 3, 3),
+ ByteArrayString("bc", "a", null),
+ ByteArrayString("acbcab", "ab", 4),
+ ByteArrayString("abcbcab", "ab", 5, 1),
+ )
+ ) { (hay, needle, position, offset) ->
+ val hay = hay.encodeToByteArray()
+ val needle = needle.encodeToByteArray()
+ assert(hay.findSubarray(needle, offset) == position)
+ }
+ }
+ }
+ "Protocol.IO" - {
+ "readUntil" {
+ val data = ("a".repeat(9) + "01" + "b".repeat(10) + "02").encodeToByteArray()
+ val protIO = Protocol.IO.FromStreams(ByteArrayInputStream(data), ByteArrayOutputStream())
+ assert(protIO.readUntil("01".encodeToByteArray()).decodeToString() == "a".repeat(9))
+ assert(protIO.readUntil("02".encodeToByteArray()).decodeToString() == "b".repeat(10))
+ }
+ }
+})