diff options
author | Robert Jaros <rjaros@finn.pl> | 2019-04-10 18:24:14 +0200 |
---|---|---|
committer | Robert Jaros <rjaros@finn.pl> | 2019-04-10 18:24:14 +0200 |
commit | 9374d0df9c462493e9cb91c846e4f820b3325f7b (patch) | |
tree | b8654b6b6bb3d31fc2fae50195faa8096749f8a7 | |
parent | fc8023bf8eb4c17b7fc1e77bb057a7c3df19eb0d (diff) | |
download | kvision-9374d0df9c462493e9cb91c846e4f820b3325f7b.tar.gz kvision-9374d0df9c462493e9cb91c846e4f820b3325f7b.tar.bz2 kvision-9374d0df9c462493e9cb91c846e4f820b3325f7b.zip |
Websockets support for Spring Boot and Jooby.
12 files changed, 434 insertions, 22 deletions
diff --git a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt index 9c6f8274..d3331b2a 100644 --- a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt +++ b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVRemoteAgent.kt @@ -21,6 +21,7 @@ */ package pl.treksoft.kvision.remote +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.Job import kotlinx.coroutines.asDeferred import kotlinx.coroutines.channels.Channel @@ -41,7 +42,7 @@ import kotlin.js.JSON as NativeJSON * Client side agent for JSON-RPC remote calls. */ @Suppress("LargeClass", "TooManyFunctions") -@UseExperimental(ImplicitReflectionSerializer::class) +@UseExperimental(ImplicitReflectionSerializer::class, ExperimentalCoroutinesApi::class) open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : RemoteAgent { val callAgent = CallAgent() @@ -407,7 +408,7 @@ open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : Rem responseJob = launch { while (true) { val str = socket.receiveOrNull() ?: break - val data = kotlin.js.JSON.parse<dynamic>(str).result + val data = kotlin.js.JSON.parse<JsonRpcResponse>(str).result ?: "" val par2 = try { @Suppress("UNCHECKED_CAST") deserialize<PAR2>(data, PAR2::class.js.name) @@ -477,7 +478,7 @@ open class KVRemoteAgent<T : Any>(val serviceManager: KVServiceManager<T>) : Rem responseJob = launch { while (true) { val str = socket.receiveOrNull() ?: break - val data = kotlin.js.JSON.parse<dynamic>(str).result + val data = kotlin.js.JSON.parse<JsonRpcResponse>(str).result ?: "" val par2 = try { deserializeList<PAR2>(data, PAR2::class.js.name) } catch (t: NotStandardTypeException) { diff --git a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt index 0cf72f76..7db690b2 100644 --- a/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt +++ b/kvision-modules/kvision-remote/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt @@ -149,7 +149,7 @@ actual open class KVServiceManager<T : Any> actual constructor(serviceClass: KCl route: String? ) { val routeDef = "route${this::class.simpleName}${counter++}" - calls[function.toString().replace("\\s".toRegex(), "")] = Pair("/kv/$routeDef", HttpMethod.POST) + calls[function.toString().replace("\\s".toRegex(), "")] = Pair("/kvws/$routeDef", HttpMethod.POST) } /** diff --git a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt index 34b2877f..b3da785d 100644 --- a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt +++ b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt @@ -24,6 +24,9 @@ package pl.treksoft.kvision.remote import org.jooby.Kooby import org.jooby.json.Jackson +/** + * Initialization function for Jooby server. + */ fun Kooby.kvisionInit() { assets("/", "/assets/index.html") assets("/**", "/assets/{0}").onMissing(0) diff --git a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt index cfd9d467..ca08b080 100644 --- a/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt +++ b/kvision-modules/kvision-server-jooby/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt @@ -24,9 +24,15 @@ package pl.treksoft.kvision.remote import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import com.google.inject.Injector import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.filterNotNull +import kotlinx.coroutines.channels.map +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import org.jooby.Kooby import org.jooby.Request @@ -39,6 +45,7 @@ import kotlin.reflect.KClass * Multiplatform service manager for Jooby. */ @Suppress("LargeClass") +@UseExperimental(ExperimentalCoroutinesApi::class) actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: KClass<T>) { companion object { @@ -313,7 +320,7 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } /** - * Binds a given web socket connetion with a function of the receiver. + * Binds a given web socket connection with a function of the receiver. * @param function a function of the receiver * @param route a route */ @@ -321,10 +328,67 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: noinline function: suspend T.(ReceiveChannel<PAR1>, SendChannel<PAR2>) -> Unit, route: String? ) { - TODO("Not implemented in Jooby module") + val routeDef = "route${this::class.simpleName}${counter++}" + routes.add { + ws("/kvws/$routeDef") { ws -> + val injector = ws.require(Injector::class.java) + val service = injector.getInstance(serviceClass.java) + val incoming = Channel<String>() + val outgoing = Channel<String>() + GlobalScope.launch { + coroutineScope { + launch(Dispatchers.IO) { + for (text in outgoing) { + ws.send(text) + } + ws.close() + } + launch { + val requestChannel = incoming.map { + val jsonRpcRequest = getParameter<JsonRpcRequest>(it) + if (jsonRpcRequest.params.size == 1) { + getParameter<PAR1>(jsonRpcRequest.params[0]) + } else { + null + } + }.filterNotNull() + val responseChannel = Channel<PAR2>() + coroutineScope { + launch(Dispatchers.IO) { + for (p in responseChannel) { + val text = mapper.writeValueAsString( + JsonRpcResponse( + id = 0, + result = mapper.writeValueAsString(p) + ) + ) + outgoing.send(text) + } + } + launch { + function.invoke(service, requestChannel, responseChannel) + if (!responseChannel.isClosedForReceive) responseChannel.close() + } + } + if (!outgoing.isClosedForReceive) outgoing.close() + } + } + } + ws.onClose { + GlobalScope.launch { + outgoing.close() + incoming.close() + } + } + ws.onMessage { msg -> + GlobalScope.launch { + incoming.send(msg.value()) + } + } + } + } } - /** * Binds a given function of the receiver as a select options source * @param function a function of the receiver @@ -363,6 +427,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress Internal method + */ fun call( method: HttpMethod, path: String, @@ -379,6 +446,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress Internal method + */ protected inline fun <reified T> getParameter(str: String?): T { return str?.let { if (T::class == String::class) { @@ -391,6 +461,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } +/** + * A function to generate routes based on definitions from the service manager. + */ fun <T : Any> Kooby.applyRoutes(serviceManager: KVServiceManager<T>) { serviceManager.routes.forEach { it.invoke(this@applyRoutes) diff --git a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt index a5e22b2d..34fff02d 100644 --- a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt +++ b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVModules.kt @@ -45,6 +45,9 @@ import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlin.coroutines.CoroutineContext +/** + * Initialization function for Ktor server. + */ fun Application.kvisionInit(vararg modules: Module) { install(ContentNegotiation) { jackson() @@ -69,18 +72,21 @@ val injectorKey = AttributeKey<Injector>("injector") val ApplicationCall.injector: Injector get() = attributes[injectorKey] -class CallModule(private val call: ApplicationCall) : AbstractModule() { +internal class CallModule(private val call: ApplicationCall) : AbstractModule() { override fun configure() { bind(ApplicationCall::class.java).toInstance(call) } } -class MainModule(private val application: Application) : AbstractModule() { +internal class MainModule(private val application: Application) : AbstractModule() { override fun configure() { bind(Application::class.java).toInstance(application) } } +/** + * @suppress internal class + */ class WsSessionModule(private val webSocketSession: WebSocketServerSession) : AbstractModule() { override fun configure() { @@ -88,12 +94,18 @@ class WsSessionModule(private val webSocketSession: WebSocketServerSession) : } } +/** + * @suppress internal class + */ class DummyWsSessionModule : AbstractModule() { override fun configure() { bind(WebSocketServerSession::class.java).toInstance(DummyWebSocketServerSession()) } } +/** + * @suppress internal class + */ @Suppress("UNUSED_PARAMETER") class DummyWebSocketServerSession : WebSocketServerSession { override val call: ApplicationCall diff --git a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt index aca5d2f0..11771cfa 100644 --- a/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt +++ b/kvision-modules/kvision-server-ktor/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt @@ -36,6 +36,7 @@ import io.ktor.routing.get import io.ktor.routing.options import io.ktor.routing.post import io.ktor.routing.put +import io.ktor.util.KtorExperimentalAPI import io.ktor.util.pipeline.PipelineContext import io.ktor.websocket.WebSocketServerSession import io.ktor.websocket.webSocket @@ -45,7 +46,6 @@ import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.filterNotNull import kotlinx.coroutines.channels.map -import kotlinx.coroutines.channels.mapNotNull import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import org.slf4j.Logger @@ -55,6 +55,7 @@ import kotlin.reflect.KClass /** * Multiplatform service manager for Ktor. */ +@KtorExperimentalAPI @UseExperimental(ExperimentalCoroutinesApi::class) @Suppress("LargeClass") actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: KClass<T>) { @@ -374,15 +375,17 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: route: String? ) { val routeDef = "route${this::class.simpleName}${counter++}" - webSocketRequests["/kv/$routeDef"] = { + webSocketRequests["/kvws/$routeDef"] = { val wsInjector = call.injector.createChildInjector(WsSessionModule(this)) val service = wsInjector.getInstance(serviceClass.java) - val requestChannel = incoming.mapNotNull { it as? Frame.Text }.map { - val jsonRpcRequest = getParameter<JsonRpcRequest>(it.readText()) - if (jsonRpcRequest.params.size == 1) { - getParameter<PAR1>(jsonRpcRequest.params[0]) - } else { - null + val requestChannel = incoming.map { + (it as? Frame.Text)?.readText()?.let { text -> + val jsonRpcRequest = getParameter<JsonRpcRequest>(text) + if (jsonRpcRequest.params.size == 1) { + getParameter<PAR1>(jsonRpcRequest.params[0]) + } else { + null + } } }.filterNotNull() val responseChannel = Channel<PAR2>() @@ -452,6 +455,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress Internal method + */ fun addRoute( method: HttpMethod, path: String, @@ -466,6 +472,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress Internal method + */ protected inline fun <reified T> getParameter(str: String?): T { return str?.let { if (T::class == String::class) { @@ -477,6 +486,10 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } +/** + * A function to generate routes based on definitions from the service manager. + */ +@KtorExperimentalAPI fun <T : Any> Route.applyRoutes(serviceManager: KVServiceManager<T>) { serviceManager.getRequests.forEach { (path, handler) -> get(path, handler) diff --git a/kvision-modules/kvision-server-spring-boot/build.gradle b/kvision-modules/kvision-server-spring-boot/build.gradle index 3bb2b2aa..6e79d007 100644 --- a/kvision-modules/kvision-server-spring-boot/build.gradle +++ b/kvision-modules/kvision-server-spring-boot/build.gradle @@ -12,6 +12,7 @@ dependencies { compile "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion" compile "org.springframework.boot:spring-boot-starter:$springBootVersion" compile "org.springframework.boot:spring-boot-starter-web:$springBootVersion" + compile "org.springframework.boot:spring-boot-starter-websocket:$springBootVersion" compile "org.pac4j:pac4j-core:$pac4jVersion" compile "com.fasterxml.jackson.module:jackson-module-kotlin:${jacksonModuleKotlinVersion}" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlinVersion" diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt index 13b95b8b..23e00284 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt @@ -29,6 +29,9 @@ import org.springframework.web.bind.annotation.RequestMethod import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse +/** + * Controller for handling automatic routes. + */ @Controller open class KVController { @@ -44,7 +47,7 @@ open class KVController { ) open fun kVMapping(req: HttpServletRequest, res: HttpServletResponse) { val routeUrl = req.requestURI - val route = services.mapNotNull { + val handler = services.mapNotNull { when (req.method) { "GET" -> it.getRequests[routeUrl] "POST" -> it.postRequests[routeUrl] @@ -54,8 +57,8 @@ open class KVController { else -> null } }.firstOrNull() - if (route != null) { - route.invoke(req, res, applicationContext) + if (handler != null) { + handler.invoke(req, res, applicationContext) } else { res.status = HttpServletResponse.SC_NOT_FOUND } diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt index 66429acf..0606afef 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt @@ -23,14 +23,21 @@ package pl.treksoft.kvision.remote import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.filterNotNull +import kotlinx.coroutines.channels.map +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import org.slf4j.Logger import org.slf4j.LoggerFactory import org.springframework.context.ApplicationContext +import org.springframework.web.context.support.GenericWebApplicationContext +import org.springframework.web.socket.WebSocketSession import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse import kotlin.reflect.KClass @@ -55,6 +62,10 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: mutableMapOf() val optionsRequests: MutableMap<String, (HttpServletRequest, HttpServletResponse, ApplicationContext) -> Unit> = mutableMapOf() + val webSocketsRequests: MutableMap<String, suspend ( + WebSocketSession, GenericWebApplicationContext, ReceiveChannel<String>, SendChannel<String> + ) -> Unit> = + mutableMapOf() val mapper = jacksonObjectMapper() var counter: Int = 0 @@ -402,7 +413,39 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: noinline function: suspend T.(ReceiveChannel<PAR1>, SendChannel<PAR2>) -> Unit, route: String? ) { - TODO("Not implemented in Spring Boot module") + val routeDef = "route${this::class.simpleName}${counter++}" + webSocketsRequests[routeDef] = { webSocketSession, ctx, incoming, outgoing -> + val service = synchronized(this) { + WebSocketSessionHolder.webSocketSession = webSocketSession + ctx.getBean(serviceClass.java) + } + val requestChannel = incoming.map { + val jsonRpcRequest = getParameter<JsonRpcRequest>(it) + if (jsonRpcRequest.params.size == 1) { + getParameter<PAR1>(jsonRpcRequest.params[0]) + } else { + null + } + }.filterNotNull() + val responseChannel = Channel<PAR2>() + coroutineScope { + launch { + for (p in responseChannel) { + val text = mapper.writeValueAsString( + JsonRpcResponse( + id = 0, + result = mapper.writeValueAsString(p) + ) + ) + outgoing.send(text) + } + } + launch(start = CoroutineStart.UNDISPATCHED) { + function.invoke(service, requestChannel, responseChannel) + if (!responseChannel.isClosedForReceive) responseChannel.close() + } + } + } } /** @@ -454,6 +497,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress internal function + */ fun addRoute( method: HttpMethod, path: String, @@ -468,6 +514,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress internal function + */ protected inline fun <reified T> getParameter(str: String?): T { return str?.let { if (T::class == String::class) { @@ -479,6 +528,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } +/** + * @suppress internal function + */ fun HttpServletResponse.writeJSON(json: String) { val out = this.outputStream this.contentType = "application/json" diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt new file mode 100644 index 00000000..7bc2e612 --- /dev/null +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2017-present Robert Jaros + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package pl.treksoft.kvision.remote + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.beans.factory.config.ConfigurableBeanFactory +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.context.annotation.Scope +import org.springframework.http.HttpHeaders +import org.springframework.http.server.ServerHttpRequest +import org.springframework.http.server.ServerHttpResponse +import org.springframework.web.context.support.GenericWebApplicationContext +import org.springframework.web.socket.CloseStatus +import org.springframework.web.socket.TextMessage +import org.springframework.web.socket.WebSocketExtension +import org.springframework.web.socket.WebSocketHandler +import org.springframework.web.socket.WebSocketMessage +import org.springframework.web.socket.WebSocketSession +import org.springframework.web.socket.config.annotation.EnableWebSocket +import org.springframework.web.socket.config.annotation.WebSocketConfigurer +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry +import org.springframework.web.socket.handler.TextWebSocketHandler +import org.springframework.web.socket.server.HandshakeInterceptor +import java.net.InetSocketAddress +import java.net.URI +import java.security.Principal +import java.util.concurrent.ConcurrentHashMap + +const val KV_ROUTE_ID_ATTRIBUTE = "KV_ROUTE_ID_ATTRIBUTE" + +/** + * Automatic websockets configuration. + */ +@Configuration +@EnableWebSocket +open class KVWebSocketConfig : WebSocketConfigurer { + + @Autowired + lateinit var services: List<KVServiceManager<*>> + + @Autowired + lateinit var applicationContext: GenericWebApplicationContext + + override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) { + registry.addHandler(socketHandler(), "/kvws/*").setAllowedOrigins("*").addInterceptors(routeInterceptor()) + } + + @Bean + open fun routeInterceptor(): HandshakeInterceptor { + return KvHandshakeInterceptor() + } + + @Bean + open fun socketHandler(): WebSocketHandler { + return KvWebSocketHandler(services, applicationContext) + } + + @Bean + @Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE) + open fun webSocketSession(): WebSocketSession { + return WebSocketSessionHolder.webSocketSession + } +} + +object WebSocketSessionHolder { + var webSocketSession: WebSocketSession = DummyWebSocketSession() +} + +internal open class KvHandshakeInterceptor : HandshakeInterceptor { + override fun beforeHandshake( + request: ServerHttpRequest, + response: ServerHttpResponse, + wsHandler: WebSocketHandler, + attributes: MutableMap<String, Any> + ): Boolean { + val path = request.uri.path + val route = path.substring(path.lastIndexOf('/') + 1) + attributes[KV_ROUTE_ID_ATTRIBUTE] = route + return true + } + + override fun afterHandshake( + request: ServerHttpRequest, + response: ServerHttpResponse, + wsHandler: WebSocketHandler, + exception: Exception? + ) { + } +} + +@UseExperimental(ExperimentalCoroutinesApi::class) +internal open class KvWebSocketHandler( + private val services: List<KVServiceManager<*>>, + private val applicationContext: GenericWebApplicationContext +) : TextWebSocketHandler() { + + private val sessions = ConcurrentHashMap<String, Pair<Channel<String>, Channel<String>>>() + + private fun getHandler(session: WebSocketSession): (suspend ( + WebSocketSession, GenericWebApplicationContext, + ReceiveChannel<String>, SendChannel<String> + ) -> Unit)? { + val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String + return services.mapNotNull { + it.webSocketsRequests[routeId] + }.firstOrNull() + } + + private fun getSessionId(session: WebSocketSession): String { + val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String + return session.id + "###" + routeId + } + + override fun afterConnectionEstablished(session: WebSocketSession) { + getHandler(session)?.let { handler -> + val requestChannel = Channel<String>() + val responseChannel = Channel<String>() + GlobalScope.launch { + coroutineScope { + launch(Dispatchers.IO) { + for (text in responseChannel) { + session.sendMessage(TextMessage(text)) + } + session.close() + } + launch { + handler.invoke(session, applicationContext, requestChannel, responseChannel) + if (!responseChannel.isClosedForReceive) responseChannel.close() + } + sessions[getSessionId(session)] = responseChannel to requestChannel + } + } + } + } + + override fun handleTextMessage(session: WebSocketSession, message: TextMessage) { + getHandler(session)?.let { + sessions[getSessionId(session)]?.let { (_, requestChannel) -> + GlobalScope.launch { + requestChannel.send(message.payload) + } + } + } + } + + override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) { + getHandler(session)?.let { + sessions[getSessionId(session)]?.let { (responseChannel, requestChannel) -> + GlobalScope.launch { + responseChannel.close() + requestChannel.close() + } + sessions.remove(getSessionId(session)) + } + } + } +} + +open class DummyWebSocketSession : WebSocketSession { + override fun getBinaryMessageSizeLimit(): Int { + return 0 + } + + override fun sendMessage(message: WebSocketMessage<*>) { + } + + override fun getAcceptedProtocol(): String? { + return null + } + + override fun getTextMessageSizeLimit(): Int { + return 0 + } + + override fun getLocalAddress(): InetSocketAddress? { + return null + } + + override fun getId(): String { + return "" + } + + override fun getExtensions(): MutableList<WebSocketExtension> { + return mutableListOf() + } + + override fun getUri(): URI? { + return null + } + + override fun setBinaryMessageSizeLimit(messageSizeLimit: Int) { + } + + override fun getAttributes(): MutableMap<String, Any> { + return mutableMapOf() + } + + override fun getHandshakeHeaders(): HttpHeaders { + return HttpHeaders.EMPTY + } + + override fun isOpen(): Boolean { + return false + } + + override fun getPrincipal(): Principal? { + return null + } + + override fun close() { + } + + override fun close(status: CloseStatus) { + } + + override fun setTextMessageSizeLimit(messageSizeLimit: Int) { + } + + override fun getRemoteAddress(): InetSocketAddress? { + return null + } + +} diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt index abfef934..287a3a7e 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt @@ -23,6 +23,9 @@ package pl.treksoft.kvision.remote import org.springframework.web.servlet.config.annotation.InterceptorRegistration +/** + * A function to gather paths for an interceptor from a list of service managers. + */ fun InterceptorRegistration.addPathPatternsFromServices(services: List<KVServiceManager<*>>) { val paths = services.flatMap { it.postRequests.keys + it.putRequests.keys + it.optionsRequests.keys + it.optionsRequests.keys diff --git a/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories b/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories index e31ff24d..17ca7a1d 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories +++ b/kvision-modules/kvision-server-spring-boot/src/main/resources/META-INF/spring.factories @@ -1 +1 @@ -org.springframework.boot.autoconfigure.EnableAutoConfiguration=pl.treksoft.kvision.remote.KVController +org.springframework.boot.autoconfigure.EnableAutoConfiguration=pl.treksoft.kvision.remote.KVController,pl.treksoft.kvision.remote.KVWebSocketConfig |