-
-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'shadab14meb346-feat/subscription-based-rate-limit-custo…
…mization'
- Loading branch information
Showing
13 changed files
with
244 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
### Rate Limits | ||
|
||
Revert will return a rate limit error with the HTTP status code `429` when the request rate limit for an application or an individual IP address has been exceeded. | ||
Revert will return a rate limit error with the HTTP status code `429` when the request rate limit for a tenant has been exceeded. | ||
|
||
As a default, 4 requests per minute are allowed from a single IP address but if you're needs are higher please get in touch and we can tailor it to your use-case. | ||
As a default, 100 requests per minute per connection are allowed for a single tenant but if you're needs are higher please get in touch and we can tailor it to your use-case. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import { Request, Response } from 'express'; | ||
import { RateLimiterRedis, IRateLimiterStoreOptions } from 'rate-limiter-flexible'; | ||
|
||
import redis from '../redis/client'; | ||
import { skipRateLimitRoutes } from './utils'; | ||
import config from '../config'; | ||
|
||
// In Memory Cache for storing RateLimiterRedis instances to prevent the creation of a new instance for each request. | ||
// The cache key is the 'rate_limit' value derived from the subscriptions table. Currently, we cache by the 'rate_limit' value for simplicity. | ||
// Using 'subscriptionId' as a key would be more precise but would add complexity in keeping the cache in sync with the database. | ||
const rateLimiters = new Map<number, RateLimiterRedis>(); | ||
|
||
//We can make this dynamic based on the subscription as well | ||
const RATE_LIMIT_DURATION_IN_MINUTES = 1; | ||
const FALL_BACK_DEFAULT_RATE_LIMIT = 100; | ||
|
||
const getRateLimiter = (rateLimit: number): RateLimiterRedis => { | ||
if (!rateLimiters.has(rateLimit)) { | ||
const opts: IRateLimiterStoreOptions = { | ||
storeClient: redis, | ||
points: rateLimit, // Points represent the maximum number of requests allowed within the set duration. | ||
duration: RATE_LIMIT_DURATION_IN_MINUTES * 60, // Converts minutes to seconds for the duration. | ||
}; | ||
rateLimiters.set(rateLimit, new RateLimiterRedis(opts)); | ||
} | ||
return rateLimiters.get(rateLimit)!; | ||
}; | ||
|
||
const rateLimitMiddleware = () => async (req: Request, res: Response, next: Function) => { | ||
if (skipRateLimitRoutes(req)) next(); | ||
try { | ||
const { 'x-revert-t-id': tenantId } = req.headers; | ||
const { subscription, id: accountId } = res.locals.account; // Subscription details are retrieved from response locals set earlier in the revertAuthMiddleware. | ||
const rateLimit = | ||
subscription?.rate_limit ?? config.DEFAULT_RATE_LIMIT_DEVELOPER_PLAN ?? FALL_BACK_DEFAULT_RATE_LIMIT; //incase subscription undefined, we will use the default rate limit this is to make sure backward compatibility as currently some accounts might not have subscription attached to them. We can remove the optional chaining and nullish coalescing once we are sure that all accounts have subscription attached to them. In case of DEFAULT_RATE_LIMIT_DEVELOPER_PLAN is missing in config, we will fallback to 1 | ||
const rateLimiter = getRateLimiter(rateLimit); | ||
//added accountId to make the key unique | ||
const uniqueKey = `${accountId}-${tenantId}`; | ||
// rate limit is per tenantId | ||
await rateLimiter.consume(uniqueKey, 1); // consume 1 point for each request | ||
next(); | ||
} catch (rejRes) { | ||
//currently not handling the redis failure here explicitly as it is getting handled in the redis client. If required in future we can handle it here so service does't go down due to rate limit middleware failure which might occur due to redis failure | ||
if (rejRes instanceof Error) { | ||
throw rejRes; | ||
} else { | ||
res.status(429).send('Rate limit reached.'); | ||
} | ||
} | ||
}; | ||
|
||
export default rateLimitMiddleware; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import { Request } from 'express'; | ||
|
||
const skipRateLimitRoutes = (req: Request) => { | ||
const nonSecurePaths = ['/oauth-callback', '/oauth/refresh']; | ||
const nonSecurePathsPartialMatch = ['/integration-status', '/trello-request-token']; | ||
const allowedRoutes = ['/health-check']; | ||
if ( | ||
nonSecurePaths.includes(req.path) || | ||
nonSecurePathsPartialMatch.some((path) => req.path.includes(path)) || | ||
allowedRoutes.includes(req.baseUrl + req.path) | ||
) { | ||
return true; | ||
} | ||
return false; | ||
}; | ||
|
||
export { skipRateLimitRoutes }; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
packages/backend/prisma/migrations/20240508161507_create_subscriptions_table/migration.sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
-- AlterTable | ||
ALTER TABLE "accounts" ADD COLUMN "subscriptionId" TEXT; | ||
|
||
-- CreateTable | ||
CREATE TABLE "subscriptions" ( | ||
"id" TEXT NOT NULL, | ||
"name" TEXT NOT NULL, | ||
"rate_limit" INTEGER NOT NULL, | ||
"rate_limit_duration_in_minutes" INTEGER NOT NULL, | ||
"price" DOUBLE PRECISION NOT NULL, | ||
"features" JSONB, | ||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
|
||
CONSTRAINT "subscriptions_pkey" PRIMARY KEY ("id") | ||
); | ||
|
||
-- CreateIndex | ||
CREATE UNIQUE INDEX "subscriptions_name_key" ON "subscriptions"("name"); | ||
|
||
-- AddForeignKey | ||
ALTER TABLE "accounts" ADD CONSTRAINT "accounts_subscriptionId_fkey" FOREIGN KEY ("subscriptionId") REFERENCES "subscriptions"("id") ON DELETE SET NULL ON UPDATE CASCADE; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import { Response } from 'express'; | ||
import rateLimitMiddleware from '../helpers/rateLimitMiddleware'; | ||
import { skipRateLimitRoutes } from '../helpers/utils'; | ||
|
||
jest.mock('../redis/client', () => { | ||
return { | ||
createClient: jest.fn().mockReturnThis(), | ||
on: jest.fn(), | ||
connect: jest.fn(), | ||
}; | ||
}); | ||
jest.mock('../helpers/utils', () => { | ||
return { | ||
skipRateLimitRoutes: jest.fn(), | ||
}; | ||
}); | ||
jest.mock('rate-limiter-flexible', () => { | ||
return { | ||
RateLimiterRedis: jest.fn().mockImplementation(() => ({ | ||
consume: jest.fn(), | ||
})), | ||
}; | ||
}); | ||
//TODO: some more extensive test case are required | ||
describe('Rate Limit Middleware', () => { | ||
const mockRequest = (options = {}) => ({ | ||
headers: { 'x-revert-t-id': 'tenant123' }, | ||
...options, | ||
}); | ||
const mockResponse = () => { | ||
const res = { | ||
locals: { account: { subscription: { rate_limit: 100 }, id: 'account123' } }, | ||
} as unknown as Response; | ||
res.status = jest.fn().mockReturnValue(res); | ||
res.send = jest.fn().mockReturnValue(res); | ||
return res; | ||
}; | ||
const nextFunction = jest.fn(); | ||
|
||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
}); | ||
|
||
it('should call next if route should be skipped', async () => { | ||
//@ts-ignore | ||
skipRateLimitRoutes.mockReturnValue(true); | ||
const req = mockRequest(); | ||
const res = mockResponse(); | ||
|
||
//@ts-ignore | ||
await rateLimitMiddleware()(req, res, nextFunction); | ||
|
||
expect(skipRateLimitRoutes).toHaveBeenCalled(); | ||
expect(nextFunction).toHaveBeenCalled(); | ||
}); | ||
}); |
Oops, something went wrong.