Skip to content

Commit

Permalink
use conditional dynamic import for aws-sdk (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
MCarlomagno authored May 1, 2024
1 parent cd7d7cf commit fee8405
Show file tree
Hide file tree
Showing 9 changed files with 967 additions and 28 deletions.
1 change: 1 addition & 0 deletions packages/base/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"author": "OpenZeppelin Defender <[email protected]>",
"license": "MIT",
"devDependencies": {
"@aws-sdk/client-lambda": "^3.563.0",
"@types/async-retry": "^1.4.4",
"aws-sdk": "^2.1366.0"
},
Expand Down
49 changes: 28 additions & 21 deletions packages/base/src/autotask/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import Lambda, { _Blob } from 'aws-sdk/clients/lambda';
import { rateLimitModule, RateLimitModule } from '../utils/rate-limit';
import { getTimestampInSeconds } from '../utils/time';
import { getLambdaFromCredentials, isLambdaV3, LambdaLike, PayloadResponseV2, PayloadResponseV3 } from '../utils/lambda';

// do our best to get .errorMessage, but return object by default
function cleanError(payload?: _Blob): _Blob {
function cleanError(payload?: PayloadResponseV2 | PayloadResponseV3): PayloadResponseV2 | PayloadResponseV3 {
if (!payload) {
return 'Error occurred, but error payload was not defined';
}
Expand All @@ -17,7 +18,7 @@ function cleanError(payload?: _Blob): _Blob {
}

export abstract class BaseAutotaskClient {
private lambda: Lambda;
private lambda: LambdaLike;

private invocationRateLimit: RateLimitModule;

Expand All @@ -26,17 +27,25 @@ export abstract class BaseAutotaskClient {

this.invocationRateLimit = rateLimitModule.createCounterFor(arn, 300);

this.lambda = new Lambda(
creds
? {
credentials: {
accessKeyId: creds.AccessKeyId,
secretAccessKey: creds.SecretAccessKey,
sessionToken: creds.SessionToken,
},
}
: undefined,
);
this.lambda = getLambdaFromCredentials(credentials);
}

private async invoke(FunctionName: string, Payload: string) {
if (isLambdaV3(this.lambda)) {
return this.lambda.invoke({
FunctionName,
Payload,
InvocationType: 'RequestResponse',
});
} else {
return this.lambda
.invoke({
FunctionName,
Payload,
InvocationType: 'RequestResponse',
})
.promise();
}
}

// eslint-disable-next-line @typescript-eslint/ban-types
Expand All @@ -46,18 +55,16 @@ export abstract class BaseAutotaskClient {
this.invocationRateLimit.checkRateFor(invocationTimeStamp);
this.invocationRateLimit.incrementRateFor(invocationTimeStamp);

const invocationRequestResult = await this.lambda
.invoke({
FunctionName: this.arn,
Payload: JSON.stringify(request),
InvocationType: 'RequestResponse',
})
.promise();
const invocationRequestResult = await this.invoke(this.arn, JSON.stringify(request));

if (invocationRequestResult.FunctionError) {
throw new Error(`Error while attempting request: ${cleanError(invocationRequestResult.Payload)}`);
}

return JSON.parse(invocationRequestResult.Payload as string) as T;
return JSON.parse(
isLambdaV3(this.lambda)
? (invocationRequestResult.Payload as PayloadResponseV3).transformToString()
: (invocationRequestResult.Payload as string),
) as T;
}
}
78 changes: 78 additions & 0 deletions packages/base/src/utils/lambda.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { version } from 'node:process';

const NODE_MIN_VERSION_FOR_V3 = 18;

export type InvokeResponse = {
FunctionError?: string;
Payload: PayloadResponseV2 | PayloadResponseV3;
};

export type InvokeResponseV2 = {
promise: () => Promise<InvokeResponse>;
};

export type PayloadResponseV3 = {
transformToString: () => string;
};

export type PayloadResponseV2 = string | Buffer | Uint8Array | Blob;

export type LambdaV2 = {
invoke: (params: { FunctionName: string; Payload: string; InvocationType: string }) => InvokeResponseV2;
};

export type LambdaV3 = {
invoke: (params: { FunctionName: string; Payload: string; InvocationType: string }) => Promise<InvokeResponse>;
};

export type LambdaLike = LambdaV2 | LambdaV3;

export type LambdaCredentials = {
AccessKeyId: string;
SecretAccessKey: string;
SessionToken: string;
};

function isLambdaV3Compatible(): boolean {
// example version: v14.17.0
const majorVersion = version.slice(1).split('.')[0];
if (!majorVersion) return false;
return parseInt(majorVersion, 10) >= NODE_MIN_VERSION_FOR_V3;
}

export function isLambdaV3(lambda: LambdaLike): lambda is LambdaV3 {
return isLambdaV3Compatible();
}

export function isV3ResponsePayload(payload: PayloadResponseV2 | PayloadResponseV3): payload is PayloadResponseV3 {
return (payload as PayloadResponseV3).transformToString !== undefined;
}

export function getLambdaFromCredentials(credentials: string): LambdaLike {
const creds: LambdaCredentials = credentials ? JSON.parse(credentials) : undefined;
if (isLambdaV3Compatible()) {
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { Lambda } = require('@aws-sdk/client-lambda');
return new Lambda({
credentials: {
accessKeyId: creds.AccessKeyId,
secretAccessKey: creds.SecretAccessKey,
sessionToken: creds.SessionToken,
},
});
} else {
// eslint-disable-next-line @typescript-eslint/no-var-requires
const Lambda = require('aws-sdk/clients/lambda');
return new Lambda(
creds
? {
credentials: {
accessKeyId: creds.AccessKeyId,
secretAccessKey: creds.SecretAccessKey,
sessionToken: creds.SessionToken,
},
}
: undefined,
);
}
}
11 changes: 11 additions & 0 deletions packages/kvstore/__mocks__/@aws-sdk/client-lambda.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
const Lambda = jest.fn(() => ({
invoke: jest.fn(() =>
Promise.resolve({
Payload: {
transformToString: () => JSON.stringify({ result: 'result' }),
},
}),
),
}));

export { Lambda };
20 changes: 13 additions & 7 deletions packages/kvstore/src/autotask.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import { KeyValueStoreAutotaskClient } from './autotask';
import Lambda from 'aws-sdk/clients/lambda';
import Lambda from '../__mocks__/aws-sdk/clients/lambda';
import { Lambda as LambdaV3 } from '../__mocks__/@aws-sdk/client-lambda';
jest.mock('node:process', () => ({
...jest.requireActual('node:process'),
version: 'v16.0.3',
}));

jest.mock('aws-sdk/clients/lambda', () => require('../__mocks__/aws-sdk/clients/lambda'));

type TestClient = Omit<KeyValueStoreAutotaskClient, 'lambda'> & { lambda: Lambda };
type TestClient = Omit<KeyValueStoreAutotaskClient, 'lambda'> & { lambda: typeof Lambda };

describe('KeyValueStoreAutotaskClient', () => {
const credentials = {
Expand All @@ -15,6 +19,8 @@ describe('KeyValueStoreAutotaskClient', () => {
let client: TestClient;

beforeEach(async function () {
jest.mock('aws-sdk/clients/lambda', () => Lambda);
jest.mock('@aws-sdk/client-lambda', () => ({ Lambda: LambdaV3 }));
client = new KeyValueStoreAutotaskClient({
credentials: JSON.stringify(credentials),
kvstoreARN: 'arn',
Expand All @@ -23,14 +29,14 @@ describe('KeyValueStoreAutotaskClient', () => {

describe('get', () => {
test('calls kvstore function', async () => {
(client.lambda.invoke as jest.Mock).mockImplementationOnce(() => ({
((client.lambda as any).invoke as jest.Mock).mockImplementationOnce(() => ({
promise: () => Promise.resolve({ Payload: JSON.stringify('myvalue') }),
}));

const result = await client.get('mykey');

expect(result).toEqual('myvalue');
expect(client.lambda.invoke).toBeCalledWith({
expect((client.lambda as any).invoke).toBeCalledWith({
FunctionName: 'arn',
InvocationType: 'RequestResponse',
Payload: '{"action":"get","key":"mykey"}',
Expand All @@ -41,7 +47,7 @@ describe('KeyValueStoreAutotaskClient', () => {
describe('del', () => {
test('calls kvstore function', async () => {
await client.del('mykey');
expect(client.lambda.invoke).toBeCalledWith({
expect((client.lambda as any).invoke).toBeCalledWith({
FunctionName: 'arn',
InvocationType: 'RequestResponse',
Payload: '{"action":"del","key":"mykey"}',
Expand All @@ -52,7 +58,7 @@ describe('KeyValueStoreAutotaskClient', () => {
describe('put', () => {
test('calls kvstore function', async () => {
await client.put('mykey', 'myvalue');
expect(client.lambda.invoke).toBeCalledWith({
expect((client.lambda as any).invoke).toBeCalledWith({
FunctionName: 'arn',
InvocationType: 'RequestResponse',
Payload: '{"action":"put","key":"mykey","value":"myvalue"}',
Expand Down
11 changes: 11 additions & 0 deletions packages/relay/src/__mocks__/@aws-sdk/client-lambda.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
const Lambda = jest.fn(() => ({
invoke: jest.fn(() =>
Promise.resolve({
Payload: {
transformToString: () => JSON.stringify({ result: 'result' }),
},
}),
),
}));

export { Lambda };
7 changes: 7 additions & 0 deletions packages/relay/src/autotask/index-rate.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { AutotaskRelayer } from '.';
import Lambda from 'aws-sdk/clients/lambda';
import { Lambda as LambdaV3 } from '../__mocks__/@aws-sdk/client-lambda';
jest.mock('node:process', () => ({
...jest.requireActual('node:process'),
version: 'v16.0.3',
}));

type TestAutotaskRelayer = Omit<AutotaskRelayer, 'lambda' | 'relayerARN'> & { lambda: Lambda; arn: string };

Expand Down Expand Up @@ -29,6 +34,8 @@ describe('AutotaskRelayer', () => {
let relayer: TestAutotaskRelayer;

beforeEach(async function () {
jest.mock('aws-sdk/clients/lambda', () => Lambda);
jest.mock('@aws-sdk/client-lambda', () => ({ Lambda: LambdaV3 }));
relayer = new AutotaskRelayer({
credentials: JSON.stringify(credentials),
relayerARN: 'arn',
Expand Down
7 changes: 7 additions & 0 deletions packages/relay/src/autotask/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { AutotaskRelayer } from '.';
import Lambda from 'aws-sdk/clients/lambda';
import { Lambda as LambdaV3 } from '../__mocks__/@aws-sdk/client-lambda';
jest.mock('node:process', () => ({
...jest.requireActual('node:process'),
version: 'v16.0.3',
}));

type TestAutotaskRelayer = Omit<AutotaskRelayer, 'lambda' | 'relayerARN'> & { lambda: Lambda; arn: string };

Expand All @@ -18,6 +23,8 @@ describe('AutotaskRelayer', () => {
let relayer: TestAutotaskRelayer;

beforeEach(async function () {
jest.mock('aws-sdk/clients/lambda', () => Lambda);
jest.mock('@aws-sdk/client-lambda', () => ({ Lambda: LambdaV3 }));
relayer = new AutotaskRelayer({
credentials: JSON.stringify(credentials),
relayerARN: 'arn',
Expand Down
Loading

0 comments on commit fee8405

Please sign in to comment.