diff options
Diffstat (limited to 'src/main/kotlin')
-rw-r--r-- | src/main/kotlin/MailServer.kt | 30 | ||||
-rw-r--r-- | src/main/kotlin/Main.kt | 24 | ||||
-rw-r--r-- | src/main/kotlin/Protocol.kt | 157 | ||||
-rw-r--r-- | src/main/kotlin/RFC822Parser.kt | 58 | ||||
-rw-r--r-- | src/main/kotlin/SMTPProtocol.kt | 144 |
5 files changed, 259 insertions, 154 deletions
diff --git a/src/main/kotlin/MailServer.kt b/src/main/kotlin/MailServer.kt new file mode 100644 index 0000000..94a53c9 --- /dev/null +++ b/src/main/kotlin/MailServer.kt @@ -0,0 +1,30 @@ +import kotlinx.coroutines.* +import java.net.ServerSocket +import java.net.Socket +import kotlin.coroutines.EmptyCoroutineContext + +class MailServer( + val localhostName: String, + val scope: CoroutineScope = CoroutineScope(EmptyCoroutineContext) +) { + + fun createAndLaunchHandlerFor(socket: Socket): Job { + val protocol = SMTPReceiveProtocol(localhostName, socket.inetAddress) + return protocol.executeAsync(scope + CoroutineName("connection handler from ${socket.inetAddress}"), Protocol.IO.FromSocket(socket)) + } + + suspend fun createServer(port: Int) { + listenToServerSocket(ServerSocket(port)) + } + + suspend fun listenToServerSocket(serverSocket: ServerSocket) { + withContext(Dispatchers.Unconfined) { + while (true) { + val newIncomingConnection = + withContext(Dispatchers.IO) { serverSocket.accept() } + createAndLaunchHandlerFor(newIncomingConnection) + } + } + } + +} diff --git a/src/main/kotlin/Main.kt b/src/main/kotlin/Main.kt index 577c3fb..7f19f90 100644 --- a/src/main/kotlin/Main.kt +++ b/src/main/kotlin/Main.kt @@ -1,31 +1,21 @@ -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job import kotlinx.coroutines.runBlocking -import java.net.ServerSocket +import kotlin.system.exitProcess object Main { @JvmStatic fun main(args: Array<String>) { if (args.size != 1) { - System.err.println("Use ./javamailteste run/install") + System.err.println("Use ./javamailteste run") + exitProcess(1) } when (args[0]) { - "run" -> runServer(2500) + "run" -> runServer(args.getOrElse(1) { "2500" }.toInt()) } } - fun runServer(port: Int) = runBlocking(Dispatchers.Default) { - val ss = ServerSocket(port) - val jobs = mutableListOf<Job>() - println("Starting SMTP socket on port $port") - while (true) { - val scope = CoroutineScope(Dispatchers.Default) - val socket = with(Dispatchers.IO) { ss.accept() } - val prot = SMTPReceiveProtocol("nea89.moe", socket.inetAddress) - jobs.add(prot.executeAsync(scope, Protocol.IO.FromSocket(socket))) - println("jobs: $jobs") - } + fun runServer(port: Int) { + val mailServer = MailServer("nea89.moe") + runBlocking { mailServer.createServer(port) } } } diff --git a/src/main/kotlin/Protocol.kt b/src/main/kotlin/Protocol.kt new file mode 100644 index 0000000..019e172 --- /dev/null +++ b/src/main/kotlin/Protocol.kt @@ -0,0 +1,157 @@ +import kotlinx.coroutines.* +import java.io.InputStream +import java.io.OutputStream +import java.net.Socket + +class Invalidatable { + var isInvalid: Boolean = false + fun checkValid() { + if (isInvalid) + throw IllegalStateException("Accessed invalid object") + } + + fun invalidate() { + isInvalid = true + } +} + +abstract class Protocol { + interface IO { + fun isOpen(): Boolean + suspend fun pushBack(data: ByteArray) + suspend fun readBytes(into: ByteArray): Int + suspend fun send(bytes: ByteArray) + suspend fun close() + + class FromSocket(val socket: Socket) : FromStreams(socket.getInputStream(), socket.getOutputStream()) { + override suspend fun close() { + super.close() + withContext(Dispatchers.IO) { + socket.close() + } + } + } + + open class FromStreams(val inputStream: InputStream, val outputStream: OutputStream) : IO { + private val i = Invalidatable() + override fun isOpen(): Boolean = + !i.isInvalid + + + val readBuffer = mutableListOf<ByteArray>() + override suspend fun pushBack(data: ByteArray) { + i.checkValid() + if (data.isEmpty()) return + readBuffer.add(0, data) + } + + override suspend fun send(bytes: ByteArray) { + i.checkValid() + withContext(Dispatchers.IO) { + outputStream.write(bytes) + outputStream.flush() + } + } + + override suspend fun close() { + i.checkValid() + i.invalidate() + withContext(Dispatchers.IO) { + inputStream.close() + outputStream.close() + } + } + + override suspend fun readBytes(into: ByteArray): Int { + i.checkValid() + val rb = readBuffer.removeFirstOrNull() + if (rb != null) { + val w = minOf(rb.size, into.size) + rb.copyInto(into, 0, 0, w) + return w + } + return withContext(Dispatchers.IO) { + inputStream.read(into) + } + } + } + } + + protected abstract suspend fun IO.execute() + + fun executeAsync(scope: CoroutineScope, io: Protocol.IO): Job { + return scope.launch { + io.execute() + } + } +} + +suspend fun Protocol.IO.readAll(): ByteArray { + var ret = ByteArray(0) + val buffer = ByteArray(4096) + while (true) { + val read = readBytes(buffer) + if (read == -1) { + return ret + } + val oldSize = ret.size + ret = ret.copyOf(oldSize + read) + buffer.copyInto(ret, oldSize, endIndex = read) + } +} + +suspend fun Protocol.IO.send(string: String) = send(string.encodeToByteArray()) +suspend fun Protocol.IO.readLine(): String = readUntil(CRLF).decodeToString() +suspend fun Protocol.IO.readUntil(search: ByteArray, errorOnEOF: Boolean = true): ByteArray { + var ret = ByteArray(0) + val buffer = ByteArray(4096) + while (true) { + val read = readBytes(buffer) + if (read == -1) { + if (errorOnEOF) { + throw IllegalStateException("End of Protocol.IO") + } else { + return ret + } + } + val oldSize = ret.size + ret = ret.copyOf(oldSize + read) + buffer.copyInto(ret, oldSize, endIndex = read) + val firstFoundIndex = ret.findSubarray(search, startIndex = (oldSize - search.size - 1).coerceAtLeast(0)) + if (firstFoundIndex != null) { + pushBack(ret.copyOfRange(firstFoundIndex + search.size, ret.size)) + return ret.copyOf(firstFoundIndex) + } + } +} + +val CRLF = "\r\n".encodeToByteArray() + +fun ByteArray.findSubarray(subarray: ByteArray, startIndex: Int = 0): Int? { + if (subarray.size > size - startIndex) return null + for (i in startIndex..size - subarray.size) { + var isEqual = true + for (j in subarray.indices) { + if (this[i + j] != subarray[j]) { + isEqual = false + break + } + } + if (isEqual) { + return i + } + } + return null +} + +suspend fun Protocol.IO.pushBack(string: String) = pushBack(string.encodeToByteArray()) +suspend fun Protocol.IO.lookahead(string: String): Boolean = lookahead(string.encodeToByteArray()) +suspend fun Protocol.IO.lookahead(bytes: ByteArray): Boolean { + val buffer = ByteArray(bytes.size) + val read = readBytes(buffer) + if (read != bytes.size || !buffer.contentEquals(bytes)) { + pushBack(buffer.copyOf(read)) + return false + } + return true +} diff --git a/src/main/kotlin/RFC822Parser.kt b/src/main/kotlin/RFC822Parser.kt new file mode 100644 index 0000000..4a6ce74 --- /dev/null +++ b/src/main/kotlin/RFC822Parser.kt @@ -0,0 +1,58 @@ +import kotlinx.coroutines.MainScope +import java.io.ByteArrayOutputStream +import java.io.File + +suspend fun main() { + val io = Protocol.IO.FromStreams(File("samplemail.txt").inputStream(), ByteArrayOutputStream()) + val rfc = RFC822Parser() + rfc.executeAsync(MainScope(), io).join() + println(rfc.headers) + println("Content: ${rfc.content.decodeToString()}") +} + +class RFC822Parser() : Protocol() { + + class Header(val field: String, val value: ByteArray) + + private val _headers = mutableListOf<Header>() + val headers: List<Header> get() = _headers + lateinit var content: ByteArray + private set + + override suspend fun IO.execute() { + while (parseField()) Unit + content = readAll() + } + + private suspend fun IO.parseField(): Boolean { + val read = readUntil(CRLF) + if (read.contentEquals(CRLF)) { + return false + } + val indexOfColon = read.indexOf(':'.code.toByte()) + if (indexOfColon == -1) { + throw IllegalStateException("Expected : in MIME header") + } + val headerField = read.sliceArray(0 until indexOfColon).decodeToString().trim() + var data = read.sliceArray(indexOfColon + 1 until read.size) + while (true) { + val nextLine = readUntil(CRLF) + if (nextLine.isNotEmpty() && isWhitespaceCharacter(nextLine[0])) { + val oldSize = data.size + data = data.copyOf(oldSize + nextLine.size) + nextLine.copyInto(data, oldSize) + } else { + pushBack(CRLF) + pushBack(nextLine) + break + } + } + _headers.add(Header(headerField, data)) + return true + } + + fun isWhitespaceCharacter(char: Byte): Boolean { + val char = char.toInt().toChar() + return char == ' ' || char == '\t' + } +} diff --git a/src/main/kotlin/SMTPProtocol.kt b/src/main/kotlin/SMTPProtocol.kt index 6f01bb2..e5da601 100644 --- a/src/main/kotlin/SMTPProtocol.kt +++ b/src/main/kotlin/SMTPProtocol.kt @@ -1,133 +1,8 @@ -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job -import kotlinx.coroutines.launch -import java.io.InputStream -import java.io.OutputStream import java.net.InetAddress -import java.net.Socket import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract -class Invalidatable { - var isInvalid: Boolean = false - fun checkValid() { - if (isInvalid) - throw IllegalStateException("Accessed invalid object") - } - - fun invalidate() { - isInvalid = true - } -} - -abstract class Protocol { - interface IO { - fun isOpen(): Boolean - suspend fun pushBack(data: ByteArray) - suspend fun readBytes(into: ByteArray): Int - suspend fun send(bytes: ByteArray) - suspend fun close() - class FromSocket(val socket: Socket) : FromStreams(socket.getInputStream(), socket.getOutputStream()) { - override suspend fun close() { - super.close() - with(Dispatchers.IO) { - socket.close() - } - } - } - - open class FromStreams(val inputStream: InputStream, val outputStream: OutputStream) : IO { - private val i = Invalidatable() - override fun isOpen(): Boolean = - !i.isInvalid - - - val readBuffer = mutableListOf<ByteArray>() - override suspend fun pushBack(data: ByteArray) { - i.checkValid() - if (data.isEmpty()) return - readBuffer.add(0, data) - } - - override suspend fun send(bytes: ByteArray) { - i.checkValid() - with(Dispatchers.IO) { - outputStream.write(bytes) - outputStream.flush() - } - } - - override suspend fun close() { - i.checkValid() - i.invalidate() - with(Dispatchers.IO) { - inputStream.close() - outputStream.close() - } - } - - override suspend fun readBytes(into: ByteArray): Int { - i.checkValid() - val rb = readBuffer.removeFirstOrNull() - if (rb != null) { - val w = minOf(rb.size, into.size) - rb.copyInto(into, 0, 0, w) - return w - } - return with(Dispatchers.IO) { - inputStream.read(into) - } - } - } - } - - protected abstract suspend fun IO.execute() - - fun executeAsync(scope: CoroutineScope, io: Protocol.IO): Job { - return scope.launch { - io.execute() - } - } -} - -suspend fun Protocol.IO.send(string: String) = send(string.encodeToByteArray()) -suspend fun Protocol.IO.readLine(): String { - val y = mutableListOf<String>() - while (true) { - val buffer = ByteArray(4096) - val read = readBytes(buffer) - val i = buffer.findCRLF() - if (i in 0 until read) { - y.add(buffer.copyOfRange(0, i).decodeToString()) - pushBack(buffer.copyOfRange(i + 2, read)) - break - } else { - y.add(buffer.copyOfRange(0, read).decodeToString()) - } - } - return y.joinToString("") -} - -private fun ByteArray.findCRLF(): Int { - return this.asSequence().zipWithNext().withIndex().firstOrNull { (idx, v) -> - (v.first == '\r'.code.toByte()) and (v.second == '\n'.code.toByte()) - }?.index ?: -1 -} - -suspend fun Protocol.IO.pushBack(string: String) = pushBack(string.encodeToByteArray()) -suspend fun Protocol.IO.lookahead(string: String): Boolean = lookahead(string.encodeToByteArray()) -suspend fun Protocol.IO.lookahead(bytes: ByteArray): Boolean { - val buffer = ByteArray(bytes.size) - val read = readBytes(buffer) - if (read != bytes.size || !buffer.contentEquals(bytes)) { - pushBack(buffer.copyOf(read)) - return false - } - return true -} - @OptIn(ExperimentalContracts::class) class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) : Protocol() { @@ -136,20 +11,20 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) : var matched = false - suspend inline fun command(vararg name: String, block: IO.(String) -> Unit) { + suspend inline fun command(vararg name: String, block: suspend IO.(String) -> Unit) { contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) } for (n in name) commandOnce(n, block) } - suspend inline fun commandOnce(name: String, block: IO.(String) -> Unit) { + suspend inline fun commandOnce(name: String, block: suspend IO.(String) -> Unit) { contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) } if (matched) return - if (!line.startsWith(name)) return + if (!line.startsWith(name, ignoreCase = true)) return matched = true block(line.substring(name.length).trimStart()) } - suspend inline fun otherwise(block: IO.(String) -> Unit) { + suspend inline fun otherwise(block: suspend IO.(String) -> Unit) { contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) } if (matched) return matched = true @@ -157,7 +32,7 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) : } } - suspend inline fun IO.commands(line: String, block: Commands.() -> Unit) { + suspend inline fun IO.commands(line: String, block: suspend Commands.() -> Unit) { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } Commands(line, this).block() } @@ -202,13 +77,8 @@ class SMTPReceiveProtocol(val localHost: String, val inetAddress: InetAddress) : } command("DATA") { send("354 Enter mail, end with \".\" on a line by itself\r\n") - var text = "" - while (true) { - val tmp = readLine() - if (tmp == ".") break - text += tmp + "\n" - } - messages.add(Mail(trans.sender!!, trans.recipients.toList(), text)) + var text = readUntil("\r\n.\r\n".encodeToByteArray()) + messages.add(Mail(trans.sender!!, trans.recipients.toList(), text.decodeToString())) trans.reset() send("250 Message accepted for delivery\r\n") } |