Skip to content

Commit

Permalink
Merge pull request #1 from seniorjoinu/fix/concurrent-transmission
Browse files Browse the repository at this point in the history
Fix/concurrent transmission
  • Loading branch information
seniorjoinu authored Apr 25, 2019
2 parents dda29f5 + 1a936ba commit af341d1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 39 deletions.
61 changes: 33 additions & 28 deletions src/main/kotlin/net/joinu/nioudp/NonBlockingUDPSocket.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package net.joinu.nioudp

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import kotlinx.coroutines.supervisorScope
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KotlinLogging
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.util.*
import kotlin.coroutines.CoroutineContext
import java.util.concurrent.Executors


interface NioSocket {
Expand All @@ -22,15 +22,17 @@ interface NioSocket {
suspend fun close()
}

class NonBlockingUDPSocket(
override val coroutineContext: CoroutineContext = Dispatchers.IO
) : NioSocket, CoroutineScope {
const val ALLOCATION_THRESHOLD_BYTES = 1400

class NonBlockingUDPSocket : NioSocket {

private val logger = KotlinLogging.logger("NonBlockingUDPSocket-${Random().nextInt()}")

lateinit var channel: DatagramChannel
private val channelMutex = Mutex()

private val listenDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher()

var onMessageHandler: NetworkMessageHandler? = null

private var state = SocketState.UNBOUND
Expand Down Expand Up @@ -91,35 +93,38 @@ class NonBlockingUDPSocket(
logger.trace { "Listening" }
state = SocketState.LISTENING

var count = 0
while (!isClosed()) {
count++
if (count == 10) {
count = 0
delay(1)
}
supervisorScope {
while (!isClosed()) {
val remoteAddress = channelMutex.withLock {
if (channel.isOpen) channel.receive(buf)
else null
}

val remoteAddress = channelMutex.withLock {
if (channel.isOpen) channel.receive(buf)
else null
}
if (buf.position() == 0) continue

val size = buf.position()

if (buf.position() == 0) continue
val size = buf.position()
buf.flip()

buf.flip()
// when allocating byte buffer follow the next rule: if data.size < ~1400 bytes - use on heap buffer, else - use off heap buffer
val data = if (size < ALLOCATION_THRESHOLD_BYTES)
ByteBuffer.allocate(size)
else
ByteBuffer.allocateDirect(size)

val data = ByteBuffer.allocateDirect(size)
data.put(buf)
data.flip()
data.put(buf)
data.flip()

buf.clear()
buf.clear()

val from = InetSocketAddress::class.java.cast(remoteAddress)
launch {
val from = InetSocketAddress::class.java.cast(remoteAddress)

logger.trace { "Received data packet from $from, invoking onMessage handler" }
logger.trace { "Received data packet from $from, invoking onMessage handler" }

onMessageHandler?.invoke(data, from)
onMessageHandler?.invoke(data, from)
}
}
}
}

Expand Down
10 changes: 7 additions & 3 deletions src/main/kotlin/net/joinu/rudp/ConfigurableRUDPSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ConfigurableRUDPSocket(mtu: Int) {

// TODO: clean up acks eventually
val acks = ConcurrentHashMap<InetSocketAddress, ConcurrentSkipListSet<Long>>()
val received = ConcurrentHashMap<InetSocketAddress, ConcurrentSkipListSet<Long>>()

var onMessageHandler: NetworkMessageHandler? = null
val repairBlockSizeBytes = mtu - RepairBlock.METADATA_SIZE_BYTES
Expand Down Expand Up @@ -233,7 +234,7 @@ class ConfigurableRUDPSocket(mtu: Int) {

logger.trace { "Received REPAIR_BLOCK message for threadId: ${block.threadId} blockId: ${block.blockId} from: $from" }

if (ackReceived(from, block.threadId)) {
if (packetReceived(from, block.threadId)) {
logger.trace { "Received a repair block for already received threadId: ${block.threadId}, skipping..." }
sendACK(block.threadId, from)
return@onMessage
Expand All @@ -254,8 +255,7 @@ class ConfigurableRUDPSocket(mtu: Int) {
val message = ByteBuffer.allocateDirect(block.messageBytes)
decoder.recover(message as DirectBuffer, block.messageBytes)

// TODO: put ack somewhere else, because it's incorrect
putACK(from, block.threadId)
markPacketReceived(from, block.threadId)

decoder.close()
decoders[from]?.remove(block.threadId)
Expand Down Expand Up @@ -303,6 +303,10 @@ class ConfigurableRUDPSocket(mtu: Int) {
private fun putACK(from: InetSocketAddress, threadId: Long) =
acks.getOrPut(from) { ConcurrentSkipListSet() }.add(threadId)

private fun packetReceived(from: InetSocketAddress, threadId: Long) = received[from]?.contains(threadId) == true
private fun markPacketReceived(from: InetSocketAddress, threadId: Long) =
received.getOrPut(from) { ConcurrentSkipListSet() }.add(threadId)

private fun throwIfTransmissionTimeoutElapsed(trtBefore: Long, trtTimeoutMs: Long, threadId: Long) {
if (System.currentTimeMillis() - trtBefore > trtTimeoutMs)
throw TransmissionTimeoutException("Transmission timeout for threadId: $threadId elapsed")
Expand Down
35 changes: 27 additions & 8 deletions src/test/kotlin/net/joinu/RUDPSocketTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,59 @@ class RUDPSocketTest {
launch(Dispatchers.IO) {
rudp1.listen()
}

launch(Dispatchers.IO) {
rudp2.listen()
}

val n = 100

var receive = 1
var receive1 = 1
rudp2.onMessage { buffer, from ->
receive++
receive1++
}

var sent = 1
var sent1 = 1
for (i in (0 until n)) {
launch {
launch(Dispatchers.IO) {
rudp1.send(
net1Content.toDirectByteBuffer(),
net2Addr,
fctTimeoutMsProvider = { 50 },
windowSizeProvider = { 1016 }
windowSizeProvider = { 1000 }
)
sent1++
}
}

var receive2 = 1
rudp1.onMessage { buffer, from ->
receive2++
}

var sent2 = 1
for (i in (0 until n)) {
launch(Dispatchers.IO) {
rudp2.send(
net1Content.toDirectByteBuffer(),
net1Addr,
fctTimeoutMsProvider = { 50 },
windowSizeProvider = { 1000 }
)
sent++
sent2++
}
}

while (true) {
if (sent >= n && receive >= n) {
if (sent1 >= n && receive1 >= n && sent2 >= n && receive2 >= n) {
delay(100)

rudp1.close()
rudp2.close()
break
}
delay(1)
println("$sent, $receive")
println("$sent1, $receive1, $sent2, $receive2")
}
}
}
Expand Down

0 comments on commit af341d1

Please sign in to comment.