From e854d05e47616bb59469063bed88f25ea189d621 Mon Sep 17 00:00:00 2001 From: lukellmann Date: Mon, 29 Jan 2024 22:50:00 +0100 Subject: [PATCH 1/4] Simplify IdentifyRateLimiterImpl Instead of a complicated channel structure, simply use a Mutex for each rateLimitKey as a way to queue the consume calls for the same key. --- .../kotlin/ratelimit/IdentifyRateLimiter.kt | 235 +++++------------- 1 file changed, 60 insertions(+), 175 deletions(-) diff --git a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt index d9228a1a47d5..c75444ef8df7 100644 --- a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt +++ b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt @@ -2,17 +2,13 @@ package dev.kord.gateway.ratelimit import dev.kord.gateway.* import io.github.oshai.kotlinlogging.KotlinLogging -import kotlinx.atomicfu.atomic -import kotlinx.atomicfu.getAndUpdate -import kotlinx.atomicfu.loop -import kotlinx.atomicfu.update import kotlinx.coroutines.* -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.onSuccess -import kotlinx.coroutines.flow.* -import kotlinx.coroutines.selects.onTimeout -import kotlinx.coroutines.selects.select -import kotlin.jvm.JvmField +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.onSubscription +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.coroutines.resume import kotlin.time.Duration.Companion.seconds /** @@ -61,197 +57,86 @@ private class IdentifyRateLimiterImpl( private val dispatcher: CoroutineDispatcher, ) : IdentifyRateLimiter { - private class IdentifyRequest( - @JvmField val shardId: Int, - @JvmField val events: SharedFlow, - private val permission: CompletableDeferred, - ) { - fun allow() = permission.complete(Unit) - } - - private val IdentifyRequest.rateLimitKey get() = shardId % maxConcurrency - - // doesn't need onUndeliveredElement for rejecting requests, we don't close or cancel the channel, receiving can't - // be cancelled (running in GlobalScope), only send can be cancelled, which is ok because permission isn't needed in - // that case - private val channel = Channel(capacity = maxConcurrency) - - /** - * Can be - * - [NOT_RUNNING]: no coroutine ([launchRateLimiterCoroutine]) is running, or it is about to stop and will no - * longer process [IdentifyRequest]s - * - [RUNNING_WITH_NO_CONSUMERS]: the coroutine is running and ready to process [IdentifyRequest]s but there are no - * consumers that sent a request - * - unsigned number of consumers: number of concurrent [consume] invocations that wait for their [IdentifyRequest] - * to be processed (min [ONE_CONSUMER], max [MAX_CONSUMERS]) - */ - private val state = atomic(initial = NOT_RUNNING) - - private fun getOldStateAndIncrementConsumers() = state.getAndUpdate { current -> - when (current) { - NOT_RUNNING, RUNNING_WITH_NO_CONSUMERS -> ONE_CONSUMER // we are the first consumer - MAX_CONSUMERS -> error( - "Too many concurrent identify attempts, overflow happened. There are already ${current.toUInt()} " + - "other consume() invocations waiting. This is most likely a bug." - ) - else -> current + 1 // increment number of consumers - } - } - - private fun decrementConsumers() = state.update { current -> - when (current) { - NOT_RUNNING -> error("Should be running but was NOT_RUNNING") - RUNNING_WITH_NO_CONSUMERS -> error("Should have consumers but was RUNNING_WITH_NO_CONSUMERS") - ONE_CONSUMER -> RUNNING_WITH_NO_CONSUMERS // we were the last consumer - else -> current - 1 // decrement number of consumers - } - } - - private fun stopIfHasNoConsumers(): Boolean = state.loop { current -> - when (current) { - NOT_RUNNING -> error("Should be running but was NOT_RUNNING") - RUNNING_WITH_NO_CONSUMERS -> // no new requests in sight -> try to stop - if (state.compareAndSet(expect = current, update = NOT_RUNNING)) return true - else -> return false // don't change number of consumers + // These mutexes are fair (according to the documentation of Mutex and its factory function), we rely on this to + // guarantee a fair rate limiter behavior. + // rateLimitKey is always in the range 0.. no resource cleanup needed + // SupervisorJob is used to ensure an unexpected failure in one coroutine doesn't leave the rate limiter unusable + private val scope = CoroutineScope( + context = SupervisorJob() + dispatcher + CoroutineExceptionHandler { context, exception -> + // we can't be cancelled, so all exceptions are bugs + logger.error(exception) { + "IdentifyRateLimiter threw an exception in context $context, please report this, it should not happen" + } } - } - + ) override suspend fun consume(shardId: Int, events: SharedFlow) { require(shardId >= 0) { "shardId must be non-negative but was $shardId" } - val oldState = getOldStateAndIncrementConsumers() - if (oldState == NOT_RUNNING) launchRateLimiterCoroutine() // if this throws we are screwed anyway - - // this might throw because of cancellation of the coroutine that called consume(), which is ok - try { - val permission = CompletableDeferred() - channel.send(IdentifyRequest(shardId, events, permission)) - permission.await() - } finally { - decrementConsumers() + // if the coroutine that called consume() is cancelled, the CancellableContinuation makes sure the waiting is + // stopped (the Gateway won't try to identify), so we don't need to hold the mutex and waste time for other + // calls + suspendCancellableCoroutine { continuation -> + val job = launchIdentifyWaiter(shardId, events, continuation) + continuation.invokeOnCancellation { job.cancel() } } } - - private fun launchRateLimiterCoroutine() { - // GlobalScope is ok here: - // - only one coroutine is launched at a time: - // previously running one will allow next one to start by setting state to NOT_RUNNING just before exiting - // - the coroutine will time out eventually when no identify requests are sent, a new one will be launched on - // demand which will then time out again etc. - // => no leaks - @OptIn(DelicateCoroutinesApi::class) - GlobalScope.launch(context = dispatcher + ExceptionLogger) { - - // only read/written from sequential loop, not from launched concurrent coroutines - // request.rateLimitKey is always in the range 0..(size = maxConcurrency) - - while (true) { - val batch = receiveSortedBatchOfRequestsOrNull() ?: when { - identifyWaiters.any { it != null && !it.isCompleted } -> continue // keep receiving while waiting - stopIfHasNoConsumers() -> return@launch // no consumers and no waiters -> stop (only exit point) - else -> continue // has consumers -> there will be new requests soon - } - - for (request in batch) { - val key = request.rateLimitKey - val previousWaiter = identifyWaiters[key] - identifyWaiters[key] = launchIdentifyWaiter(previousWaiter, request) - } - } - } - } - - /** Returns `null` on [RECEIVE_TIMEOUT]. */ - private suspend fun receiveSortedBatchOfRequestsOrNull(): List? { - - // first receive suspends until we get a request or time out - // using select instead of withTimeoutOrNull here, so we can't remove the request from the channel on timeout - val firstRequest = select { - channel.onReceive { it } - @OptIn(ExperimentalCoroutinesApi::class) onTimeout(RECEIVE_TIMEOUT) { null } - } ?: return null - - yield() // give other requests the chance to arrive if they were sent at the same time - - return buildList { - add(firstRequest) - - do { // now we receive other requests that are immediately available - val result = channel.tryReceive().onSuccess(::add) - } while (result.isSuccess) - - sortWith(ShardIdComparator) // sort requests in this batch - } - } - - private fun CoroutineScope.launchIdentifyWaiter(previousWaiter: Job?, request: IdentifyRequest) = launch { - if (previousWaiter != null) { + private fun launchIdentifyWaiter( + shardId: Int, + events: SharedFlow, + continuation: CancellableContinuation, + ) = scope.launch { + val rateLimitKey = shardId % maxConcurrency + val mutex = mutexesByRateLimitKey[rateLimitKey] + // best effort, only used for logging (might be false even if mutex.withLock suspends later if we are unlucky) + val wasLocked = mutex.isLocked + if (wasLocked) { logger.debug { - "Waiting for other shard(s) with rate_limit_key ${request.rateLimitKey} to identify " + - "before identifying on shard ${request.shardId}" + "Waiting for other shard(s) with rate_limit_key $rateLimitKey to identify before identifying on " + + "shard $shardId" } - previousWaiter.join() } - - // using a timeout so a broken gateway won't block its rate_limit_key for a long time - val responseToIdentify = withTimeoutOrNull(IDENTIFY_TIMEOUT) { - request.events - .onSubscription { // onSubscription ensures we don't miss events + mutex.withLock { // in case something terrible happens, ensure the mutex is unlocked + // using a timeout so a broken gateway won't block its rate_limit_key for a long time + val responseToIdentify = withTimeoutOrNull(IDENTIFY_TIMEOUT) { + events.onSubscription { // onSubscription ensures we don't miss events logger.debug { "${ - if (previousWaiter != null) - "Waited for other shard(s) with rate_limit_key ${request.rateLimitKey} to identify, i" + if (wasLocked) "Waited for other shard(s) with rate_limit_key $rateLimitKey to identify, i" else "I" - }dentifying on shard ${request.shardId} with rate_limit_key ${request.rateLimitKey}..." + }dentifying on shard $shardId with rate_limit_key $rateLimitKey..." } - request.allow() // notify gateway waiting in consume -> it will try to identify -> wait for event - } - .first { it is Ready || it is InvalidSession || it is Close } - } - - logger.debug { - when (responseToIdentify) { - null -> "Identifying on shard ${request.shardId} timed out" - is Ready -> "Identified on shard ${request.shardId}" - is InvalidSession -> "Identifying on shard ${request.shardId} failed, session could not be initialized" - is Close -> "Shard ${request.shardId} was stopped before it could identify" - else -> "Unexpected responseToIdentify on shard ${request.shardId}: $responseToIdentify" - } + ", delaying $DELAY_AFTER_IDENTIFY before freeing up rate_limit_key ${request.rateLimitKey}" + // notify gateway waiting in consume -> it will try to identify -> wait for event + continuation.resume(Unit) + }.first { it is Ready || it is InvalidSession || it is Close } + } + logger.debug { + when (responseToIdentify) { + null -> "Identifying on shard $shardId timed out" + is Ready -> "Identified on shard $shardId" + is InvalidSession -> "Identifying on shard $shardId failed, session could not be initialized" + is Close -> "Shard $shardId was stopped before it could identify" + else -> "Unexpected responseToIdentify on shard $shardId: $responseToIdentify" + } + ", delaying $DELAY_AFTER_IDENTIFY before freeing up rate_limit_key $rateLimitKey" + } + delay(DELAY_AFTER_IDENTIFY) // delay before unlocking mutex to free up the current rateLimitKey } - - // next waiter for the current rate_limit_key has to wait for this delay before it can identify - delay(DELAY_AFTER_IDENTIFY) } - override fun toString() = "IdentifyRateLimiter(maxConcurrency=$maxConcurrency, dispatcher=$dispatcher)" - private companion object { // https://discord.com/developers/docs/topics/gateway#rate-limiting: // Apps also have a limit for concurrent Identify requests allowed per 5 seconds. // -> for each rate_limit_key: delay after identify private val DELAY_AFTER_IDENTIFY = 5.seconds - private val RECEIVE_TIMEOUT = DELAY_AFTER_IDENTIFY * 2 private val IDENTIFY_TIMEOUT = DELAY_AFTER_IDENTIFY / 2 - - // states - private const val MAX_CONSUMERS = -2 // interpreted as UInt.MAX_VALUE - 1u - private const val NOT_RUNNING = -1 - private const val RUNNING_WITH_NO_CONSUMERS = 0 - private const val ONE_CONSUMER = 1 - - private val ShardIdComparator = Comparator { r1, r2 -> r1.shardId.compareTo(r2.shardId) } - - private val ExceptionLogger = CoroutineExceptionHandler { context, exception -> - // we can't be cancelled (GlobalScope) and we never close the channel, so all exceptions are bugs - logger.error(exception) { - "IdentifyRateLimiter threw an exception in context $context, please report this, it should not happen" - } - } } } From 3096b64435e1dfd32a37e4124966cf9d6f3fc54a Mon Sep 17 00:00:00 2001 From: lukellmann Date: Tue, 30 Jan 2024 21:20:07 +0100 Subject: [PATCH 2/4] Make sure IdentifyRateLimiterImpl.consume does a tail-call --- gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt index c75444ef8df7..c3e7b008391e 100644 --- a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt +++ b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt @@ -81,7 +81,7 @@ private class IdentifyRateLimiterImpl( // if the coroutine that called consume() is cancelled, the CancellableContinuation makes sure the waiting is // stopped (the Gateway won't try to identify), so we don't need to hold the mutex and waste time for other // calls - suspendCancellableCoroutine { continuation -> + return suspendCancellableCoroutine { continuation -> val job = launchIdentifyWaiter(shardId, events, continuation) continuation.invokeOnCancellation { job.cancel() } } From bdb3fa4355e54eba7fb05ee4f596ae55a71a4b14 Mon Sep 17 00:00:00 2001 From: Luca Kellermann Date: Sat, 10 Aug 2024 14:56:29 +0200 Subject: [PATCH 3/4] Ensure waiting is logged --- .../commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt index c3e7b008391e..b0743bbd2ee9 100644 --- a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt +++ b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt @@ -7,7 +7,6 @@ import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.onSubscription import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock import kotlin.coroutines.resume import kotlin.time.Duration.Companion.seconds @@ -94,15 +93,15 @@ private class IdentifyRateLimiterImpl( ) = scope.launch { val rateLimitKey = shardId % maxConcurrency val mutex = mutexesByRateLimitKey[rateLimitKey] - // best effort, only used for logging (might be false even if mutex.withLock suspends later if we are unlucky) - val wasLocked = mutex.isLocked + val wasLocked = !mutex.tryLock() if (wasLocked) { logger.debug { "Waiting for other shard(s) with rate_limit_key $rateLimitKey to identify before identifying on " + "shard $shardId" } + mutex.lock() } - mutex.withLock { // in case something terrible happens, ensure the mutex is unlocked + try { // in case something terrible happens, ensure the mutex is unlocked // using a timeout so a broken gateway won't block its rate_limit_key for a long time val responseToIdentify = withTimeoutOrNull(IDENTIFY_TIMEOUT) { events.onSubscription { // onSubscription ensures we don't miss events @@ -126,6 +125,8 @@ private class IdentifyRateLimiterImpl( } + ", delaying $DELAY_AFTER_IDENTIFY before freeing up rate_limit_key $rateLimitKey" } delay(DELAY_AFTER_IDENTIFY) // delay before unlocking mutex to free up the current rateLimitKey + } finally { + mutex.unlock() } } From 5cbbf87c3700ed3549d96d2f257117a711dba7ee Mon Sep 17 00:00:00 2001 From: Luca Kellermann Date: Sun, 11 Aug 2024 03:52:17 +0200 Subject: [PATCH 4/4] Add more debug information on consume cancellation --- .../kotlin/ratelimit/IdentifyRateLimiter.kt | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt index b0743bbd2ee9..e6790fee38d1 100644 --- a/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt +++ b/gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt @@ -77,12 +77,17 @@ private class IdentifyRateLimiterImpl( override suspend fun consume(shardId: Int, events: SharedFlow) { require(shardId >= 0) { "shardId must be non-negative but was $shardId" } - // if the coroutine that called consume() is cancelled, the CancellableContinuation makes sure the waiting is - // stopped (the Gateway won't try to identify), so we don't need to hold the mutex and waste time for other - // calls return suspendCancellableCoroutine { continuation -> - val job = launchIdentifyWaiter(shardId, events, continuation) - continuation.invokeOnCancellation { job.cancel() } + val waiter = launchIdentifyWaiter(shardId, events, continuation) + // this will be invoked if the coroutine that called consume() is cancelled + continuation.invokeOnCancellation { cause -> + // stop the waiter, so we don't hold the mutex and waste time for other consume() calls (the Gateway + // won't try to identify if it was cancelled at this point) + waiter.cancel("Identify waiter was cancelled because consume() was cancelled", cause) + logger.debug(cause) { + "Identifying on shard $shardId with rate_limit_key ${shardId % maxConcurrency} was cancelled" + } + } } }