diff --git a/readme.MD b/readme.MD index ee58df6..b7b5540 100644 --- a/readme.MD +++ b/readme.MD @@ -146,5 +146,5 @@ the default [model] `gemini-1.5-pro` will be used. - [ ] `completions` -- [ ] `embeddings` +- [x] `embeddings` - [x] `models` diff --git a/src/worker.mjs b/src/worker.mjs index 2886643..8d96517 100644 --- a/src/worker.mjs +++ b/src/worker.mjs @@ -27,6 +27,10 @@ export default { assert(request.method === "POST"); return handleCompletions(await request.json(), apiKey) .catch(errHandler); + case pathname.endsWith("/embeddings"): + assert(request.method === "POST"); + return handleEmbeddings(await request.json(), apiKey) + .catch(errHandler); case pathname.endsWith("/models"): assert(request.method === "GET"); return handleModels(apiKey) @@ -92,6 +96,52 @@ async function handleModels (apiKey) { return new Response(body, { ...response, headers: fixCors(response.headers) }); } +const DEFAULT_EMBEDDINGS_MODEL = "text-embedding-004"; +async function handleEmbeddings (req, apiKey) { + if (typeof req.model !== "string") { + throw new HttpError("model is not specified", 400); + } + if (!Array.isArray(req.input)) { + req.input = [ req.input ]; + } + let model; + if (req.model.startsWith("models/")) { + model = req.model; + } else { + req.model = DEFAULT_EMBEDDINGS_MODEL; + model = "models/" + req.model; + } + const response = await fetch(`${BASE_URL}/${API_VERSION}/${model}:batchEmbedContents`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-goog-api-key": apiKey, + "x-goog-api-client": API_CLIENT, + }, + body: JSON.stringify({ + "requests": req.input.map(text => ({ + model, + content: { parts: { text } }, + outputDimensionality: req.dimensions, + })) + }) + }); + let { body } = response; + if (response.ok) { + const { embeddings } = JSON.parse(await response.text()); + body = JSON.stringify({ + object: "list", + data: embeddings.map(({ values }, index) => ({ + object: "embedding", + index, + embedding: values, + })), + model: req.model, + }, null, " "); + } + return new Response(body, { ...response, headers: fixCors(response.headers) }); +} + const DEFAULT_MODEL = "gemini-1.5-pro-latest"; async function handleCompletions (req, apiKey) { const model = req.model?.startsWith("gemini-") ? req.model : DEFAULT_MODEL;