From ef27ebce57c7371a62f7629d64cfae8f233cf435 Mon Sep 17 00:00:00 2001
From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com>
Date: Mon, 12 Aug 2024 02:22:59 -0400
Subject: [PATCH] [Embeddings][OpenAI] Support embeddings via
engine.embeddings.create()
---
examples/README.md | 1 +
examples/embeddings/README.md | 14 ++
examples/embeddings/package.json | 21 ++
examples/embeddings/src/embeddings.html | 23 ++
examples/embeddings/src/embeddings.ts | 147 ++++++++++++
src/config.ts | 55 ++++-
src/embedding.ts | 290 ++++++++++++++++++++++++
src/engine.ts | 107 +++++++--
src/error.ts | 62 ++++-
src/message.ts | 13 ++
src/openai_api_protocols/embedding.ts | 195 ++++++++++++++++
src/openai_api_protocols/index.ts | 8 +
src/support.ts | 11 +-
src/types.ts | 19 +-
src/web_worker.ts | 42 +++-
tests/openai_chat_completion.test.ts | 2 +-
tests/openai_embeddings.test.ts | 133 +++++++++++
17 files changed, 1112 insertions(+), 31 deletions(-)
create mode 100644 examples/embeddings/README.md
create mode 100644 examples/embeddings/package.json
create mode 100644 examples/embeddings/src/embeddings.html
create mode 100644 examples/embeddings/src/embeddings.ts
create mode 100644 src/embedding.ts
create mode 100644 src/openai_api_protocols/embedding.ts
create mode 100644 tests/openai_embeddings.test.ts
diff --git a/examples/README.md b/examples/README.md
index d0ffefd9..ee3b16f6 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -24,6 +24,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
+- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of Langchain.js
#### Advanced OpenAI API Capabilities
diff --git a/examples/embeddings/README.md b/examples/embeddings/README.md
new file mode 100644
index 00000000..7450aad8
--- /dev/null
+++ b/examples/embeddings/README.md
@@ -0,0 +1,14 @@
+# WebLLM Get Started App
+
+This folder provides a minimum demo to show WebLLM API in a webapp setting.
+To try it out, you can do the following steps under this folder
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/examples/embeddings/package.json b/examples/embeddings/package.json
new file mode 100644
index 00000000..b363dc21
--- /dev/null
+++ b/examples/embeddings/package.json
@@ -0,0 +1,21 @@
+{
+ "name": "embeddings-example",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel src/embeddings.html --port 8885",
+ "build": "parcel build src/embeddings.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "file:../..",
+ "langchain": "0.2.15"
+ }
+}
diff --git a/examples/embeddings/src/embeddings.html b/examples/embeddings/src/embeddings.html
new file mode 100644
index 00000000..484ee7c3
--- /dev/null
+++ b/examples/embeddings/src/embeddings.html
@@ -0,0 +1,23 @@
+
+
+
+
+ WebLLM Test Page
+ Open console to see output
+
+
+
+
+ Prompt
+
+
+ Response
+
+
+
+
+
+
+
diff --git a/examples/embeddings/src/embeddings.ts b/examples/embeddings/src/embeddings.ts
new file mode 100644
index 00000000..b8ff521c
--- /dev/null
+++ b/examples/embeddings/src/embeddings.ts
@@ -0,0 +1,147 @@
+import * as webllm from "@mlc-ai/web-llm";
+import { MemoryVectorStore } from "langchain/vectorstores/memory";
+import type { EmbeddingsInterface } from "@langchain/core/embeddings";
+import type { Document } from "@langchain/core/documents";
+
+function setLabel(id: string, text: string) {
+ const label = document.getElementById(id);
+ if (label == null) {
+ throw Error("Cannot find label " + id);
+ }
+ label.innerText = text;
+}
+
+const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+};
+
+// For integration with Langchain
+class WebLLMEmbeddings implements EmbeddingsInterface {
+ engine: webllm.MLCEngineInterface;
+ constructor(engine: webllm.MLCEngineInterface) {
+ this.engine = engine;
+ }
+
+ async _embed(texts: string[]): Promise {
+ const reply = await this.engine.embeddings.create({ input: texts });
+ const result: number[][] = [];
+ for (let i = 0; i < texts.length; i++) {
+ result.push(reply.data[i].embedding);
+ }
+ return result;
+ }
+
+ async embedQuery(document: string): Promise {
+ return this._embed([document]).then((embeddings) => embeddings[0]);
+ }
+
+ async embedDocuments(documents: string[]): Promise {
+ return this._embed(documents);
+ }
+}
+
+// Prepare inputs
+const documents_og = ["The Data Cloud!", "Mexico City of Course!"];
+const queries_og = ["what is snowflake?", "Where can I get the best tacos?"];
+const documents: string[] = [];
+const queries: string[] = [];
+const query_prefix =
+ "Represent this sentence for searching relevant passages: ";
+// Process according to Snowflake model
+documents_og.forEach(function (item, index) {
+ documents[index] = `[CLS] ${item} [SEP]`;
+});
+queries_og.forEach(function (item, index) {
+ queries[index] = `[CLS] ${query_prefix}${item} [SEP]`;
+});
+console.log("Formatted documents: ", documents);
+console.log("Formatted queries: ", queries);
+
+// Using webllm's API
+async function webllmAPI() {
+ // b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
+ // batch. If given more than 4, the model will forward multiple times. The larger the max batch
+ // size, the more memory it consumes.
+ // const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
+ const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO", // specify the log level
+ },
+ );
+
+ const docReply = await engine.embeddings.create({ input: documents });
+ console.log(docReply);
+ console.log(docReply.usage);
+
+ const queryReply = await engine.embeddings.create({ input: queries });
+ console.log(queryReply);
+ console.log(queryReply.usage);
+
+ // Calculate similarity (we use langchain here, but any method works)
+ const vectorStore = await MemoryVectorStore.fromExistingIndex(
+ new WebLLMEmbeddings(engine),
+ );
+ // See score
+ for (let i = 0; i < queries_og.length; i++) {
+ console.log(`Similarity with: ${queries_og[i]}`);
+ for (let j = 0; j < documents_og.length; j++) {
+ const similarity = vectorStore.similarity(
+ queryReply.data[i].embedding,
+ docReply.data[j].embedding,
+ );
+ console.log(`${documents_og[j]}: ${similarity}`);
+ }
+ }
+}
+
+// Alternatively, integrating with Langchain's API
+async function langchainAPI() {
+ // b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
+ // batch. If given more than 4, the model will forward multiple times. The larger the max batch
+ // size, the more memory it consumes.
+ // const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
+ const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO", // specify the log level
+ },
+ );
+
+ const vectorStore = await MemoryVectorStore.fromExistingIndex(
+ new WebLLMEmbeddings(engine),
+ );
+ const document0: Document = {
+ pageContent: documents[0],
+ metadata: {},
+ };
+ const document1: Document = {
+ pageContent: documents[1],
+ metadata: {},
+ };
+ await vectorStore.addDocuments([document0, document1]);
+
+ const similaritySearchResults0 = await vectorStore.similaritySearch(
+ queries[0],
+ 1,
+ );
+ for (const doc of similaritySearchResults0) {
+ console.log(`* ${doc.pageContent}`);
+ }
+
+ const similaritySearchResults1 = await vectorStore.similaritySearch(
+ queries[1],
+ 1,
+ );
+ for (const doc of similaritySearchResults1) {
+ console.log(`* ${doc.pageContent}`);
+ }
+}
+
+// Select one to run
+webllmAPI();
+// langchainAPI();
diff --git a/src/config.ts b/src/config.ts
index 5aae7693..6c81febf 100644
--- a/src/config.ts
+++ b/src/config.ts
@@ -68,10 +68,10 @@ export interface TokenizerInfo {
* Only these fields affect the conversation in runtime.
* i.e. The third part in https://llm.mlc.ai/docs/get_started/mlc_chat_config.html.
*
- * This is initialized in `ChatModule.reload()` with the model's `mlc-chat-config.json`.
+ * This is initialized in `MLCEngine.reload()` with the model's `mlc-chat-config.json`.
*/
export interface ChatConfig {
- // First three fields affect the entire conversation, i.e. used in `ChatModule.reload()`
+ // First three fields affect the entire conversation, i.e. used in `MLCEngine.reload()`
tokenizer_files: Array;
tokenizer_info?: TokenizerInfo;
token_table_postproc_method?: string; // TODO: backward compatibility, remove soon
@@ -122,7 +122,7 @@ export interface MLCEngineConfig {
* We also support additional fields not present in `mlc-chat-config.json` due to OpenAI-like APIs.
*
* Note that all values are optional. If unspecified, we use whatever values in `ChatConfig`
- * initialized during `ChatModule.reload()`.
+ * initialized during `MLCEngine.reload()`.
*/
export interface GenerationConfig {
// Only used in MLC
@@ -226,6 +226,11 @@ export function postInitAndCheckGenerationConfigValues(
}
}
+export enum ModelType {
+ "LLM",
+ "embedding",
+}
+
/**
* Information for a model.
* @param model: the huggingface link to download the model weights, accepting four formats:
@@ -241,6 +246,7 @@ export function postInitAndCheckGenerationConfigValues(
* @param low_resource_required: whether the model can run on limited devices (e.g. Android phone).
* @param buffer_size_required_bytes: required `maxStorageBufferBindingSize`, different for each device.
* @param required_features: feature needed to run this model (e.g. shader-f16).
+ * @param model_type: the intended usecase for the model, if unspecified, default to LLM.
*/
export interface ModelRecord {
model: string;
@@ -251,6 +257,7 @@ export interface ModelRecord {
low_resource_required?: boolean;
buffer_size_required_bytes?: number;
required_features?: Array;
+ model_type?: ModelType;
}
/**
@@ -1514,5 +1521,47 @@ export const prebuiltAppConfig: AppConfig = {
context_window_size: 1024,
},
},
+ // Embedding models
+ // -b means max_batch_size this model allows. The smaller it is, the less memory the model consumes.
+ {
+ model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
+ model_id: "snowflake-arctic-embed-m-q0f32-MLC-b32",
+ model_lib:
+ modelLibURLPrefix +
+ modelVersion +
+ "/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch32-webgpu.wasm",
+ vram_required_MB: 1407.51,
+ model_type: ModelType.embedding,
+ },
+ {
+ model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
+ model_id: "snowflake-arctic-embed-m-q0f32-MLC-b4",
+ model_lib:
+ modelLibURLPrefix +
+ modelVersion +
+ "/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch4-webgpu.wasm",
+ vram_required_MB: 539.4,
+ model_type: ModelType.embedding,
+ },
+ {
+ model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
+ model_id: "snowflake-arctic-embed-s-q0f32-MLC-b32",
+ model_lib:
+ modelLibURLPrefix +
+ modelVersion +
+ "/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch32-webgpu.wasm",
+ vram_required_MB: 1022.82,
+ model_type: ModelType.embedding,
+ },
+ {
+ model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
+ model_id: "snowflake-arctic-embed-s-q0f32-MLC-b4",
+ model_lib:
+ modelLibURLPrefix +
+ modelVersion +
+ "/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch4-webgpu.wasm",
+ vram_required_MB: 238.71,
+ model_type: ModelType.embedding,
+ },
],
};
diff --git a/src/embedding.ts b/src/embedding.ts
new file mode 100644
index 00000000..c620fb26
--- /dev/null
+++ b/src/embedding.ts
@@ -0,0 +1,290 @@
+import * as tvmjs from "tvmjs";
+import log from "loglevel";
+import { Tokenizer } from "@mlc-ai/web-tokenizers";
+import { ChatConfig } from "./config";
+import {
+ EmbeddingChunkingUnsupportedError,
+ EmbeddingExceedContextWindowSizeError,
+ EmbeddingInputEmptyError,
+ EmbeddingSlidingWindowError,
+ MinValueError,
+} from "./error";
+
+export class EmbeddingPipeline {
+ private config: ChatConfig;
+ private tokenizer: Tokenizer;
+
+ // TVM functions
+ private tvm: tvmjs.Instance;
+ private device: tvmjs.DLDevice;
+ private vm: tvmjs.VirtualMachine;
+ private prefill: tvmjs.PackedFunc;
+ private params: tvmjs.TVMObject;
+
+ // metadata
+ private contextWindowSize = -1;
+ private prefillChunkSize = -1;
+ private maxBatchSize = -1;
+
+ // performance
+ private curRoundEmbedTotalTokens = 0; // excludes padded tokens for batching
+ private curRoundEmbedTotalTime = 0;
+
+ constructor(tvm: tvmjs.Instance, tokenizer: Tokenizer, config: ChatConfig) {
+ // 0. Setting attributes
+ this.tvm = tvm;
+ this.tokenizer = tokenizer;
+ this.config = config;
+ this.device = this.tvm.webgpu();
+
+ // 1. Create VM and get the core functions
+ tvm.beginScope();
+ this.vm = this.tvm.detachFromCurrentScope(
+ this.tvm.createVirtualMachine(this.device),
+ );
+ this.prefill = this.tvm.detachFromCurrentScope(
+ this.vm.getFunction("prefill"),
+ );
+
+ // 2. Get json stored in the vm's metadata function
+ const fgetMetadata = this.vm.getFunction("_metadata");
+ const ret_value = fgetMetadata();
+ const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
+ const metadata = JSON.parse(metadataStr);
+
+ // 3. Load parameters by name
+ const paramNames: string[] = [];
+ metadata.params.forEach((param: any) => {
+ paramNames.push(param.name);
+ });
+ this.params = this.tvm.detachFromCurrentScope(
+ this.tvm.getParamsFromCacheByName(paramNames),
+ );
+
+ // 4. Read in compilation configurations from metadata
+ // We use context window size max batch size to check validity of the model
+ // We assume prefillChunkSize is the same as contextWindowSize for embedding model for now
+ this.maxBatchSize = metadata.max_batch_size;
+ this.contextWindowSize = this.config.context_window_size;
+ this.prefillChunkSize = metadata.prefill_chunk_size;
+ log.info("Using maxBatchSize: ", this.maxBatchSize);
+ log.info("Using contextWindowSize: ", this.contextWindowSize);
+ log.info("Using prefillChunkSize: ", this.prefillChunkSize);
+
+ if (this.config.sliding_window_size !== -1) {
+ throw new EmbeddingSlidingWindowError(this.config.sliding_window_size);
+ }
+ if (this.maxBatchSize <= 0) {
+ throw new MinValueError("maxBatchSize", 0);
+ }
+ if (this.contextWindowSize <= 0) {
+ throw new MinValueError("contextWindowSize", 0);
+ }
+ if (this.prefillChunkSize <= 0) {
+ throw new MinValueError("prefillChunkSize", 0);
+ }
+ if (this.prefillChunkSize !== this.contextWindowSize) {
+ throw new EmbeddingChunkingUnsupportedError(
+ this.contextWindowSize,
+ this.prefillChunkSize,
+ );
+ }
+ tvm.endScope();
+ }
+
+ async embedStep(
+ input: string | Array | Array | Array>,
+ ): Promise>> {
+ // 0. Reset performance metrics
+ this.curRoundEmbedTotalTokens = 0;
+ this.curRoundEmbedTotalTime = 0;
+ let totalNumTokens = 0;
+ const embedStart = performance.now();
+ let tokenizedInputs: Array> = [];
+ const tempInputs: Array = [];
+ // 1. Convert all possible input types to Array>, tokenize if not already
+ // Cannot use input.every to match type, which leads to TS compilation error
+ // https://github.com/microsoft/TypeScript/issues/33591
+ if (input.length === 0) {
+ throw new EmbeddingInputEmptyError();
+ }
+ if (typeof input === "string") {
+ // string
+ tokenizedInputs = [Array.from(this.tokenizer.encode(input))];
+ } else {
+ for (let i = 0; i < input.length; i++) {
+ const curInput = input[i];
+ if (Array.isArray(curInput)) {
+ // Array>
+ tokenizedInputs.push(curInput);
+ } else if (typeof curInput === "string") {
+ // Array
+ tokenizedInputs.push(Array.from(this.tokenizer.encode(curInput)));
+ } else {
+ // Array
+ tempInputs.push(curInput);
+ }
+ }
+ }
+ if (tempInputs.length > 0) {
+ tokenizedInputs.push(tempInputs);
+ }
+
+ // 2. Check each input is not larger than the context window size
+ // TODO: tokenizer.encode seems to implicitly truncates to contextWindowSize, confirm behavior
+ // and decide whether to warn user
+ for (let i = 0; i < tokenizedInputs.length; i++) {
+ const curInputSize = tokenizedInputs[i].length;
+ totalNumTokens += curInputSize;
+ if (curInputSize > this.contextWindowSize) {
+ throw new EmbeddingExceedContextWindowSizeError(
+ this.contextWindowSize,
+ curInputSize,
+ );
+ }
+ }
+ if (tokenizedInputs.length === 0) {
+ throw new Error("InternalError: batch size is zero.");
+ }
+
+ // 3. Forward each batch
+ const batchSize = tokenizedInputs.length;
+ const result: Array> = [];
+ for (let begin = 0; begin < batchSize; begin += this.maxBatchSize) {
+ this.tvm.beginScope();
+ // 3.1 Get current batch
+ const end = Math.min(batchSize, begin + this.maxBatchSize);
+ const curBatch: Array> = tokenizedInputs.slice(begin, end);
+ const curBatchSize = curBatch.length;
+ // 3.2 Max input size of current batch
+ let maxInputSize = 0;
+ for (let i = 0; i < curBatchSize; i++) {
+ const curInputSize = curBatch[i].length;
+ if (curInputSize > maxInputSize) {
+ maxInputSize = curInputSize;
+ }
+ }
+ // 3.3 Create inputs and attention mask
+ // Padded with zeros and flattened, of size curBatchSize * maxInputSize
+ const curBatchPaddedFlatten: Array = [];
+ // 1 for non-pad, 0 otherwise, also of size curBatchSize * maxInputSize
+ const curAttnMask: Array = [];
+ const flattenedInputSize = curBatchSize * maxInputSize;
+ for (let i = 0; i < curBatchSize; i++) {
+ const padding = Array(maxInputSize - curBatch[i].length).fill(0);
+ const ones = Array(curBatch[i].length).fill(1);
+ curBatchPaddedFlatten.push(...curBatch[i]);
+ curAttnMask.push(...ones);
+ curBatchPaddedFlatten.push(...padding);
+ curAttnMask.push(...padding);
+ }
+ if (
+ curBatchPaddedFlatten.length !== flattenedInputSize ||
+ curAttnMask.length !== flattenedInputSize
+ ) {
+ throw new Error(
+ `InternalError: Expect input array to be ${flattenedInputSize}, ` +
+ `but got ${curBatchPaddedFlatten.length}`,
+ );
+ }
+ // 3.4 Convert inputs and attention mask to tvm ndarray on GPU, of shape (curBatchSize, maxInputSize)
+ let inputNDArray = this.tvm.empty(
+ [flattenedInputSize],
+ "int32",
+ this.device,
+ );
+ inputNDArray.copyFrom(curBatchPaddedFlatten);
+ inputNDArray = inputNDArray.view([curBatchSize, maxInputSize]);
+ let maskNDArray = this.tvm.empty(
+ [flattenedInputSize],
+ "int32",
+ this.device,
+ );
+ maskNDArray.copyFrom(curAttnMask);
+ maskNDArray = maskNDArray.view([curBatchSize, maxInputSize]);
+
+ // 3.5 Actual forwarding on GPU, logits of shape (curBatchSize, maxInputSize, hidden_size)
+ const logitsCurBatchOnGPU: tvmjs.NDArray = this.prefill(
+ inputNDArray,
+ maskNDArray,
+ this.params,
+ );
+ await this.device.sync();
+
+ // 3.6 Copy logits to CPU, flatten to curBatchSize * maxInputSize * hidden_size
+ const hidden_size = logitsCurBatchOnGPU.shape[2];
+ let logitsCurBatchOnCPU: tvmjs.NDArray = this.tvm.empty(
+ logitsCurBatchOnGPU.shape,
+ logitsCurBatchOnGPU.dtype,
+ this.tvm.cpu(),
+ );
+ logitsCurBatchOnCPU.copyFrom(logitsCurBatchOnGPU);
+ logitsCurBatchOnCPU = logitsCurBatchOnCPU.view([
+ curBatchSize * maxInputSize * hidden_size,
+ ]);
+ await this.device.sync();
+ const logitsCurBatchOnCPUArray: Float32Array = (
+ logitsCurBatchOnCPU.toArray()
+ );
+
+ // 3.7 Update final result. For each sentence, get [0,:], i.e. only the first token's output
+ // That is, we are doing result.push(logits[:,0,:]) here.
+ // TODO: check if all models only use [0,:]. If it is snowflake-specific, need to specify
+ // this in mlc-chat-config.json
+ for (let i = 0; i < curBatchSize; i++) {
+ const b = i * maxInputSize * hidden_size;
+ const e = b + hidden_size;
+ result.push(Array.from(logitsCurBatchOnCPUArray.slice(b, e)));
+ }
+ this.tvm.endScope();
+ }
+ if (result.length !== batchSize) {
+ throw new Error(`
+ InternalError: expect result.length to be ${batchSize}, but got ${result.length}`);
+ }
+ const embedEnd = performance.now();
+ this.curRoundEmbedTotalTokens = totalNumTokens;
+ this.curRoundEmbedTotalTime = (embedEnd - embedStart) / 1e3;
+
+ return result;
+ }
+
+ dispose() {
+ this.params.dispose();
+ this.prefill.dispose();
+ this.vm.dispose();
+ this.tvm.dispose();
+ this.tokenizer.dispose();
+ }
+
+ /**
+ * Synchronize the device.
+ */
+ async sync(): Promise {
+ // Is it equivalent to this.tvm.sync()?
+ await this.device.sync();
+ }
+
+ // Performance APIs below
+
+ /**
+ * Get the time it took the last `embedStep()` in seconds.
+ */
+ getCurRoundEmbedTotalTime(): number {
+ return this.curRoundEmbedTotalTime;
+ }
+
+ /**
+ * Get the number of tokens embedded in the last `embedStep()`. This excludes the padded tokens.
+ */
+ getCurRoundEmbedTotalTokens(): number {
+ return this.curRoundEmbedTotalTokens;
+ }
+
+ /**
+ * @returns Prefill tokens per second, starting from the last prefill performed.
+ */
+ getCurRoundEmbedTokensPerSec(): number {
+ return this.curRoundEmbedTotalTokens / this.curRoundEmbedTotalTime;
+ }
+}
diff --git a/src/engine.ts b/src/engine.ts
index fceb7232..342cc3cf 100644
--- a/src/engine.ts
+++ b/src/engine.ts
@@ -10,6 +10,7 @@ import {
Role,
MLCEngineConfig,
DefaultLogLevel,
+ ModelType,
} from "./config";
import { LLMChatPipeline } from "./llm_chat";
import {
@@ -31,6 +32,9 @@ import {
CompletionCreateParams,
Completion,
CompletionChoice,
+ EmbeddingCreateParams,
+ CreateEmbeddingResponse,
+ Embedding,
} from "./openai_api_protocols/index";
import * as API from "./openai_api_protocols/index";
import {
@@ -45,19 +49,24 @@ import {
getConversation,
getConversationFromChatCompletionRequest,
} from "./conversation";
-import { cleanModelUrl, getToolCallFromOutputMessage } from "./support";
import {
- ChatModuleNotInitializedError,
+ cleanModelUrl,
+ findModelRecord,
+ getToolCallFromOutputMessage,
+} from "./support";
+import {
+ EngineNotLoadedError,
ConfigurationNotInitializedError,
DeviceLostError,
+ EmbeddingUnsupportedModelError,
FeatureSupportError,
MissingModelWasmError,
- ModelNotFoundError,
ModelNotLoadedError,
ShaderF16SupportError,
WebGPUNotAvailableError,
} from "./error";
import { asyncLoadTokenizer } from "./cache_util";
+import { EmbeddingPipeline } from "./embedding";
/**
* Creates `MLCEngine`, and loads `modelId` onto WebGPU.
@@ -93,12 +102,15 @@ export class MLCEngine implements MLCEngineInterface {
public chat: API.Chat;
/** For completions.create() */
public completions: API.Completions;
+ /** For embeddings.create() */
+ public embeddings: API.Embeddings;
private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded
private logger: (msg: string) => void = log.info;
private logitProcessorRegistry?: Map;
private logitProcessor?: LogitProcessor;
private pipeline?: LLMChatPipeline;
+ private embeddingPipeline?: EmbeddingPipeline;
private initProgressCallback?: InitProgressCallback;
private interruptSignal = false;
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
@@ -114,6 +126,7 @@ export class MLCEngine implements MLCEngineInterface {
this.chat = new API.Chat(this);
this.completions = new API.Completions(this);
+ this.embeddings = new API.Embeddings(this);
}
//-----------------------
@@ -174,15 +187,7 @@ export class MLCEngine implements MLCEngineInterface {
this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
const tstart = performance.now();
- const findModelRecord = () => {
- const matchedItem = this.appConfig?.model_list.find(
- (item) => item.model_id == modelId,
- );
- if (matchedItem !== undefined) return matchedItem;
- throw new ModelNotFoundError(modelId);
- };
-
- const modelRecord = findModelRecord();
+ const modelRecord = findModelRecord(modelId, this.appConfig);
const baseUrl =
typeof document !== "undefined"
? document.URL
@@ -305,12 +310,20 @@ export class MLCEngine implements MLCEngineInterface {
cacheType,
this.reloadController?.signal,
);
- this.pipeline = new LLMChatPipeline(
- tvm,
- tokenizer,
- this.config,
- this.logitProcessor,
- );
+ if (modelRecord.model_type === ModelType.embedding) {
+ this.embeddingPipeline = new EmbeddingPipeline(
+ tvm,
+ tokenizer,
+ this.config,
+ );
+ } else {
+ this.pipeline = new LLMChatPipeline(
+ tvm,
+ tokenizer,
+ this.config,
+ this.logitProcessor,
+ );
+ }
await this.pipeline?.asyncLoadWebGPUPipelines();
const tend = performance.now();
@@ -337,9 +350,12 @@ export class MLCEngine implements MLCEngineInterface {
async unload() {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
this.pipeline?.dispose();
+ this.embeddingPipeline?.dispose();
// Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
await this.pipeline?.sync();
+ await this.embeddingPipeline?.sync();
this.pipeline = undefined;
+ this.embeddingPipeline = undefined;
this.currentModelId = undefined;
this.deviceLostIsError = true;
if (this.reloadController) {
@@ -880,6 +896,52 @@ export class MLCEngine implements MLCEngineInterface {
return response;
}
+ async embedding(
+ request: EmbeddingCreateParams,
+ ): Promise {
+ // 0. Preprocess inputs
+ if (!this.currentModelId) {
+ throw new ModelNotLoadedError();
+ }
+ if (
+ findModelRecord(this.currentModelId, this.appConfig).model_type !==
+ ModelType.embedding
+ ) {
+ throw new EmbeddingUnsupportedModelError(this.currentModelId);
+ }
+ API.postInitAndCheckFieldsEmbedding(request, this.currentModelId);
+
+ // 1. Call EmbeddingPipeline to get embeddings
+ const embedResult: Array> =
+ await this.getEmbeddingPipeline().embedStep(request.input);
+
+ // 2. Prepare response
+ const batchSize = embedResult.length;
+ const data: Array = [];
+ for (let i = 0; i < batchSize; i++) {
+ const curEmbedding: Embedding = {
+ embedding: embedResult[i],
+ index: i,
+ object: "embedding",
+ };
+ data.push(curEmbedding);
+ }
+ return {
+ data: data,
+ model: this.currentModelId,
+ object: "list",
+ usage: {
+ prompt_tokens:
+ this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(),
+ total_tokens: this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(),
+ extra: {
+ prefill_tokens_per_s:
+ this.getEmbeddingPipeline().getCurRoundEmbedTokensPerSec(),
+ },
+ },
+ };
+ }
+
//-----------------------------
// 4. WebGPU info-querying helpers
//-----------------------------
@@ -927,11 +989,18 @@ export class MLCEngine implements MLCEngineInterface {
//----------------------------------------------
private getPipeline(): LLMChatPipeline {
if (this.pipeline === undefined) {
- throw new ChatModuleNotInitializedError();
+ throw new EngineNotLoadedError();
}
return this.pipeline;
}
+ private getEmbeddingPipeline(): EmbeddingPipeline {
+ if (this.embeddingPipeline === undefined) {
+ throw new EngineNotLoadedError();
+ }
+ return this.embeddingPipeline;
+ }
+
async forwardTokensAndSample(
inputIds: Array,
isPrefill: boolean,
diff --git a/src/error.ts b/src/error.ts
index af104544..ef672892 100644
--- a/src/error.ts
+++ b/src/error.ts
@@ -254,12 +254,12 @@ export class UnsupportedToolTypeError extends Error {
this.name = "UnsupportedToolTypeError";
}
}
-export class ChatModuleNotInitializedError extends Error {
+export class EngineNotLoadedError extends Error {
constructor() {
super(
- "Chat module not yet initialized. Ensure you initialize the chat module by calling `chat.reload()` first.",
+ "Engine not yet loaded with model. Ensure you initialize the chat module by calling `engine.reload()` first.",
);
- this.name = "ChatModuleNotInitializedError";
+ this.name = "EngineNotLoadedError";
}
}
export class UnsupportedTokenizerFilesError extends Error {
@@ -423,3 +423,59 @@ export class TextCompletionConversationError extends Error {
this.name = "TextCompletionConversationError";
}
}
+
+export class EmbeddingUnsupportedEncodingFormatError extends Error {
+ constructor() {
+ super("Embedding in base64 format is currently not supported.");
+ this.name = "EmbeddingUnsupportedEncodingFormatError";
+ }
+}
+
+export class EmbeddingUnsupportedModelError extends Error {
+ constructor(currentModel: string) {
+ super(
+ `Trying to run embeddings.create() with ${currentModel}, which does not have ` +
+ `ModelRecord.model_type === ModelType.embedding in the model record. ` +
+ `Either make sure an embedding model is loaded, or specify the model type in ModelRecord.`,
+ );
+ this.name = "EmbeddingUnsupportedModelError";
+ }
+}
+
+export class EmbeddingSlidingWindowError extends Error {
+ constructor(sliding_window_size: number) {
+ super(
+ `Embedding should not use sliding window. However, ` +
+ `sliding_window_size=${sliding_window_size} is specified in the chat config.`,
+ );
+ this.name = "EmbeddingSlidingWindowError";
+ }
+}
+
+export class EmbeddingChunkingUnsupportedError extends Error {
+ constructor(contextWindowSize: number, prefillChunkSize: number) {
+ super(
+ `Embedding currently does not support chunking. Make sure ` +
+ `contextWindowSize === prefillChunkSize. Got contextWindowSize=${contextWindowSize}, ` +
+ `prefillChunkSize=${prefillChunkSize} instead.`,
+ );
+ this.name = "EmbeddingChunkingUnsupportedError";
+ }
+}
+
+export class EmbeddingExceedContextWindowSizeError extends Error {
+ constructor(contextWindowSize: number, receivedSize: number) {
+ super(
+ `The embedding model you are using only supports up to ${contextWindowSize} context size.` +
+ `However, an input in the batch has size ${receivedSize}.`,
+ );
+ this.name = "EmbeddingExceedContextWindowSizeError";
+ }
+}
+
+export class EmbeddingInputEmptyError extends Error {
+ constructor() {
+ super("Embedding input cannot be empty string or empty token array.");
+ this.name = "EmbeddingInputEmptyError";
+ }
+}
diff --git a/src/message.ts b/src/message.ts
index 3b63e04f..214ef2f4 100644
--- a/src/message.ts
+++ b/src/message.ts
@@ -8,6 +8,8 @@ import {
CompletionCreateParamsNonStreaming,
CompletionCreateParamsStreaming,
Completion,
+ EmbeddingCreateParams,
+ CreateEmbeddingResponse,
} from "./openai_api_protocols/index";
/**
@@ -25,6 +27,7 @@ type RequestKind =
| "forwardTokensAndSample"
| "chatCompletionNonStreaming"
| "completionNonStreaming"
+ | "embedding"
| "getMessage"
| "chatCompletionStreamInit"
| "completionStreamInit"
@@ -93,6 +96,14 @@ export interface CompletionStreamInitParams {
modelId: string;
chatOpts: ChatOptions;
}
+export interface EmbeddingParams {
+ request: EmbeddingCreateParams;
+ // The model and chatOpts that the frontend engine expects the backend to be loaded with.
+ // If not loaded due to service worker unexpectedly killed, handler will call reload().
+ // TODO(webllm-team): should add appConfig here as well.
+ modelId: string;
+ chatOpts: ChatOptions;
+}
export interface CustomRequestParams {
requestName: string;
@@ -108,6 +119,7 @@ export type MessageContent =
| ChatCompletionStreamInitParams
| CompletionNonStreamingParams
| CompletionStreamInitParams
+ | EmbeddingParams
| CustomRequestParams
| InitProgressReport
| LogLevel
@@ -116,6 +128,7 @@ export type MessageContent =
| number
| ChatCompletion
| ChatCompletionChunk
+ | CreateEmbeddingResponse
| Completion
| AppConfig
| void;
diff --git a/src/openai_api_protocols/embedding.ts b/src/openai_api_protocols/embedding.ts
new file mode 100644
index 00000000..8edf20af
--- /dev/null
+++ b/src/openai_api_protocols/embedding.ts
@@ -0,0 +1,195 @@
+/**
+ * The input to OpenAI API, directly adopted from openai-node with small tweaks:
+ * https://github.com/openai/openai-node/blob/master/src/resources/embeddings.ts
+ *
+ * Copyright 2024 OpenAI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import {
+ EmbeddingInputEmptyError,
+ EmbeddingUnsupportedEncodingFormatError,
+ UnsupportedFieldsError,
+} from "../error";
+import { MLCEngineInterface } from "../types";
+
+export class Embeddings {
+ private engine: MLCEngineInterface;
+
+ constructor(engine: MLCEngineInterface) {
+ this.engine = engine;
+ }
+
+ /**
+ * Creates an embedding vector representing the input text.
+ */
+ create(request: EmbeddingCreateParams): Promise {
+ return this.engine.embedding(request);
+ }
+}
+
+export interface CreateEmbeddingResponse {
+ /**
+ * The list of embeddings generated by the model.
+ */
+ data: Array;
+
+ /**
+ * The name of the model used to generate the embedding.
+ */
+ model: string;
+
+ /**
+ * The object type, which is always "list".
+ */
+ object: "list";
+
+ /**
+ * The usage information for the request.
+ */
+ usage: CreateEmbeddingResponse.Usage;
+}
+
+/* eslint-disable @typescript-eslint/no-namespace */
+export namespace CreateEmbeddingResponse {
+ /**
+ * The usage information for the request.
+ */
+ export interface Usage {
+ /**
+ * The number of tokens used by the prompt.
+ */
+ prompt_tokens: number;
+
+ /**
+ * The total number of tokens used by the request.
+ */
+ total_tokens: number;
+
+ /**
+ * Fields specific to WebLLM, not present in OpenAI.
+ */
+ extra: {
+ /**
+ * Number of tokens per second for prefilling.
+ */
+ prefill_tokens_per_s: number;
+ };
+ }
+}
+
+/**
+ * Represents an embedding vector returned by embedding endpoint.
+ */
+export interface Embedding {
+ /**
+ * The embedding vector, which is a list of floats. The length of vector depends on
+ * the model.
+ */
+ embedding: Array;
+
+ /**
+ * The index of the embedding in the list of embeddings.
+ */
+ index: number;
+
+ /**
+ * The object type, which is always "embedding".
+ */
+ object: "embedding";
+}
+
+export interface EmbeddingCreateParams {
+ /**
+ * Input text to embed, encoded as a string or array of tokens. To embed multiple
+ * inputs in a single request, pass an array of strings or array of token arrays.
+ * The input must not exceed the max input tokens for the model, and cannot be an empty string.
+ * If the batch size is too large, multiple forward of the will take place.
+ */
+ input: string | Array | Array | Array>;
+
+ /**
+ * The format to return the embeddings in.
+ *
+ * @note Currently only support `float`.
+ */
+ encoding_format?: "float" | "base64";
+
+ /**
+ * ID of the model to use.
+ *
+ * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`.
+ */
+ model?: string;
+
+ // TODO: can support matryoshka embedding models in future, hence allow `dimensions` for those.
+ /**
+ * The number of dimensions the resulting output embeddings should have.
+ *
+ * @note Not supported.
+ */
+ dimensions?: number;
+
+ /**
+ * A unique identifier representing your end-user, which can help OpenAI to monitor
+ * and detect abuse.
+ *
+ * @note Not supported.
+ */
+ user?: string;
+}
+
+export const EmbeddingCreateParamsUnsupportedFields: Array = [
+ "model",
+ "dimensions",
+ "user",
+];
+
+export function postInitAndCheckFields(
+ request: EmbeddingCreateParams,
+ currentModelId: string,
+): void {
+ // 1. Check unsupported fields in request
+ const unsupported: Array = [];
+ EmbeddingCreateParamsUnsupportedFields.forEach((field) => {
+ if (field in request) {
+ unsupported.push(field);
+ }
+ });
+ if (unsupported.length > 0) {
+ throw new UnsupportedFieldsError(unsupported, "EmbeddingCreateParams");
+ }
+
+ // 2. Unsupported format
+ if (request.encoding_format == "base64") {
+ throw new EmbeddingUnsupportedEncodingFormatError();
+ }
+
+ // 3. Invalid input
+ const input = request.input;
+ if (typeof input === "string") {
+ if (input === "") throw new EmbeddingInputEmptyError();
+ } else {
+ // input instanceof Array
+ if (input.length === 0) {
+ // Array
+ throw new EmbeddingInputEmptyError();
+ }
+ for (let i = 0; i < input.length; i++) {
+ const curInput = input[i];
+ if (typeof curInput !== "number") {
+ // Array, Array>
+ if (curInput.length === 0) throw new EmbeddingInputEmptyError();
+ }
+ }
+ }
+}
diff --git a/src/openai_api_protocols/index.ts b/src/openai_api_protocols/index.ts
index f0d1fb4f..c91f751d 100644
--- a/src/openai_api_protocols/index.ts
+++ b/src/openai_api_protocols/index.ts
@@ -58,3 +58,11 @@ export {
CompletionChoice,
postInitAndCheckFields as postInitAndCheckFieldsCompletion,
} from "./completion";
+
+export {
+ Embeddings,
+ Embedding,
+ EmbeddingCreateParams,
+ CreateEmbeddingResponse,
+ postInitAndCheckFields as postInitAndCheckFieldsEmbedding,
+} from "./embedding";
diff --git a/src/support.ts b/src/support.ts
index 13e334a0..95b20584 100644
--- a/src/support.ts
+++ b/src/support.ts
@@ -1,11 +1,12 @@
/** Util methods. */
import { Tokenizer } from "@mlc-ai/web-tokenizers";
-import { MessagePlaceholders } from "./config";
+import { AppConfig, MessagePlaceholders } from "./config";
import {
ChatCompletionChunk,
ChatCompletionMessageToolCall,
} from "./openai_api_protocols/index";
import {
+ ModelNotFoundError,
ToolCallOutputInvalidTypeError,
ToolCallOutputMissingFieldsError,
ToolCallOutputParseError,
@@ -197,3 +198,11 @@ export function getToolCallFromOutputMessage(
return tool_calls_result;
}
}
+
+export function findModelRecord(modelId: string, appConfig: AppConfig) {
+ const matchedItem = appConfig.model_list.find(
+ (item) => item.model_id == modelId,
+ );
+ if (matchedItem !== undefined) return matchedItem;
+ throw new ModelNotFoundError(modelId);
+}
diff --git a/src/types.ts b/src/types.ts
index baaf7412..91cc185b 100644
--- a/src/types.ts
+++ b/src/types.ts
@@ -11,6 +11,8 @@ import {
CompletionCreateParamsBase,
CompletionCreateParamsStreaming,
CompletionCreateParamsNonStreaming,
+ EmbeddingCreateParams,
+ CreateEmbeddingResponse,
} from "./openai_api_protocols/index";
import * as API from "./openai_api_protocols/index";
@@ -57,7 +59,7 @@ export interface LogitProcessor {
processSampledToken: (token: number) => void;
/**
- * Called when in `ChatModule.resetChat()`. Can clear internal states.
+ * Called when in `MLCEngine.resetChat()`. Can clear internal states.
*/
resetState: () => void;
}
@@ -76,6 +78,11 @@ export interface MLCEngineInterface {
*/
completions: API.Completions;
+ /**
+ * An object that exposes embeddings APIs.
+ */
+ embeddings: API.Embeddings;
+
/**
* Set an initialization progress callback function
* which reports the progress of model loading.
@@ -173,6 +180,16 @@ export interface MLCEngineInterface {
request: CompletionCreateParams,
): Promise | Completion>;
+ /**
+ * OpenAI-style API. Creates an embedding vector representing the input text.
+ * Use `engine.embeddings.create()` to invoke this API.
+ *
+ * @param request An OpenAI-style Embeddings request.
+ *
+ * @note For more, see https://platform.openai.com/docs/api-reference/embeddings/create
+ */
+ embedding(request: EmbeddingCreateParams): Promise;
+
/**
* @returns A text summarizing the runtime stats.
* @note This is an async function
diff --git a/src/web_worker.ts b/src/web_worker.ts
index b89bcaf6..36626cad 100644
--- a/src/web_worker.ts
+++ b/src/web_worker.ts
@@ -24,6 +24,8 @@ import {
CompletionCreateParamsStreaming,
CompletionCreateParamsBase,
CompletionCreateParams,
+ CreateEmbeddingResponse,
+ EmbeddingCreateParams,
} from "./openai_api_protocols/index";
import * as API from "./openai_api_protocols/index";
import {
@@ -38,6 +40,7 @@ import {
WorkerResponse,
WorkerRequest,
CompletionNonStreamingParams,
+ EmbeddingParams,
CompletionStreamInitParams,
} from "./message";
import log from "loglevel";
@@ -187,7 +190,7 @@ export class WebWorkerMLCEngineHandler {
});
return;
}
- // For engine.chat.completions()
+ // For engine.chat.completions.create()
case "chatCompletionNonStreaming": {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
@@ -212,7 +215,7 @@ export class WebWorkerMLCEngineHandler {
});
return;
}
- // engine.completions()
+ // For engine.completions.create()
case "completionNonStreaming": {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
@@ -237,7 +240,7 @@ export class WebWorkerMLCEngineHandler {
});
return;
}
- // Shared by engine.chat.completions() and engine.completions()
+ // Shared by engine.chat.completions.create() and engine.completions.create()
case "completionStreamNextChunk": {
// Note: ChatCompletion and Completion share the same chunk generator.
// For any subsequent request, we return whatever `next()` yields
@@ -254,6 +257,18 @@ export class WebWorkerMLCEngineHandler {
});
return;
}
+ // For engine.embeddings.create()
+ case "embedding": {
+ // Directly return the Embeddings response
+ this.handleTask(msg.uuid, async () => {
+ const params = msg.content as EmbeddingParams;
+ await this.reloadIfUnmatched(params.modelId, params.chatOpts);
+ const res = await this.engine.embedding(params.request);
+ onComplete?.(res);
+ return res;
+ });
+ return;
+ }
case "runtimeStatsText": {
this.handleTask(msg.uuid, async () => {
const res = await this.engine.runtimeStatsText();
@@ -406,6 +421,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
public chat: API.Chat;
/** For completions.create() */
public completions: API.Completions;
+ /** For embeddings.create() */
+ public embeddings: API.Embeddings;
/**
* The modelId and chatOpts that the frontend expects the backend engine is currently loaded
@@ -445,6 +462,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
this.chat = new API.Chat(this);
this.completions = new API.Completions(this);
+ this.embeddings = new API.Embeddings(this);
}
setInitProgressCallback(initProgressCallback?: InitProgressCallback) {
@@ -741,6 +759,24 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return await this.getPromise(msg);
}
+ async embedding(
+ request: EmbeddingCreateParams,
+ ): Promise {
+ if (this.modelId === undefined) {
+ throw new WorkerEngineModelNotLoadedError(this.constructor.name);
+ }
+ const msg: WorkerRequest = {
+ kind: "embedding",
+ uuid: crypto.randomUUID(),
+ content: {
+ request: request,
+ modelId: this.modelId,
+ chatOpts: this.chatOpts,
+ },
+ };
+ return await this.getPromise(msg);
+ }
+
onmessage(event: any) {
let msg: WorkerResponse;
if (event instanceof MessageEvent) {
diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts
index 4b6cd395..96cd8d98 100644
--- a/tests/openai_chat_completion.test.ts
+++ b/tests/openai_chat_completion.test.ts
@@ -136,7 +136,7 @@ describe("Check chat completion unsupported requests", () => {
});
describe("Supported requests", () => {
- test("Supproted chat completion request", () => {
+ test("Supported chat completion request", () => {
const request: ChatCompletionRequest = {
messages: [
{ role: "system", content: "You are a helpful assistant." },
diff --git a/tests/openai_embeddings.test.ts b/tests/openai_embeddings.test.ts
new file mode 100644
index 00000000..dd704ad4
--- /dev/null
+++ b/tests/openai_embeddings.test.ts
@@ -0,0 +1,133 @@
+import {
+ EmbeddingInputEmptyError,
+ EmbeddingUnsupportedEncodingFormatError,
+} from "../src/error";
+import {
+ EmbeddingCreateParams,
+ postInitAndCheckFields,
+} from "../src/openai_api_protocols/embedding";
+import { describe, expect, test } from "@jest/globals";
+
+describe("Check embeddings supported requests", () => {
+ test("Supported embedding request float", () => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ encoding_format: "float",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ });
+
+ test("Supported embedding request, unspecified format", () => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ });
+
+ test("Supported embedding request, single string", () => {
+ const request: EmbeddingCreateParams = {
+ input: "Hello",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ });
+
+ test("Supported embedding request, single token array", () => {
+ const request: EmbeddingCreateParams = {
+ input: [0, 1],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ });
+
+ test("Supported embedding request, array of token arrays", () => {
+ const request: EmbeddingCreateParams = {
+ input: [
+ [0, 1],
+ [0, 1],
+ ],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ });
+});
+
+describe("Invalid embedding input", () => {
+ test("Empty string", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: "",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow(new EmbeddingInputEmptyError());
+ });
+
+ test("Contains empty string", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hi", "hello", ""],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow(new EmbeddingInputEmptyError());
+ });
+
+ test("Empty token array", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: [],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow(new EmbeddingInputEmptyError());
+ });
+
+ test("Contains empty token array", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: [[1, 2], [3], [], [4]],
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow(new EmbeddingInputEmptyError());
+ });
+});
+
+describe("Check embeddings unsupported requests", () => {
+ test("base64 encoding_format", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ encoding_format: "base64",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow(new EmbeddingUnsupportedEncodingFormatError());
+ });
+
+ test("model", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ encoding_format: "float",
+ model: "snowflake-arctic-embed-m-q0f32-MLC",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow("The following fields in");
+ });
+
+ test("user", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ encoding_format: "float",
+ user: "Bob",
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow("The following fields in");
+ });
+
+ test("dimensions", () => {
+ expect(() => {
+ const request: EmbeddingCreateParams = {
+ input: ["Hello", "Hi"],
+ encoding_format: "float",
+ dimensions: 2048,
+ };
+ postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
+ }).toThrow("The following fields in");
+ });
+});