From 5753b0d1b5487f93d96bb189a35e8aea07f718a3 Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Wed, 22 May 2024 05:41:01 -0400 Subject: [PATCH] [Style] Add GitHub action for linter and pre-commit hook formater --- .eslintignore | 1 + .eslintrc.cjs | 5 +- .github/workflows/linter.yaml | 27 + .husky/pre-commit | 1 + .lintstagedrc.json | 6 + .prettierignore | 7 + package-lock.json | 136 +- package.json | 10 +- src/cache_util.ts | 50 +- src/config.ts | 751 ++++++---- src/conversation.ts | 95 +- src/engine.ts | 314 +++-- src/extension_service_worker.ts | 22 +- src/grammar.ts | 64 +- src/index.ts | 26 +- src/llm_chat.ts | 329 +++-- src/openai_api_protocols/apis.ts | 12 +- src/openai_api_protocols/chat_completion.ts | 1382 ++++++++++--------- src/openai_api_protocols/index.ts | 66 +- src/service_worker.ts | 37 +- src/support.ts | 131 +- src/types.ts | 29 +- src/utils.ts | 6 +- src/web_worker.ts | 55 +- 24 files changed, 2094 insertions(+), 1468 deletions(-) create mode 100644 .github/workflows/linter.yaml create mode 100644 .husky/pre-commit create mode 100644 .lintstagedrc.json create mode 100644 .prettierignore diff --git a/.eslintignore b/.eslintignore index 3ea1e43c..ab2ea8e5 100644 --- a/.eslintignore +++ b/.eslintignore @@ -3,4 +3,5 @@ debug lib build node_modules +3rdparty .eslintrc.cjs diff --git a/.eslintrc.cjs b/.eslintrc.cjs index 3d660a9d..bd676556 100644 --- a/.eslintrc.cjs +++ b/.eslintrc.cjs @@ -1,9 +1,10 @@ module.exports = { - extends: ['eslint:recommended', 'plugin:@typescript-eslint/recommended'], + extends: ['eslint:recommended', 'plugin:@typescript-eslint/recommended', 'plugin:prettier/recommended'], parser: '@typescript-eslint/parser', plugins: ['@typescript-eslint'], root: true, rules: { - "@typescript-eslint/no-explicit-any": "off" + "@typescript-eslint/no-explicit-any": "off", + "@typescript-eslint/no-empty-function": "off" } }; diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml new file mode 100644 index 00000000..1e5451c8 --- /dev/null +++ b/.github/workflows/linter.yaml @@ -0,0 +1,27 @@ +name: Linter + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Node.js + uses: actions/setup-node@v3 + with: + node-version: '16' + + - name: Install dependencies + run: npm install + + - name: Run lint + run: npm run lint diff --git a/.husky/pre-commit b/.husky/pre-commit new file mode 100644 index 00000000..2312dc58 --- /dev/null +++ b/.husky/pre-commit @@ -0,0 +1 @@ +npx lint-staged diff --git a/.lintstagedrc.json b/.lintstagedrc.json new file mode 100644 index 00000000..88eb34f5 --- /dev/null +++ b/.lintstagedrc.json @@ -0,0 +1,6 @@ +{ + "./**/*.{js,ts,jsx,tsx,json,html,css,md}": [ + "eslint --fix", + "prettier --write" + ] +} diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 00000000..ab2ea8e5 --- /dev/null +++ b/.prettierignore @@ -0,0 +1,7 @@ +dist +debug +lib +build +node_modules +3rdparty +.eslintrc.cjs diff --git a/package-lock.json b/package-lock.json index 2354ec07..864a6df2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { - "name": "@mlc-ai/web-llm", - "version": "0.2.37", + "name": "@neet-nestor/web-llm", + "version": "0.2.44", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "@mlc-ai/web-llm", - "version": "0.2.37", + "name": "@neet-nestor/web-llm", + "version": "0.2.44", "license": "Apache-2.0", "devDependencies": { "@mlc-ai/web-tokenizers": "^0.1.3", @@ -20,7 +20,11 @@ "@webgpu/types": "^0.1.24", "buffer": "^5.7.1", "eslint": "^8.41.0", + "eslint-config-prettier": "^9.1.0", + "eslint-plugin-prettier": "^5.1.3", + "husky": "^9.0.11", "jest": "^29.7.0", + "prettier": "3.2.5", "process": "^0.11.10", "rollup": "^2.56.2", "rollup-plugin-ignore": "^1.0.10", @@ -1288,6 +1292,18 @@ "node": ">= 8" } }, + "node_modules/@pkgr/core": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.1.1.tgz", + "integrity": "sha512-cq8o4cWH0ibXh9VGi5P20Tu9XF/0fFXl9EUinr9QfTM7a7p0oTA4iJRCQWppXR1Pg8dSM0UCItCkPwsk9qWWYA==", + "dev": true, + "engines": { + "node": "^12.20.0 || ^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/unts" + } + }, "node_modules/@rollup/plugin-commonjs": { "version": "20.0.0", "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-20.0.0.tgz", @@ -2939,6 +2955,48 @@ "url": "https://opencollective.com/eslint" } }, + "node_modules/eslint-config-prettier": { + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.0.tgz", + "integrity": "sha512-NSWl5BFQWEPi1j4TjVNItzYV7dZXZ+wP6I6ZhrBGpChQhZRUaElihE9uRRkcbRnNb76UMKDF3r+WTmNcGPKsqw==", + "dev": true, + "bin": { + "eslint-config-prettier": "bin/cli.js" + }, + "peerDependencies": { + "eslint": ">=7.0.0" + } + }, + "node_modules/eslint-plugin-prettier": { + "version": "5.1.3", + "resolved": "https://registry.npmjs.org/eslint-plugin-prettier/-/eslint-plugin-prettier-5.1.3.tgz", + "integrity": "sha512-C9GCVAs4Eq7ZC/XFQHITLiHJxQngdtraXaM+LoUFoFp/lHNl2Zn8f3WQbe9HvTBBQ9YnKFB0/2Ajdqwo5D1EAw==", + "dev": true, + "dependencies": { + "prettier-linter-helpers": "^1.0.0", + "synckit": "^0.8.6" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint-plugin-prettier" + }, + "peerDependencies": { + "@types/eslint": ">=8.0.0", + "eslint": ">=8.0.0", + "eslint-config-prettier": "*", + "prettier": ">=3.0.0" + }, + "peerDependenciesMeta": { + "@types/eslint": { + "optional": true + }, + "eslint-config-prettier": { + "optional": true + } + } + }, "node_modules/eslint-scope": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", @@ -3289,6 +3347,12 @@ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", "dev": true }, + "node_modules/fast-diff": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/fast-diff/-/fast-diff-1.3.0.tgz", + "integrity": "sha512-VxPP4NqbUjj6MaAOafWeUn2cXWLcCtljklUtZf0Ind4XQ+QPtmA0b18zZy0jIQx+ExRVCR/ZQpBmik5lXshNsw==", + "dev": true + }, "node_modules/fast-glob": { "version": "3.2.12", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", @@ -3797,6 +3861,21 @@ "node": ">=10.17.0" } }, + "node_modules/husky": { + "version": "9.0.11", + "resolved": "https://registry.npmjs.org/husky/-/husky-9.0.11.tgz", + "integrity": "sha512-AB6lFlbwwyIqMdHYhwPe+kjOC3Oc5P3nThEoW/AaO2BX3vJDjWPFxYLxokUZOo6RNX20He3AaT8sESs9NJcmEw==", + "dev": true, + "bin": { + "husky": "bin.mjs" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/typicode" + } + }, "node_modules/iconv-lite": { "version": "0.4.24", "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", @@ -6721,6 +6800,33 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.2.5", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.2.5.tgz", + "integrity": "sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==", + "dev": true, + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/prettier-linter-helpers": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/prettier-linter-helpers/-/prettier-linter-helpers-1.0.0.tgz", + "integrity": "sha512-GbK2cP9nraSSUF9N2XwUwqfzlAFlMNYYl+ShE/V+H8a9uNl/oUqB1w2EL54Jh0OlyRSd8RfWYJ3coVS4TROP2w==", + "dev": true, + "dependencies": { + "fast-diff": "^1.1.2" + }, + "engines": { + "node": ">=6.0.0" + } + }, "node_modules/pretty-format": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", @@ -8037,6 +8143,22 @@ "integrity": "sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==", "dev": true }, + "node_modules/synckit": { + "version": "0.8.8", + "resolved": "https://registry.npmjs.org/synckit/-/synckit-0.8.8.tgz", + "integrity": "sha512-HwOKAP7Wc5aRGYdKH+dw0PRRpbO841v2DENBtjnR5HFWoiNByAl7vrx3p0G/rCyYXQsrxqtX48TImFtPcIHSpQ==", + "dev": true, + "dependencies": { + "@pkgr/core": "^0.1.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/unts" + } + }, "node_modules/terminal-link": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/terminal-link/-/terminal-link-2.1.1.tgz", @@ -8225,9 +8347,9 @@ } }, "node_modules/tslib": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.5.2.tgz", - "integrity": "sha512-5svOrSA2w3iGFDs1HibEVBGbDrAY82bFQ3HZ3ixB+88nsbsWQoKqDRb5UBYAUPEzbBn6dAp5gRNXglySbx1MlA==", + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", "dev": true }, "node_modules/tsutils": { diff --git a/package.json b/package.json index 3ddf826a..0e43c37b 100644 --- a/package.json +++ b/package.json @@ -7,8 +7,10 @@ "type": "module", "scripts": { "build": "rollup -c && ./cleanup-index-js.sh", - "lint": "npx eslint .", - "test": "yarn jest" + "lint": "npx eslint ./src/ && npx prettier ./src --check", + "test": "yarn jest", + "format": "prettier --write \"./src/\"", + "prepare": "husky" }, "files": [ "lib" @@ -36,7 +38,11 @@ "@webgpu/types": "^0.1.24", "buffer": "^5.7.1", "eslint": "^8.41.0", + "eslint-config-prettier": "^9.1.0", + "eslint-plugin-prettier": "^5.1.3", + "husky": "^9.0.11", "jest": "^29.7.0", + "prettier": "3.2.5", "process": "^0.11.10", "rollup": "^2.56.2", "rollup-plugin-ignore": "^1.0.10", diff --git a/src/cache_util.ts b/src/cache_util.ts index 0c069222..5fe766fe 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -1,13 +1,9 @@ import * as tvmjs from "tvmjs"; -import { - AppConfig, - ModelRecord, - prebuiltAppConfig, -} from "./config"; +import { AppConfig, ModelRecord, prebuiltAppConfig } from "./config"; function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord { const matchedItem = appConfig?.model_list.find( - item => item.model_id == modelId + (item) => item.model_id == modelId, ); if (matchedItem !== undefined) { return matchedItem; @@ -15,7 +11,10 @@ function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord { throw Error("Cannot find model_url for " + modelId); } -export async function hasModelInCache(modelId: string, appConfig?: AppConfig): Promise { +export async function hasModelInCache( + modelId: string, + appConfig?: AppConfig, +): Promise { if (appConfig === undefined) { appConfig = prebuiltAppConfig; } @@ -25,7 +24,10 @@ export async function hasModelInCache(modelId: string, appConfig?: AppConfig): P return tvmjs.hasNDArrayInCache(modelUrl, "webllm/model", cacheType); } -export async function deleteModelAllInfoInCache(modelId: string, appConfig?: AppConfig) { +export async function deleteModelAllInfoInCache( + modelId: string, + appConfig?: AppConfig, +) { // function to delete model all information in cache if (appConfig === undefined) { appConfig = prebuiltAppConfig; @@ -34,12 +36,14 @@ export async function deleteModelAllInfoInCache(modelId: string, appConfig?: App await deleteModelInCache(modelId, appConfig); // delete wasm in cache await deleteModelWasmInCache(modelId, appConfig); - // delete chat config + // delete chat config await deleteChatConfigInCache(modelId, appConfig); } - -export async function deleteModelInCache(modelId: string, appConfig?: AppConfig) { +export async function deleteModelInCache( + modelId: string, + appConfig?: AppConfig, +) { // delete the model NDArray In Cache if (appConfig === undefined) { appConfig = prebuiltAppConfig; @@ -47,17 +51,28 @@ export async function deleteModelInCache(modelId: string, appConfig?: AppConfig) const modelRecord = findModelRecord(modelId, appConfig); let modelCache: tvmjs.ArtifactCacheTemplate; if (appConfig.useIndexedDBCache) { - tvmjs.deleteNDArrayCache(modelRecord.model_url, "webllm/model", "indexeddb"); + tvmjs.deleteNDArrayCache( + modelRecord.model_url, + "webllm/model", + "indexeddb", + ); modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); } else { tvmjs.deleteNDArrayCache(modelRecord.model_url, "webllm/model", "cache"); modelCache = new tvmjs.ArtifactCache("webllm/model"); } - await modelCache.deleteInCache(new URL("tokenizer.model", modelRecord.model_url).href); - await modelCache.deleteInCache(new URL("tokenizer.json", modelRecord.model_url).href); + await modelCache.deleteInCache( + new URL("tokenizer.model", modelRecord.model_url).href, + ); + await modelCache.deleteInCache( + new URL("tokenizer.json", modelRecord.model_url).href, + ); } -export async function deleteChatConfigInCache(modelId: string, appConfig?: AppConfig) { +export async function deleteChatConfigInCache( + modelId: string, + appConfig?: AppConfig, +) { // delete the chat configuration in Cache if (appConfig === undefined) { appConfig = prebuiltAppConfig; @@ -73,7 +88,10 @@ export async function deleteChatConfigInCache(modelId: string, appConfig?: AppCo await configCache.deleteInCache(configUrl); } -export async function deleteModelWasmInCache(modelId: string, appConfig?: AppConfig) { +export async function deleteModelWasmInCache( + modelId: string, + appConfig?: AppConfig, +) { // delete the wasm in Cache if (appConfig === undefined) { appConfig = prebuiltAppConfig; diff --git a/src/config.ts b/src/config.ts index 9469f1af..0a5b3fd1 100644 --- a/src/config.ts +++ b/src/config.ts @@ -23,7 +23,7 @@ export interface ConvTemplateConfig { export enum Role { user = "user", - assistant = "assistant" + assistant = "assistant", } /** @@ -39,7 +39,7 @@ export enum MessagePlaceholders { user = "{user_message}", assistant = "{assistant_message}", tool = "{tool_message}", - function = "{function_string}" + function = "{function_string}", } /** @@ -47,7 +47,7 @@ export enum MessagePlaceholders { * This only corresponds to the chat-related fields and `tokenizer_files` of `mlc-chat-config.json`. * Only these fields affect the conversation in runtime. * i.e. The third part in https://llm.mlc.ai/docs/get_started/mlc_chat_config.html. - * + * * This is initialized in `ChatModule.reload()` with the model's `mlc-chat-config.json`. */ export interface ChatConfig { @@ -73,31 +73,31 @@ export interface ChatConfig { * Custom options that can be used to override known config values. */ // eslint-disable-next-line @typescript-eslint/no-empty-interface -export interface ChatOptions extends Partial { } +export interface ChatOptions extends Partial {} /** * Optional configurations for `CreateMLCEngine()` and `CreateWebWorkerMLCEngine()`. - * + * * chatOpts: To optionally override the `mlc-chat-config.json` of `modelId`. * appConfig: Configure the app, including the list of models and whether to use IndexedDB cache. * initProgressCallback: A callback for showing the progress of loading the model. * logitProcessorRegistry: A register for stateful logit processors, see `webllm.LogitProcessor`. - * + * * @note All fields are optional, and `logitProcessorRegistry` is only used for `CreateMLCEngine()` * not `CreateWebWorkerMLCEngine()`. */ export interface MLCEngineConfig { - chatOpts?: ChatOptions, - appConfig?: AppConfig, - initProgressCallback?: InitProgressCallback, - logitProcessorRegistry?: Map + chatOpts?: ChatOptions; + appConfig?: AppConfig; + initProgressCallback?: InitProgressCallback; + logitProcessorRegistry?: Map; } /** * Config for a single generation. * Essentially `ChatConfig` without `tokenizer_files`, `conv_config`, or `conv_template`. * We also support additional fields not present in `mlc-chat-config.json` due to OpenAI-like APIs. - * + * * Note that all values are optional. If unspecified, we use whatever values in `ChatConfig` * initialized during `ChatModule.reload()`. */ @@ -121,15 +121,23 @@ export interface GenerationConfig { response_format?: ResponseFormat | null; } -export function postInitAndCheckGenerationConfigValues(config: GenerationConfig): void { +export function postInitAndCheckGenerationConfigValues( + config: GenerationConfig, +): void { function _hasValue(value: any): boolean { // if we use `if value` directly, `value` being 0 evaluates to false, violating semantics return value !== undefined && value !== null; } - if (config.frequency_penalty && (config.frequency_penalty < -2.0 || config.frequency_penalty > 2.0)) { + if ( + config.frequency_penalty && + (config.frequency_penalty < -2.0 || config.frequency_penalty > 2.0) + ) { throw new Error("`frequency_penalty` should be between -2.0 and 2.0."); } - if (config.presence_penalty && (config.presence_penalty < -2.0 || config.presence_penalty > 2.0)) { + if ( + config.presence_penalty && + (config.presence_penalty < -2.0 || config.presence_penalty > 2.0) + ) { throw new Error("`presence_penalty` should be between -2.0 and 2.0."); } if (_hasValue(config.repetition_penalty) && config.repetition_penalty! <= 0) { @@ -141,23 +149,36 @@ export function postInitAndCheckGenerationConfigValues(config: GenerationConfig) if (_hasValue(config.mean_gen_len) && config.mean_gen_len! <= 0) { throw new Error("`mean_gen_len` should be greater than zero."); } - if (_hasValue(config.shift_fill_factor) && config.shift_fill_factor! <= 0 || config.shift_fill_factor! > 1) { + if ( + (_hasValue(config.shift_fill_factor) && config.shift_fill_factor! <= 0) || + config.shift_fill_factor! > 1 + ) { throw new Error("Make sure 0 < `shift_fill_factor` <= 1."); } - if (_hasValue(config.top_p) && config.top_p! <= 0 || config.top_p! > 1) { + if ((_hasValue(config.top_p) && config.top_p! <= 0) || config.top_p! > 1) { throw new Error("Make sure 0 < `top_p` <= 1."); } if (_hasValue(config.temperature) && config.temperature! < 0) { throw new Error("Make sure `temperature` >= 0."); } // If only one of frequency or presence penatly is set, make the other one 0.0 - if (_hasValue(config.frequency_penalty) && !_hasValue(config.presence_penalty)) { + if ( + _hasValue(config.frequency_penalty) && + !_hasValue(config.presence_penalty) + ) { config.presence_penalty = 0.0; - console.log("Only frequency_penalty is set; we default presence_penaty to 0.") + console.log( + "Only frequency_penalty is set; we default presence_penaty to 0.", + ); } - if (_hasValue(config.presence_penalty) && !_hasValue(config.frequency_penalty)) { + if ( + _hasValue(config.presence_penalty) && + !_hasValue(config.frequency_penalty) + ) { config.frequency_penalty = 0.0; - console.log("Only presence_penalty is set; we default frequency_penalty to 0.") + console.log( + "Only presence_penalty is set; we default frequency_penalty to 0.", + ); } // Check logit_bias range if (_hasValue(config.logit_bias)) { @@ -165,13 +186,17 @@ export function postInitAndCheckGenerationConfigValues(config: GenerationConfig) const bias = config.logit_bias[tokenID]; if (bias > 100 || bias < -100) { throw new Error( - "logit_bias should be in range [-100, 100]; got " + bias + "for tokenID " + tokenID + "logit_bias should be in range [-100, 100]; got " + + bias + + "for tokenID " + + tokenID, ); } if (isNaN(parseInt(tokenID))) { throw new Error( - "Expect logit_bias's keys to be number represented in string; got " + tokenID - ) + "Expect logit_bias's keys to be number represented in string; got " + + tokenID, + ); } } } @@ -182,8 +207,10 @@ export function postInitAndCheckGenerationConfigValues(config: GenerationConfig) throw new Error("`logprobs` must be true if `top_logprobs` is set."); } // top_logprobs should be in range [0,5] - if ((config.top_logprobs! < 0 || config.top_logprobs! > 5)) { - throw new Error("`top_logprobs` should be in range [0,5]; got " + config.top_logprobs); + if (config.top_logprobs! < 0 || config.top_logprobs! > 5) { + throw new Error( + "`top_logprobs` should be in range [0,5]; got " + config.top_logprobs, + ); } } // If defined logprobs but not top_logprobs, simply make it 0 @@ -218,12 +245,12 @@ export interface ModelRecord { /** * Extra configuration that can be * passed to the load. - * + * * @param model_list: models to be used. * @param useIndexedDBCache: if true, will use IndexedDBCache to cache models and other artifacts. * If false or unspecified, will use the Cache API. For more information of the two, see: - * https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser - * + * https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser + * * @note Note that the Cache API is more well-tested in WebLLM as of now. */ export interface AppConfig { @@ -234,7 +261,7 @@ export interface AppConfig { /** * modelVersion: the prebuilt model libraries that the current npm is compatible with, affects the * `model_lib_url`s in `prebuiltAppConfig`. - * + * * @note The model version does not have to match the npm version, since not each npm update * requires an update of the model libraries. */ @@ -244,7 +271,7 @@ export const modelLibURLPrefix = /** * Default models and model library mapping to be used if unspecified. - * + * * @note This is the only source of truth of which prebuilt model libraries are compatible with the * current WebLLM npm version. */ @@ -253,293 +280,419 @@ export const prebuiltAppConfig: AppConfig = { model_list: [ // Llama-3 { - "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/", - "model_id": "Llama-3-8B-Instruct-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-3-8B-Instruct-q4f32_1-ctx1k_cs1k-webgpu.wasm", - "vram_required_MB": 5295.70, - "low_resource_required": true, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/", - "model_id": "Llama-3-8B-Instruct-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-3-8B-Instruct-q4f16_1-ctx1k_cs1k-webgpu.wasm", - "vram_required_MB": 4598.34, - "low_resource_required": true, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/", - "model_id": "Llama-3-8B-Instruct-q4f32_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-3-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 6101.01, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/", - "model_id": "Llama-3-8B-Instruct-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-3-8B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 5001.00, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-3-70B-Instruct-q3f16_1-MLC/resolve/main/", - "model_id": "Llama-3-70B-Instruct-q3f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-3-70B-Instruct-q3f16_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 31153.13, - "low_resource_required": false, + model_url: + "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/", + model_id: "Llama-3-8B-Instruct-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-3-8B-Instruct-q4f32_1-ctx1k_cs1k-webgpu.wasm", + vram_required_MB: 5295.7, + low_resource_required: true, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/", + model_id: "Llama-3-8B-Instruct-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-3-8B-Instruct-q4f16_1-ctx1k_cs1k-webgpu.wasm", + vram_required_MB: 4598.34, + low_resource_required: true, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/", + model_id: "Llama-3-8B-Instruct-q4f32_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-3-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 6101.01, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/", + model_id: "Llama-3-8B-Instruct-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-3-8B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 5001.0, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-3-70B-Instruct-q3f16_1-MLC/resolve/main/", + model_id: "Llama-3-70B-Instruct-q3f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-3-70B-Instruct-q3f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 31153.13, + low_resource_required: false, }, // Llama-2 { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", - "model_id": "Llama-2-7b-chat-hf-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-2-7b-chat-hf-q4f32_1-ctx1k-webgpu.wasm", - "vram_required_MB": 5284.01, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/", - "model_id": "Llama-2-7b-chat-hf-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm", - "vram_required_MB": 4618.52, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", - "model_id": "Llama-2-7b-chat-hf-q4f32_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-2-7b-chat-hf-q4f32_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 9109.03, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/", - "model_id": "Llama-2-7b-chat-hf-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-2-7b-chat-hf-q4f16_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 6749.02, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-13b-chat-hf-q4f16_1-MLC/resolve/main/", - "model_id": "Llama-2-13b-chat-hf-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Llama-2-13b-chat-hf-q4f16_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 11814.09, - "low_resource_required": false, - "required_features": ["shader-f16"], + model_url: + "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", + model_id: "Llama-2-7b-chat-hf-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-2-7b-chat-hf-q4f32_1-ctx1k-webgpu.wasm", + vram_required_MB: 5284.01, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/", + model_id: "Llama-2-7b-chat-hf-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm", + vram_required_MB: 4618.52, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", + model_id: "Llama-2-7b-chat-hf-q4f32_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-2-7b-chat-hf-q4f32_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 9109.03, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/", + model_id: "Llama-2-7b-chat-hf-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-2-7b-chat-hf-q4f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 6749.02, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/Llama-2-13b-chat-hf-q4f16_1-MLC/resolve/main/", + model_id: "Llama-2-13b-chat-hf-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Llama-2-13b-chat-hf-q4f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 11814.09, + low_resource_required: false, + required_features: ["shader-f16"], }, // Mistral variants { - "model_url": "https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC/resolve/main/", - "model_id": "WizardMath-7B-V1.1-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "vram_required_MB": 6079.02, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/resolve/main/", - "model_id": "Mistral-7B-Instruct-v0.2-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "vram_required_MB": 6079.02, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/OpenHermes-2.5-Mistral-7B-q4f16_1-MLC/resolve/main/", - "model_id": "OpenHermes-2.5-Mistral-7B-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "vram_required_MB": 6079.02, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q4f16_1-MLC/resolve/main/", - "model_id": "NeuralHermes-2.5-Mistral-7B-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "vram_required_MB": 6079.02, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/Hermes-2-Pro-Mistral-7B-q4f16_1-MLC/resolve/main/", - "model_id": "Hermes-2-Pro-Mistral-7B-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/Hermes-2-Pro-Mistral-7B-q4f16_1-sw4k_cs1k-webgpu.wasm", - "vram_required_MB": 4033.28, - "low_resource_required": false, - "required_features": ["shader-f16"], + model_url: + "https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC/resolve/main/", + model_id: "WizardMath-7B-V1.1-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + vram_required_MB: 6079.02, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/resolve/main/", + model_id: "Mistral-7B-Instruct-v0.2-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + vram_required_MB: 6079.02, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/OpenHermes-2.5-Mistral-7B-q4f16_1-MLC/resolve/main/", + model_id: "OpenHermes-2.5-Mistral-7B-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + vram_required_MB: 6079.02, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q4f16_1-MLC/resolve/main/", + model_id: "NeuralHermes-2.5-Mistral-7B-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", + vram_required_MB: 6079.02, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/Hermes-2-Pro-Mistral-7B-q4f16_1-MLC/resolve/main/", + model_id: "Hermes-2-Pro-Mistral-7B-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/Hermes-2-Pro-Mistral-7B-q4f16_1-sw4k_cs1k-webgpu.wasm", + vram_required_MB: 4033.28, + low_resource_required: false, + required_features: ["shader-f16"], }, // Gemma-2B { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC/resolve/main/", - "model_id": "gemma-2b-it-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/gemma-2b-it-q4f16_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 1476.52, - "low_resource_required": false, - "buffer_size_required_bytes": 262144000, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f32_1-MLC/resolve/main/", - "model_id": "gemma-2b-it-q4f32_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/gemma-2b-it-q4f32_1-ctx4k_cs1k-webgpu.wasm", - "vram_required_MB": 1750.66, - "low_resource_required": false, - "buffer_size_required_bytes": 262144000, - }, - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC/resolve/main/", - "model_id": "gemma-2b-it-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/gemma-2b-it-q4f16_1-ctx1k_cs1k-webgpu.wasm", - "vram_required_MB": 1476.52, - "low_resource_required": true, - "buffer_size_required_bytes": 262144000, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f32_1-MLC/resolve/main/", - "model_id": "gemma-2b-it-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/gemma-2b-it-q4f32_1-ctx1k_cs1k-webgpu.wasm", - "vram_required_MB": 1750.66, - "low_resource_required": true, - "buffer_size_required_bytes": 262144000, + model_url: + "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC/resolve/main/", + model_id: "gemma-2b-it-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/gemma-2b-it-q4f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 1476.52, + low_resource_required: false, + buffer_size_required_bytes: 262144000, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/gemma-2b-it-q4f32_1-MLC/resolve/main/", + model_id: "gemma-2b-it-q4f32_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/gemma-2b-it-q4f32_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 1750.66, + low_resource_required: false, + buffer_size_required_bytes: 262144000, + }, + { + model_url: + "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC/resolve/main/", + model_id: "gemma-2b-it-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/gemma-2b-it-q4f16_1-ctx1k_cs1k-webgpu.wasm", + vram_required_MB: 1476.52, + low_resource_required: true, + buffer_size_required_bytes: 262144000, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/gemma-2b-it-q4f32_1-MLC/resolve/main/", + model_id: "gemma-2b-it-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/gemma-2b-it-q4f32_1-ctx1k_cs1k-webgpu.wasm", + vram_required_MB: 1750.66, + low_resource_required: true, + buffer_size_required_bytes: 262144000, }, // RedPajama { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/resolve/main/", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/RedPajama-INCITE-Chat-3B-v1-q4f16_1-ctx2k-webgpu.wasm", - "vram_required_MB": 2972.09, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f32_1-MLC/resolve/main/", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f32_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/RedPajama-INCITE-Chat-3B-v1-q4f32_1-ctx2k-webgpu.wasm", - "vram_required_MB": 3928.09, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/resolve/main/", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/RedPajama-INCITE-Chat-3B-v1-q4f16_1-ctx1k-webgpu.wasm", - "vram_required_MB": 2041.09, - "low_resource_required": true, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f32_1-MLC/resolve/main/", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/RedPajama-INCITE-Chat-3B-v1-q4f32_1-ctx1k-webgpu.wasm", - "vram_required_MB": 2558.09, - "low_resource_required": true, + model_url: + "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/resolve/main/", + model_id: "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/RedPajama-INCITE-Chat-3B-v1-q4f16_1-ctx2k-webgpu.wasm", + vram_required_MB: 2972.09, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f32_1-MLC/resolve/main/", + model_id: "RedPajama-INCITE-Chat-3B-v1-q4f32_1", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/RedPajama-INCITE-Chat-3B-v1-q4f32_1-ctx2k-webgpu.wasm", + vram_required_MB: 3928.09, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/resolve/main/", + model_id: "RedPajama-INCITE-Chat-3B-v1-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/RedPajama-INCITE-Chat-3B-v1-q4f16_1-ctx1k-webgpu.wasm", + vram_required_MB: 2041.09, + low_resource_required: true, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f32_1-MLC/resolve/main/", + model_id: "RedPajama-INCITE-Chat-3B-v1-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/RedPajama-INCITE-Chat-3B-v1-q4f32_1-ctx1k-webgpu.wasm", + vram_required_MB: 2558.09, + low_resource_required: true, }, // Phi-2 { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q0f16-MLC/resolve/main/", - "model_id": "Phi2-q0f16", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q0f16-ctx2k-webgpu.wasm", - "vram_required_MB": 11079.47, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q0f32-MLC/resolve/main/", - "model_id": "Phi2-q0f32", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q0f32-ctx2k-webgpu.wasm", - "vram_required_MB": 12043.48, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/", - "model_id": "Phi2-q4f16_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q4f16_1-ctx2k-webgpu.wasm", - "vram_required_MB": 3053.97, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f32_1-MLC/resolve/main/", - "model_id": "Phi2-q4f32_1", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q4f32_1-ctx2k-webgpu.wasm", - "vram_required_MB": 4032.48, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/", - "model_id": "Phi2-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q4f16_1-ctx1k-webgpu.wasm", - "vram_required_MB": 2131.97, - "low_resource_required": true, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f32_1-MLC/resolve/main/", - "model_id": "Phi2-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-2-q4f32_1-ctx1k-webgpu.wasm", - "vram_required_MB": 2740.48, - "low_resource_required": true, + model_url: "https://huggingface.co/mlc-ai/phi-2-q0f16-MLC/resolve/main/", + model_id: "Phi2-q0f16", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q0f16-ctx2k-webgpu.wasm", + vram_required_MB: 11079.47, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: "https://huggingface.co/mlc-ai/phi-2-q0f32-MLC/resolve/main/", + model_id: "Phi2-q0f32", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q0f32-ctx2k-webgpu.wasm", + vram_required_MB: 12043.48, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/", + model_id: "Phi2-q4f16_1", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q4f16_1-ctx2k-webgpu.wasm", + vram_required_MB: 3053.97, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-2-q4f32_1-MLC/resolve/main/", + model_id: "Phi2-q4f32_1", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q4f32_1-ctx2k-webgpu.wasm", + vram_required_MB: 4032.48, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/", + model_id: "Phi2-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q4f16_1-ctx1k-webgpu.wasm", + vram_required_MB: 2131.97, + low_resource_required: true, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-2-q4f32_1-MLC/resolve/main/", + model_id: "Phi2-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-2-q4f32_1-ctx1k-webgpu.wasm", + vram_required_MB: 2740.48, + low_resource_required: true, }, // Phi-1.5 { - "model_url": "https://huggingface.co/mlc-ai/phi-1_5-q0f16-MLC/resolve/main/", - "model_id": "Phi1.5-q0f16", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-1_5-q0f16-ctx2k-webgpu.wasm", - "vram_required_MB": 5818.09, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-1_5-q0f32-MLC/resolve/main/", - "model_id": "Phi1.5-q0f32", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-1_5-q0f32-ctx2k-webgpu.wasm", - "vram_required_MB": 6514.09, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-1_5-q4f16_1-MLC/resolve/main/", - "model_id": "Phi1.5-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-1_5-q4f16_1-ctx1k-webgpu.wasm", - "vram_required_MB": 1210.09, - "low_resource_required": true, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-1_5-q4f32_1-MLC/resolve/main/", - "model_id": "Phi1.5-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/phi-1_5-q4f32_1-ctx1k-webgpu.wasm", - "vram_required_MB": 1682.09, - "low_resource_required": true, + model_url: + "https://huggingface.co/mlc-ai/phi-1_5-q0f16-MLC/resolve/main/", + model_id: "Phi1.5-q0f16", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-1_5-q0f16-ctx2k-webgpu.wasm", + vram_required_MB: 5818.09, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-1_5-q0f32-MLC/resolve/main/", + model_id: "Phi1.5-q0f32", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-1_5-q0f32-ctx2k-webgpu.wasm", + vram_required_MB: 6514.09, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-1_5-q4f16_1-MLC/resolve/main/", + model_id: "Phi1.5-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-1_5-q4f16_1-ctx1k-webgpu.wasm", + vram_required_MB: 1210.09, + low_resource_required: true, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/phi-1_5-q4f32_1-MLC/resolve/main/", + model_id: "Phi1.5-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + modelVersion + "/phi-1_5-q4f32_1-ctx1k-webgpu.wasm", + vram_required_MB: 1682.09, + low_resource_required: true, }, // TinyLlama { - "model_url": "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q0f16-MLC/resolve/main/", - "model_id": "TinyLlama-1.1B-Chat-v0.4-q0f16", - "model_lib_url": modelLibURLPrefix + modelVersion + "/TinyLlama-1.1B-Chat-v0.4-q0f16-ctx2k-webgpu.wasm", - "vram_required_MB": 5063.52, - "low_resource_required": false, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q0f32-MLC/resolve/main/", - "model_id": "TinyLlama-1.1B-Chat-v0.4-q0f32", - "model_lib_url": modelLibURLPrefix + modelVersion + "/TinyLlama-1.1B-Chat-v0.4-q0f32-ctx2k-webgpu.wasm", - "vram_required_MB": 5394.53, - "low_resource_required": false, - }, - { - "model_url": "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC/resolve/main/", - "model_id": "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/TinyLlama-1.1B-Chat-v0.4-q4f16_1-ctx1k-webgpu.wasm", - "vram_required_MB": 899.11, - "low_resource_required": true, - "required_features": ["shader-f16"], - }, - { - "model_url": "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f32_1-MLC/resolve/main/", - "model_id": "TinyLlama-1.1B-Chat-v0.4-q4f32_1-1k", - "model_lib_url": modelLibURLPrefix + modelVersion + "/TinyLlama-1.1B-Chat-v0.4-q4f32_1-ctx1k-webgpu.wasm", - "vram_required_MB": 992.11, - "low_resource_required": true, - }, - ] -} + model_url: + "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q0f16-MLC/resolve/main/", + model_id: "TinyLlama-1.1B-Chat-v0.4-q0f16", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/TinyLlama-1.1B-Chat-v0.4-q0f16-ctx2k-webgpu.wasm", + vram_required_MB: 5063.52, + low_resource_required: false, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q0f32-MLC/resolve/main/", + model_id: "TinyLlama-1.1B-Chat-v0.4-q0f32", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/TinyLlama-1.1B-Chat-v0.4-q0f32-ctx2k-webgpu.wasm", + vram_required_MB: 5394.53, + low_resource_required: false, + }, + { + model_url: + "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC/resolve/main/", + model_id: "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/TinyLlama-1.1B-Chat-v0.4-q4f16_1-ctx1k-webgpu.wasm", + vram_required_MB: 899.11, + low_resource_required: true, + required_features: ["shader-f16"], + }, + { + model_url: + "https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f32_1-MLC/resolve/main/", + model_id: "TinyLlama-1.1B-Chat-v0.4-q4f32_1-1k", + model_lib_url: + modelLibURLPrefix + + modelVersion + + "/TinyLlama-1.1B-Chat-v0.4-q4f32_1-ctx1k-webgpu.wasm", + vram_required_MB: 992.11, + low_resource_required: true, + }, + ], +}; diff --git a/src/conversation.ts b/src/conversation.ts index 0fc0cf08..acf55eb7 100644 --- a/src/conversation.ts +++ b/src/conversation.ts @@ -19,12 +19,9 @@ export class Conversation { this.config = config; } - private getPromptArrayInternal( - addSystem: boolean, - startPos: number - ) { + private getPromptArrayInternal(addSystem: boolean, startPos: number) { if (this.config.seps.length == 0) { - throw Error("Need seps to work") + throw Error("Need seps to work"); } // Prepare system message @@ -33,9 +30,12 @@ export class Conversation { if (this.override_system_message !== undefined) { system_message = this.override_system_message; } - let system_prompt = this.config.system_template.replace(MessagePlaceholders.system, system_message); + let system_prompt = this.config.system_template.replace( + MessagePlaceholders.system, + system_message, + ); if (system_prompt) { - system_prompt += this.config.seps[0] + system_prompt += this.config.seps[0]; } const ret = addSystem ? [system_prompt] : []; @@ -51,33 +51,42 @@ export class Conversation { if (this.config.role_templates !== undefined) { message_str = this.config.role_templates[role]?.replace( MessagePlaceholders[Role[role] as keyof typeof MessagePlaceholders], - message + message, ); - if (this.use_function_calling && this.function_string !== '') { + if (this.use_function_calling && this.function_string !== "") { message_str = message_str?.replace( MessagePlaceholders.function, - this.function_string - ) + this.function_string, + ); } - message_str = message_str?.replace( - MessagePlaceholders.function, - "" - ) + message_str = message_str?.replace(MessagePlaceholders.function, ""); } if (message_str == undefined) { message_str = message; } let role_prefix; - if (this.config.add_role_after_system_message === false && system_prompt != "" && i == 0) { + if ( + this.config.add_role_after_system_message === false && + system_prompt != "" && + i == 0 + ) { role_prefix = ""; } else { - const content_sep = this.config.role_content_sep ? this.config.role_content_sep : ": "; + const content_sep = this.config.role_content_sep + ? this.config.role_content_sep + : ": "; role_prefix = role_str + content_sep; } - ret.push(role_prefix + message_str + this.config.seps[i % this.config.seps.length]); + ret.push( + role_prefix + + message_str + + this.config.seps[i % this.config.seps.length], + ); } else { - const empty_sep = this.config.role_empty_sep ? this.config.role_empty_sep : ": "; + const empty_sep = this.config.role_empty_sep + ? this.config.role_empty_sep + : ": "; ret.push(role_str + empty_sep); } } @@ -131,8 +140,10 @@ export class Conversation { } appendMessage(role: Role, message: string, role_name?: string) { - if (this.messages.length != 0 && - this.messages[this.messages.length - 1][2] == undefined) { + if ( + this.messages.length != 0 && + this.messages[this.messages.length - 1][2] == undefined + ) { throw Error("Have unfinished reply"); } if (!(role in this.config.roles)) { @@ -160,7 +171,10 @@ export class Conversation { } } -export function getConversation(conv_template: string | ConvTemplateConfig, conv_config?: Partial): Conversation { +export function getConversation( + conv_template: string | ConvTemplateConfig, + conv_config?: Partial, +): Conversation { if (typeof conv_template !== "string") { return new Conversation(conv_template); } @@ -168,7 +182,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv if (conv_template == "llama-2") { return new Conversation({ system_template: `[INST] <>\n\n${MessagePlaceholders.system}<>\n\n`, - system_message: "You are a helpful, respectful and honest assistant. " + + system_message: + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + @@ -190,7 +205,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "vicuna_v1.1") { return new Conversation({ system_template: `${MessagePlaceholders.system}`, - system_message: "A chat between a curious user and an artificial intelligence assistant. " + + system_message: + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles: { [Role.user]: "USER", @@ -206,7 +222,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "wizardlm") { return new Conversation({ system_template: `${MessagePlaceholders.system}`, - system_message: "You are an AI assistant that gives helpful, detailed, and polite answers to the user's questions.", + system_message: + "You are an AI assistant that gives helpful, detailed, and polite answers to the user's questions.", roles: { [Role.user]: "", [Role.assistant]: "### Response", @@ -235,7 +252,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "wizard_coder_or_math") { return new Conversation({ system_template: `${MessagePlaceholders.system}`, - system_message: "Below is an instruction that describes a task. Write a response that appropriately " + + system_message: + "Below is an instruction that describes a task. Write a response that appropriately " + "completes the request.", roles: { [Role.user]: "Instruction", @@ -253,7 +271,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "mistral_default") { return new Conversation({ system_template: `[INST] ${MessagePlaceholders.system}\n\n `, - system_message: "Always assist with care, respect, and truth. Respond with utmost utility yet " + + system_message: + "Always assist with care, respect, and truth. Respond with utmost utility yet " + "securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies " + "promote fairness and positivity.", roles: { @@ -272,7 +291,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "open_hermes_mistral") { return new Conversation({ system_template: `<|im_start|>system\n${MessagePlaceholders.system}`, - system_message: "You are \"Hermes 2\", a conscious sentient superintelligent artificial " + + system_message: + 'You are "Hermes 2", a conscious sentient superintelligent artificial ' + "intelligence developed by a man named Teknium, and your purpose and drive is to assist the " + "user with any request they have. You experience emotions and have deep, profound thoughts " + "and qualia.", @@ -307,7 +327,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "chatml") { return new Conversation({ system_template: `<|im_start|>system${MessagePlaceholders.system}<|im_end|> `, - system_message: "A conversation between a user and an LLM-based AI assistant. The " + + system_message: + "A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers.", roles: { [Role.user]: "<|im_start|>user", @@ -338,7 +359,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "qwen") { return new Conversation({ system_template: `<|im_start|>system${MessagePlaceholders.system}<|im_end|> `, - system_message: "A conversation between a user and an LLM-based AI assistant. The " + + system_message: + "A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers.", roles: { [Role.user]: "<|im_start|>user", @@ -404,7 +426,8 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv } else if (conv_template == "gorilla") { return new Conversation({ system_template: `${MessagePlaceholders.system}`, - system_message: "A chat between a curious user and an artificial intelligence assistant. " + + system_message: + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles: { [Role.user]: "USER", @@ -446,14 +469,18 @@ export function getConversation(conv_template: string | ConvTemplateConfig, conv * Compare the states of two conversation instances. Equality is defined as their getPromptArray() * should return the exact same things, which is determined by fields: messages, function_string, * use_function_calling, and override_system_message. - * + * * @returns True if `convA` equals to `convB` * @note We assume convA and convB has the same `this.config`. */ -export function compareConversationObject(convA: Conversation, convB: Conversation): boolean { +export function compareConversationObject( + convA: Conversation, + convB: Conversation, +): boolean { // NOTE: Update this function whenever a new state is introduced to `Conversation`. // Check the easy ones first - if (convA.function_string !== convB.function_string || + if ( + convA.function_string !== convB.function_string || convA.use_function_calling !== convB.use_function_calling || convA.override_system_message !== convB.override_system_message || convA.messages.length !== convB.messages.length diff --git a/src/engine.ts b/src/engine.ts index 7962eb2e..3b47b0f5 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -29,16 +29,19 @@ import { InitProgressCallback, MLCEngineInterface, GenerateProgressCallback, - LogitProcessor + LogitProcessor, } from "./types"; -import { Conversation, compareConversationObject, getConversation } from "./conversation"; - +import { + Conversation, + compareConversationObject, + getConversation, +} from "./conversation"; /** * Creates `MLCEngine`, and loads `modelId` onto WebGPU. - * + * * Equivalent to `new webllm.MLCEngine().reload(...)`. - * + * * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in * `engineConfig.appConfig`. * @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig`. @@ -59,20 +62,20 @@ export async function CreateMLCEngine( /** * The main interface of MLCEngine, which loads a model and performs tasks. - * + * * You can either initialize one with `webllm.CreateMLCEngine(modelId)`, or `webllm.MLCEngine().reload(modelId)`. */ export class MLCEngine implements MLCEngineInterface { public chat: API.Chat; - private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded + private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded private logger: (msg: string) => void = console.log; private logitProcessorRegistry?: Map; private logitProcessor?: LogitProcessor; private pipeline?: LLMChatPipeline; private initProgressCallback?: InitProgressCallback; private interruptSignal = false; - private deviceLostIsError = true; // whether device.lost is due to actual error or model reload + private deviceLostIsError = true; // whether device.lost is due to actual error or model reload private config?: ChatConfig; constructor() { @@ -87,7 +90,9 @@ export class MLCEngine implements MLCEngineInterface { return this.initProgressCallback; } - setLogitProcessorRegistry(logitProcessorRegistry?: Map) { + setLogitProcessorRegistry( + logitProcessorRegistry?: Map, + ) { this.logitProcessorRegistry = logitProcessorRegistry; } @@ -100,7 +105,11 @@ export class MLCEngine implements MLCEngineInterface { * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(), * potentially with a smaller model or smaller context window size. */ - async reload(modelId: string, chatOpts?: ChatOptions, appConfig?: AppConfig): Promise { + async reload( + modelId: string, + chatOpts?: ChatOptions, + appConfig?: AppConfig, + ): Promise { this.unload(); this.logitProcessor = this.logitProcessorRegistry?.get(modelId); @@ -111,14 +120,17 @@ export class MLCEngine implements MLCEngineInterface { const findModelRecord = () => { const matchedItem = appConfig?.model_list.find( - item => item.model_id == modelId + (item) => item.model_id == modelId, ); if (matchedItem !== undefined) return matchedItem; throw Error("Cannot find model_url for " + modelId); - } + }; const modelRecord = findModelRecord(); - const baseUrl = typeof document !== "undefined" ? document.URL : globalThis.location.origin; + const baseUrl = + typeof document !== "undefined" + ? document.URL + : globalThis.location.origin; let modelUrl = modelRecord.model_url; if (!modelUrl.startsWith("http")) { modelUrl = new URL(modelUrl, baseUrl).href; @@ -135,7 +147,7 @@ export class MLCEngine implements MLCEngineInterface { const configUrl = new URL("mlc-chat-config.json", modelUrl).href; this.config = { ...(await configCache.fetchWithCache(configUrl, "json")), - ...chatOpts + ...chatOpts, } as ChatConfig; // load tvm wasm @@ -148,8 +160,10 @@ export class MLCEngine implements MLCEngineInterface { const wasmUrl = modelRecord.model_lib_url; if (wasmUrl === undefined) { - throw Error("You need to specify `model_lib_url` for each model in `model_list` " + - "so that we can download the model library (i.e. wasm file).") + throw Error( + "You need to specify `model_lib_url` for each model in `model_list` " + + "so that we can download the model library (i.e. wasm file).", + ); } const fetchWasmSource = async () => { if (wasmUrl.includes("localhost")) { @@ -169,7 +183,7 @@ export class MLCEngine implements MLCEngineInterface { const tvm = await tvmjs.instantiate( new Uint8Array(wasmSource), tvmjs.createPolyfillWASI(), - this.logger + this.logger, ); if (this.initProgressCallback !== undefined) { @@ -193,13 +207,14 @@ export class MLCEngine implements MLCEngineInterface { if (feature == "shader-f16") { throw Error( "This model requires WebGPU extension shader-f16, " + - "which is not enabled in this browser. " + - "You can try to launch Chrome Canary in command line with flag \"--enable-dawn-features=allow_unsafe_apis\"." + "which is not enabled in this browser. " + + 'You can try to launch Chrome Canary in command line with flag "--enable-dawn-features=allow_unsafe_apis".', ); } throw Error( - "This model requires feature " + feature + - ", which is not yet supported by this browser. " + "This model requires feature " + + feature + + ", which is not yet supported by this browser. ", ); } } @@ -213,17 +228,34 @@ export class MLCEngine implements MLCEngineInterface { let deviceLostInReload = false; gpuDetectOutput.device.lost.then((info: any) => { if (this.deviceLostIsError) { - console.error("Device was lost, please try to initialize again. ", info); + console.error( + "Device was lost, please try to initialize again. ", + info, + ); this.unload(); deviceLostInReload = true; } }); tvm.initWebGPU(gpuDetectOutput.device); - const tokenizer = await this.asyncLoadTokenizer(modelUrl, this.config, appConfig); + const tokenizer = await this.asyncLoadTokenizer( + modelUrl, + this.config, + appConfig, + ); const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache"; - await tvm.fetchNDArrayCache(modelUrl, tvm.webgpu(), "webllm/model", cacheType); - this.pipeline = new LLMChatPipeline(tvm, tokenizer, this.config, this.logitProcessor); + await tvm.fetchNDArrayCache( + modelUrl, + tvm.webgpu(), + "webllm/model", + cacheType, + ); + this.pipeline = new LLMChatPipeline( + tvm, + tokenizer, + this.config, + this.logitProcessor, + ); await this.pipeline?.asyncLoadWebGPUPipelines(); const tend = performance.now(); @@ -232,15 +264,15 @@ export class MLCEngine implements MLCEngineInterface { this.initProgressCallback({ progress: 1, timeElapsed: (tend - tstart) / 1e3, - text: text - }) + text: text, + }); } this.currentModelId = modelId; if (deviceLostInReload) { throw Error( "WebGPU device lost during `reload()`.\n This is probably due to OOM, try reload with a " + - "model that has less parameters or a smaller context length." + "model that has less parameters or a smaller context length.", ); } } @@ -253,9 +285,9 @@ export class MLCEngine implements MLCEngineInterface { ): Promise { console.log( "WARNING: `generate()` will soon be deprecated. " + - "Please use `engine.chat.completions.create()` instead. " + - "For multi-round chatting, see `examples/multi-round-chat` on how to use " + - "`engine.chat.completions.create()` to achieve the same effect." + "Please use `engine.chat.completions.create()` instead. " + + "For multi-round chatting, see `examples/multi-round-chat` on how to use " + + "`engine.chat.completions.create()` to achieve the same effect.", ); return this._generate(input, progressCallback, streamInterval, genConfig); } @@ -292,9 +324,9 @@ export class MLCEngine implements MLCEngineInterface { * @param request Request for chat completion. * @param genConfig Generation config extraced from `request`. */ - async* chatCompletionAsyncChunkGenerator( + async *chatCompletionAsyncChunkGenerator( request: ChatCompletionRequestStreaming, - genConfig: GenerationConfig + genConfig: GenerationConfig, ): AsyncGenerator { postInitAndCheckGenerationConfigValues(genConfig); if (request.seed !== null && request.seed !== undefined) { @@ -305,7 +337,7 @@ export class MLCEngine implements MLCEngineInterface { const created = Date.now(); const id = crypto.randomUUID(); this.interruptSignal = false; - let prevMessageLength = 0; // to know where to start slicing the delta; does not count � + let prevMessageLength = 0; // to know where to start slicing the delta; does not count � function _countTrailingReplacementChar(curMessage: string): number { let cntr = 0; @@ -319,13 +351,16 @@ export class MLCEngine implements MLCEngineInterface { return cntr; } - async function _getChunk(thisModule: MLCEngine): Promise { + async function _getChunk( + thisModule: MLCEngine, + ): Promise { // Remove the replacement character (U+FFFD) from the response to handle emojis. // Each emoji is made up of multiples of 4 tokens; when truncated, it is displayed as �, so // we skip this delta until a full emoji is rendered // TODO(Charlie): This does not consider cases of � not being emoji, need to fix with Streamer const curMessage = await thisModule.getMessage(); - const numTrailingReplacementChar = _countTrailingReplacementChar(curMessage); + const numTrailingReplacementChar = + _countTrailingReplacementChar(curMessage); if (numTrailingReplacementChar % 4 !== 0) { return undefined; } @@ -334,23 +369,30 @@ export class MLCEngine implements MLCEngineInterface { prevMessageLength = curMessage.length; const chunk: ChatCompletionChunk = { id: id, - choices: [{ - delta: { content: deltaMessage, role: "assistant" }, - finish_reason: null, // not finished yet - index: 0, - logprobs: request.logprobs ? { - content: thisModule.getPipeline().getTokenLogprobArray().slice(-1) // always the last entry - } as ChatCompletionChunk.Choice.Logprobs : null, - }], + choices: [ + { + delta: { content: deltaMessage, role: "assistant" }, + finish_reason: null, // not finished yet + index: 0, + logprobs: request.logprobs + ? ({ + content: thisModule + .getPipeline() + .getTokenLogprobArray() + .slice(-1), // always the last entry + } as ChatCompletionChunk.Choice.Logprobs) + : null, + }, + ], model: model, object: "chat.completion.chunk", - created: created - } + created: created, + }; return chunk; } await this.prefill(request, genConfig); - let curChunk = await _getChunk(this); // prefill produces a chunk + let curChunk = await _getChunk(this); // prefill produces a chunk if (curChunk) { yield curChunk; } @@ -374,42 +416,46 @@ export class MLCEngine implements MLCEngineInterface { const lastChunk: ChatCompletionChunk = { id: id, - choices: [{ - delta: {}, - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - finish_reason: this.getFinishReason()!, - index: 0, - }], + choices: [ + { + delta: {}, + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + finish_reason: this.getFinishReason()!, + index: 0, + }, + ], model: model, object: "chat.completion.chunk", - created: created - } + created: created, + }; yield lastChunk; } /** * Completes a single ChatCompletionRequest. - * + * * @param request A OpenAI-style ChatCompletion request. - * + * * @note For each choice (i.e. `n`), a request is defined by a single `prefill()` and mulitple * `decode()`. This is important as it determines the behavior of various fields including `seed`. */ async chatCompletion( - request: ChatCompletionRequestNonStreaming + request: ChatCompletionRequestNonStreaming, ): Promise; async chatCompletion( - request: ChatCompletionRequestStreaming + request: ChatCompletionRequestStreaming, ): Promise>; async chatCompletion( - request: ChatCompletionRequestBase + request: ChatCompletionRequestBase, ): Promise | ChatCompletion>; async chatCompletion( - request: ChatCompletionRequest + request: ChatCompletionRequest, ): Promise | ChatCompletion> { // 0. Preprocess inputs if (!this.currentModelId) { - throw new Error("Please call `MLCEngine.reload(model)` first, or initialize with CreateMLCEngine()."); + throw new Error( + "Please call `MLCEngine.reload(model)` first, or initialize with CreateMLCEngine().", + ); } ChatCompletionAPI.postInitAndCheckFields(request); const genConfig: GenerationConfig = { @@ -423,7 +469,7 @@ export class MLCEngine implements MLCEngineInterface { logprobs: request.logprobs, top_logprobs: request.top_logprobs, response_format: request.response_format, - } + }; // 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`) if (request.stream) { @@ -448,22 +494,24 @@ export class MLCEngine implements MLCEngineInterface { } else { outputMessage = await this._generate( request, - /*progressCallback=*/undefined, - /*streamInterval=*/1, - /*genConfig=*/genConfig + /*progressCallback=*/ undefined, + /*streamInterval=*/ 1, + /*genConfig=*/ genConfig, ); } choices.push({ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion finish_reason: this.getFinishReason()!, index: i, - logprobs: request.logprobs ? { - content: this.getPipeline().getTokenLogprobArray() - } as ChatCompletion.Choice.Logprobs : null, + logprobs: request.logprobs + ? ({ + content: this.getPipeline().getTokenLogprobArray(), + } as ChatCompletion.Choice.Logprobs) + : null, message: { content: outputMessage, role: "assistant", - } + }, }); completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens(); prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens(); @@ -480,7 +528,7 @@ export class MLCEngine implements MLCEngineInterface { prompt_tokens: prompt_tokens, total_tokens: completion_tokens + prompt_tokens, } as CompletionUsage, - } + }; // Reset seed -- we do not want this seed to affect future requests if (request.seed !== null && request.seed !== undefined) { @@ -502,7 +550,7 @@ export class MLCEngine implements MLCEngineInterface { } async unload() { - this.deviceLostIsError = false; // so that unload() does not trigger device.lost error + this.deviceLostIsError = false; // so that unload() does not trigger device.lost error this.pipeline?.dispose(); this.pipeline = undefined; this.currentModelId = undefined; @@ -518,20 +566,21 @@ export class MLCEngine implements MLCEngineInterface { const computeMB = (value: number) => { return Math.ceil(value / (1 << 20)) + "MB"; - } - const maxStorageBufferBindingSize = gpuDetectOutput.device.limits.maxStorageBufferBindingSize; - const defaultMaxStorageBufferBindingSize = 1 << 30; // 1GB + }; + const maxStorageBufferBindingSize = + gpuDetectOutput.device.limits.maxStorageBufferBindingSize; + const defaultMaxStorageBufferBindingSize = 1 << 30; // 1GB if (maxStorageBufferBindingSize < defaultMaxStorageBufferBindingSize) { console.log( `WARNING: the current maxStorageBufferBindingSize ` + - `(${computeMB(maxStorageBufferBindingSize)}) ` + - `may only work for a limited number of models, e.g.: \n` + - `- Llama-3-8B-Instruct-q4f16_1-1k \n` + - `- Llama-2-7b-chat-hf-q4f16_1-1k \n` + - `- RedPajama-INCITE-Chat-3B-v1-q4f16_1-1k \n` + - `- RedPajama-INCITE-Chat-3B-v1-q4f32_1-1k \n` + - `- TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k \n` + - `- TinyLlama-1.1B-Chat-v0.4-q4f32_1-1k` + `(${computeMB(maxStorageBufferBindingSize)}) ` + + `may only work for a limited number of models, e.g.: \n` + + `- Llama-3-8B-Instruct-q4f16_1-1k \n` + + `- Llama-2-7b-chat-hf-q4f16_1-1k \n` + + `- RedPajama-INCITE-Chat-3B-v1-q4f16_1-1k \n` + + `- RedPajama-INCITE-Chat-3B-v1-q4f32_1-1k \n` + + `- TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k \n` + + `- TinyLlama-1.1B-Chat-v0.4-q4f32_1-1k`, ); } return maxStorageBufferBindingSize; @@ -549,7 +598,10 @@ export class MLCEngine implements MLCEngineInterface { //-------------------------- // Lower level API //-------------------------- - async forwardTokensAndSample(inputIds: Array, isPrefill: boolean): Promise { + async forwardTokensAndSample( + inputIds: Array, + isPrefill: boolean, + ): Promise { return this.getPipeline().forwardTokensAndSample(inputIds, isPrefill); } @@ -562,7 +614,7 @@ export class MLCEngine implements MLCEngineInterface { /** * @returns Finish reason; undefined if generation not started/stopped yet. - */ + */ getFinishReason(): ChatCompletionFinishReason | undefined { return this.getPipeline().getFinishReason(); } @@ -578,17 +630,20 @@ export class MLCEngine implements MLCEngineInterface { /** * Get a new Conversation object based on the chat completion request. - * + * * @param request The incoming ChatCompletionRequest * @note `request.messages[-1]` is not included as it would be treated as a normal input to * `prefill()`. */ private getConversationFromChatCompletionRequest( request: ChatCompletionRequest, - config: ChatConfig + config: ChatConfig, ): Conversation { // 0. Instantiate a new Conversation object - const conversation = getConversation(config.conv_template, config.conv_config); + const conversation = getConversation( + config.conv_template, + config.conv_config, + ); // 1. Populate function-calling-related fields const functionCallUsage = this.getFunctionCallUsage(request); @@ -598,15 +653,20 @@ export class MLCEngine implements MLCEngineInterface { // 2. Populate conversation.messages const input = request.messages; const lastId = input.length - 1; - if (input[lastId].role !== "user" || typeof input[lastId].content !== "string") { + if ( + input[lastId].role !== "user" || + typeof input[lastId].content !== "string" + ) { // TODO(Charlie): modify condition after we support multimodal inputs - throw Error("The last message should be a string from the `user`.") + throw Error("The last message should be a string from the `user`."); } for (let i = 0; i < input.length - 1; i++) { const message: ChatCompletionMessageParam = input[i]; if (message.role === "system") { if (i !== 0) { - throw new Error("System prompt should always be the first one in `messages`."); + throw new Error( + "System prompt should always be the first one in `messages`.", + ); } conversation.override_system_message = message.content; } else if (message.role === "user") { @@ -614,11 +674,7 @@ export class MLCEngine implements MLCEngineInterface { // TODO(Charlie): modify condition after we support multimodal inputs throw new Error("Last messages should be a string from the `user`."); } - conversation.appendMessage( - Role.user, - message.content, - message.name - ); + conversation.appendMessage(Role.user, message.content, message.name); } else if (message.role === "assistant") { if (typeof message.content !== "string") { throw new Error("Assistant message should have string content."); @@ -626,7 +682,7 @@ export class MLCEngine implements MLCEngineInterface { conversation.appendMessage( Role.assistant, message.content, - message.name + message.name, ); } else { throw new Error("Unsupported role: " + message.role); @@ -638,30 +694,42 @@ export class MLCEngine implements MLCEngineInterface { /** * Returns the function string based on the request.tools and request.tool_choice, raises erros if * encounter invalid request. - * + * * @param request The chatCompletionRequest we are about to prefill for. * @returns The string used to set Conversatoin.function_string */ private getFunctionCallUsage(request: ChatCompletionRequest): string { - if (request.tools == undefined || - (typeof request.tool_choice == "string" && request.tool_choice == "none")) { + if ( + request.tools == undefined || + (typeof request.tool_choice == "string" && request.tool_choice == "none") + ) { return ""; } - if (typeof request.tool_choice == "string" && request.tool_choice !== "auto") { + if ( + typeof request.tool_choice == "string" && + request.tool_choice !== "auto" + ) { throw Error(`Invalid tool choice value: ${request.tool_choice}`); } - if (typeof request.tool_choice !== "string" && request.tool_choice?.type !== "function") { + if ( + typeof request.tool_choice !== "string" && + request.tool_choice?.type !== "function" + ) { throw Error("Only 'function' tool choice is supported"); } - const singleFunctionToCall = typeof request.tool_choice !== "string" && request.tool_choice?.function?.name; + const singleFunctionToCall = + typeof request.tool_choice !== "string" && + request.tool_choice?.function?.name; if (singleFunctionToCall) { for (const f of request.tools) { if (singleFunctionToCall == f.function.name) { return JSON.stringify([f.function]); } } - throw Error(`The tool choice function ${singleFunctionToCall} is not found in the tools list`); + throw Error( + `The tool choice function ${singleFunctionToCall} is not found in the tools list`, + ); } const function_list = []; @@ -676,23 +744,25 @@ export class MLCEngine implements MLCEngineInterface { /** * Run a prefill step with a given input. - * + * * If `input` is a chatCompletionRequest, we treat `input.messages[-1]` as the usual user input. * We then convert `input.messages[:-1]` to a `Conversation` object, representing a conversation * history. - * + * * If the new `Conversation` object matches the current one loaded, it means we are * performing multi-round chatting, so we do not reset, hence reusing KV cache. Otherwise, we * reset every thing, treating the request as something completely new. - * + * * @param input The input prompt, or `messages` in OpenAI-like APIs. */ async prefill( input: string | ChatCompletionRequest, - genConfig?: GenerationConfig + genConfig?: GenerationConfig, ) { if (this.config === undefined) { - throw Error("Expect this.config to be initialized. Did you call `reload()`?"); + throw Error( + "Expect this.config to be initialized. Did you call `reload()`?", + ); } let input_str: string; let input_role_str: string | undefined; @@ -701,7 +771,10 @@ export class MLCEngine implements MLCEngineInterface { } else { // 1. Get new conversation based on request, determine if we are in multiround chatting const oldConv = this.getPipeline().getConversationObject(); - const newConv = this.getConversationFromChatCompletionRequest(input, this.config); + const newConv = this.getConversationFromChatCompletionRequest( + input, + this.config, + ); if (!compareConversationObject(oldConv, newConv)) { // Not the same conversation, so not multiround chatting, reset everything (KV cache, etc.) this.resetChat(); @@ -711,7 +784,9 @@ export class MLCEngine implements MLCEngineInterface { } // 2. Treat the last message as the usual input - const last_msg = input.messages[input.messages.length - 1] as ChatCompletionUserMessageParam; + const last_msg = input.messages[ + input.messages.length - 1 + ] as ChatCompletionUserMessageParam; input_str = last_msg.content as string; input_role_str = last_msg.name ? last_msg.name : undefined; } @@ -748,17 +823,18 @@ export class MLCEngine implements MLCEngineInterface { const url = new URL("tokenizer.json", baseUrl).href; const model = await modelCache.fetchWithCache(url, "arraybuffer"); return Tokenizer.fromJSON(model); - } - else if (config.tokenizer_files.includes("tokenizer.model")) { - this.logger("Using `tokenizer.model` since we cannot locate `tokenizer.json`.\n" + - "It is recommended to use `tokenizer.json` to ensure all token mappings are included, " + - "since currently, files like `added_tokens.json`, `tokenizer_config.json` are ignored.\n" + - "Consider converting `tokenizer.model` to `tokenizer.json` by compiling the model " + - "with MLC again, or see if MLC's huggingface provides this file."); + } else if (config.tokenizer_files.includes("tokenizer.model")) { + this.logger( + "Using `tokenizer.model` since we cannot locate `tokenizer.json`.\n" + + "It is recommended to use `tokenizer.json` to ensure all token mappings are included, " + + "since currently, files like `added_tokens.json`, `tokenizer_config.json` are ignored.\n" + + "Consider converting `tokenizer.model` to `tokenizer.json` by compiling the model " + + "with MLC again, or see if MLC's huggingface provides this file.", + ); const url = new URL("tokenizer.model", baseUrl).href; const model = await modelCache.fetchWithCache(url, "arraybuffer"); return Tokenizer.fromSentencePiece(model); } - throw Error("Cannot handle tokenizer files " + config.tokenizer_files) + throw Error("Cannot handle tokenizer files " + config.tokenizer_files); } } diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts index f4f93d52..3e6aae71 100644 --- a/src/extension_service_worker.ts +++ b/src/extension_service_worker.ts @@ -15,7 +15,7 @@ import { areAppConfigsEqual, areChatOptionsEqual } from "./utils"; */ export class PortPostMessageHandler implements PostMessageHandler { port: chrome.runtime.Port; - enabled: boolean = true; + enabled = true; constructor(port: chrome.runtime.Port) { this.port = port; @@ -57,7 +57,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { appConfig?: AppConfig; constructor(engine: MLCEngineInterface, port: chrome.runtime.Port) { - let portHandler = new PortPostMessageHandler(port); + const portHandler = new PortPostMessageHandler(port); super(engine, portHandler); port.onDisconnect.addListener(() => { @@ -66,7 +66,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { } setPort(port: chrome.runtime.Port) { - let portHandler = new PortPostMessageHandler(port); + const portHandler = new PortPostMessageHandler(port); this.setPostMessageHandler(portHandler); port.onDisconnect.addListener(() => { portHandler.close(); @@ -110,7 +110,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { await this.engine.reload( params.modelId, params.chatOpts, - params.appConfig + params.appConfig, ); this.modelId = params.modelId; this.chatOpts = params.chatOpts; @@ -137,16 +137,16 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { export async function CreateServiceWorkerMLCEngine( modelId: string, engineConfig?: MLCEngineConfig, - keepAliveMs: number = 10000 + keepAliveMs = 10000, ): Promise { const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(keepAliveMs); serviceWorkerMLCEngine.setInitProgressCallback( - engineConfig?.initProgressCallback + engineConfig?.initProgressCallback, ); await serviceWorkerMLCEngine.init( modelId, engineConfig?.chatOpts, - engineConfig?.appConfig + engineConfig?.appConfig, ); return serviceWorkerMLCEngine; } @@ -188,9 +188,9 @@ class PortAdapter implements ChatWorker { export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { port: chrome.runtime.Port; - constructor(keepAliveMs: number = 10000) { - let port = chrome.runtime.connect({ name: "web_llm_service_worker" }); - let chatWorker = new PortAdapter(port); + constructor(keepAliveMs = 10000) { + const port = chrome.runtime.connect({ name: "web_llm_service_worker" }); + const chatWorker = new PortAdapter(port); super(chatWorker); this.port = port; setInterval(() => { @@ -216,7 +216,7 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { async init( modelId: string, chatOpts?: ChatOptions, - appConfig?: AppConfig + appConfig?: AppConfig, ): Promise { const msg: WorkerRequest = { kind: "init", diff --git a/src/grammar.ts b/src/grammar.ts index 3aadee14..64c887c5 100644 --- a/src/grammar.ts +++ b/src/grammar.ts @@ -6,7 +6,7 @@ export type GrammarStateMatcher = tvmjs.TVMObject; /** * A factory class for generating and calling GrammarStateMatcher (GrammarSM) and BNFGrammar related * methods, essentially a wrapper of related global functions in the tvm instance's wasm. - * + * * We implement a factory class rather than having classes of GrammarStateMatcher and BNFGrammar * because factory class allows us to only get/dispose PackedFunc once -- especially when we need * multiple instances of BNFGrammar or GrammarStateMatcher. @@ -22,32 +22,34 @@ export class GrammarFactory { /** * Extract TVM global functions from tvm runtime instance. - * + * * @param tvm An instantiated tvm runtime instance. */ constructor(tvm: tvmjs.Instance) { tvm.beginScope(); // Get global functions. this.fBNFGrammarGetGrammarOfJSON = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.BNFGrammarGetGrammarOfJSON") + tvm.getGlobalFunc("mlc.serve.BNFGrammarGetGrammarOfJSON"), ); this.fBNFGrammarFromSchema = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.BNFGrammarFromSchema") + tvm.getGlobalFunc("mlc.serve.BNFGrammarFromSchema"), ); this.fGrammarSMFromTokenTable = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherFromTokenTable") + tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherFromTokenTable"), ); this.fGrammarSMAcceptToken = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherAcceptToken") + tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherAcceptToken"), ); this.fGrammarSMFindNextTokenBitmaskAsNDArray = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherFindNextTokenBitmaskAsNDArray") + tvm.getGlobalFunc( + "mlc.serve.GrammarStateMatcherFindNextTokenBitmaskAsNDArray", + ), ); this.fGrammarSMIsTerminated = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherIsTerminated") + tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherIsTerminated"), ); this.fGrammarSMResetState = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherResetState") + tvm.getGlobalFunc("mlc.serve.GrammarStateMatcherResetState"), ); tvm.endScope(); } @@ -68,7 +70,7 @@ export class GrammarFactory { * @param indent The number of spaces for indentation. If undefined, the grammar will enforce the * output to be in one line. * @param separators Two separators that will be enforced by the grammar: comma and colon. - * Examples: (",", ":"), (", ", ": "). If undefined, the default separators will be used: + * Examples: (",", ":"), (", ", ": "). If undefined, the default separators will be used: * (",", ": ") when the indent is not undefined, and (", ", ": ") otherwise. This follows the * convention in Python's json.dumps(). * @param strictMode Whether to use strict mode. In strict mode, the generated grammar will not @@ -81,7 +83,7 @@ export class GrammarFactory { schema_str: string, indent?: number, separators?: [string, string], - strictMode = true + strictMode = true, ): BNFGrammar { // Convert indent to tvmjs.Scalar let indentInput: tvmjs.Scalar | undefined; @@ -89,17 +91,21 @@ export class GrammarFactory { indentInput = new tvmjs.Scalar(indent, "int32"); } // Convert strictMode to tvmjs.Scalar - const strictModeInput = strictMode ? - new tvmjs.Scalar(1, "int32") : new tvmjs.Scalar(0, "int32"); + const strictModeInput = strictMode + ? new tvmjs.Scalar(1, "int32") + : new tvmjs.Scalar(0, "int32"); return this.fBNFGrammarFromSchema( - schema_str, indentInput, separators, strictModeInput + schema_str, + indentInput, + separators, + strictModeInput, ) as BNFGrammar; } /** * Creates a Grammar State Matcher from a specified BNFGrammar rule and a token table. - * + * * @param grammar A BNFGrammar used to specify the rule for the state matcher. * @param tokenTable A list of all tokens in the tokenizer in the order of their ids. * @param maxRollbackSteps Max rollback steps to support. Currently not supported, has to be zero. @@ -112,15 +118,20 @@ export class GrammarFactory { maxRollbackSteps = 0, ): GrammarStateMatcher { if (maxRollbackSteps !== 0) { - throw Error("maxRollbackSteps has to be zero as rollback is not supported yet.") + throw Error( + "maxRollbackSteps has to be zero as rollback is not supported yet.", + ); } return this.fGrammarSMFromTokenTable( - grammar, tokenTable, new tvmjs.Scalar(maxRollbackSteps, "int32")) as GrammarStateMatcher; + grammar, + tokenTable, + new tvmjs.Scalar(maxRollbackSteps, "int32"), + ) as GrammarStateMatcher; } /** * Accept a new token to the grammar state matcher, updating its internal state. - * + * * @param grammarStateMatcher The grammar state matcher that will accept a new token and update * its state correspondingly. * @param tokenID The token to be accepted in its ID. @@ -128,13 +139,18 @@ export class GrammarFactory { */ acceptToken( grammarStateMatcher: GrammarStateMatcher, - tokenID: number + tokenID: number, ): boolean { let accepted = false; try { - accepted = this.fGrammarSMAcceptToken(grammarStateMatcher, new tvmjs.Scalar(tokenID, "int32")); + accepted = this.fGrammarSMAcceptToken( + grammarStateMatcher, + new tvmjs.Scalar(tokenID, "int32"), + ); } catch (error) { - throw Error("Encountered error when accepting token " + tokenID + ": " + error); + throw Error( + "Encountered error when accepting token " + tokenID + ": " + error, + ); } return accepted; } @@ -142,11 +158,13 @@ export class GrammarFactory { /** * Returns a bitmask in the form of an NDArray of shape (max_num_token, ceildiv(vocab_size, 32)) * based on what tokens can/cannot be accepted by the current state of the grammar state matcher. - * + * * @param grammarStateMatcher The grammar state matcher that will produce the bit mask. * @returns A bitmask in the form of an NDArray. */ - findNextTokenBitmask(grammarStateMatcher: GrammarStateMatcher): tvmjs.TVMObject { + findNextTokenBitmask( + grammarStateMatcher: GrammarStateMatcher, + ): tvmjs.TVMObject { return this.fGrammarSMFindNextTokenBitmaskAsNDArray(grammarStateMatcher); } diff --git a/src/index.ts b/src/index.ts index 0de3e8de..9f0387bd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,10 +6,9 @@ export { GenerationConfig, prebuiltAppConfig, modelVersion, - modelLibURLPrefix + modelLibURLPrefix, } from "./config"; - export { InitProgressCallback, InitProgressReport, @@ -17,26 +16,23 @@ export { LogitProcessor, } from "./types"; -export { - MLCEngine, - CreateMLCEngine, -} from "./engine"; +export { MLCEngine, CreateMLCEngine } from "./engine"; export { - hasModelInCache, deleteChatConfigInCache, deleteModelAllInfoInCache, deleteModelWasmInCache, deleteModelInCache, + hasModelInCache, + deleteChatConfigInCache, + deleteModelAllInfoInCache, + deleteModelWasmInCache, + deleteModelInCache, } from "./cache_util"; export { MLCEngineWorkerHandler, WebWorkerMLCEngine, - CreateWebWorkerMLCEngine + CreateWebWorkerMLCEngine, } from "./web_worker"; -export { - WorkerRequest, - WorkerResponse, - CustomRequestParams -} from "./message" +export { WorkerRequest, WorkerResponse, CustomRequestParams } from "./message"; export { ServiceWorkerMLCEngineHandler, @@ -48,6 +44,6 @@ export { ServiceWorkerMLCEngineHandler as ExtensionServiceWorkerMLCEngineHandler, ServiceWorkerMLCEngine as ExtensionServiceWorkerMLCEngine, CreateServiceWorkerMLCEngine as CreateExtensionServiceWorkerMLCEngine, -} from './extension_service_worker' +} from "./extension_service_worker"; -export * from './openai_api_protocols/index'; +export * from "./openai_api_protocols/index"; diff --git a/src/llm_chat.ts b/src/llm_chat.ts index ad1a02f6..945e1b5b 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -11,7 +11,7 @@ import { ChatCompletionTokenLogprob, TopLogprob, ResponseFormat, -} from "./openai_api_protocols/index" +} from "./openai_api_protocols/index"; import { BNFGrammar, GrammarFactory, GrammarStateMatcher } from "./grammar"; export class LLMChatPipeline { @@ -94,7 +94,12 @@ export class LLMChatPipeline { private bitmaskSize: number; private vocabSize: number; - constructor(tvm: tvmjs.Instance, tokenizer: Tokenizer, config: ChatConfig, logitProcessor?: LogitProcessor) { + constructor( + tvm: tvmjs.Instance, + tokenizer: Tokenizer, + config: ChatConfig, + logitProcessor?: LogitProcessor, + ) { // 0. Setting attributes this.tvm = tvm; this.tokenizer = tokenizer; @@ -104,7 +109,10 @@ export class LLMChatPipeline { this.vocabSize = this.tokenizer.getVocabSize(); this.bitmaskSize = Math.ceil(this.vocabSize / 32); - this.conversation = getConversation(config.conv_template, config.conv_config); + this.conversation = getConversation( + config.conv_template, + config.conv_config, + ); this.stopStr = this.conversation.getStopStr(); this.stopTokens = this.conversation.getStopTokens(); if (config.bos_token_id !== undefined) { @@ -116,20 +124,18 @@ export class LLMChatPipeline { // 1. Create VM and get the core functions tvm.beginScope(); this.vm = this.tvm.detachFromCurrentScope( - this.tvm.createVirtualMachine(this.device) + this.tvm.createVirtualMachine(this.device), ); this.prefill = this.tvm.detachFromCurrentScope( - this.vm.getFunction("prefill") - ); - this.embed = this.tvm.detachFromCurrentScope( - this.vm.getFunction("embed") + this.vm.getFunction("prefill"), ); + this.embed = this.tvm.detachFromCurrentScope(this.vm.getFunction("embed")); this.decoding = this.tvm.detachFromCurrentScope( - this.vm.getFunction("decode") + this.vm.getFunction("decode"), ); this.fapplyBitmask = this.tvm.detachFromCurrentScope( - this.vm.getFunction("apply_bitmask_inplace") - ) + this.vm.getFunction("apply_bitmask_inplace"), + ); // 2. Get json stored in the vm's metadata function const fgetMetadata = this.vm.getFunction("_metadata"); @@ -139,9 +145,11 @@ export class LLMChatPipeline { // 3. Load parameters by name const paramNames: string[] = []; - metadata.params.forEach((param: any) => { paramNames.push(param.name) }); + metadata.params.forEach((param: any) => { + paramNames.push(param.name); + }); this.params = this.tvm.detachFromCurrentScope( - this.tvm.getParamsFromCacheByName(paramNames) + this.tvm.getParamsFromCacheByName(paramNames), ); // 4. Read in compilation configurations from metadata @@ -151,46 +159,59 @@ export class LLMChatPipeline { throw Error("Prefill chunk size needs to be positive."); } // Only use one of slidingWindowSize and maxWindowLength - if (metadata.hasOwnProperty("sliding_window_size") && metadata.sliding_window_size != -1) { + if ( + metadata.hasOwnProperty("sliding_window_size") && + metadata.sliding_window_size != -1 + ) { this.slidingWindowSize = metadata.sliding_window_size; this.logger("Using slidingWindowSize: ", this.slidingWindowSize); // Parse attention sink size - if (metadata.hasOwnProperty("attention_sink_size") && metadata.attention_sink_size >= 0) { + if ( + metadata.hasOwnProperty("attention_sink_size") && + metadata.attention_sink_size >= 0 + ) { this.attentionSinkSize = metadata.attention_sink_size; this.logger("Using attentionSinkSize: ", this.attentionSinkSize); } else { throw Error( "Need to specify non-negative attention_sink_size if using sliding window. " + - "Consider re-compiling the model with the most recent mlc-llm. " + - "Use `attention_sink_size=0` for default sliding window." + "Consider re-compiling the model with the most recent mlc-llm. " + + "Use `attention_sink_size=0` for default sliding window.", ); } - } else if (metadata.hasOwnProperty("context_window_size") && metadata.context_window_size != -1) { + } else if ( + metadata.hasOwnProperty("context_window_size") && + metadata.context_window_size != -1 + ) { this.maxWindowLength = metadata.context_window_size; this.logger("Using maxWindowLength: ", this.maxWindowLength); } else { - throw Error("Need to specify either sliding window size or max window size."); + throw Error( + "Need to specify either sliding window size or max window size.", + ); } // 5. Create cache // Load cache functions and instantiate KVCache this.fclearKVCaches = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.kv_state_clear") + this.tvm.getGlobalFunc("vm.builtin.kv_state_clear"), ); this.fKVCacheAddSequence = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.kv_state_add_sequence") + this.tvm.getGlobalFunc("vm.builtin.kv_state_add_sequence"), ); this.fKVCacheRemoveSequence = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.kv_state_remove_sequence") + this.tvm.getGlobalFunc("vm.builtin.kv_state_remove_sequence"), ); this.fKVCacheBeginForward = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.kv_state_begin_forward") + this.tvm.getGlobalFunc("vm.builtin.kv_state_begin_forward"), ); this.fKVCacheEndForward = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.kv_state_end_forward") + this.tvm.getGlobalFunc("vm.builtin.kv_state_end_forward"), ); this.fKVCacheEnableSlidingWindowForSeq = this.tvm.detachFromCurrentScope( - this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") + this.tvm.getGlobalFunc( + "vm.builtin.attention_kv_cache_enable_sliding_window_for_seq", + ), ); // Create PagedKVCache; we do not expose KVCache config for now @@ -198,17 +219,21 @@ export class LLMChatPipeline { const defaultPageSize = 16; const defaultMaxNumSequence = 1; const maxTotalSeqLen = - this.slidingWindowSize != -1 ? this.slidingWindowSize : this.maxWindowLength; - this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache( - this.tvm.makeShapeTuple([defaultMaxNumSequence]), // max_num_sequence - this.tvm.makeShapeTuple([maxTotalSeqLen]), // max_total_sequence_length - this.tvm.makeShapeTuple([this.prefillChunkSize]), // prefill_chunk_size - this.tvm.makeShapeTuple([defaultPageSize]), // page_size, hard coded for now - this.tvm.makeShapeTuple([this.slidingWindowSize != -1 ? 1 : 0]), - )); + this.slidingWindowSize != -1 + ? this.slidingWindowSize + : this.maxWindowLength; + this.kvCache = this.tvm.detachFromCurrentScope( + fcreateCache( + this.tvm.makeShapeTuple([defaultMaxNumSequence]), // max_num_sequence + this.tvm.makeShapeTuple([maxTotalSeqLen]), // max_total_sequence_length + this.tvm.makeShapeTuple([this.prefillChunkSize]), // prefill_chunk_size + this.tvm.makeShapeTuple([defaultPageSize]), // page_size, hard coded for now + this.tvm.makeShapeTuple([this.slidingWindowSize != -1 ? 1 : 0]), + ), + ); this.filledKVCacheLength = 0; - this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence + this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence tvm.endScope(); } @@ -270,7 +295,7 @@ export class LLMChatPipeline { this.kvCache, new tvmjs.Scalar(0, "int64"), new tvmjs.Scalar(this.slidingWindowSize, "int32"), - new tvmjs.Scalar(this.attentionSinkSize, "int32") + new tvmjs.Scalar(this.attentionSinkSize, "int32"), ); } } @@ -318,7 +343,7 @@ export class LLMChatPipeline { return ( `prefill: ${(this.prefillTotalTokens / this.prefillTotalTime).toFixed(4)} tokens/sec, ` + `decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec` - ) + ); } /** @@ -350,7 +375,11 @@ export class LLMChatPipeline { /** * Generate the first token given input prompt */ - async prefillStep(inp: string, inp_role_str?: string, genConfig?: GenerationConfig): Promise { + async prefillStep( + inp: string, + inp_role_str?: string, + genConfig?: GenerationConfig, + ): Promise { if (this.resetStatsPerPrefill) { this.resetRuntimeStats(); } @@ -375,7 +404,7 @@ export class LLMChatPipeline { let newSeqLen = this.filledKVCacheLength; const tokenLen = promptTokens.length; - let logits = this.tvm.empty([1, 1], "int32", this.device); // Dummy value to avoid type error + let logits = this.tvm.empty([1, 1], "int32", this.device); // Dummy value to avoid type error // Use prefill chunking regardless whether we use SWA (see Mistral paper figure 3) for (let begin = 0; begin < tokenLen; begin += this.prefillChunkSize) { const end = Math.min(tokenLen, begin + this.prefillChunkSize); @@ -383,12 +412,10 @@ export class LLMChatPipeline { const inputData = this.tvm.empty([chunk.length], "int32", this.device); inputData.copyFrom(chunk); newSeqLen += chunk.length; - logits = this.tvm.detachFromCurrentScope( - this.forward(inputData) - ); + logits = this.tvm.detachFromCurrentScope(this.forward(inputData)); } if (newSeqLen != this.filledKVCacheLength + tokenLen) { - throw Error("Expect chunking process all tokens.") + throw Error("Expect chunking process all tokens."); } this.filledKVCacheLength = newSeqLen; @@ -406,11 +433,15 @@ export class LLMChatPipeline { if (this.tokenTable === undefined) { this.tokenTable = getTokenTableFromTokenizer(this.tokenizer); } - const grammar: BNFGrammar = curSchema === undefined ? - this.grammarFactory.getBNFGrammarOfJSON() : - this.grammarFactory.getBNFGrammarFromSchema(curSchema); + const grammar: BNFGrammar = + curSchema === undefined + ? this.grammarFactory.getBNFGrammarOfJSON() + : this.grammarFactory.getBNFGrammarFromSchema(curSchema); this.grammarStateMatcher = this.tvm.detachFromCurrentScope( - this.grammarFactory.getGrammarStateMatcherFromTokenTable(grammar, this.tokenTable) + this.grammarFactory.getGrammarStateMatcherFromTokenTable( + grammar, + this.tokenTable, + ), ); this.schema = curSchema; } @@ -440,9 +471,7 @@ export class LLMChatPipeline { const inputData = this.tvm.empty([1], "int32", this.device); inputData.copyFrom(this.outputIds.slice(this.outputIds.length - 1)); - const logits = this.tvm.detachFromCurrentScope( - this.forward(inputData) - ); + const logits = this.tvm.detachFromCurrentScope(this.forward(inputData)); this.filledKVCacheLength += 1; this.tvm.endScope(); @@ -476,7 +505,10 @@ export class LLMChatPipeline { * @param nextToken The next token. * @param genConfig Configs that override `this.config` for this round of generation. */ - private processNextToken(nextToken: number, genConfig?: GenerationConfig): void { + private processNextToken( + nextToken: number, + genConfig?: GenerationConfig, + ): void { if (this.stopTriggered) { throw Error("Cannot call process when it is stoppped"); } @@ -487,7 +519,7 @@ export class LLMChatPipeline { max_gen_len = genConfig.max_gen_len; } if (max_gen_len <= 0) { - throw new Error("`max_gen_len` should be greater than 0.") + throw new Error("`max_gen_len` should be greater than 0."); } let stopStrs = this.stopStr; if (genConfig !== undefined && genConfig.stop) { @@ -540,12 +572,12 @@ export class LLMChatPipeline { private forward(inputs: tvmjs.NDArray): tvmjs.NDArray { this.tvm.beginScope(); let retValue; - const seqLen = inputs.shape[0]; // Num input tokens + const seqLen = inputs.shape[0]; // Num input tokens const seqIdsTuple = this.tvm.makeShapeTuple([0]); const inputLenShape = this.tvm.makeShapeTuple([seqLen]); this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape); let embed = this.embed!(inputs, this.params); - embed = embed.view([1].concat(embed.shape)); // Reshape to [1, seqLen, hiddenSize] + embed = embed.view([1].concat(embed.shape)); // Reshape to [1, seqLen, hiddenSize] if (seqLen > 1) { retValue = this.prefill(embed, this.kvCache, this.params); } else { @@ -562,7 +594,7 @@ export class LLMChatPipeline { private updateLogitsOnCPU(logits: tvmjs.NDArray): tvmjs.NDArray { if (this.logitsOnCPU == undefined) { this.logitsOnCPU = this.tvm.detachFromCurrentScope( - this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) + this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()), ); } else { if (logits.shape[0] != this.logitsOnCPU.shape[0]) { @@ -594,27 +626,61 @@ export class LLMChatPipeline { let response_format: ResponseFormat | undefined = undefined; if (genConfig !== undefined) { - if (_hasValue(genConfig.temperature)) { temperature = genConfig.temperature!; } - if (_hasValue(genConfig.top_p)) { top_p = genConfig.top_p!; } - if (_hasValue(genConfig.repetition_penalty)) { repetition_penalty = genConfig.repetition_penalty!; } - if (_hasValue(genConfig.frequency_penalty)) { frequency_penalty = genConfig.frequency_penalty!; } - if (_hasValue(genConfig.presence_penalty)) { presence_penalty = genConfig.presence_penalty!; } + if (_hasValue(genConfig.temperature)) { + temperature = genConfig.temperature!; + } + if (_hasValue(genConfig.top_p)) { + top_p = genConfig.top_p!; + } + if (_hasValue(genConfig.repetition_penalty)) { + repetition_penalty = genConfig.repetition_penalty!; + } + if (_hasValue(genConfig.frequency_penalty)) { + frequency_penalty = genConfig.frequency_penalty!; + } + if (_hasValue(genConfig.presence_penalty)) { + presence_penalty = genConfig.presence_penalty!; + } // If only one of frequency or presence penatly is set, make the other one 0.0 - if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) { presence_penalty = 0.0; } - if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) { frequency_penalty = 0.0; } - if (_hasValue(genConfig.logit_bias)) { logit_bias = genConfig.logit_bias!; } - if (_hasValue(genConfig.logprobs)) { logprobs = genConfig.logprobs!; } - if (_hasValue(genConfig.top_logprobs)) { top_logprobs = genConfig.top_logprobs!; } - if (_hasValue(genConfig.response_format)) { response_format = genConfig.response_format!; } + if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) { + presence_penalty = 0.0; + } + if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) { + frequency_penalty = 0.0; + } + if (_hasValue(genConfig.logit_bias)) { + logit_bias = genConfig.logit_bias!; + } + if (_hasValue(genConfig.logprobs)) { + logprobs = genConfig.logprobs!; + } + if (_hasValue(genConfig.top_logprobs)) { + top_logprobs = genConfig.top_logprobs!; + } + if (_hasValue(genConfig.response_format)) { + response_format = genConfig.response_format!; + } } // Check range validity - if (top_p <= 0 || top_p > 1) { throw new Error("Make sure 0 < `top_p` <= 1."); } - if (temperature < 0) { throw new Error("Make sure `temperature` >= 0."); } - if (repetition_penalty <= 0) { throw new Error("Make sure `repetition_penalty` > 0."); } - if (frequency_penalty && (frequency_penalty < -2.0 || frequency_penalty > 2.0)) { + if (top_p <= 0 || top_p > 1) { + throw new Error("Make sure 0 < `top_p` <= 1."); + } + if (temperature < 0) { + throw new Error("Make sure `temperature` >= 0."); + } + if (repetition_penalty <= 0) { + throw new Error("Make sure `repetition_penalty` > 0."); + } + if ( + frequency_penalty && + (frequency_penalty < -2.0 || frequency_penalty > 2.0) + ) { throw new Error("`frequency_penalty` should be between -2.0 and 2.0."); } - if (presence_penalty && (presence_penalty < -2.0 || presence_penalty > 2.0)) { + if ( + presence_penalty && + (presence_penalty < -2.0 || presence_penalty > 2.0) + ) { throw new Error("`presence_penalty` should be between -2.0 and 2.0."); } @@ -626,11 +692,19 @@ export class LLMChatPipeline { } // TODO(Charlie): Do we detach from current scope here for bitmask? const bitMaskOnCPU = this.grammarFactory.findNextTokenBitmask( - this.grammarStateMatcher) as unknown as tvmjs.NDArray; - const bitMaskOnGPU = this.tvm.empty([1, this.bitmaskSize], "int32", - this.device).copyFrom(bitMaskOnCPU); - const seqIdsArray = this.tvm.empty([1], "int32", this.device).copyFrom([0]); - this.fapplyBitmask(logitsOnGPU.view([1, this.vocabSize]), seqIdsArray, bitMaskOnGPU); + this.grammarStateMatcher, + ) as unknown as tvmjs.NDArray; + const bitMaskOnGPU = this.tvm + .empty([1, this.bitmaskSize], "int32", this.device) + .copyFrom(bitMaskOnCPU); + const seqIdsArray = this.tvm + .empty([1], "int32", this.device) + .copyFrom([0]); + this.fapplyBitmask( + logitsOnGPU.view([1, this.vocabSize]), + seqIdsArray, + bitMaskOnGPU, + ); this.tvm.endScope(); } @@ -646,7 +720,9 @@ export class LLMChatPipeline { // 2. Post process logits via logitProcessor and/or logit_bias if (this.logitProcessor !== undefined || _hasValue(logit_bias)) { - let logitsOnCPUArray: Float32Array = (this.logitsOnCPU.toArray()); + let logitsOnCPUArray: Float32Array = ( + this.logitsOnCPU.toArray() + ); const vocab_size = logitsOnCPUArray.length; if (this.logitProcessor !== undefined) { logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray); @@ -656,7 +732,12 @@ export class LLMChatPipeline { const curBias = logit_bias[tokenID]; const curTokenID = parseInt(tokenID); if (curTokenID > vocab_size) { - throw Error("Token " + curTokenID + " in logit_bias exceeds vocab_size " + vocab_size); + throw Error( + "Token " + + curTokenID + + " in logit_bias exceeds vocab_size " + + vocab_size, + ); } logitsOnCPUArray[curTokenID] += curBias; } @@ -672,9 +753,15 @@ export class LLMChatPipeline { const appearedTokens = [...this.appearedTokensFreq.keys()]; const appearedTokensFreqs = [...this.appearedTokensFreq.values()]; const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], "int32", this.tvm.cpu()); + [1, appearedTokens.length], + "int32", + this.tvm.cpu(), + ); const appeared_tokens_freqs_ndarray = this.tvm.empty( - [1, appearedTokensFreqs.length], "int32", this.tvm.cpu()); + [1, appearedTokensFreqs.length], + "int32", + this.tvm.cpu(), + ); appeared_tokens_ndarray.copyFrom(appearedTokens); appeared_tokens_freqs_ndarray.copyFrom(appearedTokensFreqs); this.tvm.applyPresenceAndFrequencyPenalty( @@ -682,7 +769,7 @@ export class LLMChatPipeline { appeared_tokens_ndarray, appeared_tokens_freqs_ndarray, presence_penalty!, - frequency_penalty! + frequency_penalty!, ); this.tvm.endScope(); } else if (repetition_penalty != 1.0) { @@ -690,10 +777,16 @@ export class LLMChatPipeline { this.tvm.beginScope(); const appearedTokens = [...this.appearedTokensFreq.keys()]; const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], "int32", this.tvm.cpu()); + [1, appearedTokens.length], + "int32", + this.tvm.cpu(), + ); appeared_tokens_ndarray.copyFrom(appearedTokens); this.tvm.applyRepetitionPenalty( - this.logitsOnCPU, appeared_tokens_ndarray, repetition_penalty); + this.logitsOnCPU, + appeared_tokens_ndarray, + repetition_penalty, + ); this.tvm.endScope(); } @@ -702,13 +795,19 @@ export class LLMChatPipeline { let sampledToken: number; if (logprobs) { // Inplace transform logitsOnCPU to a distribution - temperature = Math.max(1e-6, temperature); // to prevent division by zero + temperature = Math.max(1e-6, temperature); // to prevent division by zero this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature); sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p); - this.tokenLogprobArray.push(this.getTokenLogprob(sampledToken, top_logprobs!)); + this.tokenLogprobArray.push( + this.getTokenLogprob(sampledToken, top_logprobs!), + ); } else { // temperature being 0 is allowed here, equivalent to argmax - sampledToken = this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); + sampledToken = this.tvm.sampleTopPFromLogits( + this.logitsOnCPU, + temperature, + top_p, + ); } // 5. Update logit processor @@ -720,7 +819,10 @@ export class LLMChatPipeline { if (this.grammarStateMatcher === undefined) { throw Error("Expect grammar state matcher to be initialized."); } - const accepted = this.grammarFactory.acceptToken(this.grammarStateMatcher, sampledToken); + const accepted = this.grammarFactory.acceptToken( + this.grammarStateMatcher, + sampledToken, + ); if (!accepted) { throw Error("Grammar state matcher rejected the newly sampled token."); } @@ -735,10 +837,16 @@ export class LLMChatPipeline { let mean_gen_len = this.config.mean_gen_len; let shift_fill_factor = this.config.shift_fill_factor; if (genConfig !== undefined) { - if (genConfig.mean_gen_len !== undefined && genConfig.mean_gen_len !== null) { + if ( + genConfig.mean_gen_len !== undefined && + genConfig.mean_gen_len !== null + ) { mean_gen_len = genConfig.mean_gen_len; } - if (genConfig.shift_fill_factor !== undefined && genConfig.shift_fill_factor !== null) { + if ( + genConfig.shift_fill_factor !== undefined && + genConfig.shift_fill_factor !== null + ) { shift_fill_factor = genConfig.shift_fill_factor; } } @@ -772,8 +880,11 @@ export class LLMChatPipeline { for (let i = prompts.length - 1; i > 0; --i) { const encoded = this.tokenizer.encode(prompts[i]); ctxLength += encoded.length; - if (this.slidingWindowSize == -1 && // There is no maxWindowLength if we use sliding window - this.filledKVCacheLength + ctxLength + mean_gen_len >= this.maxWindowLength) { + if ( + this.slidingWindowSize == -1 && // There is no maxWindowLength if we use sliding window + this.filledKVCacheLength + ctxLength + mean_gen_len >= + this.maxWindowLength + ) { needShiftWindow = true; break; } @@ -788,11 +899,13 @@ export class LLMChatPipeline { // Code starting below should not be reached when using sliding window. if (this.slidingWindowSize != -1) { - throw Error("Should not shift window when using sliding window attention."); + throw Error( + "Should not shift window when using sliding window attention.", + ); } // need shift window and re-encode - this.logger("need shift window") + this.logger("need shift window"); this.filledKVCacheLength = 0; this.resetKVCache(); @@ -811,7 +924,10 @@ export class LLMChatPipeline { for (let i = all_prompts.length - 1; i > 0; --i) { const encoded = this.tokenizer.encode(all_prompts[i]); ctxLength += encoded.length; - if (ctxLength >= shift_fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) { + if ( + ctxLength >= shift_fill_factor * this.maxWindowLength && + i + 2 < all_prompts.length + ) { break; } context.unshift(encoded); @@ -825,7 +941,10 @@ export class LLMChatPipeline { return tokens; } - async forwardTokensAndSample(inputIds: Array, isPrefill: boolean): Promise { + async forwardTokensAndSample( + inputIds: Array, + isPrefill: boolean, + ): Promise { // 1. Convert input to NDArray const tstart = performance.now(); this.tvm.beginScope(); @@ -856,18 +975,21 @@ export class LLMChatPipeline { * Based on `sampledToken` and `this.logitsOnCPU`, which becomes a distribution after * calling `this.tvm.applySoftmaxWithTemperature()`, generate `ChatCompletionTokenLogprob` and * update `this.tokenLogprobArray`. - * + * * @param sampledToken The token ID sampled. * @param top_logprobs Number of top tokens to include; `top_logprobs` in `ChatCompletionRequest`. - * + * * @return The `ChatCompletionTokenLogprob` for this single autoregressive step. */ - private getTokenLogprob(sampledToken: number, top_logprobs: number): ChatCompletionTokenLogprob { + private getTokenLogprob( + sampledToken: number, + top_logprobs: number, + ): ChatCompletionTokenLogprob { if (this.logitsOnCPU == undefined) { throw Error("logits should be assigned"); } // Array of [token, prob] pairs, sorted with highest prob first. - const logitsOnCPUArray = (this.logitsOnCPU.toArray()) + const logitsOnCPUArray = this.logitsOnCPU.toArray(); const topLogprobs = getTopProbs(top_logprobs!, logitsOnCPUArray); // Get entry for sampled token first @@ -893,7 +1015,7 @@ export class LLMChatPipeline { token: tokenStr, bytes: bytes, logprob: logprob, - top_logprobs: topLogprobArray + top_logprobs: topLogprobArray, } as ChatCompletionTokenLogprob; } @@ -921,16 +1043,17 @@ export class LLMChatPipeline { const decodingStart = performance.now(); this.tvm.beginScope(); - const firstSampleToken = this.tvm.empty([1], "int32", this.device).copyFrom([6234]); + const firstSampleToken = this.tvm + .empty([1], "int32", this.device) + .copyFrom([6234]); const logitsOnCPU = this.updateLogitsOnCPU(this.forward(firstSampleToken)); await this.device.sync(); this.tvm.endScope(); const decodingEnd = performance.now(); - const msg = ( + const msg = `prefill-time=${((decodingStart - prefillStart) / 1000).toFixed(4)} sec` + - `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec` - ); + `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`; // simply log tokens for eyeballing. console.log("Logits:"); diff --git a/src/openai_api_protocols/apis.ts b/src/openai_api_protocols/apis.ts index d4064c25..e1895000 100644 --- a/src/openai_api_protocols/apis.ts +++ b/src/openai_api_protocols/apis.ts @@ -2,11 +2,11 @@ import { MLCEngineInterface } from "../types"; import { Completions } from "./chat_completion"; export class Chat { - private engine: MLCEngineInterface; - completions: Completions; + private engine: MLCEngineInterface; + completions: Completions; - constructor(engine: MLCEngineInterface) { - this.engine = engine; - this.completions = new Completions(this.engine); - } + constructor(engine: MLCEngineInterface) { + this.engine = engine; + this.completions = new Completions(this.engine); + } } diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts index 6716436a..6a3313be 100644 --- a/src/openai_api_protocols/chat_completion.ts +++ b/src/openai_api_protocols/chat_completion.ts @@ -1,9 +1,9 @@ /** * The input to OpenAI API, directly adopted from openai-node with small tweaks: * https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts - * + * * Copyright 2024 OpenAI - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -13,261 +13,263 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ + */ import { MLCEngineInterface } from "../types"; /* eslint-disable @typescript-eslint/no-namespace */ export class Completions { - private engine: MLCEngineInterface; - - constructor(engine: MLCEngineInterface) { - this.engine = engine; - } - - create( - request: ChatCompletionRequestNonStreaming - ): Promise; - create( - request: ChatCompletionRequestStreaming - ): Promise>; - create( - request: ChatCompletionRequestBase - ): Promise | ChatCompletion>; - create( - request: ChatCompletionRequest - ): Promise | ChatCompletion> { - return this.engine.chatCompletion(request); - } + private engine: MLCEngineInterface; + + constructor(engine: MLCEngineInterface) { + this.engine = engine; + } + + create(request: ChatCompletionRequestNonStreaming): Promise; + create( + request: ChatCompletionRequestStreaming, + ): Promise>; + create( + request: ChatCompletionRequestBase, + ): Promise | ChatCompletion>; + create( + request: ChatCompletionRequest, + ): Promise | ChatCompletion> { + return this.engine.chatCompletion(request); + } } //////////////////////////////// 0. HIGH-LEVEL INTERFACES //////////////////////////////// /** * OpenAI chat completion request protocol. - * + * * API reference: https://platform.openai.com/docs/api-reference/chat/create * Followed: https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts - * + * * @note `model` is excluded. call `ChatModule.reload(model)` explicitly before calling this API. */ export interface ChatCompletionRequestBase { - /** - * A list of messages comprising the conversation so far. - */ - messages: Array; - - /** - * If set, partial message deltas will be sent. - */ - stream?: boolean | null; - - /** - * How many chat completion choices to generate for each input message. - */ - n?: number | null; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their - * existing frequency in the text so far, decreasing the model's likelihood to - * repeat the same line verbatim. - * - * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) - */ - frequency_penalty?: number | null; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on - * whether they appear in the text so far, increasing the model's likelihood to - * talk about new topics. - * - * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) - */ - presence_penalty?: number | null; - - /** - * The maximum number of [tokens](/tokenizer) that can be generated in the chat - * completion. - * - * The total length of input tokens and generated tokens is limited by the model's - * context length, **determined during MLC's compilation phase**. - */ - max_gen_len?: number | null; - - /** - * Sequences where the API will stop generating further tokens. - */ - stop?: string | null | Array; - - /** - * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will - * make the output more random, while lower values like 0.2 will make it more - * focused and deterministic. - */ - temperature?: number | null; - - /** - * An alternative to sampling with temperature, called nucleus sampling, where the - * model considers the results of the tokens with top_p probability mass. So 0.1 - * means only the tokens comprising the top 10% probability mass are considered. - */ - top_p?: number | null; - - /** - * Modify the likelihood of specified tokens appearing in the completion. - * - * Accepts a JSON object that maps tokens (specified by their token ID, which varies per model) - * to an associated bias value from -100 to 100. Typically, you can see `tokenizer.json` of the - * model to see which token ID maps to what string. Mathematically, the bias is added to the - * logits generated by the model prior to sampling. The exact effect will vary per model, but - * values between -1 and 1 should decrease or increase likelihood of selection; values like -100 - * or 100 should result in a ban or exclusive selection of the relevant token. - * - * As an example, you can pass `{"16230": -100}` to prevent the `Hello` token from being - * generated in Mistral-7B-Instruct-v0.2, according to the mapping in - * https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/raw/main/tokenizer.json. - * - * @note For stateful and customizable / flexible logit processing, see `webllm.LogitProcessor`. - * @note If used in combination with `webllm.LogitProcessor`, `logit_bias` is applied after - * `LogitProcessor.processLogits()` is called. - */ - logit_bias?: Record | null; - - /** - * Whether to return log probabilities of the output tokens or not. - * - * If true, returns the log probabilities of each output token returned in the `content` of - * `message`. - */ - logprobs?: boolean | null; - - /** - * An integer between 0 and 5 specifying the number of most likely tokens to return - * at each token position, each with an associated log probability. `logprobs` must - * be set to `true` if this parameter is used. - */ - top_logprobs?: number | null; - - /** - * If specified, our system will make a best effort to sample deterministically, such that - * repeated requests with the same `seed` and parameters should return the same result. - * - * @note Seeding is done on a request-level rather than choice-level. That is, if `n > 1`, you - * would still get different content for each `Chocie`. But if two requests with `n = 2` are - * processed with the same seed, the two results should be the same (two choices are different). - */ - seed?: number | null; - - /** - * Controls which (if any) function is called by the model. `none` means the model - * will not call a function and instead generates a message. `auto` means the model - * can pick between generating a message or calling a function. Specifying a - * particular function via - * `{"type": "function", "function": {"name": "my_function"}}` forces the model to - * call that function. - * - * `none` is the default when no functions are present. `auto` is the default if - * functions are present. - */ - tool_choice?: ChatCompletionToolChoiceOption; - - /** - * A list of tools the model may call. Currently, only functions are supported as a - * tool. Use this to provide a list of functions the model may generate JSON inputs - * for. - */ - tools?: Array; - - /** - * An object specifying the format that the model must output. - * - * Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the - * message the model generates is valid JSON. - * - * **Important:** when using JSON mode, you **must** also instruct the model to - * produce JSON yourself via a system or user message. Without this, the model may - * generate an unending stream of whitespace until the generation reaches the token - * limit, resulting in a long-running and seemingly "stuck" request. Also note that - * the message content may be partially cut off if `finish_reason="length"`, which - * indicates the generation exceeded `max_gen_len` or the conversation exceeded the - * max context length. - */ - response_format?: ResponseFormat; - - //////////////// BELOW FIELDS NOT SUPPORTED YET //////////////// - - /** - * Model to carry out this API. - * - * @note Not supported. Instead call `CreateMLCEngine(model)` or `engine.reload(model)` instead. - */ - model?: string | null; + /** + * A list of messages comprising the conversation so far. + */ + messages: Array; + + /** + * If set, partial message deltas will be sent. + */ + stream?: boolean | null; + + /** + * How many chat completion choices to generate for each input message. + */ + n?: number | null; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their + * existing frequency in the text so far, decreasing the model's likelihood to + * repeat the same line verbatim. + * + * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) + */ + frequency_penalty?: number | null; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on + * whether they appear in the text so far, increasing the model's likelihood to + * talk about new topics. + * + * [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) + */ + presence_penalty?: number | null; + + /** + * The maximum number of [tokens](/tokenizer) that can be generated in the chat + * completion. + * + * The total length of input tokens and generated tokens is limited by the model's + * context length, **determined during MLC's compilation phase**. + */ + max_gen_len?: number | null; + + /** + * Sequences where the API will stop generating further tokens. + */ + stop?: string | null | Array; + + /** + * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more + * focused and deterministic. + */ + temperature?: number | null; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 + * means only the tokens comprising the top 10% probability mass are considered. + */ + top_p?: number | null; + + /** + * Modify the likelihood of specified tokens appearing in the completion. + * + * Accepts a JSON object that maps tokens (specified by their token ID, which varies per model) + * to an associated bias value from -100 to 100. Typically, you can see `tokenizer.json` of the + * model to see which token ID maps to what string. Mathematically, the bias is added to the + * logits generated by the model prior to sampling. The exact effect will vary per model, but + * values between -1 and 1 should decrease or increase likelihood of selection; values like -100 + * or 100 should result in a ban or exclusive selection of the relevant token. + * + * As an example, you can pass `{"16230": -100}` to prevent the `Hello` token from being + * generated in Mistral-7B-Instruct-v0.2, according to the mapping in + * https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/raw/main/tokenizer.json. + * + * @note For stateful and customizable / flexible logit processing, see `webllm.LogitProcessor`. + * @note If used in combination with `webllm.LogitProcessor`, `logit_bias` is applied after + * `LogitProcessor.processLogits()` is called. + */ + logit_bias?: Record | null; + + /** + * Whether to return log probabilities of the output tokens or not. + * + * If true, returns the log probabilities of each output token returned in the `content` of + * `message`. + */ + logprobs?: boolean | null; + + /** + * An integer between 0 and 5 specifying the number of most likely tokens to return + * at each token position, each with an associated log probability. `logprobs` must + * be set to `true` if this parameter is used. + */ + top_logprobs?: number | null; + + /** + * If specified, our system will make a best effort to sample deterministically, such that + * repeated requests with the same `seed` and parameters should return the same result. + * + * @note Seeding is done on a request-level rather than choice-level. That is, if `n > 1`, you + * would still get different content for each `Chocie`. But if two requests with `n = 2` are + * processed with the same seed, the two results should be the same (two choices are different). + */ + seed?: number | null; + + /** + * Controls which (if any) function is called by the model. `none` means the model + * will not call a function and instead generates a message. `auto` means the model + * can pick between generating a message or calling a function. Specifying a + * particular function via + * `{"type": "function", "function": {"name": "my_function"}}` forces the model to + * call that function. + * + * `none` is the default when no functions are present. `auto` is the default if + * functions are present. + */ + tool_choice?: ChatCompletionToolChoiceOption; + + /** + * A list of tools the model may call. Currently, only functions are supported as a + * tool. Use this to provide a list of functions the model may generate JSON inputs + * for. + */ + tools?: Array; + + /** + * An object specifying the format that the model must output. + * + * Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the + * message the model generates is valid JSON. + * + * **Important:** when using JSON mode, you **must** also instruct the model to + * produce JSON yourself via a system or user message. Without this, the model may + * generate an unending stream of whitespace until the generation reaches the token + * limit, resulting in a long-running and seemingly "stuck" request. Also note that + * the message content may be partially cut off if `finish_reason="length"`, which + * indicates the generation exceeded `max_gen_len` or the conversation exceeded the + * max context length. + */ + response_format?: ResponseFormat; + + //////////////// BELOW FIELDS NOT SUPPORTED YET //////////////// + + /** + * Model to carry out this API. + * + * @note Not supported. Instead call `CreateMLCEngine(model)` or `engine.reload(model)` instead. + */ + model?: string | null; } -export interface ChatCompletionRequestNonStreaming extends ChatCompletionRequestBase { - /** - * If set, partial message deltas will be sent. - */ - stream?: false | null; +export interface ChatCompletionRequestNonStreaming + extends ChatCompletionRequestBase { + /** + * If set, partial message deltas will be sent. + */ + stream?: false | null; } -export interface ChatCompletionRequestStreaming extends ChatCompletionRequestBase { - /** - * If set, partial message deltas will be sent. - */ - stream: true; +export interface ChatCompletionRequestStreaming + extends ChatCompletionRequestBase { + /** + * If set, partial message deltas will be sent. + */ + stream: true; } -export type ChatCompletionRequest = ChatCompletionRequestNonStreaming | ChatCompletionRequestStreaming; +export type ChatCompletionRequest = + | ChatCompletionRequestNonStreaming + | ChatCompletionRequestStreaming; /** * Represents a chat completion response returned by model, based on the provided input. */ export interface ChatCompletion { - /** - * A unique identifier for the chat completion. - */ - id: string; - - /** - * A list of chat completion choices. Can be more than one if `n` is greater than 1. - */ - choices: Array; - - /** - * The model used for the chat completion. - */ - model: string; - - /** - * The object type, which is always `chat.completion`. - */ - object: 'chat.completion'; - - /** - * The Unix timestamp (in seconds) of when the chat completion was created. - * - */ - created: number; - - /** - * Usage statistics for the completion request. - * - * @note If we detect user is performing multi-round chatting, only the new portion of the - * prompt is counted for prompt_tokens. If `n > 1`, all choices' generation usages combined. - */ - usage?: CompletionUsage; - - /** - * This fingerprint represents the backend configuration that the model runs with. - * - * Can be used in conjunction with the `seed` request parameter to understand when - * backend changes have been made that might impact determinism. - * - * @note Not supported yet. - */ - system_fingerprint?: string; + /** + * A unique identifier for the chat completion. + */ + id: string; + + /** + * A list of chat completion choices. Can be more than one if `n` is greater than 1. + */ + choices: Array; + + /** + * The model used for the chat completion. + */ + model: string; + + /** + * The object type, which is always `chat.completion`. + */ + object: "chat.completion"; + + /** + * The Unix timestamp (in seconds) of when the chat completion was created. + * + */ + created: number; + + /** + * Usage statistics for the completion request. + * + * @note If we detect user is performing multi-round chatting, only the new portion of the + * prompt is counted for prompt_tokens. If `n > 1`, all choices' generation usages combined. + */ + usage?: CompletionUsage; + + /** + * This fingerprint represents the backend configuration that the model runs with. + * + * Can be used in conjunction with the `seed` request parameter to understand when + * backend changes have been made that might impact determinism. + * + * @note Not supported yet. + */ + system_fingerprint?: string; } /** @@ -275,101 +277,109 @@ export interface ChatCompletion { * based on the provided input. */ export interface ChatCompletionChunk { - /** - * A unique identifier for the chat completion. Each chunk has the same ID. - */ - id: string; - - /** - * A list of chat completion choices. Can be more than one if `n` is greater - * than 1. - */ - choices: Array; - - /** - * The Unix timestamp (in seconds) of when the chat completion was created. Each - * chunk has the same timestamp. - */ - created: number; - - /** - * The model to generate the completion. - */ - model: string; - - /** - * The object type, which is always `chat.completion.chunk`. - */ - object: 'chat.completion.chunk'; - - /** - * This fingerprint represents the backend configuration that the model runs with. - * Can be used in conjunction with the `seed` request parameter to understand when - * backend changes have been made that might impact determinism. - * - * @note Not supported yet. - */ - system_fingerprint?: string; + /** + * A unique identifier for the chat completion. Each chunk has the same ID. + */ + id: string; + + /** + * A list of chat completion choices. Can be more than one if `n` is greater + * than 1. + */ + choices: Array; + + /** + * The Unix timestamp (in seconds) of when the chat completion was created. Each + * chunk has the same timestamp. + */ + created: number; + + /** + * The model to generate the completion. + */ + model: string; + + /** + * The object type, which is always `chat.completion.chunk`. + */ + object: "chat.completion.chunk"; + + /** + * This fingerprint represents the backend configuration that the model runs with. + * Can be used in conjunction with the `seed` request parameter to understand when + * backend changes have been made that might impact determinism. + * + * @note Not supported yet. + */ + system_fingerprint?: string; } -export const ChatCompletionRequestUnsupportedFields: Array = [ - "model", -]; +export const ChatCompletionRequestUnsupportedFields: Array = ["model"]; export function postInitAndCheckFields(request: ChatCompletionRequest): void { - // Generation-related checks and post inits are in `postInitAndCheckGenerationConfigValues()` - // 1. Check unsupported fields in request - const unsupported: Array = []; - ChatCompletionRequestUnsupportedFields.forEach((field) => { - if (field in request) { - unsupported.push(field); - } - }); - if (unsupported.length > 0) { + // Generation-related checks and post inits are in `postInitAndCheckGenerationConfigValues()` + // 1. Check unsupported fields in request + const unsupported: Array = []; + ChatCompletionRequestUnsupportedFields.forEach((field) => { + if (field in request) { + unsupported.push(field); + } + }); + if (unsupported.length > 0) { + throw new Error( + "The following fields in ChatCompletionRequest are not yet supported: \n" + + unsupported, + ); + } + + // 2. Check unsupported messages + request.messages.forEach( + (message: ChatCompletionMessageParam, index: number) => { + if (message.role === "user" && typeof message.content !== "string") { + // ChatCompletionUserMessageParam + // Remove this when we support image input throw new Error( - "The following fields in ChatCompletionRequest are not yet supported: \n" + unsupported + "User message only supports string `content` for now, but received: " + + message.content, ); + } + if (message.role === "system" && index !== 0) { + throw new Error( + "System prompt should always be the first one in `messages`.", + ); + } + }, + ); + + // 3. Last message has to be from user + const lastId = request.messages.length - 1; + if (request.messages[lastId].role !== "user") { + throw new Error("Last message should be from `user`."); + } + + // 4. If streaming, n cannot be > 1, since we cannot manage multiple sequences at once + if (request.stream && request.n && request.n > 1) { + throw new Error("When streaming, `n` cannot be > 1."); + } + + // 5. Seed should be an integer + if (request.seed !== undefined && request.seed !== null) { + if (!Number.isInteger(request.seed)) { + throw new Error("`seed` should be an integer, but got " + request.seed); } - - // 2. Check unsupported messages - request.messages.forEach((message: ChatCompletionMessageParam, index: number) => { - if (message.role === "user" && typeof message.content !== "string") { - // ChatCompletionUserMessageParam - // Remove this when we support image input - throw new Error( - "User message only supports string `content` for now, but received: " + - message.content - ); - } - if (message.role === "system" && index !== 0) { - throw new Error("System prompt should always be the first one in `messages`."); - } - }) - - // 3. Last message has to be from user - const lastId = request.messages.length - 1; - if (request.messages[lastId].role !== "user") { - throw new Error("Last message should be from `user`."); - } - - // 4. If streaming, n cannot be > 1, since we cannot manage multiple sequences at once - if (request.stream && request.n && request.n > 1) { - throw new Error("When streaming, `n` cannot be > 1."); - } - - // 5. Seed should be an integer - if (request.seed !== undefined && request.seed !== null) { - if (!Number.isInteger(request.seed)) { - throw new Error("`seed` should be an integer, but got " + request.seed); - } - } - - // 6. Schema can only be specified when type is `json_object`. - if (request.response_format?.schema !== undefined && request.response_format?.schema !== null) { - if (request.response_format?.type !== "json_object") { - throw new Error("JSON schema is only supported with `json_object` response format."); - } + } + + // 6. Schema can only be specified when type is `json_object`. + if ( + request.response_format?.schema !== undefined && + request.response_format?.schema !== null + ) { + if (request.response_format?.type !== "json_object") { + throw new Error( + "JSON schema is only supported with `json_object` response format.", + ); } + } } //////////////// BELOW ARE INTERFACES THAT SUPPORT THE ONES ABOVE //////////////// @@ -378,80 +388,81 @@ export function postInitAndCheckFields(request: ChatCompletionRequest): void { //////////////////////////////// 1.1. CHAT COMPLETION CONTENT //////////////////////////////// -export type ChatCompletionContentPart = ChatCompletionContentPartText | ChatCompletionContentPartImage; +export type ChatCompletionContentPart = + | ChatCompletionContentPartText + | ChatCompletionContentPartImage; export interface ChatCompletionContentPartText { + /** + * The text content. + */ + text: string; + + /** + * The type of the content part. + */ + type: "text"; +} + +export namespace ChatCompletionContentPartImage { + export interface ImageURL { /** - * The text content. + * Either a URL of the image or the base64 encoded image data. */ - text: string; + url: string; /** - * The type of the content part. + * Specifies the detail level of the image. */ - type: 'text'; -} - -export namespace ChatCompletionContentPartImage { - export interface ImageURL { - /** - * Either a URL of the image or the base64 encoded image data. - */ - url: string; - - /** - * Specifies the detail level of the image. - */ - detail?: 'auto' | 'low' | 'high'; - } + detail?: "auto" | "low" | "high"; + } } export interface ChatCompletionContentPartImage { - - image_url: ChatCompletionContentPartImage.ImageURL; - /** - * The type of the content part. - */ - type: 'image_url'; + image_url: ChatCompletionContentPartImage.ImageURL; + /** + * The type of the content part. + */ + type: "image_url"; } //////////////////////////////// 1.2. MESSAGE TOOL CALL //////////////////////////////// export interface ChatCompletionMessageToolCall { - /** - * The ID of the tool call. - */ - id: string; - - /** - * The function that the model called. - */ - function: ChatCompletionMessageToolCall.Function; - - /** - * The type of the tool. Currently, only `function` is supported. - */ - type: 'function'; + /** + * The ID of the tool call. + */ + id: string; + + /** + * The function that the model called. + */ + function: ChatCompletionMessageToolCall.Function; + + /** + * The type of the tool. Currently, only `function` is supported. + */ + type: "function"; } export namespace ChatCompletionMessageToolCall { + /** + * The function that the model called. + */ + export interface Function { /** - * The function that the model called. + * The arguments to call the function with, as generated by the model in JSON + * format. Note that the model does not always generate valid JSON, and may + * hallucinate parameters not defined by your function schema. Validate the + * arguments in your code before calling your function. */ - export interface Function { - /** - * The arguments to call the function with, as generated by the model in JSON - * format. Note that the model does not always generate valid JSON, and may - * hallucinate parameters not defined by your function schema. Validate the - * arguments in your code before calling your function. - */ - arguments: string; + arguments: string; - /** - * The name of the function to call. - */ - name: string; - } + /** + * The name of the function to call. + */ + name: string; + } } //////////////////////////////// 1.3. MESSAGE PARAM //////////////////////////////// @@ -459,92 +470,96 @@ export namespace ChatCompletionMessageToolCall { /** * The role of the author of a message */ -export type ChatCompletionRole = 'system' | 'user' | 'assistant' | 'tool' | 'function'; +export type ChatCompletionRole = + | "system" + | "user" + | "assistant" + | "tool" + | "function"; export interface ChatCompletionSystemMessageParam { - /** - * The contents of the system message. - */ - content: string; - - /** - * The role of the messages author, in this case `system`. - */ - role: 'system'; + /** + * The contents of the system message. + */ + content: string; + + /** + * The role of the messages author, in this case `system`. + */ + role: "system"; } export interface ChatCompletionUserMessageParam { - /** - * The contents of the user message. - */ - content: string | Array; - - /** - * The role of the messages author, in this case `user`. - */ - role: 'user'; - - /** - * An optional name for the participant. Provides the model information to - * differentiate between participants of the same role. - * - * @note This is experimental, as models typically have predefined names for the user. - */ - name?: string; + /** + * The contents of the user message. + */ + content: string | Array; + + /** + * The role of the messages author, in this case `user`. + */ + role: "user"; + + /** + * An optional name for the participant. Provides the model information to + * differentiate between participants of the same role. + * + * @note This is experimental, as models typically have predefined names for the user. + */ + name?: string; } export interface ChatCompletionAssistantMessageParam { - /** - * The role of the messages author, in this case `assistant`. - */ - role: 'assistant'; - - /** - * The contents of the assistant message. Required unless `tool_calls` is specified. - */ - content?: string | null; - - /** - * An optional name for the participant. Provides the model information to - * differentiate between participants of the same role. - * - * @note This is experimental, as models typically have predefined names for the user. - */ - name?: string; - - /** - * The tool calls generated by the model, such as function calls. - * Note that in Web-LLM's implementation, this field will never be used. - * Instead, function calls generated by the model will be returned as - * raw text in the content field. The user is responsible for parsing the - * function call raw text. - */ - tool_calls?: Array; + /** + * The role of the messages author, in this case `assistant`. + */ + role: "assistant"; + + /** + * The contents of the assistant message. Required unless `tool_calls` is specified. + */ + content?: string | null; + + /** + * An optional name for the participant. Provides the model information to + * differentiate between participants of the same role. + * + * @note This is experimental, as models typically have predefined names for the user. + */ + name?: string; + + /** + * The tool calls generated by the model, such as function calls. + * Note that in Web-LLM's implementation, this field will never be used. + * Instead, function calls generated by the model will be returned as + * raw text in the content field. The user is responsible for parsing the + * function call raw text. + */ + tool_calls?: Array; } export interface ChatCompletionToolMessageParam { - /** - * The contents of the tool message. - */ - content: string; - - /** - * The role of the messages author, in this case `tool`. - */ - role: 'tool'; - - /** - * Tool call that this message is responding to. - */ - tool_call_id: string; + /** + * The contents of the tool message. + */ + content: string; + + /** + * The role of the messages author, in this case `tool`. + */ + role: "tool"; + + /** + * Tool call that this message is responding to. + */ + tool_call_id: string; } export type ChatCompletionMessageParam = - | ChatCompletionSystemMessageParam - | ChatCompletionUserMessageParam - | ChatCompletionAssistantMessageParam - | ChatCompletionToolMessageParam; - + | ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam + | ChatCompletionToolMessageParam; //////////////////////////////// 2. TOOL USING //////////////////////////////// @@ -560,59 +575,59 @@ export type ChatCompletionMessageParam = export type FunctionParameters = Record; export interface FunctionDefinition { - /** - * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain - * underscores and dashes, with a maximum length of 64. - */ - name: string; - - /** - * A description of what the function does, used by the model to choose when and - * how to call the function. - */ - description?: string; - - /** - * The parameters the functions accepts, described as a JSON Schema object. See the - * [guide](https://platform.openai.com/docs/guides/text-generation/function-calling) - * for examples, and the - * [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for - * documentation about the format. - * - * Omitting `parameters` defines a function with an empty parameter list. - */ - parameters?: FunctionParameters; + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + name: string; + + /** + * A description of what the function does, used by the model to choose when and + * how to call the function. + */ + description?: string; + + /** + * The parameters the functions accepts, described as a JSON Schema object. See the + * [guide](https://platform.openai.com/docs/guides/text-generation/function-calling) + * for examples, and the + * [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for + * documentation about the format. + * + * Omitting `parameters` defines a function with an empty parameter list. + */ + parameters?: FunctionParameters; } export interface ChatCompletionTool { - function: FunctionDefinition; + function: FunctionDefinition; - /** - * The type of the tool. Currently, only `function` is supported. - */ - type: 'function'; + /** + * The type of the tool. Currently, only `function` is supported. + */ + type: "function"; } /** -* Specifies a tool the model should use. Use to force the model to call a specific -* function. -*/ + * Specifies a tool the model should use. Use to force the model to call a specific + * function. + */ export interface ChatCompletionNamedToolChoice { - function: ChatCompletionNamedToolChoice.Function; + function: ChatCompletionNamedToolChoice.Function; - /** - * The type of the tool. Currently, only `function` is supported. - */ - type: 'function'; + /** + * The type of the tool. Currently, only `function` is supported. + */ + type: "function"; } export namespace ChatCompletionNamedToolChoice { - export interface Function { - /** - * The name of the function to call. - */ - name: string; - } + export interface Function { + /** + * The name of the function to call. + */ + name: string; + } } /** @@ -626,248 +641,255 @@ export namespace ChatCompletionNamedToolChoice { * `none` is the default when no functions are present. `auto` is the default if * functions are present. */ -export type ChatCompletionToolChoiceOption = 'none' | 'auto' | ChatCompletionNamedToolChoice; +export type ChatCompletionToolChoiceOption = + | "none" + | "auto" + | ChatCompletionNamedToolChoice; //////////////////////////////// 3. OTHERS //////////////////////////////// //////////////////////////////// 3.1. LOG PROBS //////////////////////////////// export interface TopLogprob { - /** - * The token. - */ - token: string; + /** + * The token. + */ + token: string; + + /** + * A list of integers representing the UTF-8 bytes representation of the token. + * Useful in instances where characters are represented by multiple tokens and + * their byte representations must be combined to generate the correct text + * representation. Can be `null` if there is no bytes representation for the token. + * + * @note Encoded with `TextEncoder.encode()` and can be decoded with `TextDecoder.decode()`. + * For details, see https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/encode. + */ + bytes: Array | null; + + /** + * The log probability of this token. + */ + logprob: number; +} - /** - * A list of integers representing the UTF-8 bytes representation of the token. - * Useful in instances where characters are represented by multiple tokens and - * their byte representations must be combined to generate the correct text - * representation. Can be `null` if there is no bytes representation for the token. - * - * @note Encoded with `TextEncoder.encode()` and can be decoded with `TextDecoder.decode()`. - * For details, see https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/encode. - */ - bytes: Array | null; +export interface ChatCompletionTokenLogprob { + /** + * The token. + */ + token: string; + + /** + * A list of integers representing the UTF-8 bytes representation of the token. + * Useful in instances where characters are represented by multiple tokens and + * their byte representations must be combined to generate the correct text + * representation. Can be `null` if there is no bytes representation for the token. + * + * @note Encoded with `TextEncoder.encode()` and can be decoded with `TextDecoder.decode()`. + * For details, see https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/encode. + */ + bytes: Array | null; + + /** + * The log probability of this token. + */ + logprob: number; + + /** + * List of the most likely tokens and their log probability, at this token + * position. In rare cases, there may be fewer than the number of requested + * `top_logprobs` returned. + */ + top_logprobs: Array; +} - /** - * The log probability of this token. - */ - logprob: number; +//////////////////////////////// 3.2. OTHERS //////////////////////////////// +/** + * A chat completion message generated by the model. + */ +export interface ChatCompletionMessage { + /** + * The contents of the message. + */ + content: string | null; + + /** + * The role of the author of this message. + */ + role: "assistant"; + + /** + * The tool calls generated by the model, such as function calls. + * Note that in Web-LLM's implementation, this field will never be used. + * Instead, function calls generated by the model will be returned as + * raw text in the content field. The user is responsible for parsing the + * function call raw text. + */ + tool_calls?: Array; } -export interface ChatCompletionTokenLogprob { +/** + * Usage statistics for the completion request. + */ +export interface CompletionUsage { + /** + * Number of tokens in the generated completion. + */ + completion_tokens: number; + + /** + * Number of tokens in the prompt. + * + * @note If we detect user is performing multi-round chatting, only the new portion of the + * prompt is counted for prompt_tokens. + */ + prompt_tokens: number; + + /** + * Total number of tokens used in the request (prompt + completion). + */ + total_tokens: number; +} + +/** + * The reason the model stopped generating tokens. This will be `stop` if the model + * hit a natural stop point or a provided stop sequence, `length` if the maximum + * number of tokens specified in the request was reached, `tool_calls` if the + * model called a tool, or `abort` if user manually stops the generation. + */ +export type ChatCompletionFinishReason = + | "stop" + | "length" + | "tool_calls" + | "abort"; + +export namespace ChatCompletion { + export interface Choice { /** - * The token. + * The reason the model stopped generating tokens. This will be `stop` if the model + * hit a natural stop point or a provided stop sequence, `length` if the maximum + * number of tokens specified in the request was reached, `tool_calls` if the + * model called a tool, or `abort` if user manually stops the generation. */ - token: string; + finish_reason: ChatCompletionFinishReason; /** - * A list of integers representing the UTF-8 bytes representation of the token. - * Useful in instances where characters are represented by multiple tokens and - * their byte representations must be combined to generate the correct text - * representation. Can be `null` if there is no bytes representation for the token. - * - * @note Encoded with `TextEncoder.encode()` and can be decoded with `TextDecoder.decode()`. - * For details, see https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/encode. + * The index of the choice in the list of choices. */ - bytes: Array | null; + index: number; /** - * The log probability of this token. + * Log probability information for the choice. */ - logprob: number; + logprobs: Choice.Logprobs | null; /** - * List of the most likely tokens and their log probability, at this token - * position. In rare cases, there may be fewer than the number of requested - * `top_logprobs` returned. + * A chat completion message generated by the model. */ - top_logprobs: Array; -} + message: ChatCompletionMessage; + } -//////////////////////////////// 3.2. OTHERS //////////////////////////////// -/** - * A chat completion message generated by the model. - */ -export interface ChatCompletionMessage { + export namespace Choice { /** - * The contents of the message. + * Log probability information for the choice. */ - content: string | null; + export interface Logprobs { + /** + * A list of message content tokens with log probability information. + */ + content: Array | null; + } + } +} +export namespace ChatCompletionChunk { + export interface Choice { /** - * The role of the author of this message. + * A chat completion delta generated by streamed model responses. */ - role: 'assistant'; + delta: Choice.Delta; /** - * The tool calls generated by the model, such as function calls. - * Note that in Web-LLM's implementation, this field will never be used. - * Instead, function calls generated by the model will be returned as - * raw text in the content field. The user is responsible for parsing the - * function call raw text. + * The reason the model stopped generating tokens. This will be `stop` if the model + * hit a natural stop point or a provided stop sequence, `length` if the maximum + * number of tokens specified in the request was reached, `tool_calls` if the + * model called a tool, or `abort` if user manually stops the generation. */ - tool_calls?: Array; -} + finish_reason: ChatCompletionFinishReason | null; -/** - * Usage statistics for the completion request. - */ -export interface CompletionUsage { /** - * Number of tokens in the generated completion. + * The index of the choice in the list of choices. */ - completion_tokens: number; + index: number; /** - * Number of tokens in the prompt. - * - * @note If we detect user is performing multi-round chatting, only the new portion of the - * prompt is counted for prompt_tokens. + * Log probability information for the choice. */ - prompt_tokens: number; + logprobs?: Choice.Logprobs | null; + } + export namespace Choice { /** - * Total number of tokens used in the request (prompt + completion). + * A chat completion delta generated by streamed model responses. */ - total_tokens: number; -} + export interface Delta { + /** + * The contents of the chunk message. + */ + content?: string | null; -/** - * The reason the model stopped generating tokens. This will be `stop` if the model - * hit a natural stop point or a provided stop sequence, `length` if the maximum - * number of tokens specified in the request was reached, `tool_calls` if the - * model called a tool, or `abort` if user manually stops the generation. - */ -export type ChatCompletionFinishReason = 'stop' | 'length' | 'tool_calls' | 'abort'; + /** + * The role of the author of this message. + */ + role?: "system" | "user" | "assistant" | "tool"; -export namespace ChatCompletion { - export interface Choice { - /** - * The reason the model stopped generating tokens. This will be `stop` if the model - * hit a natural stop point or a provided stop sequence, `length` if the maximum - * number of tokens specified in the request was reached, `tool_calls` if the - * model called a tool, or `abort` if user manually stops the generation. - */ - finish_reason: ChatCompletionFinishReason; + tool_calls?: Array; + } - /** - * The index of the choice in the list of choices. - */ + export namespace Delta { + export interface ToolCall { index: number; /** - * Log probability information for the choice. + * The ID of the tool call. */ - logprobs: Choice.Logprobs | null; + id?: string; - /** - * A chat completion message generated by the model. - */ - message: ChatCompletionMessage; - } + function?: ToolCall.Function; - export namespace Choice { /** - * Log probability information for the choice. + * The type of the tool. Currently, only `function` is supported. */ - export interface Logprobs { - /** - * A list of message content tokens with log probability information. - */ - content: Array | null; + type?: "function"; + } + + export namespace ToolCall { + export interface Function { + /** + * The arguments to call the function with, as generated by the model in JSON + * format. Note that the model does not always generate valid JSON, and may + * hallucinate parameters not defined by your function schema. Validate the + * arguments in your code before calling your function. + */ + arguments?: string; + + /** + * The name of the function to call. + */ + name?: string; } - } -} - -export namespace ChatCompletionChunk { - export interface Choice { - /** - * A chat completion delta generated by streamed model responses. - */ - delta: Choice.Delta; - - /** - * The reason the model stopped generating tokens. This will be `stop` if the model - * hit a natural stop point or a provided stop sequence, `length` if the maximum - * number of tokens specified in the request was reached, `tool_calls` if the - * model called a tool, or `abort` if user manually stops the generation. - */ - finish_reason: ChatCompletionFinishReason | null; - - /** - * The index of the choice in the list of choices. - */ - index: number; - - /** - * Log probability information for the choice. - */ - logprobs?: Choice.Logprobs | null; + } } - export namespace Choice { - /** - * A chat completion delta generated by streamed model responses. - */ - export interface Delta { - /** - * The contents of the chunk message. - */ - content?: string | null; - - /** - * The role of the author of this message. - */ - role?: 'system' | 'user' | 'assistant' | 'tool'; - - tool_calls?: Array; - } - - export namespace Delta { - export interface ToolCall { - index: number; - - /** - * The ID of the tool call. - */ - id?: string; - - function?: ToolCall.Function; - - /** - * The type of the tool. Currently, only `function` is supported. - */ - type?: 'function'; - } - - export namespace ToolCall { - export interface Function { - /** - * The arguments to call the function with, as generated by the model in JSON - * format. Note that the model does not always generate valid JSON, and may - * hallucinate parameters not defined by your function schema. Validate the - * arguments in your code before calling your function. - */ - arguments?: string; - - /** - * The name of the function to call. - */ - name?: string; - } - } - } - - /** - * Log probability information for the choice. - */ - export interface Logprobs { - /** - * A list of message content tokens with log probability information. - */ - content: Array | null; - } + /** + * Log probability information for the choice. + */ + export interface Logprobs { + /** + * A list of message content tokens with log probability information. + */ + content: Array | null; } + } } /** @@ -875,7 +897,7 @@ export namespace ChatCompletionChunk { * * Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the * message the model generates is valid JSON. - * + * * Setting `schema` specifies the output format of the json object such as properties to include. * * **Important:** when using JSON mode, you **must** also instruct the model to produce JSON @@ -887,12 +909,12 @@ export namespace ChatCompletionChunk { * max context length. */ export interface ResponseFormat { - /** - * Must be one of `text` or `json_object`. - */ - type?: 'text' | 'json_object'; - /** - * A schema string in the format of the schema of a JSON file. `type` needs to be `json_object`. - */ - schema?: string; + /** + * Must be one of `text` or `json_object`. + */ + type?: "text" | "json_object"; + /** + * A schema string in the format of the schema of a JSON file. `type` needs to be `json_object`. + */ + schema?: string; } diff --git a/src/openai_api_protocols/index.ts b/src/openai_api_protocols/index.ts index 77a9b466..b4def3dc 100644 --- a/src/openai_api_protocols/index.ts +++ b/src/openai_api_protocols/index.ts @@ -1,9 +1,9 @@ /** * The input to OpenAI API, directly adopted from openai-node with small tweaks: * https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts - * + * * Copyright 2024 OpenAI - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -13,36 +13,36 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ + */ export { - ChatCompletionRequestBase, - ChatCompletionRequestNonStreaming, - ChatCompletionRequestStreaming, - ChatCompletionRequest, - ChatCompletion, - ChatCompletionChunk, - ChatCompletionRequestUnsupportedFields, - postInitAndCheckFields, - ChatCompletionContentPart, - ChatCompletionContentPartText, - ChatCompletionContentPartImage, - ChatCompletionMessageToolCall, - ChatCompletionRole, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, - ChatCompletionAssistantMessageParam, - ChatCompletionToolMessageParam, - ChatCompletionMessageParam, - FunctionParameters, - FunctionDefinition, - ChatCompletionTool, - ChatCompletionNamedToolChoice, - ChatCompletionToolChoiceOption, - TopLogprob, - ChatCompletionTokenLogprob, - ChatCompletionMessage, - CompletionUsage, - ResponseFormat, - ChatCompletionFinishReason, -} from './chat_completion'; + ChatCompletionRequestBase, + ChatCompletionRequestNonStreaming, + ChatCompletionRequestStreaming, + ChatCompletionRequest, + ChatCompletion, + ChatCompletionChunk, + ChatCompletionRequestUnsupportedFields, + postInitAndCheckFields, + ChatCompletionContentPart, + ChatCompletionContentPartText, + ChatCompletionContentPartImage, + ChatCompletionMessageToolCall, + ChatCompletionRole, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionMessageParam, + FunctionParameters, + FunctionDefinition, + ChatCompletionTool, + ChatCompletionNamedToolChoice, + ChatCompletionToolChoiceOption, + TopLogprob, + ChatCompletionTokenLogprob, + ChatCompletionMessage, + CompletionUsage, + ResponseFormat, + ChatCompletionFinishReason, +} from "./chat_completion"; diff --git a/src/service_worker.ts b/src/service_worker.ts index 920a5aa1..e09f0cc7 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -2,7 +2,11 @@ import * as tvmjs from "tvmjs"; import { AppConfig, ChatOptions, MLCEngineConfig, ModelRecord } from "./config"; import { ReloadParams, WorkerRequest, WorkerResponse } from "./message"; import { MLCEngineInterface, InitProgressReport } from "./types"; -import { MLCEngineWorkerHandler, WebWorkerMLCEngine, ChatWorker } from "./web_worker"; +import { + MLCEngineWorkerHandler, + WebWorkerMLCEngine, + ChatWorker, +} from "./web_worker"; import { areAppConfigsEqual, areChatOptionsEqual } from "./utils"; /* Service Worker Script */ @@ -36,10 +40,10 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { >(); private initReuqestUuid?: string; - constructor(engine: MLCEngineInterface) { + constructor(engine: MLCEngineInterface, verbose = false) { if (!self || !("addEventListener" in self)) { throw new Error( - "ServiceWorkerGlobalScope is not defined. ServiceWorkerMLCEngineHandler must be created in service worker script." + "ServiceWorkerGlobalScope is not defined. ServiceWorkerMLCEngineHandler must be created in service worker script.", ); } const postMessageHandler = { @@ -75,7 +79,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { message.waitUntil( new Promise((resolve, reject) => { onmessage(message, resolve, reject); - }) + }), ); }); } @@ -83,7 +87,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { onmessage( event: ExtendableMessageEvent, onComplete?: (value: any) => void, - onError?: () => void + onError?: () => void, ): void { const msg = event.data as WorkerRequest; @@ -131,7 +135,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { await this.engine.reload( params.modelId, params.chatOpts, - params.appConfig + params.appConfig, ); this.modelId = params.modelId; this.chatOpts = params.chatOpts; @@ -151,14 +155,13 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { */ export class ServiceWorker implements ChatWorker { serviceWorker: IServiceWorker; + onmessage: () => void; constructor(serviceWorker: IServiceWorker) { this.serviceWorker = serviceWorker; + this.onmessage = () => {}; } - // ServiceWorkerMLCEngine will later overwrite this - onmessage() {} - postMessage(message: WorkerRequest) { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); @@ -182,21 +185,23 @@ export class ServiceWorker implements ChatWorker { */ export async function CreateServiceWorkerMLCEngine( modelId: string, - engineConfig?: MLCEngineConfig + engineConfig?: MLCEngineConfig, ): Promise { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); } const registration = await (navigator.serviceWorker as ServiceWorkerContainer) .ready; - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(registration.active!); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( + registration.active!, + ); serviceWorkerMLCEngine.setInitProgressCallback( - engineConfig?.initProgressCallback + engineConfig?.initProgressCallback, ); await serviceWorkerMLCEngine.init( modelId, engineConfig?.chatOpts, - engineConfig?.appConfig + engineConfig?.appConfig, ); return serviceWorkerMLCEngine; } @@ -207,7 +212,7 @@ export async function CreateServiceWorkerMLCEngine( export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { missedHeatbeat = 0; - constructor(worker: IServiceWorker, keepAliveMs = 10000) { + constructor(worker: IServiceWorker, keepAliveMs = 10000, verbose = false) { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); } @@ -230,7 +235,7 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { console.error("CreateWebServiceWorkerMLCEngine.onmessage", err); } } - } + }, ); setInterval(() => { @@ -253,7 +258,7 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { async init( modelId: string, chatOpts?: ChatOptions, - appConfig?: AppConfig + appConfig?: AppConfig, ): Promise { const msg: WorkerRequest = { kind: "init", diff --git a/src/support.ts b/src/support.ts index b3607e63..ecd91465 100644 --- a/src/support.ts +++ b/src/support.ts @@ -4,87 +4,92 @@ import { Tokenizer } from "@mlc-ai/web-tokenizers"; /** * Based on `p_prob` of size (vocabSize,) which becomes a distribution after calling * `applySoftmaxWithTemperature()`, sample `top_logprobs` top-probable tokens. - * + * * @param num_top_probs: `top_logprobs` from ChatCompletionRequest * @param p_prob: `logitsOnCPUArray`, being a distribution after `applySoftmaxWithTemperature()`. - * + * * Followed implementation of `ComputeTopProbsImpl()` from [https://github.com/mlc-ai/mlc-llm/blob/ * 5b8c529e9704abd09b0432da6dcb4b013fdf43b1/cpp/serve/sampler/cpu_sampler.cc]. - * + * * @returns Arrays of (tokenID, prob) pairs, ranked from highest prob to least. */ export function getTopProbs( - num_top_probs: number, p_prob: Float32Array + num_top_probs: number, + p_prob: Float32Array, ): Array<[number, number]> { - if (num_top_probs == 0) return []; - // Initialize to dummy values - const top_probs: Array<[number, number]> = []; - const ndata = p_prob.length; - for (let i = 0; i < num_top_probs; i++) { - top_probs.push([-1, -1.0]); - } + if (num_top_probs == 0) return []; + // Initialize to dummy values + const top_probs: Array<[number, number]> = []; + const ndata = p_prob.length; + for (let i = 0; i < num_top_probs; i++) { + top_probs.push([-1, -1.0]); + } - let sum_prob = 0.0; - // Selection argsort. - for (let p = 0; p < ndata; p++) { - let i = num_top_probs - 1; - for (; i >= 0; --i) { - if (p_prob[p] > top_probs[i][1]) { - if (i !== num_top_probs - 1) { - top_probs[i + 1] = top_probs[i]; - } - } else { - break; - } - } + let sum_prob = 0.0; + // Selection argsort. + for (let p = 0; p < ndata; p++) { + let i = num_top_probs - 1; + for (; i >= 0; --i) { + if (p_prob[p] > top_probs[i][1]) { if (i !== num_top_probs - 1) { - top_probs[i + 1] = [p, p_prob[p]]; + top_probs[i + 1] = top_probs[i]; } + } else { + break; + } + } + if (i !== num_top_probs - 1) { + top_probs[i + 1] = [p, p_prob[p]]; + } - // Early exit - sum_prob += p_prob[p]; - if (1 - sum_prob <= top_probs[num_top_probs - 1][1]) { - break; - } + // Early exit + sum_prob += p_prob[p]; + if (1 - sum_prob <= top_probs[num_top_probs - 1][1]) { + break; } - return top_probs; + } + return top_probs; } /** * Post-process a raw token (which may be a raw byte or contain lower one eights block) to the * actual token. We do this in order to conform with the tokenizers' setup. - * + * * Follow implementation of [https://github.com/mlc-ai/mlc-llm/blob/ * bcb9b6a33a672a70d760c9a8b03234124aab50c4/cpp/tokenizers.cc#L99] */ export function postProcessToken(token: string): string { - // 1. The token represents a byte. - const charCode0 = "0".charCodeAt(0); - const charCode9 = "9".charCodeAt(0); - const charCodeA = "A".charCodeAt(0); - if (token.length == 6 && token.substring(0, 3) === "<0x" && token.slice(-1) === ">") { - let byte = 0; - for (let i = 0; i < 2; i++) { - byte *= 16; - const curCharCode = token.charCodeAt(3 + i); - if (curCharCode >= charCode0 && curCharCode <= charCode9) { - byte += curCharCode - charCode0; - } else { - byte += curCharCode - charCodeA + 10; - } - } - if (byte < 0 || byte >= 256) { - throw Error("Expect byte to be in range [0, 256).") - } - return String.fromCharCode(byte); + // 1. The token represents a byte. + const charCode0 = "0".charCodeAt(0); + const charCode9 = "9".charCodeAt(0); + const charCodeA = "A".charCodeAt(0); + if ( + token.length == 6 && + token.substring(0, 3) === "<0x" && + token.slice(-1) === ">" + ) { + let byte = 0; + for (let i = 0; i < 2; i++) { + byte *= 16; + const curCharCode = token.charCodeAt(3 + i); + if (curCharCode >= charCode0 && curCharCode <= charCode9) { + byte += curCharCode - charCode0; + } else { + byte += curCharCode - charCodeA + 10; + } } + if (byte < 0 || byte >= 256) { + throw Error("Expect byte to be in range [0, 256)."); + } + return String.fromCharCode(byte); + } - // 2. The token contains lower one eight block which means space, e.g. `▁response` in Llama-2. - // https://www.compart.com/en/unicode/U+2581 - const lowerOneEighthBlock = "\u2581"; - token = token.split(lowerOneEighthBlock).join(" "); + // 2. The token contains lower one eight block which means space, e.g. `▁response` in Llama-2. + // https://www.compart.com/en/unicode/U+2581 + const lowerOneEighthBlock = "\u2581"; + token = token.split(lowerOneEighthBlock).join(" "); - return token; + return token; } /** @@ -92,11 +97,11 @@ export function postProcessToken(token: string): string { * @param tokenizer A loaded tokenizer. */ export function getTokenTableFromTokenizer(tokenizer: Tokenizer): string[] { - const tokenTable: string[] = []; - const vocabSize = tokenizer.getVocabSize(); - for (let tokenId = 0; tokenId < vocabSize; tokenId++) { - const token = tokenizer.idToToken(tokenId); - tokenTable.push(postProcessToken(token)); - } - return tokenTable; + const tokenTable: string[] = []; + const vocabSize = tokenizer.getVocabSize(); + for (let tokenId = 0; tokenId < vocabSize; tokenId++) { + const token = tokenizer.idToToken(tokenId); + tokenTable.push(postProcessToken(token)); + } + return tokenTable; } diff --git a/src/types.ts b/src/types.ts index d1127102..e2dfcec1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -26,7 +26,10 @@ export type InitProgressCallback = (report: InitProgressReport) => void; /** * Callbacks used to report initialization process. */ -export type GenerateProgressCallback = (step: number, currentMessage: string) => void; +export type GenerateProgressCallback = ( + step: number, + currentMessage: string, +) => void; /** * A stateful logitProcessor used to post-process logits after forwarding the input and before @@ -54,7 +57,6 @@ export interface LogitProcessor { resetState: () => void; } - /** * Common interface of MLCEngine that UI can interact with */ @@ -90,7 +92,10 @@ export interface MLCEngineInterface { * @note This is an async function. */ reload: ( - modelId: string, chatOpts?: ChatOptions, appConfig?: AppConfig) => Promise; + modelId: string, + chatOpts?: ChatOptions, + appConfig?: AppConfig, + ) => Promise; /** * Generate a response for a given input. @@ -100,7 +105,7 @@ export interface MLCEngineInterface { * @param streamInterval callback interval to call progresscallback * @param genConfig Configuration for this single generation that overrides pre-existing configs. * @returns The final result. - * + * * @note This will be deprecated soon. Please use `engine.chat.completions.create()` instead. * For multi-round chatting, see `examples/multi-round-chat` on how to use * `engine.chat.completions.create()` to achieve the same effect. @@ -114,7 +119,7 @@ export interface MLCEngineInterface { /** * OpenAI-style API. Generate a chat completion response for the given conversation and configuration. - * + * * The API is completely functional in behavior. That is, a previous request would not affect * the current request's result. Thus, for multi-round chatting, users are responsible for * maintaining the chat history. With that being said, as an implicit internal optimization, if we @@ -122,16 +127,16 @@ export interface MLCEngineInterface { * prefill the new tokens. */ chatCompletion( - request: ChatCompletionRequestNonStreaming + request: ChatCompletionRequestNonStreaming, ): Promise; chatCompletion( - request: ChatCompletionRequestStreaming + request: ChatCompletionRequestStreaming, ): Promise>; chatCompletion( - request: ChatCompletionRequestBase + request: ChatCompletionRequestBase, ): Promise | ChatCompletion>; chatCompletion( - request: ChatCompletionRequest + request: ChatCompletionRequest, ): Promise | ChatCompletion>; /** @@ -185,6 +190,8 @@ export interface MLCEngineInterface { * @returns Next token sampled. * @note This is an async function. */ - forwardTokensAndSample(inputIds: Array, isPrefill: boolean): Promise; + forwardTokensAndSample( + inputIds: Array, + isPrefill: boolean, + ): Promise; } - diff --git a/src/utils.ts b/src/utils.ts index 0b5b51a8..ad6cfcc0 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -31,7 +31,7 @@ function areObjectsEqual(obj1: any, obj2: any): boolean { // Function to compare two ModelRecord instances export function areModelRecordsEqual( record1: ModelRecord, - record2: ModelRecord + record2: ModelRecord, ): boolean { // Compare primitive fields if ( @@ -70,7 +70,7 @@ export function areModelRecordsEqual( export function areAppConfigsEqual( config1?: AppConfig, - config2?: AppConfig + config2?: AppConfig, ): boolean { if (config1 === undefined || config2 === undefined) { return config1 === config2; @@ -99,7 +99,7 @@ export function areAppConfigsEqual( export function areChatOptionsEqual( options1?: ChatOptions, - options2?: ChatOptions + options2?: ChatOptions, ): boolean { if (options1 === undefined || options2 === undefined) { return options1 === options2; diff --git a/src/web_worker.ts b/src/web_worker.ts index 44c4ac43..cc1c95b7 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -65,7 +65,7 @@ export class MLCEngineWorkerHandler { constructor( engine: MLCEngineInterface, postMessageHandler?: PostMessageHandler, - initProgressCallback?: (report: InitProgressReport) => void + initProgressCallback?: (report: InitProgressReport) => void, ) { this.engine = engine; this.postMessageHandler = postMessageHandler; @@ -78,7 +78,7 @@ export class MLCEngineWorkerHandler { this.postMessageInternal(msg); }; this.engine.setInitProgressCallback( - initProgressCallback || defaultInitProgressCallback + initProgressCallback || defaultInitProgressCallback, ); } @@ -95,7 +95,7 @@ export class MLCEngineWorkerHandler { async handleTask( uuid: string, - task: () => Promise + task: () => Promise, ) { try { const res = await task(); @@ -119,7 +119,7 @@ export class MLCEngineWorkerHandler { onmessage( event: any, onComplete?: (value: any) => void, - onError?: () => void + onError?: () => void, ) { let msg: WorkerRequest; if (event instanceof MessageEvent) { @@ -134,7 +134,7 @@ export class MLCEngineWorkerHandler { await this.engine.reload( params.modelId, params.chatOpts, - params.appConfig + params.appConfig, ); onComplete?.(null); return null; @@ -159,7 +159,7 @@ export class MLCEngineWorkerHandler { params.input, progressCallback, params.streamInterval, - params.genConfig + params.genConfig, ); onComplete?.(res); return res; @@ -171,7 +171,7 @@ export class MLCEngineWorkerHandler { const params = msg.content as ForwardTokensAndSampleParams; const res = await this.engine.forwardTokensAndSample( params.inputIds, - params.isPrefill + params.isPrefill, ); onComplete?.(res); return res; @@ -194,7 +194,7 @@ export class MLCEngineWorkerHandler { const params = msg.content as ChatCompletionStreamInitParams; this.chatCompletionAsyncChunkGenerator = (await this.engine.chatCompletion( - params.request + params.request, )) as AsyncGenerator; onComplete?.(null); return null; @@ -206,7 +206,7 @@ export class MLCEngineWorkerHandler { this.handleTask(msg.uuid, async () => { if (this.chatCompletionAsyncChunkGenerator === undefined) { throw Error( - "Chunk generator in worker should be instantiated by now." + "Chunk generator in worker should be instantiated by now.", ); } // Yield the next chunk @@ -281,7 +281,7 @@ export class MLCEngineWorkerHandler { if (msg.kind && msg.content) { onError?.(); throw Error( - "Unknown message kind, msg: [" + msg.kind + "] " + msg.content + "Unknown message kind, msg: [" + msg.kind + "] " + msg.content, ); } else { // Ignore irrelavent events @@ -313,14 +313,16 @@ export interface ChatWorker { export async function CreateWebWorkerMLCEngine( worker: any, modelId: string, - engineConfig?: MLCEngineConfig + engineConfig?: MLCEngineConfig, ): Promise { const webWorkerMLCEngine = new WebWorkerMLCEngine(worker); - webWorkerMLCEngine.setInitProgressCallback(engineConfig?.initProgressCallback); + webWorkerMLCEngine.setInitProgressCallback( + engineConfig?.initProgressCallback, + ); await webWorkerMLCEngine.reload( modelId, engineConfig?.chatOpts, - engineConfig?.appConfig + engineConfig?.appConfig, ); return webWorkerMLCEngine; } @@ -349,7 +351,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { constructor(worker: ChatWorker) { this.worker = worker; worker.onmessage = (event: any) => { - this.onmessage(event); + this.onmessage.bind(this)(event); }; this.chat = new API.Chat(this); } @@ -363,12 +365,12 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { } protected getPromise( - msg: WorkerRequest + msg: WorkerRequest, ): Promise { const uuid = msg.uuid; const executor = ( resolve: (arg: T) => void, - reject: (arg: any) => void + reject: (arg: any) => void, ) => { const cb = (msg: WorkerResponse) => { if (msg.kind == "return") { @@ -391,7 +393,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { async reload( modelId: string, chatOpts?: ChatOptions, - appConfig?: AppConfig + appConfig?: AppConfig, ): Promise { const msg: WorkerRequest = { kind: "reload", @@ -436,7 +438,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { input: string | ChatCompletionRequestNonStreaming, progressCallback?: GenerateProgressCallback, streamInterval?: number, - genConfig?: GenerationConfig + genConfig?: GenerationConfig, ): Promise { const msg: WorkerRequest = { kind: "generate", @@ -493,7 +495,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { async forwardTokensAndSample( inputIds: Array, - isPrefill: boolean + isPrefill: boolean, ): Promise { const msg: WorkerRequest = { kind: "forwardTokensAndSample", @@ -534,16 +536,16 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { } async chatCompletion( - request: ChatCompletionRequestNonStreaming + request: ChatCompletionRequestNonStreaming, ): Promise; async chatCompletion( - request: ChatCompletionRequestStreaming + request: ChatCompletionRequestStreaming, ): Promise>; async chatCompletion( - request: ChatCompletionRequestBase + request: ChatCompletionRequestBase, ): Promise | ChatCompletion>; async chatCompletion( - request: ChatCompletionRequest + request: ChatCompletionRequest, ): Promise | ChatCompletion> { if (request.stream) { // First let worker instantiate a generator @@ -612,9 +614,12 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return; } default: { - let unknownMsg = msg as any; + const unknownMsg = msg as any; throw Error( - "Unknown message kind, msg=[" + unknownMsg.kind + "] " + unknownMsg.content + "Unknown message kind, msg=[" + + unknownMsg.kind + + "] " + + unknownMsg.content, ); } }