Skip to content

Commit

Permalink
fix(ktx): close socket when connection failed to avoid a leak of thread
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaultBee committed Dec 16, 2024
1 parent 252c110 commit 093ba0e
Showing 1 changed file with 43 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ private constructor(
ConfigurableSrtSocket, CoroutineScope {
constructor() : this(SrtSocket())

private var hasBeenConnected = false

val socketContext: CompletableJob = Job()

override val coroutineContext = socketContext
Expand Down Expand Up @@ -174,24 +172,24 @@ private constructor(
val localPort: Int
get() = socket.localPort

/**
* Internal listener to handle connection lost
*/
private val clientListener = object : SrtSocket.ClientListener {
override fun onConnectionLost(
ns: SrtSocket,
error: ErrorType,
peerAddress: InetSocketAddress,
token: Int
) {
socket.close()
complete(ConnectException(error.toString()))
}
}

init {
socket.setSockFlag(SockOpt.RCVSYN, false)
socket.setSockFlag(SockOpt.SNDSYN, false)

socket.clientListener =
object : SrtSocket.ClientListener {
override fun onConnectionLost(
ns: SrtSocket,
error: ErrorType,
peerAddress: InetSocketAddress,
token: Int
) {
if (hasBeenConnected) {
socket.close()
complete(ConnectException(error.toString()))
}
}
}
}

private fun complete(t: Throwable? = null) {
Expand Down Expand Up @@ -230,8 +228,14 @@ private constructor(
* @throws BindException if bind has failed
*/
suspend fun bind(address: InetSocketAddress) = withContext(Dispatchers.IO) {
socket.bind(address)
hasBeenConnected = true
try {
socket.bind(address)
socket.clientListener = clientListener
} catch (t: Throwable) {
socket.close()
complete(t)
throw t
}
}

/**
Expand All @@ -243,10 +247,16 @@ private constructor(
* @throws ConnectException if connection has failed
*/
suspend fun connect(address: InetSocketAddress) {
execute(EpollOpt.OUT, onContinuation = { socket.connect(address) }) {
null
try {
execute(EpollOpt.OUT, onContinuation = { socket.connect(address) }) {
null
}
socket.clientListener = clientListener
} catch (t: Throwable) {
socket.close()
complete(t)
throw t
}
hasBeenConnected = true
}

/**
Expand All @@ -262,10 +272,18 @@ private constructor(
localAddress: InetSocketAddress,
remoteAddress: InetSocketAddress
) {
execute(EpollOpt.OUT, onContinuation = { socket.rendezVous(localAddress, remoteAddress) }) {
null
try {
execute(
EpollOpt.OUT,
onContinuation = { socket.rendezVous(localAddress, remoteAddress) }) {
null
}
socket.clientListener = clientListener
} catch (t: Throwable) {
socket.close()
complete(t)
throw t
}
hasBeenConnected = true
}

/**
Expand Down

0 comments on commit 093ba0e

Please sign in to comment.