From 093ba0ee2caf163f30b8a52a5760732a3e8e1ade Mon Sep 17 00:00:00 2001 From: ThibaultBee <37510686+ThibaultBee@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:28:25 +0100 Subject: [PATCH] fix(ktx): close socket when connection failed to avoid a leak of thread --- .../srtdroid/ktx/CoroutineSrtSocket.kt | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/srtdroid-ktx/src/main/java/io/github/thibaultbee/srtdroid/ktx/CoroutineSrtSocket.kt b/srtdroid-ktx/src/main/java/io/github/thibaultbee/srtdroid/ktx/CoroutineSrtSocket.kt index 93aa8d6..100e2a1 100644 --- a/srtdroid-ktx/src/main/java/io/github/thibaultbee/srtdroid/ktx/CoroutineSrtSocket.kt +++ b/srtdroid-ktx/src/main/java/io/github/thibaultbee/srtdroid/ktx/CoroutineSrtSocket.kt @@ -52,8 +52,6 @@ private constructor( ConfigurableSrtSocket, CoroutineScope { constructor() : this(SrtSocket()) - private var hasBeenConnected = false - val socketContext: CompletableJob = Job() override val coroutineContext = socketContext @@ -174,24 +172,24 @@ private constructor( val localPort: Int get() = socket.localPort + /** + * Internal listener to handle connection lost + */ + private val clientListener = object : SrtSocket.ClientListener { + override fun onConnectionLost( + ns: SrtSocket, + error: ErrorType, + peerAddress: InetSocketAddress, + token: Int + ) { + socket.close() + complete(ConnectException(error.toString())) + } + } + init { socket.setSockFlag(SockOpt.RCVSYN, false) socket.setSockFlag(SockOpt.SNDSYN, false) - - socket.clientListener = - object : SrtSocket.ClientListener { - override fun onConnectionLost( - ns: SrtSocket, - error: ErrorType, - peerAddress: InetSocketAddress, - token: Int - ) { - if (hasBeenConnected) { - socket.close() - complete(ConnectException(error.toString())) - } - } - } } private fun complete(t: Throwable? = null) { @@ -230,8 +228,14 @@ private constructor( * @throws BindException if bind has failed */ suspend fun bind(address: InetSocketAddress) = withContext(Dispatchers.IO) { - socket.bind(address) - hasBeenConnected = true + try { + socket.bind(address) + socket.clientListener = clientListener + } catch (t: Throwable) { + socket.close() + complete(t) + throw t + } } /** @@ -243,10 +247,16 @@ private constructor( * @throws ConnectException if connection has failed */ suspend fun connect(address: InetSocketAddress) { - execute(EpollOpt.OUT, onContinuation = { socket.connect(address) }) { - null + try { + execute(EpollOpt.OUT, onContinuation = { socket.connect(address) }) { + null + } + socket.clientListener = clientListener + } catch (t: Throwable) { + socket.close() + complete(t) + throw t } - hasBeenConnected = true } /** @@ -262,10 +272,18 @@ private constructor( localAddress: InetSocketAddress, remoteAddress: InetSocketAddress ) { - execute(EpollOpt.OUT, onContinuation = { socket.rendezVous(localAddress, remoteAddress) }) { - null + try { + execute( + EpollOpt.OUT, + onContinuation = { socket.rendezVous(localAddress, remoteAddress) }) { + null + } + socket.clientListener = clientListener + } catch (t: Throwable) { + socket.close() + complete(t) + throw t } - hasBeenConnected = true } /**