From 9374d0df9c462493e9cb91c846e4f820b3325f7b Mon Sep 17 00:00:00 2001
From: Robert Jaros <rjaros@finn.pl>
Date: Wed, 10 Apr 2019 18:24:14 +0200
Subject: Websockets support for Spring Boot and Jooby.

---
 .../pl/treksoft/kvision/remote/KVRemoteAgent.kt    |   7 +-
 .../pl/treksoft/kvision/remote/KVServiceManager.kt |   2 +-
 .../kotlin/pl/treksoft/kvision/remote/KVModules.kt |   3 +
 .../pl/treksoft/kvision/remote/KVServiceManager.kt |  79 ++++++-
 .../kotlin/pl/treksoft/kvision/remote/KVModules.kt |  16 +-
 .../pl/treksoft/kvision/remote/KVServiceManager.kt |  29 ++-
 .../kvision-server-spring-boot/build.gradle        |   1 +
 .../pl/treksoft/kvision/remote/KVController.kt     |   9 +-
 .../pl/treksoft/kvision/remote/KVServiceManager.kt |  54 ++++-
 .../treksoft/kvision/remote/KVWebSocketConfig.kt   | 251 +++++++++++++++++++++
 .../kotlin/pl/treksoft/kvision/remote/Security.kt  |   3 +
 .../src/main/resources/META-INF/spring.factories   |   2 +-
 12 files changed, 434 insertions(+), 22 deletions(-)
 create mode 100644 kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt

diff --git a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt
index 9c6f8274..d3331b2a 100644
--- a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt
+++ b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt
@@ -21,6 +21,7 @@
  */
 package pl.treksoft.kvision.remote
 
+import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.asDeferred
 import kotlinx.coroutines.channels.Channel
@@ -41,7 +42,7 @@ import kotlin.js.JSON as NativeJSON
  * Client side agent for JSON-RPC remote calls.
  */
 @Suppress("LargeClass", "TooManyFunctions")
-@UseExperimental(ImplicitReflectionSerializer::class)
+@UseExperimental(ImplicitReflectionSerializer::class, ExperimentalCoroutinesApi::class)
 open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : RemoteAgent {
 
     val callAgent = CallAgent()
@@ -407,7 +408,7 @@ open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : Rem
                 responseJob = launch {
                     while (true) {
                         val str = socket.receiveOrNull() ?: break
-                        val data = kotlin.js.JSON.parse<dynamic>(str).result
+                        val data = kotlin.js.JSON.parse<JsonRpcResponse>(str).result ?: ""
                         val par2 = try {
                             @Suppress("UNCHECKED_CAST")
                             deserialize<PAR2>(data, PAR2::class.js.name)
@@ -477,7 +478,7 @@ open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : Rem
                 responseJob = launch {
                     while (true) {
                         val str = socket.receiveOrNull() ?: break
-                        val data = kotlin.js.JSON.parse<dynamic>(str).result
+                        val data = kotlin.js.JSON.parse<JsonRpcResponse>(str).result ?: ""
                         val par2 = try {
                             deserializeList<PAR2>(data, PAR2::class.js.name)
                         } catch (t: NotStandardTypeException) {
diff --git a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
index 0cf72f76..7db690b2 100644
--- a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
+++ b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
@@ -149,7 +149,7 @@ actual open class KVServiceManager<T : Any> actual constructor(serviceClass: KCl
         route: String?
     ) {
         val routeDef = "route${this::class.simpleName}${counter++}"
-        calls[function.toString().replace("\\s".toRegex(), "")] = Pair("/kv/$routeDef", HttpMethod.POST)
+        calls[function.toString().replace("\\s".toRegex(), "")] = Pair("/kvws/$routeDef", HttpMethod.POST)
     }
 
     /**
diff --git a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
index 34b2877f..b3da785d 100644
--- a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
+++ b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
@@ -24,6 +24,9 @@ package pl.treksoft.kvision.remote
 import org.jooby.Kooby
 import org.jooby.json.Jackson
 
+/**
+ * Initialization function for Jooby server.
+ */
 fun Kooby.kvisionInit() {
     assets("/", "/assets/index.html")
     assets("/**", "/assets/{0}").onMissing(0)
diff --git a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
index cfd9d467..ca08b080 100644
--- a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
+++ b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
@@ -24,9 +24,15 @@ package pl.treksoft.kvision.remote
 import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
 import com.google.inject.Injector
 import kotlinx.coroutines.CoroutineStart
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.GlobalScope
+import kotlinx.coroutines.channels.Channel
 import kotlinx.coroutines.channels.ReceiveChannel
 import kotlinx.coroutines.channels.SendChannel
+import kotlinx.coroutines.channels.filterNotNull
+import kotlinx.coroutines.channels.map
+import kotlinx.coroutines.coroutineScope
 import kotlinx.coroutines.launch
 import org.jooby.Kooby
 import org.jooby.Request
@@ -39,6 +45,7 @@ import kotlin.reflect.KClass
  * Multiplatform service manager for Jooby.
  */
 @Suppress("LargeClass")
+@UseExperimental(ExperimentalCoroutinesApi::class)
 actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: KClass<T>) {
 
     companion object {
@@ -313,7 +320,7 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
     }
 
     /**
-     * Binds a given web socket connetion with a function of the receiver.
+     * Binds a given web socket connection with a function of the receiver.
      * @param function a function of the receiver
      * @param route a route
      */
@@ -321,10 +328,67 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         noinline function: suspend T.(ReceiveChannel<PAR1>, SendChannel<PAR2>) -> Unit,
         route: String?
     ) {
-        TODO("Not implemented in Jooby module")
+        val routeDef = "route${this::class.simpleName}${counter++}"
+        routes.add {
+            ws("/kvws/$routeDef") { ws ->
+                val injector = ws.require(Injector::class.java)
+                val service = injector.getInstance(serviceClass.java)
+                val incoming = Channel<String>()
+                val outgoing = Channel<String>()
+                GlobalScope.launch {
+                    coroutineScope {
+                        launch(Dispatchers.IO) {
+                            for (text in outgoing) {
+                                ws.send(text)
+                            }
+                            ws.close()
+                        }
+                        launch {
+                            val requestChannel = incoming.map {
+                                val jsonRpcRequest = getParameter<JsonRpcRequest>(it)
+                                if (jsonRpcRequest.params.size == 1) {
+                                    getParameter<PAR1>(jsonRpcRequest.params[0])
+                                } else {
+                                    null
+                                }
+                            }.filterNotNull()
+                            val responseChannel = Channel<PAR2>()
+                            coroutineScope {
+                                launch(Dispatchers.IO) {
+                                    for (p in responseChannel) {
+                                        val text = mapper.writeValueAsString(
+                                            JsonRpcResponse(
+                                                id = 0,
+                                                result = mapper.writeValueAsString(p)
+                                            )
+                                        )
+                                        outgoing.send(text)
+                                    }
+                                }
+                                launch {
+                                    function.invoke(service, requestChannel, responseChannel)
+                                    if (!responseChannel.isClosedForReceive) responseChannel.close()
+                                }
+                            }
+                            if (!outgoing.isClosedForReceive) outgoing.close()
+                        }
+                    }
+                }
+                ws.onClose {
+                    GlobalScope.launch {
+                        outgoing.close()
+                        incoming.close()
+                    }
+                }
+                ws.onMessage { msg ->
+                    GlobalScope.launch {
+                        incoming.send(msg.value())
+                    }
+                }
+            }
+        }
     }
 
-
     /**
      * Binds a given function of the receiver as a select options source
      * @param function a function of the receiver
@@ -363,6 +427,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress Internal method
+     */
     fun call(
         method: HttpMethod,
         path: String,
@@ -379,6 +446,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress Internal method
+     */
     protected inline fun <reified T> getParameter(str: String?): T {
         return str?.let {
             if (T::class == String::class) {
@@ -391,6 +461,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
 
 }
 
+/**
+ * A function to generate routes based on definitions from the service manager.
+ */
 fun <T : Any> Kooby.applyRoutes(serviceManager: KVServiceManager<T>) {
     serviceManager.routes.forEach {
         it.invoke(this@applyRoutes)
diff --git a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
index a5e22b2d..34fff02d 100644
--- a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
+++ b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt
@@ -45,6 +45,9 @@ import kotlinx.coroutines.channels.ReceiveChannel
 import kotlinx.coroutines.channels.SendChannel
 import kotlin.coroutines.CoroutineContext
 
+/**
+ * Initialization function for Ktor server.
+ */
 fun Application.kvisionInit(vararg modules: Module) {
     install(ContentNegotiation) {
         jackson()
@@ -69,18 +72,21 @@ val injectorKey = AttributeKey<Injector>("injector")
 
 val ApplicationCall.injector: Injector get() = attributes[injectorKey]
 
-class CallModule(private val call: ApplicationCall) : AbstractModule() {
+internal class CallModule(private val call: ApplicationCall) : AbstractModule() {
     override fun configure() {
         bind(ApplicationCall::class.java).toInstance(call)
     }
 }
 
-class MainModule(private val application: Application) : AbstractModule() {
+internal class MainModule(private val application: Application) : AbstractModule() {
     override fun configure() {
         bind(Application::class.java).toInstance(application)
     }
 }
 
+/**
+ * @suppress internal class
+ */
 class WsSessionModule(private val webSocketSession: WebSocketServerSession) :
     AbstractModule() {
     override fun configure() {
@@ -88,12 +94,18 @@ class WsSessionModule(private val webSocketSession: WebSocketServerSession) :
     }
 }
 
+/**
+ * @suppress internal class
+ */
 class DummyWsSessionModule : AbstractModule() {
     override fun configure() {
         bind(WebSocketServerSession::class.java).toInstance(DummyWebSocketServerSession())
     }
 }
 
+/**
+ * @suppress internal class
+ */
 @Suppress("UNUSED_PARAMETER")
 class DummyWebSocketServerSession : WebSocketServerSession {
     override val call: ApplicationCall
diff --git a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
index aca5d2f0..11771cfa 100644
--- a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
+++ b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
@@ -36,6 +36,7 @@ import io.ktor.routing.get
 import io.ktor.routing.options
 import io.ktor.routing.post
 import io.ktor.routing.put
+import io.ktor.util.KtorExperimentalAPI
 import io.ktor.util.pipeline.PipelineContext
 import io.ktor.websocket.WebSocketServerSession
 import io.ktor.websocket.webSocket
@@ -45,7 +46,6 @@ import kotlinx.coroutines.channels.ReceiveChannel
 import kotlinx.coroutines.channels.SendChannel
 import kotlinx.coroutines.channels.filterNotNull
 import kotlinx.coroutines.channels.map
-import kotlinx.coroutines.channels.mapNotNull
 import kotlinx.coroutines.coroutineScope
 import kotlinx.coroutines.launch
 import org.slf4j.Logger
@@ -55,6 +55,7 @@ import kotlin.reflect.KClass
 /**
  * Multiplatform service manager for Ktor.
  */
+@KtorExperimentalAPI
 @UseExperimental(ExperimentalCoroutinesApi::class)
 @Suppress("LargeClass")
 actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: KClass<T>) {
@@ -374,15 +375,17 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         route: String?
     ) {
         val routeDef = "route${this::class.simpleName}${counter++}"
-        webSocketRequests["/kv/$routeDef"] = {
+        webSocketRequests["/kvws/$routeDef"] = {
             val wsInjector = call.injector.createChildInjector(WsSessionModule(this))
             val service = wsInjector.getInstance(serviceClass.java)
-            val requestChannel = incoming.mapNotNull { it as? Frame.Text }.map {
-                val jsonRpcRequest = getParameter<JsonRpcRequest>(it.readText())
-                if (jsonRpcRequest.params.size == 1) {
-                    getParameter<PAR1>(jsonRpcRequest.params[0])
-                } else {
-                    null
+            val requestChannel = incoming.map {
+                (it as? Frame.Text)?.readText()?.let { text ->
+                    val jsonRpcRequest = getParameter<JsonRpcRequest>(text)
+                    if (jsonRpcRequest.params.size == 1) {
+                        getParameter<PAR1>(jsonRpcRequest.params[0])
+                    } else {
+                        null
+                    }
                 }
             }.filterNotNull()
             val responseChannel = Channel<PAR2>()
@@ -452,6 +455,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress Internal method
+     */
     fun addRoute(
         method: HttpMethod,
         path: String,
@@ -466,6 +472,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress Internal method
+     */
     protected inline fun <reified T> getParameter(str: String?): T {
         return str?.let {
             if (T::class == String::class) {
@@ -477,6 +486,10 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
     }
 }
 
+/**
+ * A function to generate routes based on definitions from the service manager.
+ */
+@KtorExperimentalAPI
 fun <T : Any> Route.applyRoutes(serviceManager: KVServiceManager<T>) {
     serviceManager.getRequests.forEach { (path, handler) ->
         get(path, handler)
diff --git a/kvision-modules/kvision-server-spring-boot/build.gradle b/kvision-modules/kvision-server-spring-boot/build.gradle
index 3bb2b2aa..6e79d007 100644
--- a/kvision-modules/kvision-server-spring-boot/build.gradle
+++ b/kvision-modules/kvision-server-spring-boot/build.gradle
@@ -12,6 +12,7 @@ dependencies {
     compile "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion"
     compile "org.springframework.boot:spring-boot-starter:$springBootVersion"
     compile "org.springframework.boot:spring-boot-starter-web:$springBootVersion"
+    compile "org.springframework.boot:spring-boot-starter-websocket:$springBootVersion"
     compile "org.pac4j:pac4j-core:$pac4jVersion"
     compile "com.fasterxml.jackson.module:jackson-module-kotlin:${jacksonModuleKotlinVersion}"
     testCompile "org.jetbrains.kotlin:kotlin-test:$kotlinVersion"
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt
index 13b95b8b..23e00284 100644
--- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt
+++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt
@@ -29,6 +29,9 @@ import org.springframework.web.bind.annotation.RequestMethod
 import javax.servlet.http.HttpServletRequest
 import javax.servlet.http.HttpServletResponse
 
+/**
+ * Controller for handling automatic routes.
+ */
 @Controller
 open class KVController {
 
@@ -44,7 +47,7 @@ open class KVController {
     )
     open fun kVMapping(req: HttpServletRequest, res: HttpServletResponse) {
         val routeUrl = req.requestURI
-        val route = services.mapNotNull {
+        val handler = services.mapNotNull {
             when (req.method) {
                 "GET" -> it.getRequests[routeUrl]
                 "POST" -> it.postRequests[routeUrl]
@@ -54,8 +57,8 @@ open class KVController {
                 else -> null
             }
         }.firstOrNull()
-        if (route != null) {
-            route.invoke(req, res, applicationContext)
+        if (handler != null) {
+            handler.invoke(req, res, applicationContext)
         } else {
             res.status = HttpServletResponse.SC_NOT_FOUND
         }
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
index 66429acf..0606afef 100644
--- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
+++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt
@@ -23,14 +23,21 @@ package pl.treksoft.kvision.remote
 
 import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
 import kotlinx.coroutines.CoroutineStart
+import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.GlobalScope
+import kotlinx.coroutines.channels.Channel
 import kotlinx.coroutines.channels.ReceiveChannel
 import kotlinx.coroutines.channels.SendChannel
+import kotlinx.coroutines.channels.filterNotNull
+import kotlinx.coroutines.channels.map
+import kotlinx.coroutines.coroutineScope
 import kotlinx.coroutines.launch
 import org.slf4j.Logger
 import org.slf4j.LoggerFactory
 import org.springframework.context.ApplicationContext
+import org.springframework.web.context.support.GenericWebApplicationContext
+import org.springframework.web.socket.WebSocketSession
 import javax.servlet.http.HttpServletRequest
 import javax.servlet.http.HttpServletResponse
 import kotlin.reflect.KClass
@@ -55,6 +62,10 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         mutableMapOf()
     val optionsRequests: MutableMap<String, (HttpServletRequest, HttpServletResponse, ApplicationContext) -> Unit> =
         mutableMapOf()
+    val webSocketsRequests: MutableMap<String, suspend (
+        WebSocketSession, GenericWebApplicationContext, ReceiveChannel<String>, SendChannel<String>
+    ) -> Unit> =
+        mutableMapOf()
 
     val mapper = jacksonObjectMapper()
     var counter: Int = 0
@@ -402,7 +413,39 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         noinline function: suspend T.(ReceiveChannel<PAR1>, SendChannel<PAR2>) -> Unit,
         route: String?
     ) {
-        TODO("Not implemented in Spring Boot module")
+        val routeDef = "route${this::class.simpleName}${counter++}"
+        webSocketsRequests[routeDef] = { webSocketSession, ctx, incoming, outgoing ->
+            val service = synchronized(this) {
+                WebSocketSessionHolder.webSocketSession = webSocketSession
+                ctx.getBean(serviceClass.java)
+            }
+            val requestChannel = incoming.map {
+                val jsonRpcRequest = getParameter<JsonRpcRequest>(it)
+                if (jsonRpcRequest.params.size == 1) {
+                    getParameter<PAR1>(jsonRpcRequest.params[0])
+                } else {
+                    null
+                }
+            }.filterNotNull()
+            val responseChannel = Channel<PAR2>()
+            coroutineScope {
+                launch {
+                    for (p in responseChannel) {
+                        val text = mapper.writeValueAsString(
+                            JsonRpcResponse(
+                                id = 0,
+                                result = mapper.writeValueAsString(p)
+                            )
+                        )
+                        outgoing.send(text)
+                    }
+                }
+                launch(start = CoroutineStart.UNDISPATCHED) {
+                    function.invoke(service, requestChannel, responseChannel)
+                    if (!responseChannel.isClosedForReceive) responseChannel.close()
+                }
+            }
+        }
     }
 
     /**
@@ -454,6 +497,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress internal function
+     */
     fun addRoute(
         method: HttpMethod,
         path: String,
@@ -468,6 +514,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
         }
     }
 
+    /**
+     * @suppress internal function
+     */
     protected inline fun <reified T> getParameter(str: String?): T {
         return str?.let {
             if (T::class == String::class) {
@@ -479,6 +528,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass:
     }
 }
 
+/**
+ * @suppress internal function
+ */
 fun HttpServletResponse.writeJSON(json: String) {
     val out = this.outputStream
     this.contentType = "application/json"
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt
new file mode 100644
index 00000000..7bc2e612
--- /dev/null
+++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt
@@ -0,0 +1,251 @@
+/*
+ * Copyright (c) 2017-present Robert Jaros
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+package pl.treksoft.kvision.remote
+
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.GlobalScope
+import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.channels.ReceiveChannel
+import kotlinx.coroutines.channels.SendChannel
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
+import org.springframework.beans.factory.annotation.Autowired
+import org.springframework.beans.factory.config.ConfigurableBeanFactory
+import org.springframework.context.annotation.Bean
+import org.springframework.context.annotation.Configuration
+import org.springframework.context.annotation.Scope
+import org.springframework.http.HttpHeaders
+import org.springframework.http.server.ServerHttpRequest
+import org.springframework.http.server.ServerHttpResponse
+import org.springframework.web.context.support.GenericWebApplicationContext
+import org.springframework.web.socket.CloseStatus
+import org.springframework.web.socket.TextMessage
+import org.springframework.web.socket.WebSocketExtension
+import org.springframework.web.socket.WebSocketHandler
+import org.springframework.web.socket.WebSocketMessage
+import org.springframework.web.socket.WebSocketSession
+import org.springframework.web.socket.config.annotation.EnableWebSocket
+import org.springframework.web.socket.config.annotation.WebSocketConfigurer
+import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry
+import org.springframework.web.socket.handler.TextWebSocketHandler
+import org.springframework.web.socket.server.HandshakeInterceptor
+import java.net.InetSocketAddress
+import java.net.URI
+import java.security.Principal
+import java.util.concurrent.ConcurrentHashMap
+
+const val KV_ROUTE_ID_ATTRIBUTE = "KV_ROUTE_ID_ATTRIBUTE"
+
+/**
+ * Automatic websockets configuration.
+ */
+@Configuration
+@EnableWebSocket
+open class KVWebSocketConfig : WebSocketConfigurer {
+
+    @Autowired
+    lateinit var services: List<KVServiceManager<*>>
+
+    @Autowired
+    lateinit var applicationContext: GenericWebApplicationContext
+
+    override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) {
+        registry.addHandler(socketHandler(), "/kvws/*").setAllowedOrigins("*").addInterceptors(routeInterceptor())
+    }
+
+    @Bean
+    open fun routeInterceptor(): HandshakeInterceptor {
+        return KvHandshakeInterceptor()
+    }
+
+    @Bean
+    open fun socketHandler(): WebSocketHandler {
+        return KvWebSocketHandler(services, applicationContext)
+    }
+
+    @Bean
+    @Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
+    open fun webSocketSession(): WebSocketSession {
+        return WebSocketSessionHolder.webSocketSession
+    }
+}
+
+object WebSocketSessionHolder {
+    var webSocketSession: WebSocketSession = DummyWebSocketSession()
+}
+
+internal open class KvHandshakeInterceptor : HandshakeInterceptor {
+    override fun beforeHandshake(
+        request: ServerHttpRequest,
+        response: ServerHttpResponse,
+        wsHandler: WebSocketHandler,
+        attributes: MutableMap<String, Any>
+    ): Boolean {
+        val path = request.uri.path
+        val route = path.substring(path.lastIndexOf('/') + 1)
+        attributes[KV_ROUTE_ID_ATTRIBUTE] = route
+        return true
+    }
+
+    override fun afterHandshake(
+        request: ServerHttpRequest,
+        response: ServerHttpResponse,
+        wsHandler: WebSocketHandler,
+        exception: Exception?
+    ) {
+    }
+}
+
+@UseExperimental(ExperimentalCoroutinesApi::class)
+internal open class KvWebSocketHandler(
+    private val services: List<KVServiceManager<*>>,
+    private val applicationContext: GenericWebApplicationContext
+) : TextWebSocketHandler() {
+
+    private val sessions = ConcurrentHashMap<String, Pair<Channel<String>, Channel<String>>>()
+
+    private fun getHandler(session: WebSocketSession): (suspend (
+        WebSocketSession, GenericWebApplicationContext,
+        ReceiveChannel<String>, SendChannel<String>
+    ) -> Unit)? {
+        val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String
+        return services.mapNotNull {
+            it.webSocketsRequests[routeId]
+        }.firstOrNull()
+    }
+
+    private fun getSessionId(session: WebSocketSession): String {
+        val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String
+        return session.id + "###" + routeId
+    }
+
+    override fun afterConnectionEstablished(session: WebSocketSession) {
+        getHandler(session)?.let { handler ->
+            val requestChannel = Channel<String>()
+            val responseChannel = Channel<String>()
+            GlobalScope.launch {
+                coroutineScope {
+                    launch(Dispatchers.IO) {
+                        for (text in responseChannel) {
+                            session.sendMessage(TextMessage(text))
+                        }
+                        session.close()
+                    }
+                    launch {
+                        handler.invoke(session, applicationContext, requestChannel, responseChannel)
+                        if (!responseChannel.isClosedForReceive) responseChannel.close()
+                    }
+                    sessions[getSessionId(session)] = responseChannel to requestChannel
+                }
+            }
+        }
+    }
+
+    override fun handleTextMessage(session: WebSocketSession, message: TextMessage) {
+        getHandler(session)?.let {
+            sessions[getSessionId(session)]?.let { (_, requestChannel) ->
+                GlobalScope.launch {
+                    requestChannel.send(message.payload)
+                }
+            }
+        }
+    }
+
+    override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
+        getHandler(session)?.let {
+            sessions[getSessionId(session)]?.let { (responseChannel, requestChannel) ->
+                GlobalScope.launch {
+                    responseChannel.close()
+                    requestChannel.close()
+                }
+                sessions.remove(getSessionId(session))
+            }
+        }
+    }
+}
+
+open class DummyWebSocketSession : WebSocketSession {
+    override fun getBinaryMessageSizeLimit(): Int {
+        return 0
+    }
+
+    override fun sendMessage(message: WebSocketMessage<*>) {
+    }
+
+    override fun getAcceptedProtocol(): String? {
+        return null
+    }
+
+    override fun getTextMessageSizeLimit(): Int {
+        return 0
+    }
+
+    override fun getLocalAddress(): InetSocketAddress? {
+        return null
+    }
+
+    override fun getId(): String {
+        return ""
+    }
+
+    override fun getExtensions(): MutableList<WebSocketExtension> {
+        return mutableListOf()
+    }
+
+    override fun getUri(): URI? {
+        return null
+    }
+
+    override fun setBinaryMessageSizeLimit(messageSizeLimit: Int) {
+    }
+
+    override fun getAttributes(): MutableMap<String, Any> {
+        return mutableMapOf()
+    }
+
+    override fun getHandshakeHeaders(): HttpHeaders {
+        return HttpHeaders.EMPTY
+    }
+
+    override fun isOpen(): Boolean {
+        return false
+    }
+
+    override fun getPrincipal(): Principal? {
+        return null
+    }
+
+    override fun close() {
+    }
+
+    override fun close(status: CloseStatus) {
+    }
+
+    override fun setTextMessageSizeLimit(messageSizeLimit: Int) {
+    }
+
+    override fun getRemoteAddress(): InetSocketAddress? {
+        return null
+    }
+
+}
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt
index abfef934..287a3a7e 100644
--- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt
+++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt
@@ -23,6 +23,9 @@ package pl.treksoft.kvision.remote
 
 import org.springframework.web.servlet.config.annotation.InterceptorRegistration
 
+/**
+ * A function to gather paths for an interceptor from a list of service managers.
+ */
 fun InterceptorRegistration.addPathPatternsFromServices(services: List<KVServiceManager<*>>) {
     val paths = services.flatMap {
         it.postRequests.keys + it.putRequests.keys + it.optionsRequests.keys + it.optionsRequests.keys
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories b/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories
index e31ff24d..17ca7a1d 100644
--- a/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories
+++ b/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories
@@ -1 +1 @@
-org.springframework.boot.autoconfigure.EnableAutoConfiguration=pl.treksoft.kvision.remote.KVController
+org.springframework.boot.autoconfigure.EnableAutoConfiguration=pl.treksoft.kvision.remote.KVController,pl.treksoft.kvision.remote.KVWebSocketConfig
-- 
cgit