summaryrefslogtreecommitdiff
path: root/src/main/kotlin/Protocol.kt
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/kotlin/Protocol.kt')
-rw-r--r--src/main/kotlin/Protocol.kt157
1 files changed, 157 insertions, 0 deletions
diff --git a/src/main/kotlin/Protocol.kt b/src/main/kotlin/Protocol.kt
new file mode 100644
index 0000000..019e172
--- /dev/null
+++ b/src/main/kotlin/Protocol.kt
@@ -0,0 +1,157 @@
+import kotlinx.coroutines.*
+import java.io.InputStream
+import java.io.OutputStream
+import java.net.Socket
+
+class Invalidatable {
+ var isInvalid: Boolean = false
+ fun checkValid() {
+ if (isInvalid)
+ throw IllegalStateException("Accessed invalid object")
+ }
+
+ fun invalidate() {
+ isInvalid = true
+ }
+}
+
+abstract class Protocol {
+ interface IO {
+ fun isOpen(): Boolean
+ suspend fun pushBack(data: ByteArray)
+ suspend fun readBytes(into: ByteArray): Int
+ suspend fun send(bytes: ByteArray)
+ suspend fun close()
+
+ class FromSocket(val socket: Socket) : FromStreams(socket.getInputStream(), socket.getOutputStream()) {
+ override suspend fun close() {
+ super.close()
+ withContext(Dispatchers.IO) {
+ socket.close()
+ }
+ }
+ }
+
+ open class FromStreams(val inputStream: InputStream, val outputStream: OutputStream) : IO {
+ private val i = Invalidatable()
+ override fun isOpen(): Boolean =
+ !i.isInvalid
+
+
+ val readBuffer = mutableListOf<ByteArray>()
+ override suspend fun pushBack(data: ByteArray) {
+ i.checkValid()
+ if (data.isEmpty()) return
+ readBuffer.add(0, data)
+ }
+
+ override suspend fun send(bytes: ByteArray) {
+ i.checkValid()
+ withContext(Dispatchers.IO) {
+ outputStream.write(bytes)
+ outputStream.flush()
+ }
+ }
+
+ override suspend fun close() {
+ i.checkValid()
+ i.invalidate()
+ withContext(Dispatchers.IO) {
+ inputStream.close()
+ outputStream.close()
+ }
+ }
+
+ override suspend fun readBytes(into: ByteArray): Int {
+ i.checkValid()
+ val rb = readBuffer.removeFirstOrNull()
+ if (rb != null) {
+ val w = minOf(rb.size, into.size)
+ rb.copyInto(into, 0, 0, w)
+ return w
+ }
+ return withContext(Dispatchers.IO) {
+ inputStream.read(into)
+ }
+ }
+ }
+ }
+
+ protected abstract suspend fun IO.execute()
+
+ fun executeAsync(scope: CoroutineScope, io: Protocol.IO): Job {
+ return scope.launch {
+ io.execute()
+ }
+ }
+}
+
+suspend fun Protocol.IO.readAll(): ByteArray {
+ var ret = ByteArray(0)
+ val buffer = ByteArray(4096)
+ while (true) {
+ val read = readBytes(buffer)
+ if (read == -1) {
+ return ret
+ }
+ val oldSize = ret.size
+ ret = ret.copyOf(oldSize + read)
+ buffer.copyInto(ret, oldSize, endIndex = read)
+ }
+}
+
+suspend fun Protocol.IO.send(string: String) = send(string.encodeToByteArray())
+suspend fun Protocol.IO.readLine(): String = readUntil(CRLF).decodeToString()
+suspend fun Protocol.IO.readUntil(search: ByteArray, errorOnEOF: Boolean = true): ByteArray {
+ var ret = ByteArray(0)
+ val buffer = ByteArray(4096)
+ while (true) {
+ val read = readBytes(buffer)
+ if (read == -1) {
+ if (errorOnEOF) {
+ throw IllegalStateException("End of Protocol.IO")
+ } else {
+ return ret
+ }
+ }
+ val oldSize = ret.size
+ ret = ret.copyOf(oldSize + read)
+ buffer.copyInto(ret, oldSize, endIndex = read)
+ val firstFoundIndex = ret.findSubarray(search, startIndex = (oldSize - search.size - 1).coerceAtLeast(0))
+ if (firstFoundIndex != null) {
+ pushBack(ret.copyOfRange(firstFoundIndex + search.size, ret.size))
+ return ret.copyOf(firstFoundIndex)
+ }
+ }
+}
+
+val CRLF = "\r\n".encodeToByteArray()
+
+fun ByteArray.findSubarray(subarray: ByteArray, startIndex: Int = 0): Int? {
+ if (subarray.size > size - startIndex) return null
+ for (i in startIndex..size - subarray.size) {
+ var isEqual = true
+ for (j in subarray.indices) {
+ if (this[i + j] != subarray[j]) {
+ isEqual = false
+ break
+ }
+ }
+ if (isEqual) {
+ return i
+ }
+ }
+ return null
+}
+
+suspend fun Protocol.IO.pushBack(string: String) = pushBack(string.encodeToByteArray())
+suspend fun Protocol.IO.lookahead(string: String): Boolean = lookahead(string.encodeToByteArray())
+suspend fun Protocol.IO.lookahead(bytes: ByteArray): Boolean {
+ val buffer = ByteArray(bytes.size)
+ val read = readBytes(buffer)
+ if (read != bytes.size || !buffer.contentEquals(bytes)) {
+ pushBack(buffer.copyOf(read))
+ return false
+ }
+ return true
+}