diff --git a/examples/abort-reload/README.md b/examples/abort-reload/README.md
new file mode 100644
index 00000000..6365ec36
--- /dev/null
+++ b/examples/abort-reload/README.md
@@ -0,0 +1,13 @@
+# WebLLM Get Started App
+
+This folder provides a demo for cancelling model fetching after calling `engine.reload()`.
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/examples/abort-reload/package.json b/examples/abort-reload/package.json
new file mode 100644
index 00000000..12c603a9
--- /dev/null
+++ b/examples/abort-reload/package.json
@@ -0,0 +1,20 @@
+{
+ "name": "get-started",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel src/get_started.html --port 8887",
+ "build": "parcel build src/get_started.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "file:../../lib"
+ }
+}
diff --git a/examples/abort-reload/src/get_started.html b/examples/abort-reload/src/get_started.html
new file mode 100644
index 00000000..1ce32e28
--- /dev/null
+++ b/examples/abort-reload/src/get_started.html
@@ -0,0 +1,23 @@
+
+
+
+
+ WebLLM Test Page
+ Open console to see output
+
+
+
+
+ Prompt
+
+
+ Response
+
+
+
+
+
+
+
diff --git a/examples/abort-reload/src/get_started.js b/examples/abort-reload/src/get_started.js
new file mode 100644
index 00000000..cbd7f24a
--- /dev/null
+++ b/examples/abort-reload/src/get_started.js
@@ -0,0 +1,32 @@
+import * as webllm from "@mlc-ai/web-llm";
+import { error } from "loglevel";
+
+let engine;
+
+function setLabel(id, text) {
+ const label = document.getElementById(id);
+ if (label == null) {
+ throw Error("Cannot find label " + id);
+ }
+ label.innerText = text;
+}
+
+async function main() {
+ const initProgressCallback = (report) => {
+ console.log(report.text);
+ setLabel("init-label", report.text);
+ };
+ // Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
+ const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
+ engine = new webllm.MLCEngine({
+ initProgressCallback,
+ });
+ engine.reload(selectedModel);
+}
+main();
+setTimeout(() => {
+ console.log("calling unload");
+ engine.unload().catch((err) => {
+ console.log(err);
+ });
+}, 5000);
diff --git a/src/engine.ts b/src/engine.ts
index 8eb99520..e4cbb483 100644
--- a/src/engine.ts
+++ b/src/engine.ts
@@ -103,6 +103,7 @@ export class MLCEngine implements MLCEngineInterface {
private initProgressCallback?: InitProgressCallback;
private interruptSignal = false;
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
+ private reloadController: AbortController | undefined;
private config?: ChatConfig;
private appConfig: AppConfig;
@@ -143,7 +144,25 @@ export class MLCEngine implements MLCEngineInterface {
*/
async reload(modelId: string, chatOpts?: ChatOptions): Promise {
await this.unload();
+ this.reloadController = new AbortController();
+ try {
+ await this.reloadInternal(modelId, chatOpts);
+ } catch (error) {
+ if (error instanceof DOMException && error.name === "AbortError") {
+ log.warn("Reload() is aborted.", error.message);
+ return;
+ }
+ throw error;
+ } finally {
+ this.reloadController = undefined;
+ }
+ }
+
+ private async reloadInternal(
+ modelId: string,
+ chatOpts?: ChatOptions,
+ ): Promise {
this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
const tstart = performance.now();
@@ -175,7 +194,11 @@ export class MLCEngine implements MLCEngineInterface {
// load config
const configUrl = new URL("mlc-chat-config.json", modelUrl).href;
this.config = {
- ...(await configCache.fetchWithCache(configUrl, "json")),
+ ...(await configCache.fetchWithCache(
+ configUrl,
+ "json",
+ this.reloadController?.signal,
+ )),
...modelRecord.overrides,
...chatOpts,
} as ChatConfig;
@@ -201,8 +224,11 @@ export class MLCEngine implements MLCEngineInterface {
// rely on the normal caching strategy
return (await fetch(new URL(wasmUrl, baseUrl).href)).arrayBuffer();
} else {
- // use cache
- return await wasmCache.fetchWithCache(wasmUrl, "arraybuffer");
+ return await wasmCache.fetchWithCache(
+ wasmUrl,
+ "arraybuffer",
+ this.reloadController?.signal,
+ );
}
};
const wasmSource = await fetchWasmSource();
@@ -248,7 +274,8 @@ export class MLCEngine implements MLCEngineInterface {
gpuDetectOutput.device.lost.then((info: any) => {
if (this.deviceLostIsError) {
log.error(
- `Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
+ `Device was lost. This can happen due to insufficient memory or other GPU constraints. ` +
+ `Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
);
this.unload();
deviceLostInReload = true;
@@ -267,6 +294,7 @@ export class MLCEngine implements MLCEngineInterface {
tvm.webgpu(),
"webllm/model",
cacheType,
+ this.reloadController?.signal,
);
this.pipeline = new LLMChatPipeline(
tvm,
@@ -646,12 +674,23 @@ export class MLCEngine implements MLCEngineInterface {
this.pipeline?.resetChat(keepStats);
}
+ /**
+ * Unloads the currently loaded model and destroy the webgpu device. Waits
+ * until the webgpu device finishes all submitted work and destroys itself.
+ * @note This is an asynchronous function.
+ */
async unload() {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
this.pipeline?.dispose();
+ // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
+ await this.pipeline?.sync();
this.pipeline = undefined;
this.currentModelId = undefined;
this.deviceLostIsError = true;
+ if (this.reloadController) {
+ this.reloadController.abort("Engine.unload() is called.");
+ this.reloadController = undefined;
+ }
}
async getMaxStorageBufferBindingSize(): Promise {
diff --git a/src/llm_chat.ts b/src/llm_chat.ts
index af6803e9..9a97ac92 100644
--- a/src/llm_chat.ts
+++ b/src/llm_chat.ts
@@ -1040,6 +1040,14 @@ export class LLMChatPipeline {
} as ChatCompletionTokenLogprob;
}
+ /**
+ * Synchronize the device.
+ */
+ async sync(): Promise {
+ // Is it equivalent to this.tvm.sync()?
+ await this.device.sync();
+ }
+
async evaluate() {
// run a canonical evaluation of the flow
this.resetKVCache();