Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify IdentifyRateLimiterImpl #913

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 64 additions & 173 deletions gateway/src/commonMain/kotlin/ratelimit/IdentifyRateLimiter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@ package dev.kord.gateway.ratelimit

import dev.kord.gateway.*
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.loop
import kotlinx.atomicfu.update
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.onSuccess
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.selects.onTimeout
import kotlinx.coroutines.selects.select
import kotlin.jvm.JvmField
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onSubscription
import kotlinx.coroutines.sync.Mutex
import kotlin.coroutines.resume
import kotlin.time.Duration.Companion.seconds

/**
Expand Down Expand Up @@ -61,197 +56,93 @@ private class IdentifyRateLimiterImpl(
private val dispatcher: CoroutineDispatcher,
) : IdentifyRateLimiter {

private class IdentifyRequest(
@JvmField val shardId: Int,
@JvmField val events: SharedFlow<Event>,
private val permission: CompletableDeferred<Unit>,
) {
fun allow() = permission.complete(Unit)
}

private val IdentifyRequest.rateLimitKey get() = shardId % maxConcurrency

// doesn't need onUndeliveredElement for rejecting requests, we don't close or cancel the channel, receiving can't
// be cancelled (running in GlobalScope), only send can be cancelled, which is ok because permission isn't needed in
// that case
private val channel = Channel<IdentifyRequest>(capacity = maxConcurrency)

/**
* Can be
* - [NOT_RUNNING]: no coroutine ([launchRateLimiterCoroutine]) is running, or it is about to stop and will no
* longer process [IdentifyRequest]s
* - [RUNNING_WITH_NO_CONSUMERS]: the coroutine is running and ready to process [IdentifyRequest]s but there are no
* consumers that sent a request
* - unsigned number of consumers: number of concurrent [consume] invocations that wait for their [IdentifyRequest]
* to be processed (min [ONE_CONSUMER], max [MAX_CONSUMERS])
*/
private val state = atomic(initial = NOT_RUNNING)

private fun getOldStateAndIncrementConsumers() = state.getAndUpdate { current ->
when (current) {
NOT_RUNNING, RUNNING_WITH_NO_CONSUMERS -> ONE_CONSUMER // we are the first consumer
MAX_CONSUMERS -> error(
"Too many concurrent identify attempts, overflow happened. There are already ${current.toUInt()} " +
"other consume() invocations waiting. This is most likely a bug."
)
else -> current + 1 // increment number of consumers
}
}

private fun decrementConsumers() = state.update { current ->
when (current) {
NOT_RUNNING -> error("Should be running but was NOT_RUNNING")
RUNNING_WITH_NO_CONSUMERS -> error("Should have consumers but was RUNNING_WITH_NO_CONSUMERS")
ONE_CONSUMER -> RUNNING_WITH_NO_CONSUMERS // we were the last consumer
else -> current - 1 // decrement number of consumers
}
}

private fun stopIfHasNoConsumers(): Boolean = state.loop { current ->
when (current) {
NOT_RUNNING -> error("Should be running but was NOT_RUNNING")
RUNNING_WITH_NO_CONSUMERS -> // no new requests in sight -> try to stop
if (state.compareAndSet(expect = current, update = NOT_RUNNING)) return true
else -> return false // don't change number of consumers
// These mutexes are fair (according to the documentation of Mutex and its factory function), we rely on this to
// guarantee a fair rate limiter behavior.
// rateLimitKey is always in the range 0..<maxConcurrency, so size has to be maxConcurrency (one per rateLimitKey)
private val mutexesByRateLimitKey = Array(size = maxConcurrency) { Mutex() }

// scope.cancel() is never called, but that's ok: all coroutines that are launched in this scope complete after
// waiting for the previous one trying to lock the same mutex and a fixed delay of at most DELAY_AFTER_IDENTIFY +
// IDENTIFY_TIMEOUT -> no resource cleanup needed
// SupervisorJob is used to ensure an unexpected failure in one coroutine doesn't leave the rate limiter unusable
private val scope = CoroutineScope(
context = SupervisorJob() + dispatcher + CoroutineExceptionHandler { context, exception ->
// we can't be cancelled, so all exceptions are bugs
logger.error(exception) {
"IdentifyRateLimiter threw an exception in context $context, please report this, it should not happen"
}
}
}

)

override suspend fun consume(shardId: Int, events: SharedFlow<Event>) {
require(shardId >= 0) { "shardId must be non-negative but was $shardId" }

val oldState = getOldStateAndIncrementConsumers()
if (oldState == NOT_RUNNING) launchRateLimiterCoroutine() // if this throws we are screwed anyway

// this might throw because of cancellation of the coroutine that called consume(), which is ok
try {
val permission = CompletableDeferred<Unit>()
channel.send(IdentifyRequest(shardId, events, permission))
permission.await()
} finally {
decrementConsumers()
}
}


private fun launchRateLimiterCoroutine() {
// GlobalScope is ok here:
// - only one coroutine is launched at a time:
// previously running one will allow next one to start by setting state to NOT_RUNNING just before exiting
// - the coroutine will time out eventually when no identify requests are sent, a new one will be launched on
// demand which will then time out again etc.
// => no leaks
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.launch(context = dispatcher + ExceptionLogger) {

// only read/written from sequential loop, not from launched concurrent coroutines
// request.rateLimitKey is always in the range 0..<maxConcurrency
val identifyWaiters = arrayOfNulls<Job>(size = maxConcurrency)

while (true) {
val batch = receiveSortedBatchOfRequestsOrNull() ?: when {
identifyWaiters.any { it != null && !it.isCompleted } -> continue // keep receiving while waiting
stopIfHasNoConsumers() -> return@launch // no consumers and no waiters -> stop (only exit point)
else -> continue // has consumers -> there will be new requests soon
}

for (request in batch) {
val key = request.rateLimitKey
val previousWaiter = identifyWaiters[key]
identifyWaiters[key] = launchIdentifyWaiter(previousWaiter, request)
return suspendCancellableCoroutine { continuation ->
val waiter = launchIdentifyWaiter(shardId, events, continuation)
// this will be invoked if the coroutine that called consume() is cancelled
continuation.invokeOnCancellation { cause ->
// stop the waiter, so we don't hold the mutex and waste time for other consume() calls (the Gateway
// won't try to identify if it was cancelled at this point)
waiter.cancel("Identify waiter was cancelled because consume() was cancelled", cause)
logger.debug(cause) {
"Identifying on shard $shardId with rate_limit_key ${shardId % maxConcurrency} was cancelled"
}
}
}
}

/** Returns `null` on [RECEIVE_TIMEOUT]. */
private suspend fun receiveSortedBatchOfRequestsOrNull(): List<IdentifyRequest>? {

// first receive suspends until we get a request or time out
// using select instead of withTimeoutOrNull here, so we can't remove the request from the channel on timeout
val firstRequest = select<IdentifyRequest?> {
channel.onReceive { it }
@OptIn(ExperimentalCoroutinesApi::class) onTimeout(RECEIVE_TIMEOUT) { null }
} ?: return null

yield() // give other requests the chance to arrive if they were sent at the same time

return buildList {
add(firstRequest)

do { // now we receive other requests that are immediately available
val result = channel.tryReceive().onSuccess(::add)
} while (result.isSuccess)

sortWith(ShardIdComparator) // sort requests in this batch
}
}

private fun CoroutineScope.launchIdentifyWaiter(previousWaiter: Job?, request: IdentifyRequest) = launch {
if (previousWaiter != null) {
private fun launchIdentifyWaiter(
shardId: Int,
events: SharedFlow<Event>,
continuation: CancellableContinuation<Unit>,
) = scope.launch {
val rateLimitKey = shardId % maxConcurrency
val mutex = mutexesByRateLimitKey[rateLimitKey]
val wasLocked = !mutex.tryLock()
if (wasLocked) {
logger.debug {
"Waiting for other shard(s) with rate_limit_key ${request.rateLimitKey} to identify " +
"before identifying on shard ${request.shardId}"
"Waiting for other shard(s) with rate_limit_key $rateLimitKey to identify before identifying on " +
"shard $shardId"
}
previousWaiter.join()
mutex.lock()
}

// using a timeout so a broken gateway won't block its rate_limit_key for a long time
val responseToIdentify = withTimeoutOrNull(IDENTIFY_TIMEOUT) {
request.events
.onSubscription { // onSubscription ensures we don't miss events
try { // in case something terrible happens, ensure the mutex is unlocked
// using a timeout so a broken gateway won't block its rate_limit_key for a long time
val responseToIdentify = withTimeoutOrNull(IDENTIFY_TIMEOUT) {
events.onSubscription { // onSubscription ensures we don't miss events
logger.debug {
"${
if (previousWaiter != null)
"Waited for other shard(s) with rate_limit_key ${request.rateLimitKey} to identify, i"
if (wasLocked) "Waited for other shard(s) with rate_limit_key $rateLimitKey to identify, i"
else "I"
}dentifying on shard ${request.shardId} with rate_limit_key ${request.rateLimitKey}..."
}dentifying on shard $shardId with rate_limit_key $rateLimitKey..."
}
request.allow() // notify gateway waiting in consume -> it will try to identify -> wait for event
}
.first { it is Ready || it is InvalidSession || it is Close }
}

logger.debug {
when (responseToIdentify) {
null -> "Identifying on shard ${request.shardId} timed out"
is Ready -> "Identified on shard ${request.shardId}"
is InvalidSession -> "Identifying on shard ${request.shardId} failed, session could not be initialized"
is Close -> "Shard ${request.shardId} was stopped before it could identify"
else -> "Unexpected responseToIdentify on shard ${request.shardId}: $responseToIdentify"
} + ", delaying $DELAY_AFTER_IDENTIFY before freeing up rate_limit_key ${request.rateLimitKey}"
// notify gateway waiting in consume -> it will try to identify -> wait for event
continuation.resume(Unit)
}.first { it is Ready || it is InvalidSession || it is Close }
}
logger.debug {
when (responseToIdentify) {
null -> "Identifying on shard $shardId timed out"
is Ready -> "Identified on shard $shardId"
is InvalidSession -> "Identifying on shard $shardId failed, session could not be initialized"
is Close -> "Shard $shardId was stopped before it could identify"
else -> "Unexpected responseToIdentify on shard $shardId: $responseToIdentify"
} + ", delaying $DELAY_AFTER_IDENTIFY before freeing up rate_limit_key $rateLimitKey"
}
delay(DELAY_AFTER_IDENTIFY) // delay before unlocking mutex to free up the current rateLimitKey
} finally {
mutex.unlock()
}

// next waiter for the current rate_limit_key has to wait for this delay before it can identify
delay(DELAY_AFTER_IDENTIFY)
}


override fun toString() = "IdentifyRateLimiter(maxConcurrency=$maxConcurrency, dispatcher=$dispatcher)"


private companion object {
// https://discord.com/developers/docs/topics/gateway#rate-limiting:
// Apps also have a limit for concurrent Identify requests allowed per 5 seconds.
// -> for each rate_limit_key: delay after identify
private val DELAY_AFTER_IDENTIFY = 5.seconds

private val RECEIVE_TIMEOUT = DELAY_AFTER_IDENTIFY * 2
private val IDENTIFY_TIMEOUT = DELAY_AFTER_IDENTIFY / 2

// states
private const val MAX_CONSUMERS = -2 // interpreted as UInt.MAX_VALUE - 1u
private const val NOT_RUNNING = -1
private const val RUNNING_WITH_NO_CONSUMERS = 0
private const val ONE_CONSUMER = 1

private val ShardIdComparator = Comparator<IdentifyRequest> { r1, r2 -> r1.shardId.compareTo(r2.shardId) }

private val ExceptionLogger = CoroutineExceptionHandler { context, exception ->
// we can't be cancelled (GlobalScope) and we never close the channel, so all exceptions are bugs
logger.error(exception) {
"IdentifyRateLimiter threw an exception in context $context, please report this, it should not happen"
}
}
}
}