From ef27ebce57c7371a62f7629d64cfae8f233cf435 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 12 Aug 2024 02:22:59 -0400 Subject: [PATCH] [Embeddings][OpenAI] Support embeddings via engine.embeddings.create() --- examples/README.md | 1 + examples/embeddings/README.md | 14 ++ examples/embeddings/package.json | 21 ++ examples/embeddings/src/embeddings.html | 23 ++ examples/embeddings/src/embeddings.ts | 147 ++++++++++++ src/config.ts | 55 ++++- src/embedding.ts | 290 ++++++++++++++++++++++++ src/engine.ts | 107 +++++++-- src/error.ts | 62 ++++- src/message.ts | 13 ++ src/openai_api_protocols/embedding.ts | 195 ++++++++++++++++ src/openai_api_protocols/index.ts | 8 + src/support.ts | 11 +- src/types.ts | 19 +- src/web_worker.ts | 42 +++- tests/openai_chat_completion.test.ts | 2 +- tests/openai_embeddings.test.ts | 133 +++++++++++ 17 files changed, 1112 insertions(+), 31 deletions(-) create mode 100644 examples/embeddings/README.md create mode 100644 examples/embeddings/package.json create mode 100644 examples/embeddings/src/embeddings.html create mode 100644 examples/embeddings/src/embeddings.ts create mode 100644 src/embedding.ts create mode 100644 src/openai_api_protocols/embedding.ts create mode 100644 tests/openai_embeddings.test.ts diff --git a/examples/README.md b/examples/README.md index d0ffefd9..ee3b16f6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,6 +24,7 @@ Note that all examples below run in-browser and use WebGPU as a backend. - [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/). - [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache - [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()` +- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of Langchain.js #### Advanced OpenAI API Capabilities diff --git a/examples/embeddings/README.md b/examples/embeddings/README.md new file mode 100644 index 00000000..7450aad8 --- /dev/null +++ b/examples/embeddings/README.md @@ -0,0 +1,14 @@ +# WebLLM Get Started App + +This folder provides a minimum demo to show WebLLM API in a webapp setting. +To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/embeddings/package.json b/examples/embeddings/package.json new file mode 100644 index 00000000..b363dc21 --- /dev/null +++ b/examples/embeddings/package.json @@ -0,0 +1,21 @@ +{ + "name": "embeddings-example", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/embeddings.html --port 8885", + "build": "parcel build src/embeddings.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "file:../..", + "langchain": "0.2.15" + } +} diff --git a/examples/embeddings/src/embeddings.html b/examples/embeddings/src/embeddings.html new file mode 100644 index 00000000..484ee7c3 --- /dev/null +++ b/examples/embeddings/src/embeddings.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/embeddings/src/embeddings.ts b/examples/embeddings/src/embeddings.ts new file mode 100644 index 00000000..b8ff521c --- /dev/null +++ b/examples/embeddings/src/embeddings.ts @@ -0,0 +1,147 @@ +import * as webllm from "@mlc-ai/web-llm"; +import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import type { EmbeddingsInterface } from "@langchain/core/embeddings"; +import type { Document } from "@langchain/core/documents"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); +}; + +// For integration with Langchain +class WebLLMEmbeddings implements EmbeddingsInterface { + engine: webllm.MLCEngineInterface; + constructor(engine: webllm.MLCEngineInterface) { + this.engine = engine; + } + + async _embed(texts: string[]): Promise { + const reply = await this.engine.embeddings.create({ input: texts }); + const result: number[][] = []; + for (let i = 0; i < texts.length; i++) { + result.push(reply.data[i].embedding); + } + return result; + } + + async embedQuery(document: string): Promise { + return this._embed([document]).then((embeddings) => embeddings[0]); + } + + async embedDocuments(documents: string[]): Promise { + return this._embed(documents); + } +} + +// Prepare inputs +const documents_og = ["The Data Cloud!", "Mexico City of Course!"]; +const queries_og = ["what is snowflake?", "Where can I get the best tacos?"]; +const documents: string[] = []; +const queries: string[] = []; +const query_prefix = + "Represent this sentence for searching relevant passages: "; +// Process according to Snowflake model +documents_og.forEach(function (item, index) { + documents[index] = `[CLS] ${item} [SEP]`; +}); +queries_og.forEach(function (item, index) { + queries[index] = `[CLS] ${query_prefix}${item} [SEP]`; +}); +console.log("Formatted documents: ", documents); +console.log("Formatted queries: ", queries); + +// Using webllm's API +async function webllmAPI() { + // b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a + // batch. If given more than 4, the model will forward multiple times. The larger the max batch + // size, the more memory it consumes. + // const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32"; + const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", // specify the log level + }, + ); + + const docReply = await engine.embeddings.create({ input: documents }); + console.log(docReply); + console.log(docReply.usage); + + const queryReply = await engine.embeddings.create({ input: queries }); + console.log(queryReply); + console.log(queryReply.usage); + + // Calculate similarity (we use langchain here, but any method works) + const vectorStore = await MemoryVectorStore.fromExistingIndex( + new WebLLMEmbeddings(engine), + ); + // See score + for (let i = 0; i < queries_og.length; i++) { + console.log(`Similarity with: ${queries_og[i]}`); + for (let j = 0; j < documents_og.length; j++) { + const similarity = vectorStore.similarity( + queryReply.data[i].embedding, + docReply.data[j].embedding, + ); + console.log(`${documents_og[j]}: ${similarity}`); + } + } +} + +// Alternatively, integrating with Langchain's API +async function langchainAPI() { + // b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a + // batch. If given more than 4, the model will forward multiple times. The larger the max batch + // size, the more memory it consumes. + // const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32"; + const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", // specify the log level + }, + ); + + const vectorStore = await MemoryVectorStore.fromExistingIndex( + new WebLLMEmbeddings(engine), + ); + const document0: Document = { + pageContent: documents[0], + metadata: {}, + }; + const document1: Document = { + pageContent: documents[1], + metadata: {}, + }; + await vectorStore.addDocuments([document0, document1]); + + const similaritySearchResults0 = await vectorStore.similaritySearch( + queries[0], + 1, + ); + for (const doc of similaritySearchResults0) { + console.log(`* ${doc.pageContent}`); + } + + const similaritySearchResults1 = await vectorStore.similaritySearch( + queries[1], + 1, + ); + for (const doc of similaritySearchResults1) { + console.log(`* ${doc.pageContent}`); + } +} + +// Select one to run +webllmAPI(); +// langchainAPI(); diff --git a/src/config.ts b/src/config.ts index 5aae7693..6c81febf 100644 --- a/src/config.ts +++ b/src/config.ts @@ -68,10 +68,10 @@ export interface TokenizerInfo { * Only these fields affect the conversation in runtime. * i.e. The third part in https://llm.mlc.ai/docs/get_started/mlc_chat_config.html. * - * This is initialized in `ChatModule.reload()` with the model's `mlc-chat-config.json`. + * This is initialized in `MLCEngine.reload()` with the model's `mlc-chat-config.json`. */ export interface ChatConfig { - // First three fields affect the entire conversation, i.e. used in `ChatModule.reload()` + // First three fields affect the entire conversation, i.e. used in `MLCEngine.reload()` tokenizer_files: Array; tokenizer_info?: TokenizerInfo; token_table_postproc_method?: string; // TODO: backward compatibility, remove soon @@ -122,7 +122,7 @@ export interface MLCEngineConfig { * We also support additional fields not present in `mlc-chat-config.json` due to OpenAI-like APIs. * * Note that all values are optional. If unspecified, we use whatever values in `ChatConfig` - * initialized during `ChatModule.reload()`. + * initialized during `MLCEngine.reload()`. */ export interface GenerationConfig { // Only used in MLC @@ -226,6 +226,11 @@ export function postInitAndCheckGenerationConfigValues( } } +export enum ModelType { + "LLM", + "embedding", +} + /** * Information for a model. * @param model: the huggingface link to download the model weights, accepting four formats: @@ -241,6 +246,7 @@ export function postInitAndCheckGenerationConfigValues( * @param low_resource_required: whether the model can run on limited devices (e.g. Android phone). * @param buffer_size_required_bytes: required `maxStorageBufferBindingSize`, different for each device. * @param required_features: feature needed to run this model (e.g. shader-f16). + * @param model_type: the intended usecase for the model, if unspecified, default to LLM. */ export interface ModelRecord { model: string; @@ -251,6 +257,7 @@ export interface ModelRecord { low_resource_required?: boolean; buffer_size_required_bytes?: number; required_features?: Array; + model_type?: ModelType; } /** @@ -1514,5 +1521,47 @@ export const prebuiltAppConfig: AppConfig = { context_window_size: 1024, }, }, + // Embedding models + // -b means max_batch_size this model allows. The smaller it is, the less memory the model consumes. + { + model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC", + model_id: "snowflake-arctic-embed-m-q0f32-MLC-b32", + model_lib: + modelLibURLPrefix + + modelVersion + + "/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch32-webgpu.wasm", + vram_required_MB: 1407.51, + model_type: ModelType.embedding, + }, + { + model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC", + model_id: "snowflake-arctic-embed-m-q0f32-MLC-b4", + model_lib: + modelLibURLPrefix + + modelVersion + + "/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch4-webgpu.wasm", + vram_required_MB: 539.4, + model_type: ModelType.embedding, + }, + { + model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC", + model_id: "snowflake-arctic-embed-s-q0f32-MLC-b32", + model_lib: + modelLibURLPrefix + + modelVersion + + "/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch32-webgpu.wasm", + vram_required_MB: 1022.82, + model_type: ModelType.embedding, + }, + { + model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC", + model_id: "snowflake-arctic-embed-s-q0f32-MLC-b4", + model_lib: + modelLibURLPrefix + + modelVersion + + "/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch4-webgpu.wasm", + vram_required_MB: 238.71, + model_type: ModelType.embedding, + }, ], }; diff --git a/src/embedding.ts b/src/embedding.ts new file mode 100644 index 00000000..c620fb26 --- /dev/null +++ b/src/embedding.ts @@ -0,0 +1,290 @@ +import * as tvmjs from "tvmjs"; +import log from "loglevel"; +import { Tokenizer } from "@mlc-ai/web-tokenizers"; +import { ChatConfig } from "./config"; +import { + EmbeddingChunkingUnsupportedError, + EmbeddingExceedContextWindowSizeError, + EmbeddingInputEmptyError, + EmbeddingSlidingWindowError, + MinValueError, +} from "./error"; + +export class EmbeddingPipeline { + private config: ChatConfig; + private tokenizer: Tokenizer; + + // TVM functions + private tvm: tvmjs.Instance; + private device: tvmjs.DLDevice; + private vm: tvmjs.VirtualMachine; + private prefill: tvmjs.PackedFunc; + private params: tvmjs.TVMObject; + + // metadata + private contextWindowSize = -1; + private prefillChunkSize = -1; + private maxBatchSize = -1; + + // performance + private curRoundEmbedTotalTokens = 0; // excludes padded tokens for batching + private curRoundEmbedTotalTime = 0; + + constructor(tvm: tvmjs.Instance, tokenizer: Tokenizer, config: ChatConfig) { + // 0. Setting attributes + this.tvm = tvm; + this.tokenizer = tokenizer; + this.config = config; + this.device = this.tvm.webgpu(); + + // 1. Create VM and get the core functions + tvm.beginScope(); + this.vm = this.tvm.detachFromCurrentScope( + this.tvm.createVirtualMachine(this.device), + ); + this.prefill = this.tvm.detachFromCurrentScope( + this.vm.getFunction("prefill"), + ); + + // 2. Get json stored in the vm's metadata function + const fgetMetadata = this.vm.getFunction("_metadata"); + const ret_value = fgetMetadata(); + const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString(); + const metadata = JSON.parse(metadataStr); + + // 3. Load parameters by name + const paramNames: string[] = []; + metadata.params.forEach((param: any) => { + paramNames.push(param.name); + }); + this.params = this.tvm.detachFromCurrentScope( + this.tvm.getParamsFromCacheByName(paramNames), + ); + + // 4. Read in compilation configurations from metadata + // We use context window size max batch size to check validity of the model + // We assume prefillChunkSize is the same as contextWindowSize for embedding model for now + this.maxBatchSize = metadata.max_batch_size; + this.contextWindowSize = this.config.context_window_size; + this.prefillChunkSize = metadata.prefill_chunk_size; + log.info("Using maxBatchSize: ", this.maxBatchSize); + log.info("Using contextWindowSize: ", this.contextWindowSize); + log.info("Using prefillChunkSize: ", this.prefillChunkSize); + + if (this.config.sliding_window_size !== -1) { + throw new EmbeddingSlidingWindowError(this.config.sliding_window_size); + } + if (this.maxBatchSize <= 0) { + throw new MinValueError("maxBatchSize", 0); + } + if (this.contextWindowSize <= 0) { + throw new MinValueError("contextWindowSize", 0); + } + if (this.prefillChunkSize <= 0) { + throw new MinValueError("prefillChunkSize", 0); + } + if (this.prefillChunkSize !== this.contextWindowSize) { + throw new EmbeddingChunkingUnsupportedError( + this.contextWindowSize, + this.prefillChunkSize, + ); + } + tvm.endScope(); + } + + async embedStep( + input: string | Array | Array | Array>, + ): Promise>> { + // 0. Reset performance metrics + this.curRoundEmbedTotalTokens = 0; + this.curRoundEmbedTotalTime = 0; + let totalNumTokens = 0; + const embedStart = performance.now(); + let tokenizedInputs: Array> = []; + const tempInputs: Array = []; + // 1. Convert all possible input types to Array>, tokenize if not already + // Cannot use input.every to match type, which leads to TS compilation error + // https://github.com/microsoft/TypeScript/issues/33591 + if (input.length === 0) { + throw new EmbeddingInputEmptyError(); + } + if (typeof input === "string") { + // string + tokenizedInputs = [Array.from(this.tokenizer.encode(input))]; + } else { + for (let i = 0; i < input.length; i++) { + const curInput = input[i]; + if (Array.isArray(curInput)) { + // Array> + tokenizedInputs.push(curInput); + } else if (typeof curInput === "string") { + // Array + tokenizedInputs.push(Array.from(this.tokenizer.encode(curInput))); + } else { + // Array + tempInputs.push(curInput); + } + } + } + if (tempInputs.length > 0) { + tokenizedInputs.push(tempInputs); + } + + // 2. Check each input is not larger than the context window size + // TODO: tokenizer.encode seems to implicitly truncates to contextWindowSize, confirm behavior + // and decide whether to warn user + for (let i = 0; i < tokenizedInputs.length; i++) { + const curInputSize = tokenizedInputs[i].length; + totalNumTokens += curInputSize; + if (curInputSize > this.contextWindowSize) { + throw new EmbeddingExceedContextWindowSizeError( + this.contextWindowSize, + curInputSize, + ); + } + } + if (tokenizedInputs.length === 0) { + throw new Error("InternalError: batch size is zero."); + } + + // 3. Forward each batch + const batchSize = tokenizedInputs.length; + const result: Array> = []; + for (let begin = 0; begin < batchSize; begin += this.maxBatchSize) { + this.tvm.beginScope(); + // 3.1 Get current batch + const end = Math.min(batchSize, begin + this.maxBatchSize); + const curBatch: Array> = tokenizedInputs.slice(begin, end); + const curBatchSize = curBatch.length; + // 3.2 Max input size of current batch + let maxInputSize = 0; + for (let i = 0; i < curBatchSize; i++) { + const curInputSize = curBatch[i].length; + if (curInputSize > maxInputSize) { + maxInputSize = curInputSize; + } + } + // 3.3 Create inputs and attention mask + // Padded with zeros and flattened, of size curBatchSize * maxInputSize + const curBatchPaddedFlatten: Array = []; + // 1 for non-pad, 0 otherwise, also of size curBatchSize * maxInputSize + const curAttnMask: Array = []; + const flattenedInputSize = curBatchSize * maxInputSize; + for (let i = 0; i < curBatchSize; i++) { + const padding = Array(maxInputSize - curBatch[i].length).fill(0); + const ones = Array(curBatch[i].length).fill(1); + curBatchPaddedFlatten.push(...curBatch[i]); + curAttnMask.push(...ones); + curBatchPaddedFlatten.push(...padding); + curAttnMask.push(...padding); + } + if ( + curBatchPaddedFlatten.length !== flattenedInputSize || + curAttnMask.length !== flattenedInputSize + ) { + throw new Error( + `InternalError: Expect input array to be ${flattenedInputSize}, ` + + `but got ${curBatchPaddedFlatten.length}`, + ); + } + // 3.4 Convert inputs and attention mask to tvm ndarray on GPU, of shape (curBatchSize, maxInputSize) + let inputNDArray = this.tvm.empty( + [flattenedInputSize], + "int32", + this.device, + ); + inputNDArray.copyFrom(curBatchPaddedFlatten); + inputNDArray = inputNDArray.view([curBatchSize, maxInputSize]); + let maskNDArray = this.tvm.empty( + [flattenedInputSize], + "int32", + this.device, + ); + maskNDArray.copyFrom(curAttnMask); + maskNDArray = maskNDArray.view([curBatchSize, maxInputSize]); + + // 3.5 Actual forwarding on GPU, logits of shape (curBatchSize, maxInputSize, hidden_size) + const logitsCurBatchOnGPU: tvmjs.NDArray = this.prefill( + inputNDArray, + maskNDArray, + this.params, + ); + await this.device.sync(); + + // 3.6 Copy logits to CPU, flatten to curBatchSize * maxInputSize * hidden_size + const hidden_size = logitsCurBatchOnGPU.shape[2]; + let logitsCurBatchOnCPU: tvmjs.NDArray = this.tvm.empty( + logitsCurBatchOnGPU.shape, + logitsCurBatchOnGPU.dtype, + this.tvm.cpu(), + ); + logitsCurBatchOnCPU.copyFrom(logitsCurBatchOnGPU); + logitsCurBatchOnCPU = logitsCurBatchOnCPU.view([ + curBatchSize * maxInputSize * hidden_size, + ]); + await this.device.sync(); + const logitsCurBatchOnCPUArray: Float32Array = ( + logitsCurBatchOnCPU.toArray() + ); + + // 3.7 Update final result. For each sentence, get [0,:], i.e. only the first token's output + // That is, we are doing result.push(logits[:,0,:]) here. + // TODO: check if all models only use [0,:]. If it is snowflake-specific, need to specify + // this in mlc-chat-config.json + for (let i = 0; i < curBatchSize; i++) { + const b = i * maxInputSize * hidden_size; + const e = b + hidden_size; + result.push(Array.from(logitsCurBatchOnCPUArray.slice(b, e))); + } + this.tvm.endScope(); + } + if (result.length !== batchSize) { + throw new Error(` + InternalError: expect result.length to be ${batchSize}, but got ${result.length}`); + } + const embedEnd = performance.now(); + this.curRoundEmbedTotalTokens = totalNumTokens; + this.curRoundEmbedTotalTime = (embedEnd - embedStart) / 1e3; + + return result; + } + + dispose() { + this.params.dispose(); + this.prefill.dispose(); + this.vm.dispose(); + this.tvm.dispose(); + this.tokenizer.dispose(); + } + + /** + * Synchronize the device. + */ + async sync(): Promise { + // Is it equivalent to this.tvm.sync()? + await this.device.sync(); + } + + // Performance APIs below + + /** + * Get the time it took the last `embedStep()` in seconds. + */ + getCurRoundEmbedTotalTime(): number { + return this.curRoundEmbedTotalTime; + } + + /** + * Get the number of tokens embedded in the last `embedStep()`. This excludes the padded tokens. + */ + getCurRoundEmbedTotalTokens(): number { + return this.curRoundEmbedTotalTokens; + } + + /** + * @returns Prefill tokens per second, starting from the last prefill performed. + */ + getCurRoundEmbedTokensPerSec(): number { + return this.curRoundEmbedTotalTokens / this.curRoundEmbedTotalTime; + } +} diff --git a/src/engine.ts b/src/engine.ts index fceb7232..342cc3cf 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -10,6 +10,7 @@ import { Role, MLCEngineConfig, DefaultLogLevel, + ModelType, } from "./config"; import { LLMChatPipeline } from "./llm_chat"; import { @@ -31,6 +32,9 @@ import { CompletionCreateParams, Completion, CompletionChoice, + EmbeddingCreateParams, + CreateEmbeddingResponse, + Embedding, } from "./openai_api_protocols/index"; import * as API from "./openai_api_protocols/index"; import { @@ -45,19 +49,24 @@ import { getConversation, getConversationFromChatCompletionRequest, } from "./conversation"; -import { cleanModelUrl, getToolCallFromOutputMessage } from "./support"; import { - ChatModuleNotInitializedError, + cleanModelUrl, + findModelRecord, + getToolCallFromOutputMessage, +} from "./support"; +import { + EngineNotLoadedError, ConfigurationNotInitializedError, DeviceLostError, + EmbeddingUnsupportedModelError, FeatureSupportError, MissingModelWasmError, - ModelNotFoundError, ModelNotLoadedError, ShaderF16SupportError, WebGPUNotAvailableError, } from "./error"; import { asyncLoadTokenizer } from "./cache_util"; +import { EmbeddingPipeline } from "./embedding"; /** * Creates `MLCEngine`, and loads `modelId` onto WebGPU. @@ -93,12 +102,15 @@ export class MLCEngine implements MLCEngineInterface { public chat: API.Chat; /** For completions.create() */ public completions: API.Completions; + /** For embeddings.create() */ + public embeddings: API.Embeddings; private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded private logger: (msg: string) => void = log.info; private logitProcessorRegistry?: Map; private logitProcessor?: LogitProcessor; private pipeline?: LLMChatPipeline; + private embeddingPipeline?: EmbeddingPipeline; private initProgressCallback?: InitProgressCallback; private interruptSignal = false; private deviceLostIsError = true; // whether device.lost is due to actual error or model reload @@ -114,6 +126,7 @@ export class MLCEngine implements MLCEngineInterface { this.chat = new API.Chat(this); this.completions = new API.Completions(this); + this.embeddings = new API.Embeddings(this); } //----------------------- @@ -174,15 +187,7 @@ export class MLCEngine implements MLCEngineInterface { this.logitProcessor = this.logitProcessorRegistry?.get(modelId); const tstart = performance.now(); - const findModelRecord = () => { - const matchedItem = this.appConfig?.model_list.find( - (item) => item.model_id == modelId, - ); - if (matchedItem !== undefined) return matchedItem; - throw new ModelNotFoundError(modelId); - }; - - const modelRecord = findModelRecord(); + const modelRecord = findModelRecord(modelId, this.appConfig); const baseUrl = typeof document !== "undefined" ? document.URL @@ -305,12 +310,20 @@ export class MLCEngine implements MLCEngineInterface { cacheType, this.reloadController?.signal, ); - this.pipeline = new LLMChatPipeline( - tvm, - tokenizer, - this.config, - this.logitProcessor, - ); + if (modelRecord.model_type === ModelType.embedding) { + this.embeddingPipeline = new EmbeddingPipeline( + tvm, + tokenizer, + this.config, + ); + } else { + this.pipeline = new LLMChatPipeline( + tvm, + tokenizer, + this.config, + this.logitProcessor, + ); + } await this.pipeline?.asyncLoadWebGPUPipelines(); const tend = performance.now(); @@ -337,9 +350,12 @@ export class MLCEngine implements MLCEngineInterface { async unload() { this.deviceLostIsError = false; // so that unload() does not trigger device.lost error this.pipeline?.dispose(); + this.embeddingPipeline?.dispose(); // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true await this.pipeline?.sync(); + await this.embeddingPipeline?.sync(); this.pipeline = undefined; + this.embeddingPipeline = undefined; this.currentModelId = undefined; this.deviceLostIsError = true; if (this.reloadController) { @@ -880,6 +896,52 @@ export class MLCEngine implements MLCEngineInterface { return response; } + async embedding( + request: EmbeddingCreateParams, + ): Promise { + // 0. Preprocess inputs + if (!this.currentModelId) { + throw new ModelNotLoadedError(); + } + if ( + findModelRecord(this.currentModelId, this.appConfig).model_type !== + ModelType.embedding + ) { + throw new EmbeddingUnsupportedModelError(this.currentModelId); + } + API.postInitAndCheckFieldsEmbedding(request, this.currentModelId); + + // 1. Call EmbeddingPipeline to get embeddings + const embedResult: Array> = + await this.getEmbeddingPipeline().embedStep(request.input); + + // 2. Prepare response + const batchSize = embedResult.length; + const data: Array = []; + for (let i = 0; i < batchSize; i++) { + const curEmbedding: Embedding = { + embedding: embedResult[i], + index: i, + object: "embedding", + }; + data.push(curEmbedding); + } + return { + data: data, + model: this.currentModelId, + object: "list", + usage: { + prompt_tokens: + this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(), + total_tokens: this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(), + extra: { + prefill_tokens_per_s: + this.getEmbeddingPipeline().getCurRoundEmbedTokensPerSec(), + }, + }, + }; + } + //----------------------------- // 4. WebGPU info-querying helpers //----------------------------- @@ -927,11 +989,18 @@ export class MLCEngine implements MLCEngineInterface { //---------------------------------------------- private getPipeline(): LLMChatPipeline { if (this.pipeline === undefined) { - throw new ChatModuleNotInitializedError(); + throw new EngineNotLoadedError(); } return this.pipeline; } + private getEmbeddingPipeline(): EmbeddingPipeline { + if (this.embeddingPipeline === undefined) { + throw new EngineNotLoadedError(); + } + return this.embeddingPipeline; + } + async forwardTokensAndSample( inputIds: Array, isPrefill: boolean, diff --git a/src/error.ts b/src/error.ts index af104544..ef672892 100644 --- a/src/error.ts +++ b/src/error.ts @@ -254,12 +254,12 @@ export class UnsupportedToolTypeError extends Error { this.name = "UnsupportedToolTypeError"; } } -export class ChatModuleNotInitializedError extends Error { +export class EngineNotLoadedError extends Error { constructor() { super( - "Chat module not yet initialized. Ensure you initialize the chat module by calling `chat.reload()` first.", + "Engine not yet loaded with model. Ensure you initialize the chat module by calling `engine.reload()` first.", ); - this.name = "ChatModuleNotInitializedError"; + this.name = "EngineNotLoadedError"; } } export class UnsupportedTokenizerFilesError extends Error { @@ -423,3 +423,59 @@ export class TextCompletionConversationError extends Error { this.name = "TextCompletionConversationError"; } } + +export class EmbeddingUnsupportedEncodingFormatError extends Error { + constructor() { + super("Embedding in base64 format is currently not supported."); + this.name = "EmbeddingUnsupportedEncodingFormatError"; + } +} + +export class EmbeddingUnsupportedModelError extends Error { + constructor(currentModel: string) { + super( + `Trying to run embeddings.create() with ${currentModel}, which does not have ` + + `ModelRecord.model_type === ModelType.embedding in the model record. ` + + `Either make sure an embedding model is loaded, or specify the model type in ModelRecord.`, + ); + this.name = "EmbeddingUnsupportedModelError"; + } +} + +export class EmbeddingSlidingWindowError extends Error { + constructor(sliding_window_size: number) { + super( + `Embedding should not use sliding window. However, ` + + `sliding_window_size=${sliding_window_size} is specified in the chat config.`, + ); + this.name = "EmbeddingSlidingWindowError"; + } +} + +export class EmbeddingChunkingUnsupportedError extends Error { + constructor(contextWindowSize: number, prefillChunkSize: number) { + super( + `Embedding currently does not support chunking. Make sure ` + + `contextWindowSize === prefillChunkSize. Got contextWindowSize=${contextWindowSize}, ` + + `prefillChunkSize=${prefillChunkSize} instead.`, + ); + this.name = "EmbeddingChunkingUnsupportedError"; + } +} + +export class EmbeddingExceedContextWindowSizeError extends Error { + constructor(contextWindowSize: number, receivedSize: number) { + super( + `The embedding model you are using only supports up to ${contextWindowSize} context size.` + + `However, an input in the batch has size ${receivedSize}.`, + ); + this.name = "EmbeddingExceedContextWindowSizeError"; + } +} + +export class EmbeddingInputEmptyError extends Error { + constructor() { + super("Embedding input cannot be empty string or empty token array."); + this.name = "EmbeddingInputEmptyError"; + } +} diff --git a/src/message.ts b/src/message.ts index 3b63e04f..214ef2f4 100644 --- a/src/message.ts +++ b/src/message.ts @@ -8,6 +8,8 @@ import { CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming, Completion, + EmbeddingCreateParams, + CreateEmbeddingResponse, } from "./openai_api_protocols/index"; /** @@ -25,6 +27,7 @@ type RequestKind = | "forwardTokensAndSample" | "chatCompletionNonStreaming" | "completionNonStreaming" + | "embedding" | "getMessage" | "chatCompletionStreamInit" | "completionStreamInit" @@ -93,6 +96,14 @@ export interface CompletionStreamInitParams { modelId: string; chatOpts: ChatOptions; } +export interface EmbeddingParams { + request: EmbeddingCreateParams; + // The model and chatOpts that the frontend engine expects the backend to be loaded with. + // If not loaded due to service worker unexpectedly killed, handler will call reload(). + // TODO(webllm-team): should add appConfig here as well. + modelId: string; + chatOpts: ChatOptions; +} export interface CustomRequestParams { requestName: string; @@ -108,6 +119,7 @@ export type MessageContent = | ChatCompletionStreamInitParams | CompletionNonStreamingParams | CompletionStreamInitParams + | EmbeddingParams | CustomRequestParams | InitProgressReport | LogLevel @@ -116,6 +128,7 @@ export type MessageContent = | number | ChatCompletion | ChatCompletionChunk + | CreateEmbeddingResponse | Completion | AppConfig | void; diff --git a/src/openai_api_protocols/embedding.ts b/src/openai_api_protocols/embedding.ts new file mode 100644 index 00000000..8edf20af --- /dev/null +++ b/src/openai_api_protocols/embedding.ts @@ -0,0 +1,195 @@ +/** + * The input to OpenAI API, directly adopted from openai-node with small tweaks: + * https://github.com/openai/openai-node/blob/master/src/resources/embeddings.ts + * + * Copyright 2024 OpenAI + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + EmbeddingInputEmptyError, + EmbeddingUnsupportedEncodingFormatError, + UnsupportedFieldsError, +} from "../error"; +import { MLCEngineInterface } from "../types"; + +export class Embeddings { + private engine: MLCEngineInterface; + + constructor(engine: MLCEngineInterface) { + this.engine = engine; + } + + /** + * Creates an embedding vector representing the input text. + */ + create(request: EmbeddingCreateParams): Promise { + return this.engine.embedding(request); + } +} + +export interface CreateEmbeddingResponse { + /** + * The list of embeddings generated by the model. + */ + data: Array; + + /** + * The name of the model used to generate the embedding. + */ + model: string; + + /** + * The object type, which is always "list". + */ + object: "list"; + + /** + * The usage information for the request. + */ + usage: CreateEmbeddingResponse.Usage; +} + +/* eslint-disable @typescript-eslint/no-namespace */ +export namespace CreateEmbeddingResponse { + /** + * The usage information for the request. + */ + export interface Usage { + /** + * The number of tokens used by the prompt. + */ + prompt_tokens: number; + + /** + * The total number of tokens used by the request. + */ + total_tokens: number; + + /** + * Fields specific to WebLLM, not present in OpenAI. + */ + extra: { + /** + * Number of tokens per second for prefilling. + */ + prefill_tokens_per_s: number; + }; + } +} + +/** + * Represents an embedding vector returned by embedding endpoint. + */ +export interface Embedding { + /** + * The embedding vector, which is a list of floats. The length of vector depends on + * the model. + */ + embedding: Array; + + /** + * The index of the embedding in the list of embeddings. + */ + index: number; + + /** + * The object type, which is always "embedding". + */ + object: "embedding"; +} + +export interface EmbeddingCreateParams { + /** + * Input text to embed, encoded as a string or array of tokens. To embed multiple + * inputs in a single request, pass an array of strings or array of token arrays. + * The input must not exceed the max input tokens for the model, and cannot be an empty string. + * If the batch size is too large, multiple forward of the will take place. + */ + input: string | Array | Array | Array>; + + /** + * The format to return the embeddings in. + * + * @note Currently only support `float`. + */ + encoding_format?: "float" | "base64"; + + /** + * ID of the model to use. + * + * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`. + */ + model?: string; + + // TODO: can support matryoshka embedding models in future, hence allow `dimensions` for those. + /** + * The number of dimensions the resulting output embeddings should have. + * + * @note Not supported. + */ + dimensions?: number; + + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor + * and detect abuse. + * + * @note Not supported. + */ + user?: string; +} + +export const EmbeddingCreateParamsUnsupportedFields: Array = [ + "model", + "dimensions", + "user", +]; + +export function postInitAndCheckFields( + request: EmbeddingCreateParams, + currentModelId: string, +): void { + // 1. Check unsupported fields in request + const unsupported: Array = []; + EmbeddingCreateParamsUnsupportedFields.forEach((field) => { + if (field in request) { + unsupported.push(field); + } + }); + if (unsupported.length > 0) { + throw new UnsupportedFieldsError(unsupported, "EmbeddingCreateParams"); + } + + // 2. Unsupported format + if (request.encoding_format == "base64") { + throw new EmbeddingUnsupportedEncodingFormatError(); + } + + // 3. Invalid input + const input = request.input; + if (typeof input === "string") { + if (input === "") throw new EmbeddingInputEmptyError(); + } else { + // input instanceof Array + if (input.length === 0) { + // Array + throw new EmbeddingInputEmptyError(); + } + for (let i = 0; i < input.length; i++) { + const curInput = input[i]; + if (typeof curInput !== "number") { + // Array, Array> + if (curInput.length === 0) throw new EmbeddingInputEmptyError(); + } + } + } +} diff --git a/src/openai_api_protocols/index.ts b/src/openai_api_protocols/index.ts index f0d1fb4f..c91f751d 100644 --- a/src/openai_api_protocols/index.ts +++ b/src/openai_api_protocols/index.ts @@ -58,3 +58,11 @@ export { CompletionChoice, postInitAndCheckFields as postInitAndCheckFieldsCompletion, } from "./completion"; + +export { + Embeddings, + Embedding, + EmbeddingCreateParams, + CreateEmbeddingResponse, + postInitAndCheckFields as postInitAndCheckFieldsEmbedding, +} from "./embedding"; diff --git a/src/support.ts b/src/support.ts index 13e334a0..95b20584 100644 --- a/src/support.ts +++ b/src/support.ts @@ -1,11 +1,12 @@ /** Util methods. */ import { Tokenizer } from "@mlc-ai/web-tokenizers"; -import { MessagePlaceholders } from "./config"; +import { AppConfig, MessagePlaceholders } from "./config"; import { ChatCompletionChunk, ChatCompletionMessageToolCall, } from "./openai_api_protocols/index"; import { + ModelNotFoundError, ToolCallOutputInvalidTypeError, ToolCallOutputMissingFieldsError, ToolCallOutputParseError, @@ -197,3 +198,11 @@ export function getToolCallFromOutputMessage( return tool_calls_result; } } + +export function findModelRecord(modelId: string, appConfig: AppConfig) { + const matchedItem = appConfig.model_list.find( + (item) => item.model_id == modelId, + ); + if (matchedItem !== undefined) return matchedItem; + throw new ModelNotFoundError(modelId); +} diff --git a/src/types.ts b/src/types.ts index baaf7412..91cc185b 100644 --- a/src/types.ts +++ b/src/types.ts @@ -11,6 +11,8 @@ import { CompletionCreateParamsBase, CompletionCreateParamsStreaming, CompletionCreateParamsNonStreaming, + EmbeddingCreateParams, + CreateEmbeddingResponse, } from "./openai_api_protocols/index"; import * as API from "./openai_api_protocols/index"; @@ -57,7 +59,7 @@ export interface LogitProcessor { processSampledToken: (token: number) => void; /** - * Called when in `ChatModule.resetChat()`. Can clear internal states. + * Called when in `MLCEngine.resetChat()`. Can clear internal states. */ resetState: () => void; } @@ -76,6 +78,11 @@ export interface MLCEngineInterface { */ completions: API.Completions; + /** + * An object that exposes embeddings APIs. + */ + embeddings: API.Embeddings; + /** * Set an initialization progress callback function * which reports the progress of model loading. @@ -173,6 +180,16 @@ export interface MLCEngineInterface { request: CompletionCreateParams, ): Promise | Completion>; + /** + * OpenAI-style API. Creates an embedding vector representing the input text. + * Use `engine.embeddings.create()` to invoke this API. + * + * @param request An OpenAI-style Embeddings request. + * + * @note For more, see https://platform.openai.com/docs/api-reference/embeddings/create + */ + embedding(request: EmbeddingCreateParams): Promise; + /** * @returns A text summarizing the runtime stats. * @note This is an async function diff --git a/src/web_worker.ts b/src/web_worker.ts index b89bcaf6..36626cad 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -24,6 +24,8 @@ import { CompletionCreateParamsStreaming, CompletionCreateParamsBase, CompletionCreateParams, + CreateEmbeddingResponse, + EmbeddingCreateParams, } from "./openai_api_protocols/index"; import * as API from "./openai_api_protocols/index"; import { @@ -38,6 +40,7 @@ import { WorkerResponse, WorkerRequest, CompletionNonStreamingParams, + EmbeddingParams, CompletionStreamInitParams, } from "./message"; import log from "loglevel"; @@ -187,7 +190,7 @@ export class WebWorkerMLCEngineHandler { }); return; } - // For engine.chat.completions() + // For engine.chat.completions.create() case "chatCompletionNonStreaming": { // Directly return the ChatCompletion response this.handleTask(msg.uuid, async () => { @@ -212,7 +215,7 @@ export class WebWorkerMLCEngineHandler { }); return; } - // engine.completions() + // For engine.completions.create() case "completionNonStreaming": { // Directly return the ChatCompletion response this.handleTask(msg.uuid, async () => { @@ -237,7 +240,7 @@ export class WebWorkerMLCEngineHandler { }); return; } - // Shared by engine.chat.completions() and engine.completions() + // Shared by engine.chat.completions.create() and engine.completions.create() case "completionStreamNextChunk": { // Note: ChatCompletion and Completion share the same chunk generator. // For any subsequent request, we return whatever `next()` yields @@ -254,6 +257,18 @@ export class WebWorkerMLCEngineHandler { }); return; } + // For engine.embeddings.create() + case "embedding": { + // Directly return the Embeddings response + this.handleTask(msg.uuid, async () => { + const params = msg.content as EmbeddingParams; + await this.reloadIfUnmatched(params.modelId, params.chatOpts); + const res = await this.engine.embedding(params.request); + onComplete?.(res); + return res; + }); + return; + } case "runtimeStatsText": { this.handleTask(msg.uuid, async () => { const res = await this.engine.runtimeStatsText(); @@ -406,6 +421,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { public chat: API.Chat; /** For completions.create() */ public completions: API.Completions; + /** For embeddings.create() */ + public embeddings: API.Embeddings; /** * The modelId and chatOpts that the frontend expects the backend engine is currently loaded @@ -445,6 +462,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { this.chat = new API.Chat(this); this.completions = new API.Completions(this); + this.embeddings = new API.Embeddings(this); } setInitProgressCallback(initProgressCallback?: InitProgressCallback) { @@ -741,6 +759,24 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return await this.getPromise(msg); } + async embedding( + request: EmbeddingCreateParams, + ): Promise { + if (this.modelId === undefined) { + throw new WorkerEngineModelNotLoadedError(this.constructor.name); + } + const msg: WorkerRequest = { + kind: "embedding", + uuid: crypto.randomUUID(), + content: { + request: request, + modelId: this.modelId, + chatOpts: this.chatOpts, + }, + }; + return await this.getPromise(msg); + } + onmessage(event: any) { let msg: WorkerResponse; if (event instanceof MessageEvent) { diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts index 4b6cd395..96cd8d98 100644 --- a/tests/openai_chat_completion.test.ts +++ b/tests/openai_chat_completion.test.ts @@ -136,7 +136,7 @@ describe("Check chat completion unsupported requests", () => { }); describe("Supported requests", () => { - test("Supproted chat completion request", () => { + test("Supported chat completion request", () => { const request: ChatCompletionRequest = { messages: [ { role: "system", content: "You are a helpful assistant." }, diff --git a/tests/openai_embeddings.test.ts b/tests/openai_embeddings.test.ts new file mode 100644 index 00000000..dd704ad4 --- /dev/null +++ b/tests/openai_embeddings.test.ts @@ -0,0 +1,133 @@ +import { + EmbeddingInputEmptyError, + EmbeddingUnsupportedEncodingFormatError, +} from "../src/error"; +import { + EmbeddingCreateParams, + postInitAndCheckFields, +} from "../src/openai_api_protocols/embedding"; +import { describe, expect, test } from "@jest/globals"; + +describe("Check embeddings supported requests", () => { + test("Supported embedding request float", () => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + encoding_format: "float", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }); + + test("Supported embedding request, unspecified format", () => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }); + + test("Supported embedding request, single string", () => { + const request: EmbeddingCreateParams = { + input: "Hello", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }); + + test("Supported embedding request, single token array", () => { + const request: EmbeddingCreateParams = { + input: [0, 1], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }); + + test("Supported embedding request, array of token arrays", () => { + const request: EmbeddingCreateParams = { + input: [ + [0, 1], + [0, 1], + ], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }); +}); + +describe("Invalid embedding input", () => { + test("Empty string", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: "", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow(new EmbeddingInputEmptyError()); + }); + + test("Contains empty string", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: ["Hi", "hello", ""], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow(new EmbeddingInputEmptyError()); + }); + + test("Empty token array", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: [], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow(new EmbeddingInputEmptyError()); + }); + + test("Contains empty token array", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: [[1, 2], [3], [], [4]], + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow(new EmbeddingInputEmptyError()); + }); +}); + +describe("Check embeddings unsupported requests", () => { + test("base64 encoding_format", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + encoding_format: "base64", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow(new EmbeddingUnsupportedEncodingFormatError()); + }); + + test("model", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + encoding_format: "float", + model: "snowflake-arctic-embed-m-q0f32-MLC", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow("The following fields in"); + }); + + test("user", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + encoding_format: "float", + user: "Bob", + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow("The following fields in"); + }); + + test("dimensions", () => { + expect(() => { + const request: EmbeddingCreateParams = { + input: ["Hello", "Hi"], + encoding_format: "float", + dimensions: 2048, + }; + postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); + }).toThrow("The following fields in"); + }); +});