Skip to content

Commit

Permalink
refactor: switched to a more advanced rate-limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
alireza-sheikholmolouki committed Sep 20, 2024
1 parent 01223ed commit 0051835
Show file tree
Hide file tree
Showing 45 changed files with 235 additions and 181 deletions.
3 changes: 2 additions & 1 deletion src/aws-app/utils/createRateLimiterFactory.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {RateLimiter, createRateLimiter} from '@/scanners'
import {createRateLimiter} from '@/scanners'
import {GetRateLimiterFunction} from '@/scanners/types'
import {RateLimiter} from '@/types'

/**
* Returns the getRateLimiter function for a given service and region
Expand Down
28 changes: 28 additions & 0 deletions src/scanners/common/AWSRateLimitExhaustionStrategy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import {RetryOptions, RetryStrategy} from './RetryStrategy'

export class AWSRateLimitExhaustionStrategy extends RetryStrategy {
constructor(options: Partial<RetryOptions> = {}) {
super({
maxRetries: 5,
initialDelay: 1000,
maxDelay: 20000,
backoffFactor: 2,
...options,
})
}

shouldRetry(error: any): boolean {
// AWS-specific error codes for rate limit exhaustion
const rateLimitErrorCodes = [
'ThrottlingException',
'TooManyRequestsException',
'RequestLimitExceeded',
'Throttling',
'RequestThrottled',
'RequestThrottledException',
'SlowDown',
]

return rateLimitErrorCodes.includes(error.code) || error.statusCode === 429
}
}
171 changes: 77 additions & 94 deletions src/scanners/common/RateLimiter.ts
Original file line number Diff line number Diff line change
@@ -1,120 +1,103 @@
const BASE_RETRY_DELAY = 1000 // 1 second
const MAX_RETRIES = 6 // last attempt will take 64 seconds (2^6 = 64 * base retry delay) and overall will take 127 seconds for all retries to complete, before giving up
const CAPACITY_USAGE_PERCENTAGE = 0.6
import {RateLimiter} from '@/types'
import {RetryStrategy} from './RetryStrategy'

const ONE_SECOND = 1000
export class RateLimiterImpl implements RateLimiter {
private executionTimesMs: number[] = []
private pendingQueue: (() => void)[] = []
private timeoutId: NodeJS.Timeout | null = null
private _isPaused = false
private readonly intervalMs: number
private retryStrategy: RetryStrategy | null

export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))

export interface RateLimiter {
throttle<U>(fn: () => Promise<U>): Promise<U>
}

class RateLimiterImpl implements RateLimiter {
private allowance: number
private lastCheck: number
private maxUsage: number
private queue: (() => void)[]

constructor(rate: number) {
this.maxUsage = rate * CAPACITY_USAGE_PERCENTAGE
this.allowance = this.maxUsage
this.lastCheck = Date.now()
this.queue = []
constructor(
private _rate: number,
retryStrategy: RetryStrategy | null = null,
) {
if (!Number.isFinite(_rate) || _rate <= 0) {
throw new TypeError('Expected `rate` to be a positive finite number')
}
this.intervalMs = 1000 / _rate
this.retryStrategy = retryStrategy
}

// Acquire permission to proceed with a function call
private async acquire() {
const current = Date.now()
const timePassed = current - this.lastCheck
this.lastCheck = current
this.allowance += (timePassed / ONE_SECOND) * this.maxUsage
private scheduleNextExecution(): void {
if (this.timeoutId) {
clearTimeout(this.timeoutId)
}

if (this.allowance > this.maxUsage) {
this.allowance = this.maxUsage
if (this._isPaused || this.pendingQueue.length === 0) {
return
}

// If allowance is insufficient, wait for more capacity
if (this.allowance < 1) {
const waitTime = (1 - this.allowance) * (ONE_SECOND / this.maxUsage)
await sleep(waitTime)
this.allowance = 0
} else {
this.allowance -= 1
const now = Date.now()
let delay = 0

if (this.executionTimesMs.length > 0) {
const nextExecutionTime = this.executionTimesMs[this.executionTimesMs.length - 1] + this.intervalMs
delay = Math.max(0, nextExecutionTime - now)
}

this.timeoutId = setTimeout(() => {
const fn = this.pendingQueue.shift()
if (fn) {
this.executionTimesMs.push(Date.now())
if (this.executionTimesMs.length > 10) {
this.executionTimesMs.shift()
}
fn()
}
this.scheduleNextExecution()
}, delay)
}

// Throttle function execution
async throttle<U>(fn: () => Promise<U>): Promise<U> {
throttle<U>(fn: () => Promise<U>): Promise<U> {
return new Promise<U>((resolve, reject) => {
const execute = async () => {
try {
await this.acquire()
resolve(await fn())
} catch (error: any) {
if (this.isRequestLimitError(error)) {
this.retry(fn, resolve, reject, 1)
} else {
reject(error)
}
const wrappedFn = () => {
if (this.retryStrategy) {
this.retryStrategy.retry(fn).then(resolve).catch(reject)
} else {
fn().then(resolve).catch(reject)
}
}

this.queue.push(execute)

// If it's the only function in the queue, it means that we just added it. So we should start the process.
// But, if it's more in the queue, it means that we are already processing the queue, and the function will be processed when it's its turn.
if (this.queue.length === 1) {
this.dequeue()
}
this.pendingQueue.push(wrappedFn)
this.scheduleNextExecution()
})
}

// Retry function with exponential backoff
private async retry<U>(
fn: () => Promise<U>,
resolve: (value: U | PromiseLike<U>) => void,
reject: (reason?: any) => void,
attempt: number,
) {
if (attempt < MAX_RETRIES) {
// Exponential backoff
// https://en.wikipedia.org/wiki/Exponential_backoff
await sleep(BASE_RETRY_DELAY * Math.pow(2, attempt))
try {
// acquire permission to proceed and add enough sleep time
await this.acquire()
resolve(await fn())
} catch (error: any) {
if (this.isRequestLimitError(error)) {
this.retry(fn, resolve, reject, attempt + 1)
} else {
reject(error)
}
}
} else {
reject(new Error('Max retries reached'))
pause(): void {
this._isPaused = true
if (this.timeoutId) {
clearTimeout(this.timeoutId)
this.timeoutId = null
}
}

// Check if error is a request limit error
private isRequestLimitError(error: Error) {
return [
'TooManyRequestsException',
'ThrottlingException',
'ProvisionedThroughputExceededException',
'RequestLimitExceeded',
].includes(error.name)
resume(): void {
this._isPaused = false
this.scheduleNextExecution()
}

// Process the queue of function calls
// @idea Maybe we can add some lifecycle hooks here in the future?
private async dequeue() {
while (this.queue.length > 0) {
const fn = this.queue.shift()
if (fn) await fn()
abort(): void {
this.pendingQueue = []
this.executionTimesMs = []
if (this.timeoutId) {
clearTimeout(this.timeoutId)
this.timeoutId = null
}
}

get queueSize(): number {
return this.pendingQueue.length
}

get isPaused(): boolean {
return this._isPaused
}

get rate(): number {
return this._rate
}
}

export const createRateLimiter = (rate: number): RateLimiter => new RateLimiterImpl(rate)
export const createRateLimiter = (rate: number) => new RateLimiterImpl(rate)
28 changes: 28 additions & 0 deletions src/scanners/common/RetryStrategy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export interface RetryOptions {
maxRetries: number
initialDelay: number
maxDelay: number
backoffFactor: number
}

export abstract class RetryStrategy {
constructor(protected options: RetryOptions) {}

abstract shouldRetry(error: any): boolean

async retry<T>(fn: () => Promise<T>, retryCount: number = 0): Promise<T> {
try {
return await fn()
} catch (error) {
if (retryCount < this.options.maxRetries && this.shouldRetry(error)) {
const delay = Math.min(
this.options.initialDelay * Math.pow(this.options.backoffFactor, retryCount),
this.options.maxDelay,
)
await new Promise((resolve) => setTimeout(resolve, delay))
return this.retry(fn, retryCount + 1)
}
throw error
}
}
}
10 changes: 8 additions & 2 deletions src/scanners/common/createGlobalScanner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import {Resources, ResourceDescription, GlobalScanFunction, Credentials, ScannerLifecycleHook} from '@/types'
import {
Resources,
ResourceDescription,
GlobalScanFunction,
Credentials,
ScannerLifecycleHook,
RateLimiter,
} from '@/types'
import {CreateGlobalScannerFunction, GetRateLimiterFunction} from '@/scanners/types'
import {RateLimiter} from './RateLimiter'

type GlobalScanResult<T extends ResourceDescription> = {
resources: Resources<T>
Expand Down
2 changes: 1 addition & 1 deletion src/scanners/common/createRegionalScanner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import {
RegionalScanFunction,
Credentials,
ScannerLifecycleHook,
RateLimiter,
} from '@/types'
import {CreateRegionalScannerFunction, GetRateLimiterFunction} from '@/scanners/types'
import {RateLimiter} from './RateLimiter'

type RegionScanResult<T extends ResourceDescription> = {
region: string
Expand Down
2 changes: 1 addition & 1 deletion src/scanners/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export {RateLimiter, createRateLimiter} from './common/RateLimiter'
export {createRateLimiter} from './common/RateLimiter'
export {createGlobalScanner} from './common/createGlobalScanner'
export {createRegionalScanner} from './common/createRegionalScanner'
export {getAwsScanners} from './getAwsScanners'
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/_boilerplate.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {DBInstance} from '@aws-sdk/client-rds'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'

/**
* 0️⃣ I put DBInstance just as an example of a real aws type.
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/athena-named-queries.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {AthenaClient, ListNamedQueriesCommand, GetNamedQueryCommand, NamedQuery} from '@aws-sdk/client-athena'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'
import {buildARN} from './common/buildArn'
import {getAwsAccountId} from './common/getAwsAccountId'

Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/autoscaling-groups.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {AutoScalingClient, DescribeAutoScalingGroupsCommand, AutoScalingGroup} from '@aws-sdk/client-auto-scaling'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'

async function getAutoScalingGroups(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/cloudfront-distributions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {CloudFrontClient, ListDistributionsCommand, DistributionSummary} from '@aws-sdk/client-cloudfront'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'

export async function getCloudFrontDistributions(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/cloudtrail-trails.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {CloudTrailClient, ListTrailsCommand, ListTrailsCommandOutput, TrailInfo} from '@aws-sdk/client-cloudtrail'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'

export async function getCloudTrailTrails(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/cloudwatch-metric-alarms.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {CloudWatchClient, DescribeAlarmsCommand, MetricAlarm} from '@aws-sdk/client-cloudwatch'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'

export async function getCloudWatchMetricAlarms(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/cloudwatch-metric-streams.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {CloudWatchClient, ListMetricStreamsCommand, MetricStreamEntry} from '@aws-sdk/client-cloudwatch'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'

export async function getCloudWatchMetricStreams(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/dynamodb-tables.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {DynamoDBClient, ListTablesCommand, DescribeTableCommand, TableDescription} from '@aws-sdk/client-dynamodb'
import {Credentials, Resources} from '@/types'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {Credentials, Resources, RateLimiter} from '@/types'

export async function getDynamoDBTables(
credentials: Credentials,
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/ec2-instances.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {EC2Client, DescribeInstancesCommand, Instance} from '@aws-sdk/client-ec2'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {buildARN} from './common/buildArn'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'
import {getAwsAccountId} from './common/getAwsAccountId'

export async function getEC2Instances(
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/ec2-internet-gateways.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {EC2Client, DescribeInternetGatewaysCommand, InternetGateway} from '@aws-sdk/client-ec2'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {buildARN} from './common/buildArn'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'
import {getAwsAccountId} from './common/getAwsAccountId'

export async function getEC2InternetGateways(
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/ec2-nat-gateways.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {EC2Client, DescribeNatGatewaysCommand, NatGateway} from '@aws-sdk/client-ec2'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {buildARN} from './common/buildArn'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'
import {getAwsAccountId} from './common/getAwsAccountId'

export async function getEC2NatGateways(
Expand Down
3 changes: 1 addition & 2 deletions src/scanners/scan-functions/aws/ec2-network-acls.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {EC2Client, DescribeNetworkAclsCommand, NetworkAcl} from '@aws-sdk/client-ec2'
import {RateLimiter} from '@/scanners/common/RateLimiter'
import {buildARN} from './common/buildArn'
import {Credentials, Resources} from '@/types'
import {Credentials, Resources, RateLimiter} from '@/types'
import {getAwsAccountId} from './common/getAwsAccountId'

export async function getEC2NetworkAcls(
Expand Down
Loading

0 comments on commit 0051835

Please sign in to comment.