From 9cebce7af22833e497527cabc83647fcddff176c Mon Sep 17 00:00:00 2001 From: Frank Date: Mon, 16 Sep 2024 17:29:26 -0400 Subject: [PATCH] js sdk: use aws4fetch in vector client --- sdk/js/package.json | 3 +- sdk/js/src/vector/index.ts | 133 ++++++++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 40 deletions(-) diff --git a/sdk/js/package.json b/sdk/js/package.json index 06daec727..93a1de6ed 100644 --- a/sdk/js/package.json +++ b/sdk/js/package.json @@ -48,9 +48,8 @@ }, "optionalDependencies": {}, "dependencies": { - "@aws-sdk/client-lambda": "3.478.0", "aws4fetch": "^1.0.18", "jose": "5.2.3", "openid-client": "5.6.4" } -} \ No newline at end of file +} diff --git a/sdk/js/src/vector/index.ts b/sdk/js/src/vector/index.ts index 9958cd1f8..eb6deb5d5 100644 --- a/sdk/js/src/vector/index.ts +++ b/sdk/js/src/vector/index.ts @@ -1,11 +1,13 @@ -import { - LambdaClient, - InvokeCommand, - InvokeCommandOutput, -} from "@aws-sdk/client-lambda"; +import { AwsClient } from "aws4fetch"; import { Resource } from "../resource.js"; -const lambda = new LambdaClient(); +const client = new AwsClient({ + accessKeyId: process.env.AWS_ACCESS_KEY_ID!, + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.AWS_SESSION_TOKEN!, + region: process.env.AWS_REGION!, +}); +const endpoint = `https://lambda.${process.env.AWS_REGION}.amazonaws.com/2015-03-31`; export interface PutEvent { /** @@ -254,54 +256,109 @@ export function VectorClient< ? never : key : never]: Resource[key]; - } + }, >(name: T): VectorClientResponse { return { put: async (event: PutEvent) => { - const ret = await lambda.send( - new InvokeCommand({ - // @ts-expect-error - FunctionName: Resource[name].putFunction, - Payload: JSON.stringify(event), - }) + await invokeFunction( + // @ts-expect-error + Resource[name].putFunction, + JSON.stringify(event), + "Failed to store into the vector db" ); - - parsePayload(ret, "Failed to store into the vector db"); }, query: async (event: QueryEvent) => { - const ret = await lambda.send( - new InvokeCommand({ - // @ts-expect-error - FunctionName: Resource[name].queryFunction, - Payload: JSON.stringify(event), - }) + return await invokeFunction( + // @ts-expect-error + Resource[name].queryFunction, + JSON.stringify(event), + "Failed to query the vector db" ); - return parsePayload(ret, "Failed to query the vector db"); }, remove: async (event: RemoveEvent) => { - const ret = await lambda.send( - new InvokeCommand({ - // @ts-expect-error - FunctionName: Resource[name].removeFunction, - Payload: JSON.stringify(event), - }) + await invokeFunction( + // @ts-expect-error + Resource[name].removeFunction, + JSON.stringify(event), + "Failed to remove from the vector db" ); - parsePayload(ret, "Failed to remove from the vector db"); }, }; } -function parsePayload(output: InvokeCommandOutput, message: string): T { - const payload = JSON.parse(Buffer.from(output.Payload!).toString()); +async function invokeFunction( + functionName: string, + body: string, + errorMessage: string, + attempts = 0 +): Promise { + try { + const response = await client.fetch( + `${endpoint}/functions/${functionName}/invocations`, + { + method: "POST", + headers: { Accept: "application/json" }, + body, + } + ); - // Set cause to the payload so that it can be logged in CloudWatch - if (output.FunctionError) { - const e = new Error(message); - e.cause = payload; - throw e; - } + // success + if (response.status === 200 || response.status === 201) { + if (response.headers.get("content-length") === "0") return undefined as T; + const text = await response.text(); + try { + return JSON.parse(text); + } catch (e) { + throw new Error(`Failed to parse JSON response: ${text}`); + } + } - return payload; + // error + const error = new Error(); + const text = await response.text(); + try { + const json = JSON.parse(text); + error.name = json.Error?.Code; + error.message = json.Error?.Message ?? json.message ?? text; + } catch (e) { + error.message = text; + } + error.name = error.name ?? response.headers.get("x-amzn-ErrorType"); + // @ts-expect-error + error.requestID = response.headers.get("x-amzn-RequestId"); + // @ts-expect-error + error.statusCode = response.status; + throw error; + } catch (e: any) { + let isRetryable = false; + + // AWS throttling errors => retry + if ( + [ + "ThrottlingException", + "Throttling", + "TooManyRequestsException", + "OperationAbortedException", + "TimeoutError", + "NetworkingError", + ].includes(e.name) + ) { + isRetryable = true; + } + + if (!isRetryable) throw e; + + // retry + await new Promise((resolve) => + setTimeout(resolve, 1.5 ** attempts * 100 * Math.random()) + ); + return await invokeFunction( + functionName, + body, + errorMessage, + attempts + 1 + ); + } }