diff options
Diffstat (limited to 'kvision-modules/kvision-server-spring-boot/src/main/kotlin')
4 files changed, 313 insertions, 4 deletions
diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt index 13b95b8b..23e00284 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVController.kt @@ -29,6 +29,9 @@ import org.springframework.web.bind.annotation.RequestMethod import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse +/** + * Controller for handling automatic routes. + */ @Controller open class KVController { @@ -44,7 +47,7 @@ open class KVController { ) open fun kVMapping(req: HttpServletRequest, res: HttpServletResponse) { val routeUrl = req.requestURI - val route = services.mapNotNull { + val handler = services.mapNotNull { when (req.method) { "GET" -> it.getRequests[routeUrl] "POST" -> it.postRequests[routeUrl] @@ -54,8 +57,8 @@ open class KVController { else -> null } }.firstOrNull() - if (route != null) { - route.invoke(req, res, applicationContext) + if (handler != null) { + handler.invoke(req, res, applicationContext) } else { res.status = HttpServletResponse.SC_NOT_FOUND } diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt index 66429acf..0606afef 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVServiceManager.kt @@ -23,14 +23,21 @@ package pl.treksoft.kvision.remote import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.filterNotNull +import kotlinx.coroutines.channels.map +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import org.slf4j.Logger import org.slf4j.LoggerFactory import org.springframework.context.ApplicationContext +import org.springframework.web.context.support.GenericWebApplicationContext +import org.springframework.web.socket.WebSocketSession import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse import kotlin.reflect.KClass @@ -55,6 +62,10 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: mutableMapOf() val optionsRequests: MutableMap<String, (HttpServletRequest, HttpServletResponse, ApplicationContext) -> Unit> = mutableMapOf() + val webSocketsRequests: MutableMap<String, suspend ( + WebSocketSession, GenericWebApplicationContext, ReceiveChannel<String>, SendChannel<String> + ) -> Unit> = + mutableMapOf() val mapper = jacksonObjectMapper() var counter: Int = 0 @@ -402,7 +413,39 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: noinline function: suspend T.(ReceiveChannel<PAR1>, SendChannel<PAR2>) -> Unit, route: String? ) { - TODO("Not implemented in Spring Boot module") + val routeDef = "route${this::class.simpleName}${counter++}" + webSocketsRequests[routeDef] = { webSocketSession, ctx, incoming, outgoing -> + val service = synchronized(this) { + WebSocketSessionHolder.webSocketSession = webSocketSession + ctx.getBean(serviceClass.java) + } + val requestChannel = incoming.map { + val jsonRpcRequest = getParameter<JsonRpcRequest>(it) + if (jsonRpcRequest.params.size == 1) { + getParameter<PAR1>(jsonRpcRequest.params[0]) + } else { + null + } + }.filterNotNull() + val responseChannel = Channel<PAR2>() + coroutineScope { + launch { + for (p in responseChannel) { + val text = mapper.writeValueAsString( + JsonRpcResponse( + id = 0, + result = mapper.writeValueAsString(p) + ) + ) + outgoing.send(text) + } + } + launch(start = CoroutineStart.UNDISPATCHED) { + function.invoke(service, requestChannel, responseChannel) + if (!responseChannel.isClosedForReceive) responseChannel.close() + } + } + } } /** @@ -454,6 +497,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress internal function + */ fun addRoute( method: HttpMethod, path: String, @@ -468,6 +514,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } + /** + * @suppress internal function + */ protected inline fun <reified T> getParameter(str: String?): T { return str?.let { if (T::class == String::class) { @@ -479,6 +528,9 @@ actual open class KVServiceManager<T : Any> actual constructor(val serviceClass: } } +/** + * @suppress internal function + */ fun HttpServletResponse.writeJSON(json: String) { val out = this.outputStream this.contentType = "application/json" diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt new file mode 100644 index 00000000..7bc2e612 --- /dev/null +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/KVWebSocketConfig.kt @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2017-present Robert Jaros + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package pl.treksoft.kvision.remote + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.beans.factory.config.ConfigurableBeanFactory +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.context.annotation.Scope +import org.springframework.http.HttpHeaders +import org.springframework.http.server.ServerHttpRequest +import org.springframework.http.server.ServerHttpResponse +import org.springframework.web.context.support.GenericWebApplicationContext +import org.springframework.web.socket.CloseStatus +import org.springframework.web.socket.TextMessage +import org.springframework.web.socket.WebSocketExtension +import org.springframework.web.socket.WebSocketHandler +import org.springframework.web.socket.WebSocketMessage +import org.springframework.web.socket.WebSocketSession +import org.springframework.web.socket.config.annotation.EnableWebSocket +import org.springframework.web.socket.config.annotation.WebSocketConfigurer +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry +import org.springframework.web.socket.handler.TextWebSocketHandler +import org.springframework.web.socket.server.HandshakeInterceptor +import java.net.InetSocketAddress +import java.net.URI +import java.security.Principal +import java.util.concurrent.ConcurrentHashMap + +const val KV_ROUTE_ID_ATTRIBUTE = "KV_ROUTE_ID_ATTRIBUTE" + +/** + * Automatic websockets configuration. + */ +@Configuration +@EnableWebSocket +open class KVWebSocketConfig : WebSocketConfigurer { + + @Autowired + lateinit var services: List<KVServiceManager<*>> + + @Autowired + lateinit var applicationContext: GenericWebApplicationContext + + override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) { + registry.addHandler(socketHandler(), "/kvws/*").setAllowedOrigins("*").addInterceptors(routeInterceptor()) + } + + @Bean + open fun routeInterceptor(): HandshakeInterceptor { + return KvHandshakeInterceptor() + } + + @Bean + open fun socketHandler(): WebSocketHandler { + return KvWebSocketHandler(services, applicationContext) + } + + @Bean + @Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE) + open fun webSocketSession(): WebSocketSession { + return WebSocketSessionHolder.webSocketSession + } +} + +object WebSocketSessionHolder { + var webSocketSession: WebSocketSession = DummyWebSocketSession() +} + +internal open class KvHandshakeInterceptor : HandshakeInterceptor { + override fun beforeHandshake( + request: ServerHttpRequest, + response: ServerHttpResponse, + wsHandler: WebSocketHandler, + attributes: MutableMap<String, Any> + ): Boolean { + val path = request.uri.path + val route = path.substring(path.lastIndexOf('/') + 1) + attributes[KV_ROUTE_ID_ATTRIBUTE] = route + return true + } + + override fun afterHandshake( + request: ServerHttpRequest, + response: ServerHttpResponse, + wsHandler: WebSocketHandler, + exception: Exception? + ) { + } +} + +@UseExperimental(ExperimentalCoroutinesApi::class) +internal open class KvWebSocketHandler( + private val services: List<KVServiceManager<*>>, + private val applicationContext: GenericWebApplicationContext +) : TextWebSocketHandler() { + + private val sessions = ConcurrentHashMap<String, Pair<Channel<String>, Channel<String>>>() + + private fun getHandler(session: WebSocketSession): (suspend ( + WebSocketSession, GenericWebApplicationContext, + ReceiveChannel<String>, SendChannel<String> + ) -> Unit)? { + val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String + return services.mapNotNull { + it.webSocketsRequests[routeId] + }.firstOrNull() + } + + private fun getSessionId(session: WebSocketSession): String { + val routeId = session.attributes[KV_ROUTE_ID_ATTRIBUTE] as String + return session.id + "###" + routeId + } + + override fun afterConnectionEstablished(session: WebSocketSession) { + getHandler(session)?.let { handler -> + val requestChannel = Channel<String>() + val responseChannel = Channel<String>() + GlobalScope.launch { + coroutineScope { + launch(Dispatchers.IO) { + for (text in responseChannel) { + session.sendMessage(TextMessage(text)) + } + session.close() + } + launch { + handler.invoke(session, applicationContext, requestChannel, responseChannel) + if (!responseChannel.isClosedForReceive) responseChannel.close() + } + sessions[getSessionId(session)] = responseChannel to requestChannel + } + } + } + } + + override fun handleTextMessage(session: WebSocketSession, message: TextMessage) { + getHandler(session)?.let { + sessions[getSessionId(session)]?.let { (_, requestChannel) -> + GlobalScope.launch { + requestChannel.send(message.payload) + } + } + } + } + + override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) { + getHandler(session)?.let { + sessions[getSessionId(session)]?.let { (responseChannel, requestChannel) -> + GlobalScope.launch { + responseChannel.close() + requestChannel.close() + } + sessions.remove(getSessionId(session)) + } + } + } +} + +open class DummyWebSocketSession : WebSocketSession { + override fun getBinaryMessageSizeLimit(): Int { + return 0 + } + + override fun sendMessage(message: WebSocketMessage<*>) { + } + + override fun getAcceptedProtocol(): String? { + return null + } + + override fun getTextMessageSizeLimit(): Int { + return 0 + } + + override fun getLocalAddress(): InetSocketAddress? { + return null + } + + override fun getId(): String { + return "" + } + + override fun getExtensions(): MutableList<WebSocketExtension> { + return mutableListOf() + } + + override fun getUri(): URI? { + return null + } + + override fun setBinaryMessageSizeLimit(messageSizeLimit: Int) { + } + + override fun getAttributes(): MutableMap<String, Any> { + return mutableMapOf() + } + + override fun getHandshakeHeaders(): HttpHeaders { + return HttpHeaders.EMPTY + } + + override fun isOpen(): Boolean { + return false + } + + override fun getPrincipal(): Principal? { + return null + } + + override fun close() { + } + + override fun close(status: CloseStatus) { + } + + override fun setTextMessageSizeLimit(messageSizeLimit: Int) { + } + + override fun getRemoteAddress(): InetSocketAddress? { + return null + } + +} diff --git a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt index abfef934..287a3a7e 100644 --- a/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt +++ b/kvision-modules/kvision-server-spring-boot/src/main/kotlin/pl/treksoft/kvision/remote/Security.kt @@ -23,6 +23,9 @@ package pl.treksoft.kvision.remote import org.springframework.web.servlet.config.annotation.InterceptorRegistration +/** + * A function to gather paths for an interceptor from a list of service managers. + */ fun InterceptorRegistration.addPathPatternsFromServices(services: List<KVServiceManager<*>>) { val paths = services.flatMap { it.postRequests.keys + it.putRequests.keys + it.optionsRequests.keys + it.optionsRequests.keys |