diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt index b141081e..d91c45cc 100644 --- a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt @@ -24,10 +24,18 @@ internal fun ByteReadPacket.withLength(): ByteReadPacket = buildPacket { writePacket(this@withLength) } -internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) { - private var expectedFrameLength = 0 //TODO atomic for native +internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) : Closeable { + private var closed = false + private var expectedFrameLength = 0 private val packetBuilder: BytePacketBuilder = BytePacketBuilder() + + override fun close() { + packetBuilder.close() + closed = true + } + inline fun write(write: BytePacketBuilder.() -> Unit) { + if (closed) return packetBuilder.write() loop() } @@ -39,6 +47,7 @@ internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPac expectedFrameLength = it.readInt24() if (it.remaining >= expectedFrameLength) build(it) // if has length and frame } + packetBuilder.size < expectedFrameLength -> return // not enough bytes to read frame else -> withTemp { build(it) } // enough bytes to read frame } diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt new file mode 100644 index 00000000..1fa43b41 --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface NodejsTcpClientTransport : RSocketTransport { + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::NodejsTcpClientTransportBuilderImpl) +} + +public sealed interface NodejsTcpClientTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) +} + +private class NodejsTcpClientTransportBuilderImpl : NodejsTcpClientTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NodejsTcpClientTransport = NodejsTcpClientTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + ) +} + +private class NodejsTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, +) : NodejsTcpClientTransport { + override fun target(host: String, port: Int): RSocketClientTarget = NodejsTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + host = host, + port = port + ) +} + +private class NodejsTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val host: String, + private val port: Int, +) : RSocketClientTarget { + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + val socket = connect(port, host) + handler.handleNodejsTcpConnection(socket) + } +} diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt new file mode 100644 index 00000000..7846f0f4 --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.ktor.utils.io.core.* +import io.ktor.utils.io.js.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import org.khronos.webgl.* + +@RSocketTransportApi +internal suspend fun RSocketConnectionHandler.handleNodejsTcpConnection(socket: Socket): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + val inbound = channelForCloseable(Channel.UNLIMITED) + + val closed = CompletableDeferred() + val frameAssembler = FrameWithLengthAssembler { inbound.trySend(it) } + socket.on( + onData = { frameAssembler.write { writeFully(it.buffer) } }, + onError = { closed.completeExceptionally(it) }, + onClose = { + frameAssembler.close() + if (!it) closed.complete(Unit) + } + ) + + val writerJob = launch { + while (true) socket.writeFrame(outboundQueue.dequeueFrame() ?: break) + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(NodejsTcpConnection(outboundQueue, inbound)) + } finally { + inbound.cancel() + outboundQueue.close() // will cause `writerJob` completion + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + writerJob.join() + // close socket + socket.destroy() + closed.join() + } + } +} + +@RSocketTransportApi +private class NodejsTcpConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} + +private fun Socket.writeFrame(frame: ByteReadPacket) { + val packet = buildPacket { + writeInt24(frame.remaining.toInt()) + writePacket(frame) + } + write(Uint8Array(packet.readArrayBuffer())) +} diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt new file mode 100644 index 00000000..a89b2f1a --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface NodejsTcpServerInstance : RSocketServerInstance { + public val host: String + public val port: Int +} + +public sealed interface NodejsTcpServerTransport : RSocketTransport { + public fun target(host: String, port: Int): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory({ NodejsTcpServerTransportBuilderImpl }) +} + +public sealed interface NodejsTcpServerTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) +} + +private object NodejsTcpServerTransportBuilderImpl : NodejsTcpServerTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NodejsTcpServerTransport = NodejsTcpServerTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + ) +} + +private class NodejsTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, +) : NodejsTcpServerTransport { + override fun target(host: String, port: Int): RSocketServerTarget = NodejsTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + host = host, + port = port + ) +} + +private class NodejsTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val host: String, + private val port: Int, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): NodejsTcpServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val serverJob = launch { + val handlerScope = CoroutineScope(coroutineContext.supervisorContext()) + val server = createServer(port, host, { + coroutineContext.job.cancel("Server closed") + }) { + handlerScope.launch { handler.handleNodejsTcpConnection(it) } + } + try { + awaitCancellation() + } finally { + suspendCoroutine { cont -> server.close { cont.resume(Unit) } } + } + } + + return NodejsTcpServerInstanceImpl( + coroutineContext = coroutineContext + serverJob, + host = host, + port = port + ) + } +} + +@RSocketTransportApi +private class NodejsTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val host: String, + override val port: Int, +) : NodejsTcpServerInstance diff --git a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt index 0fd9a497..c2fc9163 100644 --- a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt +++ b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt @@ -34,3 +34,11 @@ class TcpTransportTest : TransportTest() { server.close() } } + +class NodejsTcpTransportTest : TransportTest() { + override suspend fun before() { + val port = PortProvider.next() + startServer(NodejsTcpServerTransport(testContext).target("127.0.0.1", port)) + client = connectClient(NodejsTcpClientTransport(testContext).target("127.0.0.1", port)) + } +}