diff --git a/examples/abort-reload/README.md b/examples/abort-reload/README.md new file mode 100644 index 00000000..6365ec36 --- /dev/null +++ b/examples/abort-reload/README.md @@ -0,0 +1,13 @@ +# WebLLM Get Started App + +This folder provides a demo for cancelling model fetching after calling `engine.reload()`. + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/abort-reload/package.json b/examples/abort-reload/package.json new file mode 100644 index 00000000..12c603a9 --- /dev/null +++ b/examples/abort-reload/package.json @@ -0,0 +1,20 @@ +{ + "name": "get-started", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/get_started.html --port 8887", + "build": "parcel build src/get_started.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "file:../../lib" + } +} diff --git a/examples/abort-reload/src/get_started.html b/examples/abort-reload/src/get_started.html new file mode 100644 index 00000000..1ce32e28 --- /dev/null +++ b/examples/abort-reload/src/get_started.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/abort-reload/src/get_started.js b/examples/abort-reload/src/get_started.js new file mode 100644 index 00000000..cbd7f24a --- /dev/null +++ b/examples/abort-reload/src/get_started.js @@ -0,0 +1,32 @@ +import * as webllm from "@mlc-ai/web-llm"; +import { error } from "loglevel"; + +let engine; + +function setLabel(id, text) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +async function main() { + const initProgressCallback = (report) => { + console.log(report.text); + setLabel("init-label", report.text); + }; + // Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts` + const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC"; + engine = new webllm.MLCEngine({ + initProgressCallback, + }); + engine.reload(selectedModel); +} +main(); +setTimeout(() => { + console.log("calling unload"); + engine.unload().catch((err) => { + console.log(err); + }); +}, 5000); diff --git a/src/engine.ts b/src/engine.ts index 8eb99520..e4cbb483 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -103,6 +103,7 @@ export class MLCEngine implements MLCEngineInterface { private initProgressCallback?: InitProgressCallback; private interruptSignal = false; private deviceLostIsError = true; // whether device.lost is due to actual error or model reload + private reloadController: AbortController | undefined; private config?: ChatConfig; private appConfig: AppConfig; @@ -143,7 +144,25 @@ export class MLCEngine implements MLCEngineInterface { */ async reload(modelId: string, chatOpts?: ChatOptions): Promise { await this.unload(); + this.reloadController = new AbortController(); + try { + await this.reloadInternal(modelId, chatOpts); + } catch (error) { + if (error instanceof DOMException && error.name === "AbortError") { + log.warn("Reload() is aborted.", error.message); + return; + } + throw error; + } finally { + this.reloadController = undefined; + } + } + + private async reloadInternal( + modelId: string, + chatOpts?: ChatOptions, + ): Promise { this.logitProcessor = this.logitProcessorRegistry?.get(modelId); const tstart = performance.now(); @@ -175,7 +194,11 @@ export class MLCEngine implements MLCEngineInterface { // load config const configUrl = new URL("mlc-chat-config.json", modelUrl).href; this.config = { - ...(await configCache.fetchWithCache(configUrl, "json")), + ...(await configCache.fetchWithCache( + configUrl, + "json", + this.reloadController?.signal, + )), ...modelRecord.overrides, ...chatOpts, } as ChatConfig; @@ -201,8 +224,11 @@ export class MLCEngine implements MLCEngineInterface { // rely on the normal caching strategy return (await fetch(new URL(wasmUrl, baseUrl).href)).arrayBuffer(); } else { - // use cache - return await wasmCache.fetchWithCache(wasmUrl, "arraybuffer"); + return await wasmCache.fetchWithCache( + wasmUrl, + "arraybuffer", + this.reloadController?.signal, + ); } }; const wasmSource = await fetchWasmSource(); @@ -248,7 +274,8 @@ export class MLCEngine implements MLCEngineInterface { gpuDetectOutput.device.lost.then((info: any) => { if (this.deviceLostIsError) { log.error( - `Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`, + `Device was lost. This can happen due to insufficient memory or other GPU constraints. ` + + `Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`, ); this.unload(); deviceLostInReload = true; @@ -267,6 +294,7 @@ export class MLCEngine implements MLCEngineInterface { tvm.webgpu(), "webllm/model", cacheType, + this.reloadController?.signal, ); this.pipeline = new LLMChatPipeline( tvm, @@ -646,12 +674,23 @@ export class MLCEngine implements MLCEngineInterface { this.pipeline?.resetChat(keepStats); } + /** + * Unloads the currently loaded model and destroy the webgpu device. Waits + * until the webgpu device finishes all submitted work and destroys itself. + * @note This is an asynchronous function. + */ async unload() { this.deviceLostIsError = false; // so that unload() does not trigger device.lost error this.pipeline?.dispose(); + // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true + await this.pipeline?.sync(); this.pipeline = undefined; this.currentModelId = undefined; this.deviceLostIsError = true; + if (this.reloadController) { + this.reloadController.abort("Engine.unload() is called."); + this.reloadController = undefined; + } } async getMaxStorageBufferBindingSize(): Promise { diff --git a/src/llm_chat.ts b/src/llm_chat.ts index af6803e9..9a97ac92 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -1040,6 +1040,14 @@ export class LLMChatPipeline { } as ChatCompletionTokenLogprob; } + /** + * Synchronize the device. + */ + async sync(): Promise { + // Is it equivalent to this.tvm.sync()? + await this.device.sync(); + } + async evaluate() { // run a canonical evaluation of the flow this.resetKVCache();