Skip to content

Commit

Permalink
[Engine] Allow manually aborting reload, fix unexpected deviceLostErr…
Browse files Browse the repository at this point in the history
…or (#525)

### Manually aborting reload
This PR updates the engine `reload()` and `unload()` methods to allow
users to abort an uncompleted `reload()` by either:

- call `unload()` any time before `reload()` completed
- call `reload()` again before the previous `reload()` completed

### Note on unload() and unexpected device lost error

Previously, we had an issue where a device lost error is reported when
we simply switch a model intentionally (i.e. calling `reload()`). This
is because `unload()` sets `deviceLostIsError` back to true immediately
after calling `this.pipeline.dispose()`, which destroys the WebGPU
device internally. However, WebGPU is asynchronous and may not finish
after `dispose()` returns. This PR also fixes this issue by making
`unload()` wait until the device is actually destroyed by introducing
`LLMChatPipeline.sync()`. 

---------

Co-authored-by: Charlie Ruan <[email protected]>
  • Loading branch information
Neet-Nestor and CharlieFRuan authored Aug 8, 2024
1 parent 7690707 commit ddac6d1
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 4 deletions.
13 changes: 13 additions & 0 deletions examples/abort-reload/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# WebLLM Get Started App

This folder provides a demo for cancelling model fetching after calling `engine.reload()`.

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/abort-reload/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "get-started",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/get_started.html --port 8887",
"build": "parcel build src/get_started.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../../lib"
}
}
23 changes: 23 additions & 0 deletions examples/abort-reload/src/get_started.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./get_started.js"></script>
</body>
</html>
32 changes: 32 additions & 0 deletions examples/abort-reload/src/get_started.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import * as webllm from "@mlc-ai/web-llm";
import { error } from "loglevel";

let engine;

function setLabel(id, text) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

async function main() {
const initProgressCallback = (report) => {
console.log(report.text);
setLabel("init-label", report.text);
};
// Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
engine = new webllm.MLCEngine({
initProgressCallback,
});
engine.reload(selectedModel);
}
main();
setTimeout(() => {
console.log("calling unload");
engine.unload().catch((err) => {
console.log(err);
});
}, 5000);
47 changes: 43 additions & 4 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ export class MLCEngine implements MLCEngineInterface {
private initProgressCallback?: InitProgressCallback;
private interruptSignal = false;
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
private reloadController: AbortController | undefined;
private config?: ChatConfig;
private appConfig: AppConfig;

Expand Down Expand Up @@ -143,7 +144,25 @@ export class MLCEngine implements MLCEngineInterface {
*/
async reload(modelId: string, chatOpts?: ChatOptions): Promise<void> {
await this.unload();
this.reloadController = new AbortController();

try {
await this.reloadInternal(modelId, chatOpts);
} catch (error) {
if (error instanceof DOMException && error.name === "AbortError") {
log.warn("Reload() is aborted.", error.message);
return;
}
throw error;
} finally {
this.reloadController = undefined;
}
}

private async reloadInternal(
modelId: string,
chatOpts?: ChatOptions,
): Promise<void> {
this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
const tstart = performance.now();

Expand Down Expand Up @@ -175,7 +194,11 @@ export class MLCEngine implements MLCEngineInterface {
// load config
const configUrl = new URL("mlc-chat-config.json", modelUrl).href;
this.config = {
...(await configCache.fetchWithCache(configUrl, "json")),
...(await configCache.fetchWithCache(
configUrl,
"json",
this.reloadController?.signal,
)),
...modelRecord.overrides,
...chatOpts,
} as ChatConfig;
Expand All @@ -201,8 +224,11 @@ export class MLCEngine implements MLCEngineInterface {
// rely on the normal caching strategy
return (await fetch(new URL(wasmUrl, baseUrl).href)).arrayBuffer();
} else {
// use cache
return await wasmCache.fetchWithCache(wasmUrl, "arraybuffer");
return await wasmCache.fetchWithCache(
wasmUrl,
"arraybuffer",
this.reloadController?.signal,
);
}
};
const wasmSource = await fetchWasmSource();
Expand Down Expand Up @@ -248,7 +274,8 @@ export class MLCEngine implements MLCEngineInterface {
gpuDetectOutput.device.lost.then((info: any) => {
if (this.deviceLostIsError) {
log.error(
`Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
`Device was lost. This can happen due to insufficient memory or other GPU constraints. ` +
`Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
);
this.unload();
deviceLostInReload = true;
Expand All @@ -267,6 +294,7 @@ export class MLCEngine implements MLCEngineInterface {
tvm.webgpu(),
"webllm/model",
cacheType,
this.reloadController?.signal,
);
this.pipeline = new LLMChatPipeline(
tvm,
Expand Down Expand Up @@ -646,12 +674,23 @@ export class MLCEngine implements MLCEngineInterface {
this.pipeline?.resetChat(keepStats);
}

/**
* Unloads the currently loaded model and destroy the webgpu device. Waits
* until the webgpu device finishes all submitted work and destroys itself.
* @note This is an asynchronous function.
*/
async unload() {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
this.pipeline?.dispose();
// Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
await this.pipeline?.sync();
this.pipeline = undefined;
this.currentModelId = undefined;
this.deviceLostIsError = true;
if (this.reloadController) {
this.reloadController.abort("Engine.unload() is called.");
this.reloadController = undefined;
}
}

async getMaxStorageBufferBindingSize(): Promise<number> {
Expand Down
8 changes: 8 additions & 0 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,14 @@ export class LLMChatPipeline {
} as ChatCompletionTokenLogprob;
}

/**
* Synchronize the device.
*/
async sync(): Promise<void> {
// Is it equivalent to this.tvm.sync()?
await this.device.sync();
}

async evaluate() {
// run a canonical evaluation of the flow
this.resetKVCache();
Expand Down

0 comments on commit ddac6d1

Please sign in to comment.