|
1 |
| -import { |
2 |
| - LambdaClient, |
3 |
| - InvokeCommand, |
4 |
| - InvokeCommandOutput, |
5 |
| -} from "@aws-sdk/client-lambda"; |
| 1 | +import { AwsClient } from "aws4fetch"; |
6 | 2 | import { Resource } from "../resource.js";
|
7 | 3 |
|
8 |
| -const lambda = new LambdaClient(); |
| 4 | +const client = new AwsClient({ |
| 5 | + accessKeyId: process.env.AWS_ACCESS_KEY_ID!, |
| 6 | + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY!, |
| 7 | + sessionToken: process.env.AWS_SESSION_TOKEN!, |
| 8 | + region: process.env.AWS_REGION!, |
| 9 | +}); |
| 10 | +const endpoint = `https://lambda.${process.env.AWS_REGION}.amazonaws.com/2015-03-31`; |
9 | 11 |
|
10 | 12 | export interface PutEvent {
|
11 | 13 | /**
|
@@ -254,54 +256,109 @@ export function VectorClient<
|
254 | 256 | ? never
|
255 | 257 | : key
|
256 | 258 | : never]: Resource[key];
|
257 |
| - } |
| 259 | + }, |
258 | 260 | >(name: T): VectorClientResponse {
|
259 | 261 | return {
|
260 | 262 | put: async (event: PutEvent) => {
|
261 |
| - const ret = await lambda.send( |
262 |
| - new InvokeCommand({ |
263 |
| - // @ts-expect-error |
264 |
| - FunctionName: Resource[name].putFunction, |
265 |
| - Payload: JSON.stringify(event), |
266 |
| - }) |
| 263 | + await invokeFunction( |
| 264 | + // @ts-expect-error |
| 265 | + Resource[name].putFunction, |
| 266 | + JSON.stringify(event), |
| 267 | + "Failed to store into the vector db" |
267 | 268 | );
|
268 |
| - |
269 |
| - parsePayload(ret, "Failed to store into the vector db"); |
270 | 269 | },
|
271 | 270 |
|
272 | 271 | query: async (event: QueryEvent) => {
|
273 |
| - const ret = await lambda.send( |
274 |
| - new InvokeCommand({ |
275 |
| - // @ts-expect-error |
276 |
| - FunctionName: Resource[name].queryFunction, |
277 |
| - Payload: JSON.stringify(event), |
278 |
| - }) |
| 272 | + return await invokeFunction<QueryResponse>( |
| 273 | + // @ts-expect-error |
| 274 | + Resource[name].queryFunction, |
| 275 | + JSON.stringify(event), |
| 276 | + "Failed to query the vector db" |
279 | 277 | );
|
280 |
| - return parsePayload<QueryResponse>(ret, "Failed to query the vector db"); |
281 | 278 | },
|
282 | 279 |
|
283 | 280 | remove: async (event: RemoveEvent) => {
|
284 |
| - const ret = await lambda.send( |
285 |
| - new InvokeCommand({ |
286 |
| - // @ts-expect-error |
287 |
| - FunctionName: Resource[name].removeFunction, |
288 |
| - Payload: JSON.stringify(event), |
289 |
| - }) |
| 281 | + await invokeFunction( |
| 282 | + // @ts-expect-error |
| 283 | + Resource[name].removeFunction, |
| 284 | + JSON.stringify(event), |
| 285 | + "Failed to remove from the vector db" |
290 | 286 | );
|
291 |
| - parsePayload(ret, "Failed to remove from the vector db"); |
292 | 287 | },
|
293 | 288 | };
|
294 | 289 | }
|
295 | 290 |
|
296 |
| -function parsePayload<T>(output: InvokeCommandOutput, message: string): T { |
297 |
| - const payload = JSON.parse(Buffer.from(output.Payload!).toString()); |
| 291 | +async function invokeFunction<T>( |
| 292 | + functionName: string, |
| 293 | + body: string, |
| 294 | + errorMessage: string, |
| 295 | + attempts = 0 |
| 296 | +): Promise<T> { |
| 297 | + try { |
| 298 | + const response = await client.fetch( |
| 299 | + `${endpoint}/functions/${functionName}/invocations`, |
| 300 | + { |
| 301 | + method: "POST", |
| 302 | + headers: { Accept: "application/json" }, |
| 303 | + body, |
| 304 | + } |
| 305 | + ); |
298 | 306 |
|
299 |
| - // Set cause to the payload so that it can be logged in CloudWatch |
300 |
| - if (output.FunctionError) { |
301 |
| - const e = new Error(message); |
302 |
| - e.cause = payload; |
303 |
| - throw e; |
304 |
| - } |
| 307 | + // success |
| 308 | + if (response.status === 200 || response.status === 201) { |
| 309 | + if (response.headers.get("content-length") === "0") return undefined as T; |
| 310 | + const text = await response.text(); |
| 311 | + try { |
| 312 | + return JSON.parse(text); |
| 313 | + } catch (e) { |
| 314 | + throw new Error(`Failed to parse JSON response: ${text}`); |
| 315 | + } |
| 316 | + } |
305 | 317 |
|
306 |
| - return payload; |
| 318 | + // error |
| 319 | + const error = new Error(); |
| 320 | + const text = await response.text(); |
| 321 | + try { |
| 322 | + const json = JSON.parse(text); |
| 323 | + error.name = json.Error?.Code; |
| 324 | + error.message = json.Error?.Message ?? json.message ?? text; |
| 325 | + } catch (e) { |
| 326 | + error.message = text; |
| 327 | + } |
| 328 | + error.name = error.name ?? response.headers.get("x-amzn-ErrorType"); |
| 329 | + // @ts-expect-error |
| 330 | + error.requestID = response.headers.get("x-amzn-RequestId"); |
| 331 | + // @ts-expect-error |
| 332 | + error.statusCode = response.status; |
| 333 | + throw error; |
| 334 | + } catch (e: any) { |
| 335 | + let isRetryable = false; |
| 336 | + |
| 337 | + // AWS throttling errors => retry |
| 338 | + if ( |
| 339 | + [ |
| 340 | + "ThrottlingException", |
| 341 | + "Throttling", |
| 342 | + "TooManyRequestsException", |
| 343 | + "OperationAbortedException", |
| 344 | + "TimeoutError", |
| 345 | + "NetworkingError", |
| 346 | + ].includes(e.name) |
| 347 | + ) { |
| 348 | + isRetryable = true; |
| 349 | + } |
| 350 | + |
| 351 | + if (!isRetryable) throw e; |
| 352 | + |
| 353 | + // retry |
| 354 | + await new Promise((resolve) => |
| 355 | + setTimeout(resolve, 1.5 ** attempts * 100 * Math.random()) |
| 356 | + ); |
| 357 | + return await invokeFunction<T>( |
| 358 | + functionName, |
| 359 | + body, |
| 360 | + errorMessage, |
| 361 | + attempts + 1 |
| 362 | + ); |
| 363 | + } |
307 | 364 | }
|
0 commit comments