Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: implement microsoft key caching to combat rate limiting #109

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/common/Logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,32 @@ export const writeLogMessage = (event: APIGatewayTokenAuthorizerEvent, log: ILog
}
return log;
};

export enum LogLevel {
DEBUG = "DEBUG",
INFO = "INFO",
WARN = "WARN",
ERROR = "ERROR",
}

export const envLogger = (level: LogLevel, ...messages: string[]) => {
if (process.env.DEBUG === "true") {
switch (level) {
case LogLevel.DEBUG:
console.debug(messages);
break;
case LogLevel.INFO:
console.info(messages);
break;
case LogLevel.WARN:
console.warn(messages);
break;
case LogLevel.ERROR:
console.error(messages);
break;
default:
console.log(messages);
return;
}
}
};
14 changes: 13 additions & 1 deletion src/functions/authorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { generatePolicy as generateFunctionalPolicy } from "./functionalPolicyFa
import { getValidJwt } from "../services/tokens";
import { JWT_MESSAGE } from "../models/enums";
import { ILogEvent } from "../models/ILogEvent";
import { writeLogMessage } from "../common/Logger";
import { envLogger, LogLevel, writeLogMessage } from "../common/Logger";
import newPolicyDocument from "./newPolicyDocument";
import { Jwt, JwtPayload } from "jsonwebtoken";

Expand All @@ -20,25 +20,35 @@ import { Jwt, JwtPayload } from "jsonwebtoken";
export const authorizer = async (event: APIGatewayTokenAuthorizerEvent, context: Context): Promise<APIGatewayAuthorizerResult> => {
const logEvent: ILogEvent = {};

envLogger(LogLevel.DEBUG, "Invoked authoriser");

if (!process.env.AZURE_TENANT_ID || !process.env.AZURE_CLIENT_ID) {
writeLogMessage(event, logEvent, JWT_MESSAGE.INVALID_ID_SETUP);
return unauthorisedPolicy();
}

envLogger(LogLevel.DEBUG, "AZURE_TENANT_ID and AZURE_CLIENT_ID are set");

try {
initialiseLogEvent(event);

envLogger(LogLevel.INFO, "Getting valid JWT");
const jwt = await getValidJwt(event.authorizationToken, logEvent, process.env.AZURE_TENANT_ID, process.env.AZURE_CLIENT_ID);

envLogger(LogLevel.INFO, "Generating role policy");
const policy = generateRolePolicy(jwt, logEvent) ?? generateFunctionalPolicy(jwt, logEvent);

if (policy !== undefined) {
envLogger(LogLevel.INFO, "Role policy generated");
return policy;
}

reportNoValidRoles(jwt, event, context, logEvent);
writeLogMessage(event, logEvent, JWT_MESSAGE.INVALID_ROLES);

return unauthorisedPolicy();
} catch (error: any) {
envLogger(LogLevel.ERROR, "Catch - Error occurred", error);
writeLogMessage(event, logEvent, error);
return unauthorisedPolicy();
}
Expand Down Expand Up @@ -67,6 +77,8 @@ const reportNoValidRoles = (jwt: Jwt, event: APIGatewayTokenAuthorizerEvent, con
* @param event
*/
const initialiseLogEvent = (event: APIGatewayTokenAuthorizerEvent): ILogEvent => {
envLogger(LogLevel.DEBUG, "Init log event");

return {
requestUrl: event.methodArn,
timeOfRequest: new Date().toISOString(),
Expand Down
18 changes: 17 additions & 1 deletion src/services/azure.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import { KeyResponse } from "../models/KeyResponse";
import { envLogger, LogLevel } from "../common/Logger";

const cache: Map<string, Map<string, string>> = new Map();

export const getCertificateChain = async (tenantId: string, keyId: string): Promise<string> => {
const keys: Map<string, string> = await getKeys(tenantId);
const cacheKeys = cache.get(tenantId);

envLogger(LogLevel.DEBUG, `Cache ${cacheKeys ? "hit" : "not hit"}`);

const keys: Map<string, string> = cacheKeys ?? (await getKeys(tenantId));

envLogger(LogLevel.DEBUG, "Public keys read");

if (!cache.has(tenantId)) {
cache.set(tenantId, keys);
}

const certificateChain = keys.get(keyId);

Expand All @@ -25,9 +38,12 @@ const getKeys = async (tenantId: string): Promise<Map<string, string>> => {

map.set(keyId, certificateChain);
}

envLogger(LogLevel.DEBUG, "Key Map Created");
return map;
};

export const fetchKeys = (tenantId: string) => {
envLogger(LogLevel.DEBUG, `Fetching keys from https://login.microsoftonline.com/${tenantId}/discovery/keys`);
return fetch(`https://login.microsoftonline.com/${tenantId}/discovery/keys`);
};
3 changes: 3 additions & 0 deletions src/services/signature-check.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import * as JWT from "jsonwebtoken";
import { getCertificateChain } from "./azure";
import { envLogger, LogLevel } from "../common/Logger";

export const checkSignature = async (encodedToken: string, decodedToken: JWT.Jwt, tenantId: string, clientId: string): Promise<void> => {
// tid = tenant ID, kid = key ID
envLogger(LogLevel.DEBUG, "Getting cert chain");
const certificate = await getCertificateChain(tenantId, decodedToken.header.kid as string);

envLogger(LogLevel.INFO, "Verifying token");
JWT.verify(encodedToken, certificate, {
audience: clientId.split(","),
issuer: [`https://sts.windows.net/${tenantId}/`, `https://login.microsoftonline.com/${tenantId}/v2.0`],
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/services/azure.unitTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ describe("getCertificateChain()", () => {
it("should throw an error if no key matches the given key ID", async (): Promise<void> => {
fetchSpy("somethingElse", "mySuperSecurePublicKey");

await expect(azure.getCertificateChain("tenantId", "keyToTheKingdom")).rejects.toThrow("no public key");
await expect(azure.getCertificateChain("tenantId", "otherKeyToTheKingdom")).rejects.toThrow("no public key");
});
});
Loading