diff --git a/src/cache_util.ts b/src/cache_util.ts index 921cb390..44b70725 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -1,7 +1,13 @@ import * as tvmjs from "tvmjs"; -import { AppConfig, ModelRecord, prebuiltAppConfig } from "./config"; +import { + AppConfig, + ChatConfig, + ModelRecord, + prebuiltAppConfig, +} from "./config"; import { cleanModelUrl } from "./support"; -import { ModelNotFoundError } from "./error"; +import { ModelNotFoundError, UnsupportedTokenizerFilesError } from "./error"; +import { Tokenizer } from "@mlc-ai/web-tokenizers"; function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord { const matchedItem = appConfig?.model_list.find( @@ -101,3 +107,43 @@ export async function deleteModelWasmInCache( } await wasmCache.deleteInCache(modelRecord.model_lib); } + +/** + * + * @param baseUrl The link to which we can find tokenizer files, usually is a `ModelRecord.model`. + * @param config A ChatConfig, usually loaded from `mlc-chat-config.json` in `baseUrl`. + * @param appConfig An AppConfig, usually `webllm.prebuiltAppConfig` if not defined by user. + * @param logger Logging function, console.log by default. + * @returns + */ +export async function asyncLoadTokenizer( + baseUrl: string, + config: ChatConfig, + appConfig: AppConfig, + logger: (msg: string) => void = console.log, +): Promise { + let modelCache: tvmjs.ArtifactCacheTemplate; + if (appConfig.useIndexedDBCache) { + modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); + } else { + modelCache = new tvmjs.ArtifactCache("webllm/model"); + } + + if (config.tokenizer_files.includes("tokenizer.json")) { + const url = new URL("tokenizer.json", baseUrl).href; + const model = await modelCache.fetchWithCache(url, "arraybuffer"); + return Tokenizer.fromJSON(model); + } else if (config.tokenizer_files.includes("tokenizer.model")) { + logger( + "Using `tokenizer.model` since we cannot locate `tokenizer.json`.\n" + + "It is recommended to use `tokenizer.json` to ensure all token mappings are included, " + + "since currently, files like `added_tokens.json`, `tokenizer_config.json` are ignored.\n" + + "Consider converting `tokenizer.model` to `tokenizer.json` by compiling the model " + + "with MLC again, or see if MLC's huggingface provides this file.", + ); + const url = new URL("tokenizer.model", baseUrl).href; + const model = await modelCache.fetchWithCache(url, "arraybuffer"); + return Tokenizer.fromSentencePiece(model); + } + throw new UnsupportedTokenizerFilesError(config.tokenizer_files); +} diff --git a/src/engine.ts b/src/engine.ts index a3094ee9..fceb7232 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -1,6 +1,5 @@ import * as tvmjs from "tvmjs"; import log from "loglevel"; -import { Tokenizer } from "@mlc-ai/web-tokenizers"; import { ChatConfig, ChatOptions, @@ -46,7 +45,7 @@ import { getConversation, getConversationFromChatCompletionRequest, } from "./conversation"; -import { cleanModelUrl } from "./support"; +import { cleanModelUrl, getToolCallFromOutputMessage } from "./support"; import { ChatModuleNotInitializedError, ConfigurationNotInitializedError, @@ -56,12 +55,9 @@ import { ModelNotFoundError, ModelNotLoadedError, ShaderF16SupportError, - ToolCallOutputInvalidTypeError, - ToolCallOutputMissingFieldsError, - ToolCallOutputParseError, - UnsupportedTokenizerFilesError, WebGPUNotAvailableError, } from "./error"; +import { asyncLoadTokenizer } from "./cache_util"; /** * Creates `MLCEngine`, and loads `modelId` onto WebGPU. @@ -89,7 +85,8 @@ export async function CreateMLCEngine( /** * The main interface of MLCEngine, which loads a model and performs tasks. * - * You can either initialize one with `webllm.CreateMLCEngine(modelId)`, or `webllm.MLCEngine().reload(modelId)`. + * You can either initialize one with `webllm.CreateMLCEngine(modelId)`, or + * `webllm.MLCEngine().reload(modelId)`. */ export class MLCEngine implements MLCEngineInterface { /** For chat.completions.create() */ @@ -119,6 +116,10 @@ export class MLCEngine implements MLCEngineInterface { this.completions = new API.Completions(this); } + //----------------------- + // 0. Setters and getters + //----------------------- + setAppConfig(appConfig: AppConfig) { this.appConfig = appConfig; } @@ -137,6 +138,10 @@ export class MLCEngine implements MLCEngineInterface { this.logitProcessorRegistry = logitProcessorRegistry; } + //---------------------------------------- + // 1. Model/pipeline loading and unloading + //---------------------------------------- + /** * Reload model `modelId`. * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in @@ -286,10 +291,11 @@ export class MLCEngine implements MLCEngineInterface { }); tvm.initWebGPU(gpuDetectOutput.device); - const tokenizer = await this.asyncLoadTokenizer( + const tokenizer = await asyncLoadTokenizer( modelUrl, this.config, this.appConfig, + this.logger, ); const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache"; await tvm.fetchNDArrayCache( @@ -323,21 +329,29 @@ export class MLCEngine implements MLCEngineInterface { } } - async generate( - input: string | ChatCompletionRequestNonStreaming, - progressCallback?: GenerateProgressCallback, - streamInterval = 1, - genConfig?: GenerationConfig, - ): Promise { - log.warn( - "WARNING: `generate()` will soon be deprecated. " + - "Please use `engine.chat.completions.create()` instead. " + - "For multi-round chatting, see `examples/multi-round-chat` on how to use " + - "`engine.chat.completions.create()` to achieve the same effect.", - ); - return this._generate(input, progressCallback, streamInterval, genConfig); + /** + * Unloads the currently loaded model and destroy the webgpu device. Waits + * until the webgpu device finishes all submitted work and destroys itself. + * @note This is an asynchronous function. + */ + async unload() { + this.deviceLostIsError = false; // so that unload() does not trigger device.lost error + this.pipeline?.dispose(); + // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true + await this.pipeline?.sync(); + this.pipeline = undefined; + this.currentModelId = undefined; + this.deviceLostIsError = true; + if (this.reloadController) { + this.reloadController.abort("Engine.unload() is called."); + this.reloadController = undefined; + } } + //--------------------------------------------------- + // 2. Underlying auto-regressive generation functions + //--------------------------------------------------- + private async _generate( input: | string @@ -369,7 +383,7 @@ export class MLCEngine implements MLCEngineInterface { } /** - * Similar to `generate()`; but instead of using callback, we use an async iterable. + * Similar to `_generate()`; but instead of using callback, we use an async iterable. * @param request Request for chat completion. * @param genConfig Generation config extraced from `request`. */ @@ -512,7 +526,7 @@ export class MLCEngine implements MLCEngineInterface { // If stopped due to length or abort, cannot output return tool_calls field finish_reason = "tool_calls"; const outputMessage = await this.getMessage(); - tool_calls = this.getToolCallFromOutputMessage( + tool_calls = getToolCallFromOutputMessage( outputMessage, /*isStreaming=*/ true, ) as Array; @@ -597,6 +611,32 @@ export class MLCEngine implements MLCEngineInterface { } } + async interruptGenerate() { + this.interruptSignal = true; + } + + //------------------------------ + // 3. High-level generation APIs + //------------------------------ + + /** + * A legacy E2E generation API. Functionally equivalent to `chatCompletion()`. + */ + async generate( + input: string | ChatCompletionRequestNonStreaming, + progressCallback?: GenerateProgressCallback, + streamInterval = 1, + genConfig?: GenerationConfig, + ): Promise { + log.warn( + "WARNING: `generate()` will soon be deprecated. " + + "Please use `engine.chat.completions.create()` instead. " + + "For multi-round chatting, see `examples/multi-round-chat` on how to use " + + "`engine.chat.completions.create()` to achieve the same effect.", + ); + return this._generate(input, progressCallback, streamInterval, genConfig); + } + /** * Completes a single ChatCompletionRequest. * @@ -635,7 +675,7 @@ export class MLCEngine implements MLCEngineInterface { response_format: request.response_format, }; - // 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`) + // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) { return this.asyncGenerate(request, genConfig); } @@ -644,7 +684,7 @@ export class MLCEngine implements MLCEngineInterface { this.getPipeline().setSeed(request.seed); } - // 2. If request is non-streaming, directly reuse `generate()` + // 2. If request is non-streaming, directly reuse `_generate()` const n = request.n ? request.n : 1; const choices: Array = []; let completion_tokens = 0; @@ -674,7 +714,7 @@ export class MLCEngine implements MLCEngineInterface { if (this.getFinishReason()! == "stop" && isFunctionCalling) { // If stopped due to length or abort, cannot output return tool_calls field finish_reason = "tool_calls"; - tool_calls = this.getToolCallFromOutputMessage( + tool_calls = getToolCallFromOutputMessage( outputMessage, /*isStreaming=*/ false, ); @@ -767,7 +807,7 @@ export class MLCEngine implements MLCEngineInterface { top_logprobs: request.top_logprobs, }; - // 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`) + // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) { return this.asyncGenerate(request, genConfig); } @@ -776,7 +816,7 @@ export class MLCEngine implements MLCEngineInterface { this.getPipeline().setSeed(request.seed); } - // 2. If request is non-streaming, directly reuse `generate()` + // 2. If request is non-streaming, directly reuse `_generate()` const n = request.n ? request.n : 1; const choices: Array = []; let completion_tokens = 0; @@ -840,42 +880,9 @@ export class MLCEngine implements MLCEngineInterface { return response; } - async interruptGenerate() { - this.interruptSignal = true; - } - - async runtimeStatsText(): Promise { - log.warn( - "WARNING: `runtimeStatsText()` will soon be deprecated. " + - "Please use `ChatCompletion.usage` for non-streaming requests, or " + - "`ChatCompletionChunk.usage` for streaming requests, enabled by `stream_options`. " + - "The only flow that expects to use `runtimeStatsText()` as of now is `forwardTokensAndSample()`.", - ); - return this.getPipeline().runtimeStatsText(); - } - - async resetChat(keepStats = false) { - this.pipeline?.resetChat(keepStats); - } - - /** - * Unloads the currently loaded model and destroy the webgpu device. Waits - * until the webgpu device finishes all submitted work and destroys itself. - * @note This is an asynchronous function. - */ - async unload() { - this.deviceLostIsError = false; // so that unload() does not trigger device.lost error - this.pipeline?.dispose(); - // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true - await this.pipeline?.sync(); - this.pipeline = undefined; - this.currentModelId = undefined; - this.deviceLostIsError = true; - if (this.reloadController) { - this.reloadController.abort("Engine.unload() is called."); - this.reloadController = undefined; - } - } + //----------------------------- + // 4. WebGPU info-querying helpers + //----------------------------- async getMaxStorageBufferBindingSize(): Promise { // First detect GPU @@ -915,9 +922,16 @@ export class MLCEngine implements MLCEngineInterface { return gpuDetectOutput.adapterInfo.vendor; } - //-------------------------- - // Lower level API - //-------------------------- + //---------------------------------------------- + // 5. Low-level APIs that interact with pipeline + //---------------------------------------------- + private getPipeline(): LLMChatPipeline { + if (this.pipeline === undefined) { + throw new ChatModuleNotInitializedError(); + } + return this.pipeline; + } + async forwardTokensAndSample( inputIds: Array, isPrefill: boolean, @@ -957,90 +971,18 @@ export class MLCEngine implements MLCEngineInterface { log.setLevel(logLevel); } - /** - * Given a string outputMessage, parse it as a JSON object and return an array of tool calls. - * - * Expect outputMessage to be a valid JSON string, and expect it to be an array of Function with - * fields `arguments` and `name`. - */ - private getToolCallFromOutputMessage( - outputMessage: string, - isStreaming: false, - ): Array; - private getToolCallFromOutputMessage( - outputMessage: string, - isStreaming: true, - ): Array; - private getToolCallFromOutputMessage( - outputMessage: string, - isStreaming: boolean, - ): - | Array - | Array { - // 1. Parse outputMessage to JSON object - let toolCallsObject; - try { - toolCallsObject = JSON.parse(outputMessage); - } catch (err) { - throw new ToolCallOutputParseError(outputMessage, err as Error); - } - - // 2. Expect to be an array - if (!(toolCallsObject instanceof Array)) { - throw new ToolCallOutputInvalidTypeError("array"); - } - - // 3. Parse each tool call and populate tool_calls - const numToolCalls = toolCallsObject.length; - const tool_calls = []; - for (let id = 0; id < numToolCalls; id++) { - const curToolCall = toolCallsObject[id]; - if ( - curToolCall.name === undefined || - curToolCall.arguments === undefined - ) { - throw new ToolCallOutputMissingFieldsError( - ["name", "arguments"], - curToolCall, - ); - } - tool_calls.push({ - name: curToolCall.name, - arguments: JSON.stringify(curToolCall.arguments), - }); - } + async runtimeStatsText(): Promise { + log.warn( + "WARNING: `runtimeStatsText()` will soon be deprecated. " + + "Please use `ChatCompletion.usage` for non-streaming requests, or " + + "`ChatCompletionChunk.usage` for streaming requests, enabled by `stream_options`. " + + "The only flow that expects to use `runtimeStatsText()` as of now is `forwardTokensAndSample()`.", + ); + return this.getPipeline().runtimeStatsText(); + } - // 4. Return based on whether it is streaming or not - if (isStreaming) { - const tool_calls_result: Array = - []; - for (let id = 0; id < numToolCalls; id++) { - const curToolCall = tool_calls[id]; - tool_calls_result.push({ - index: id, - function: { - name: curToolCall.name, - arguments: curToolCall.arguments, - }, - type: "function", - }); - } - return tool_calls_result; - } else { - const tool_calls_result: Array = []; - for (let id = 0; id < numToolCalls; id++) { - const curToolCall = tool_calls[id]; - tool_calls_result.push({ - id: id.toString(), - function: { - name: curToolCall.name, - arguments: curToolCall.arguments, - }, - type: "function", - }); - } - return tool_calls_result; - } + async resetChat(keepStats = false) { + this.pipeline?.resetChat(keepStats); } /** @@ -1121,42 +1063,4 @@ export class MLCEngine implements MLCEngineInterface { async decode(genConfig?: GenerationConfig) { return this.getPipeline().decodeStep(genConfig); } - - private getPipeline(): LLMChatPipeline { - if (this.pipeline === undefined) { - throw new ChatModuleNotInitializedError(); - } - return this.pipeline; - } - - private async asyncLoadTokenizer( - baseUrl: string, - config: ChatConfig, - appConfig: AppConfig, - ): Promise { - let modelCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); - } else { - modelCache = new tvmjs.ArtifactCache("webllm/model"); - } - - if (config.tokenizer_files.includes("tokenizer.json")) { - const url = new URL("tokenizer.json", baseUrl).href; - const model = await modelCache.fetchWithCache(url, "arraybuffer"); - return Tokenizer.fromJSON(model); - } else if (config.tokenizer_files.includes("tokenizer.model")) { - this.logger( - "Using `tokenizer.model` since we cannot locate `tokenizer.json`.\n" + - "It is recommended to use `tokenizer.json` to ensure all token mappings are included, " + - "since currently, files like `added_tokens.json`, `tokenizer_config.json` are ignored.\n" + - "Consider converting `tokenizer.model` to `tokenizer.json` by compiling the model " + - "with MLC again, or see if MLC's huggingface provides this file.", - ); - const url = new URL("tokenizer.model", baseUrl).href; - const model = await modelCache.fetchWithCache(url, "arraybuffer"); - return Tokenizer.fromSentencePiece(model); - } - throw new UnsupportedTokenizerFilesError(config.tokenizer_files); - } } diff --git a/src/support.ts b/src/support.ts index a19ee524..13e334a0 100644 --- a/src/support.ts +++ b/src/support.ts @@ -1,6 +1,15 @@ /** Util methods. */ import { Tokenizer } from "@mlc-ai/web-tokenizers"; import { MessagePlaceholders } from "./config"; +import { + ChatCompletionChunk, + ChatCompletionMessageToolCall, +} from "./openai_api_protocols/index"; +import { + ToolCallOutputInvalidTypeError, + ToolCallOutputMissingFieldsError, + ToolCallOutputParseError, +} from "./error"; /** * Based on `p_prob` of size (vocabSize,) which becomes a distribution after calling @@ -105,3 +114,86 @@ to assist with the user query. Don't make assumptions about what values to plug are the available tools: ${MessagePlaceholders.hermes_tools} . Use the following pydantic model json schema for each tool call you will make: ${officialHermes2FunctionCallSchema} For each function call return a json object.`; + +/** + * Given a string outputMessage, parse it as a JSON object and return an array of tool calls. + * + * Expect outputMessage to be a valid JSON string, and expect it to be an array of Function with + * fields `arguments` and `name`. + */ +export function getToolCallFromOutputMessage( + outputMessage: string, + isStreaming: false, +): Array; +export function getToolCallFromOutputMessage( + outputMessage: string, + isStreaming: true, +): Array; +export function getToolCallFromOutputMessage( + outputMessage: string, + isStreaming: boolean, +): + | Array + | Array { + // 1. Parse outputMessage to JSON object + let toolCallsObject; + try { + toolCallsObject = JSON.parse(outputMessage); + } catch (err) { + throw new ToolCallOutputParseError(outputMessage, err as Error); + } + + // 2. Expect to be an array + if (!(toolCallsObject instanceof Array)) { + throw new ToolCallOutputInvalidTypeError("array"); + } + + // 3. Parse each tool call and populate tool_calls + const numToolCalls = toolCallsObject.length; + const tool_calls = []; + for (let id = 0; id < numToolCalls; id++) { + const curToolCall = toolCallsObject[id]; + if (curToolCall.name === undefined || curToolCall.arguments === undefined) { + throw new ToolCallOutputMissingFieldsError( + ["name", "arguments"], + curToolCall, + ); + } + tool_calls.push({ + name: curToolCall.name, + arguments: JSON.stringify(curToolCall.arguments), + }); + } + + // 4. Return based on whether it is streaming or not + if (isStreaming) { + const tool_calls_result: Array = + []; + for (let id = 0; id < numToolCalls; id++) { + const curToolCall = tool_calls[id]; + tool_calls_result.push({ + index: id, + function: { + name: curToolCall.name, + arguments: curToolCall.arguments, + }, + type: "function", + }); + } + return tool_calls_result; + } else { + const tool_calls_result: Array = []; + for (let id = 0; id < numToolCalls; id++) { + const curToolCall = tool_calls[id]; + tool_calls_result.push({ + id: id.toString(), + function: { + name: curToolCall.name, + arguments: curToolCall.arguments, + }, + type: "function", + }); + } + return tool_calls_result; + } +} diff --git a/src/types.ts b/src/types.ts index 287b5f5f..baaf7412 100644 --- a/src/types.ts +++ b/src/types.ts @@ -128,13 +128,18 @@ export interface MLCEngineInterface { ) => Promise; /** - * OpenAI-style API. Generate a chat completion response for the given conversation and configuration. + * OpenAI-style API. Generate a chat completion response for the given conversation and + * configuration. Use `engine.chat.completions.create()` to invoke this API. * - * The API is completely functional in behavior. That is, a previous request would not affect - * the current request's result. Thus, for multi-round chatting, users are responsible for + * @param request A OpenAI-style ChatCompletion request. + * + * @note The API is completely functional in behavior. That is, a previous request would not + * affect the current request's result. Thus, for multi-round chatting, users are responsible for * maintaining the chat history. With that being said, as an implicit internal optimization, if we - * detect that the user is performing multiround chatting, we will preserve the KV cache and only + * detect that the user is performing multi-round chatting, we will preserve the KV cache and only * prefill the new tokens. + * + * @note For more, see https://platform.openai.com/docs/api-reference/chat */ chatCompletion( request: ChatCompletionRequestNonStreaming, @@ -149,6 +154,14 @@ export interface MLCEngineInterface { request: ChatCompletionRequest, ): Promise | ChatCompletion>; + /** + * OpenAI-style API. Completes a CompletionCreateParams, a text completion with no chat template. + * Use `engine.completions.create()` to invoke this API. + * + * @param request An OpenAI-style Completion request. + * + * @note For more, see https://platform.openai.com/docs/api-reference/completions + */ completion(request: CompletionCreateParamsNonStreaming): Promise; completion( request: CompletionCreateParamsStreaming,