Skip to content

Commit

Permalink
js sdk: use aws4fetch in vector client
Browse files Browse the repository at this point in the history
  • Loading branch information
fwang committed Sep 16, 2024
1 parent d0d8b5f commit 9cebce7
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 40 deletions.
3 changes: 1 addition & 2 deletions sdk/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
},
"optionalDependencies": {},
"dependencies": {
"@aws-sdk/client-lambda": "3.478.0",
"aws4fetch": "^1.0.18",
"jose": "5.2.3",
"openid-client": "5.6.4"
}
}
}
133 changes: 95 additions & 38 deletions sdk/js/src/vector/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import {
LambdaClient,
InvokeCommand,
InvokeCommandOutput,
} from "@aws-sdk/client-lambda";
import { AwsClient } from "aws4fetch";
import { Resource } from "../resource.js";

const lambda = new LambdaClient();
const client = new AwsClient({
accessKeyId: process.env.AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY!,
sessionToken: process.env.AWS_SESSION_TOKEN!,
region: process.env.AWS_REGION!,
});
const endpoint = `https://lambda.${process.env.AWS_REGION}.amazonaws.com/2015-03-31`;

export interface PutEvent {
/**
Expand Down Expand Up @@ -254,54 +256,109 @@ export function VectorClient<
? never
: key
: never]: Resource[key];
}
},
>(name: T): VectorClientResponse {
return {
put: async (event: PutEvent) => {
const ret = await lambda.send(
new InvokeCommand({
// @ts-expect-error
FunctionName: Resource[name].putFunction,
Payload: JSON.stringify(event),
})
await invokeFunction(
// @ts-expect-error
Resource[name].putFunction,
JSON.stringify(event),
"Failed to store into the vector db"
);

parsePayload(ret, "Failed to store into the vector db");
},

query: async (event: QueryEvent) => {
const ret = await lambda.send(
new InvokeCommand({
// @ts-expect-error
FunctionName: Resource[name].queryFunction,
Payload: JSON.stringify(event),
})
return await invokeFunction<QueryResponse>(
// @ts-expect-error
Resource[name].queryFunction,
JSON.stringify(event),
"Failed to query the vector db"
);
return parsePayload<QueryResponse>(ret, "Failed to query the vector db");
},

remove: async (event: RemoveEvent) => {
const ret = await lambda.send(
new InvokeCommand({
// @ts-expect-error
FunctionName: Resource[name].removeFunction,
Payload: JSON.stringify(event),
})
await invokeFunction(
// @ts-expect-error
Resource[name].removeFunction,
JSON.stringify(event),
"Failed to remove from the vector db"
);
parsePayload(ret, "Failed to remove from the vector db");
},
};
}

function parsePayload<T>(output: InvokeCommandOutput, message: string): T {
const payload = JSON.parse(Buffer.from(output.Payload!).toString());
async function invokeFunction<T>(
functionName: string,
body: string,
errorMessage: string,
attempts = 0
): Promise<T> {
try {
const response = await client.fetch(
`${endpoint}/functions/${functionName}/invocations`,
{
method: "POST",
headers: { Accept: "application/json" },
body,
}
);

// Set cause to the payload so that it can be logged in CloudWatch
if (output.FunctionError) {
const e = new Error(message);
e.cause = payload;
throw e;
}
// success
if (response.status === 200 || response.status === 201) {
if (response.headers.get("content-length") === "0") return undefined as T;
const text = await response.text();
try {
return JSON.parse(text);
} catch (e) {
throw new Error(`Failed to parse JSON response: ${text}`);
}
}

return payload;
// error
const error = new Error();
const text = await response.text();
try {
const json = JSON.parse(text);
error.name = json.Error?.Code;
error.message = json.Error?.Message ?? json.message ?? text;
} catch (e) {
error.message = text;
}
error.name = error.name ?? response.headers.get("x-amzn-ErrorType");
// @ts-expect-error
error.requestID = response.headers.get("x-amzn-RequestId");
// @ts-expect-error
error.statusCode = response.status;
throw error;
} catch (e: any) {
let isRetryable = false;

// AWS throttling errors => retry
if (
[
"ThrottlingException",
"Throttling",
"TooManyRequestsException",
"OperationAbortedException",
"TimeoutError",
"NetworkingError",
].includes(e.name)
) {
isRetryable = true;
}

if (!isRetryable) throw e;

// retry
await new Promise((resolve) =>
setTimeout(resolve, 1.5 ** attempts * 100 * Math.random())
);
return await invokeFunction<T>(
functionName,
body,
errorMessage,
attempts + 1
);
}
}

0 comments on commit 9cebce7

Please sign in to comment.