From b4bc1121a61b46a83b47a8a88a0de957114160ca Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Sun, 26 May 2024 18:38:49 -0400 Subject: [PATCH 1/2] [Log] Set log level using 'loglevel' package --- package-lock.json | 15 +++++++ package.json | 3 ++ src/config.ts | 13 +++--- src/engine.ts | 14 ++++--- src/extension_service_worker.ts | 14 ++++--- src/llm_chat.ts | 20 ++++----- src/service_worker.ts | 27 ++++++++---- src/types.ts | 2 + src/web_worker.ts | 5 ++- .../src/vram_requirements.ts | 41 +++++++++++-------- 10 files changed, 101 insertions(+), 53 deletions(-) diff --git a/package-lock.json b/package-lock.json index e33d27da..70a66922 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8,6 +8,9 @@ "name": "@mlc-ai/web-llm", "version": "0.2.38", "license": "Apache-2.0", + "dependencies": { + "loglevel": "^1.9.1" + }, "devDependencies": { "@mlc-ai/web-tokenizers": "^0.1.3", "@rollup/plugin-commonjs": "^20.0.0", @@ -6116,6 +6119,18 @@ "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", "dev": true }, + "node_modules/loglevel": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/loglevel/-/loglevel-1.9.1.tgz", + "integrity": "sha512-hP3I3kCrDIMuRwAwHltphhDM1r8i55H33GgqjXbrisuJhF4kRhW1dNuxsRklp4bXl8DSdLaNLuiL4A/LWRfxvg==", + "engines": { + "node": ">= 0.6.0" + }, + "funding": { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/loglevel" + } + }, "node_modules/lru-cache": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", diff --git a/package.json b/package.json index df4a331e..1184654c 100644 --- a/package.json +++ b/package.json @@ -51,5 +51,8 @@ "tslib": "^2.3.1", "tvmjs": "file:./tvm_home/web", "typescript": "^4.9.5" + }, + "dependencies": { + "loglevel": "^1.9.1" } } diff --git a/src/config.ts b/src/config.ts index 0a5b3fd1..b758baf7 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ - +import log from "loglevel"; import { ResponseFormat } from "./openai_api_protocols"; -import { LogitProcessor, InitProgressCallback } from "./types"; +import { LogitProcessor, InitProgressCallback, LogLevel } from "./types"; /** * Conversation template config @@ -26,6 +26,8 @@ export enum Role { assistant = "assistant", } +export const DefaultLogLevel: LogLevel = "WARN"; + /** * Place holders that can be used in role templates. * For example, a role template of @@ -91,6 +93,7 @@ export interface MLCEngineConfig { appConfig?: AppConfig; initProgressCallback?: InitProgressCallback; logitProcessorRegistry?: Map; + logLevel: LogLevel; } /** @@ -167,16 +170,14 @@ export function postInitAndCheckGenerationConfigValues( !_hasValue(config.presence_penalty) ) { config.presence_penalty = 0.0; - console.log( - "Only frequency_penalty is set; we default presence_penaty to 0.", - ); + log.warn("Only frequency_penalty is set; we default presence_penaty to 0."); } if ( _hasValue(config.presence_penalty) && !_hasValue(config.frequency_penalty) ) { config.frequency_penalty = 0.0; - console.log( + log.warn( "Only presence_penalty is set; we default frequency_penalty to 0.", ); } diff --git a/src/engine.ts b/src/engine.ts index 3c76a721..2b3b3af5 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -1,4 +1,5 @@ import * as tvmjs from "tvmjs"; +import log from "loglevel"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; import * as API from "./openai_api_protocols/apis"; import { @@ -10,6 +11,7 @@ import { postInitAndCheckGenerationConfigValues, Role, MLCEngineConfig, + DefaultLogLevel, } from "./config"; import { LLMChatPipeline } from "./llm_chat"; import { @@ -30,6 +32,7 @@ import { MLCEngineInterface, GenerateProgressCallback, LogitProcessor, + LogLevel, } from "./types"; import { Conversation, @@ -60,6 +63,7 @@ export async function CreateMLCEngine( modelId: string, engineConfig?: MLCEngineConfig, ): Promise { + log.setLevel(engineConfig?.logLevel || DefaultLogLevel); const engine = new MLCEngine(); engine.setInitProgressCallback(engineConfig?.initProgressCallback); engine.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry); @@ -76,7 +80,7 @@ export class MLCEngine implements MLCEngineInterface { public chat: API.Chat; private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded - private logger: (msg: string) => void = console.log; + private logger: (msg: string) => void = log.info; private logitProcessorRegistry?: Map; private logitProcessor?: LogitProcessor; private pipeline?: LLMChatPipeline; @@ -238,7 +242,7 @@ export class MLCEngine implements MLCEngineInterface { let deviceLostInReload = false; gpuDetectOutput.device.lost.then((info: any) => { if (this.deviceLostIsError) { - console.error( + log.error( `Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`, ); this.unload(); @@ -291,7 +295,7 @@ export class MLCEngine implements MLCEngineInterface { streamInterval = 1, genConfig?: GenerationConfig, ): Promise { - console.log( + log.warn( "WARNING: `generate()` will soon be deprecated. " + "Please use `engine.chat.completions.create()` instead. " + "For multi-round chatting, see `examples/multi-round-chat` on how to use " + @@ -579,7 +583,7 @@ export class MLCEngine implements MLCEngineInterface { gpuDetectOutput.device.limits.maxStorageBufferBindingSize; const defaultMaxStorageBufferBindingSize = 1 << 30; // 1GB if (maxStorageBufferBindingSize < defaultMaxStorageBufferBindingSize) { - console.log( + log.warn( `WARNING: the current maxStorageBufferBindingSize ` + `(${computeMB(maxStorageBufferBindingSize)}) ` + `may only work for a limited number of models, e.g.: \n` + @@ -792,7 +796,7 @@ export class MLCEngine implements MLCEngineInterface { this.resetChat(); this.getPipeline().setConversation(newConv); } else { - console.log("Multiround chatting, reuse KVCache."); + log.info("Multiround chatting, reuse KVCache."); } // 2. Treat the last message as the usual input diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts index 3e6aae71..451677e2 100644 --- a/src/extension_service_worker.ts +++ b/src/extension_service_worker.ts @@ -1,7 +1,8 @@ import * as tvmjs from "tvmjs"; +import log from "loglevel"; import { AppConfig, ChatOptions, MLCEngineConfig } from "./config"; import { ReloadParams, WorkerRequest } from "./message"; -import { MLCEngineInterface } from "./types"; +import { LogLevel, MLCEngineInterface } from "./types"; import { ChatWorker, MLCEngineWorkerHandler, @@ -88,7 +89,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { areChatOptionsEqual(this.chatOpts, params.chatOpts) && areAppConfigsEqual(this.appConfig, params.appConfig) ) { - console.log("Already loaded the model. Skip loading"); + log.info("Already loaded the model. Skip loading"); const gpuDetectOutput = await tvmjs.detectGPUDevice(); if (gpuDetectOutput == undefined) { throw Error("Cannot find WebGPU in the environment"); @@ -139,7 +140,10 @@ export async function CreateServiceWorkerMLCEngine( engineConfig?: MLCEngineConfig, keepAliveMs = 10000, ): Promise { - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(keepAliveMs); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( + keepAliveMs, + engineConfig?.logLevel, + ); serviceWorkerMLCEngine.setInitProgressCallback( engineConfig?.initProgressCallback, ); @@ -188,10 +192,10 @@ class PortAdapter implements ChatWorker { export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { port: chrome.runtime.Port; - constructor(keepAliveMs = 10000) { + constructor(keepAliveMs = 10000, logLevel: LogLevel = "WARN") { const port = chrome.runtime.connect({ name: "web_llm_service_worker" }); const chatWorker = new PortAdapter(port); - super(chatWorker); + super(chatWorker, logLevel); this.port = port; setInterval(() => { this.keepAlive(); diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 945e1b5b..129d7e92 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable no-prototype-builtins */ import * as tvmjs from "tvmjs"; +import log from "loglevel"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; import { ChatConfig, GenerationConfig, Role } from "./config"; import { getConversation, Conversation } from "./conversation"; @@ -72,9 +73,6 @@ export class LLMChatPipeline { private curRoundDecodingTotalTokens = 0; private curRoundPrefillTotalTokens = 0; - // logger - private logger = console.log; - // LogitProcessor private logitProcessor?: LogitProcessor = undefined; @@ -154,7 +152,7 @@ export class LLMChatPipeline { // 4. Read in compilation configurations from metadata this.prefillChunkSize = metadata.prefill_chunk_size; - this.logger("Using prefillChunkSize: ", this.prefillChunkSize); + log.info("Using prefillChunkSize: ", this.prefillChunkSize); if (this.prefillChunkSize <= 0) { throw Error("Prefill chunk size needs to be positive."); } @@ -164,14 +162,14 @@ export class LLMChatPipeline { metadata.sliding_window_size != -1 ) { this.slidingWindowSize = metadata.sliding_window_size; - this.logger("Using slidingWindowSize: ", this.slidingWindowSize); + log.info("Using slidingWindowSize: ", this.slidingWindowSize); // Parse attention sink size if ( metadata.hasOwnProperty("attention_sink_size") && metadata.attention_sink_size >= 0 ) { this.attentionSinkSize = metadata.attention_sink_size; - this.logger("Using attentionSinkSize: ", this.attentionSinkSize); + log.info("Using attentionSinkSize: ", this.attentionSinkSize); } else { throw Error( "Need to specify non-negative attention_sink_size if using sliding window. " + @@ -184,7 +182,7 @@ export class LLMChatPipeline { metadata.context_window_size != -1 ) { this.maxWindowLength = metadata.context_window_size; - this.logger("Using maxWindowLength: ", this.maxWindowLength); + log.info("Using maxWindowLength: ", this.maxWindowLength); } else { throw Error( "Need to specify either sliding window size or max window size.", @@ -905,7 +903,7 @@ export class LLMChatPipeline { } // need shift window and re-encode - this.logger("need shift window"); + log.info("need shift window"); this.filledKVCacheLength = 0; this.resetKVCache(); @@ -1056,8 +1054,8 @@ export class LLMChatPipeline { `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`; // simply log tokens for eyeballing. - console.log("Logits:"); - console.log(logitsOnCPU.toArray()); - console.log(msg); + log.info("Logits:"); + log.info(logitsOnCPU.toArray()); + log.info(msg); } } diff --git a/src/service_worker.ts b/src/service_worker.ts index 0cd2707d..8eea3117 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -1,7 +1,8 @@ import * as tvmjs from "tvmjs"; +import log from "loglevel"; import { AppConfig, ChatOptions, MLCEngineConfig } from "./config"; import { ReloadParams, WorkerRequest, WorkerResponse } from "./message"; -import { MLCEngineInterface, InitProgressReport } from "./types"; +import { MLCEngineInterface, InitProgressReport, LogLevel } from "./types"; import { MLCEngineWorkerHandler, WebWorkerMLCEngine, @@ -90,7 +91,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { onError?: () => void, ): void { const msg = event.data as WorkerRequest; - console.debug( + log.trace( `ServiceWorker message: [${msg.kind}] ${JSON.stringify(msg.content)}`, ); @@ -114,7 +115,7 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { areChatOptionsEqual(this.chatOpts, params.chatOpts) && areAppConfigsEqual(this.appConfig, params.appConfig) ) { - console.log("Already loaded the model. Skip loading"); + log.info("Already loaded the model. Skip loading"); const gpuDetectOutput = await tvmjs.detectGPUDevice(); if (gpuDetectOutput == undefined) { throw Error("Cannot find WebGPU in the environment"); @@ -205,7 +206,11 @@ export async function CreateServiceWorkerMLCEngine( "Please refresh the page to retry initializing the service worker.", ); } - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(serviceWorker); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( + serviceWorker, + undefined, + engineConfig?.logLevel, + ); serviceWorkerMLCEngine.setInitProgressCallback( engineConfig?.initProgressCallback, ); @@ -223,18 +228,22 @@ export async function CreateServiceWorkerMLCEngine( export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { missedHeatbeat = 0; - constructor(worker: IServiceWorker, keepAliveMs = 10000) { + constructor( + worker: IServiceWorker, + keepAliveMs = 10000, + logLevel: LogLevel = "WARN", + ) { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); } - super(new ServiceWorker(worker)); + super(new ServiceWorker(worker), logLevel); const onmessage = this.onmessage.bind(this); (navigator.serviceWorker as ServiceWorkerContainer).addEventListener( "message", (event: MessageEvent) => { const msg = event.data; - console.debug( + log.trace( `MLC client message: [${msg.kind}] ${JSON.stringify(msg.content)}`, ); try { @@ -246,7 +255,7 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { } catch (err: any) { // This is expected to throw if user has multiple windows open if (!err.message.startsWith("return from a unknown uuid")) { - console.error("CreateWebServiceWorkerMLCEngine.onmessage", err); + log.error("CreateWebServiceWorkerMLCEngine.onmessage", err); } } }, @@ -255,7 +264,7 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { setInterval(() => { this.worker.postMessage({ kind: "keepAlive", uuid: crypto.randomUUID() }); this.missedHeatbeat += 1; - console.debug("missedHeatbeat", this.missedHeatbeat); + log.trace("missedHeatbeat", this.missedHeatbeat); }, keepAliveMs); } diff --git a/src/types.ts b/src/types.ts index e2dfcec1..d9cca24c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -195,3 +195,5 @@ export interface MLCEngineInterface { isPrefill: boolean, ): Promise; } + +export type LogLevel = "TRACE" | "DEBUG" | "INFO" | "WARN" | "ERROR" | "SILENT"; diff --git a/src/web_worker.ts b/src/web_worker.ts index cc1c95b7..67fecf24 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -9,6 +9,7 @@ import { GenerateProgressCallback, InitProgressCallback, InitProgressReport, + LogLevel, } from "./types"; import { ChatCompletionRequest, @@ -31,6 +32,7 @@ import { WorkerResponse, WorkerRequest, } from "./message"; +import log from "loglevel"; export interface PostMessageHandler { postMessage: (message: any) => void; @@ -348,7 +350,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { >(); private pendingPromise = new Map void>(); - constructor(worker: ChatWorker) { + constructor(worker: ChatWorker, logLevel: LogLevel = "WARN") { + log.setLevel(logLevel); this.worker = worker; worker.onmessage = (event: any) => { this.onmessage.bind(this)(event); diff --git a/utils/vram_requirements/src/vram_requirements.ts b/utils/vram_requirements/src/vram_requirements.ts index 2e5c2a8c..14988fcc 100644 --- a/utils/vram_requirements/src/vram_requirements.ts +++ b/utils/vram_requirements/src/vram_requirements.ts @@ -1,6 +1,7 @@ import ModelRecord from "@mlc-ai/web-llm"; -import appConfig from "./app-config"; // Modify this to inspect vram requirement for models of choice +import appConfig from "./app-config"; // Modify this to inspect vram requirement for models of choice import * as tvmjs from "tvmjs"; +import log from "loglevel"; function setLabel(id: string, text: string) { const label = document.getElementById(id); @@ -14,16 +15,16 @@ interface AppConfig { model_list: Array; } -let dtypeBytesMap = new Map([ +const dtypeBytesMap = new Map([ ["uint32", 4], ["uint16", 2], ["float32", 4], - ["float16", 4] + ["float16", 4], ]); async function main() { - let config: AppConfig = appConfig; - let report: string = ""; + const config: AppConfig = appConfig; + let report = ""; for (let i = 0; i < config.model_list.length; ++i) { // 1. Read each model record const modelRecord: ModelRecord = config.model_list[i]; @@ -36,7 +37,7 @@ async function main() { const tvm = await tvmjs.instantiate( new Uint8Array(wasmSource), tvmjs.createPolyfillWASI(), - console.log + log.info, ); const gpuDetectOutput = await tvmjs.detectGPUDevice(); if (gpuDetectOutput == undefined) { @@ -45,14 +46,17 @@ async function main() { tvm.initWebGPU(gpuDetectOutput.device); tvm.beginScope(); const vm = tvm.detachFromCurrentScope( - tvm.createVirtualMachine(tvm.webgpu()) + tvm.createVirtualMachine(tvm.webgpu()), ); // 4. Get metadata from the vm let fgetMetadata: any; try { fgetMetadata = vm.getFunction("_metadata"); } catch (err) { - console.error("The wasm needs to have function `_metadata` to inspect vram requirement.", err); + log.error( + "The wasm needs to have function `_metadata` to inspect vram requirement.", + err, + ); } const ret_value = fgetMetadata(); const metadataStr = tvm.detachFromCurrentScope(ret_value).toString(); @@ -65,33 +69,38 @@ async function main() { // Possible to have shape -1 signifying a dynamic shape -- we disregard them const dtypeBytes = dtypeBytesMap.get(param.dtype); if (dtypeBytes === undefined) { - throw Error("Cannot find size of " + param.dtype + ", add it to `dtypeBytesMap`.") + throw Error( + "Cannot find size of " + + param.dtype + + ", add it to `dtypeBytesMap`.", + ); } const numParams = param.shape.reduce((a: number, b: number) => a * b); paramBytes += numParams * dtypeBytes; } else { - console.log(`${model_id}'s ${param.name} has dynamic shape; excluded from vRAM calculation.`) + log.info( + `${model_id}'s ${param.name} has dynamic shape; excluded from vRAM calculation.`, + ); } }); // 5.2. Get maximum bytes needed for temporary buffer across all functions - let maxTempFuncBytes: number = 0; + let maxTempFuncBytes = 0; Object.entries(metadata.memory_usage).forEach(([funcName, funcBytes]) => { if (typeof funcBytes !== "number") { - throw Error("`memory_usage` expects entry `funcName: funcBytes`.") + throw Error("`memory_usage` expects entry `funcName: funcBytes`."); } maxTempFuncBytes = Math.max(maxTempFuncBytes, funcBytes); - }) + }); // 5.3. Get kv cache bytes const kv_cache_bytes: number = metadata.kv_cache_bytes; // 5.4. Get total vRAM needed const totalBytes = paramBytes + maxTempFuncBytes + kv_cache_bytes; // 6. Report vRAM Requirement - report += ( + report += `totalBytes: ${(totalBytes / 1024 / 1024).toFixed(2)} MB\n` + `paramBytes: ${(paramBytes / 1024 / 1024).toFixed(2)} MB\n` + `maxTempFuncBytes: ${(maxTempFuncBytes / 1024 / 1024).toFixed(2)} MB\n` + - `kv_cache_bytes: ${(kv_cache_bytes / 1024 / 1024).toFixed(2)} MB\n\n` - ); + `kv_cache_bytes: ${(kv_cache_bytes / 1024 / 1024).toFixed(2)} MB\n\n`; // 7. Dispose everything tvm.endScope(); vm.dispose(); From 016563e5a82de6fd89240ab022eca3aaebfec551 Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Mon, 27 May 2024 17:16:42 -0400 Subject: [PATCH 2/2] [Log] Add setLogLevel to Engine Interface --- src/engine.ts | 11 ++++++++++- src/extension_service_worker.ts | 12 ++++++------ src/index.ts | 1 + src/service_worker.ts | 17 ++++++----------- src/types.ts | 17 ++++++++++++++++- src/web_worker.ts | 7 +++++-- 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/engine.ts b/src/engine.ts index 2b3b3af5..65823ea7 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -63,8 +63,8 @@ export async function CreateMLCEngine( modelId: string, engineConfig?: MLCEngineConfig, ): Promise { - log.setLevel(engineConfig?.logLevel || DefaultLogLevel); const engine = new MLCEngine(); + engine.setLogLevel(engineConfig?.logLevel || DefaultLogLevel); engine.setInitProgressCallback(engineConfig?.initProgressCallback); engine.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry); await engine.reload(modelId, engineConfig?.chatOpts, engineConfig?.appConfig); @@ -640,6 +640,15 @@ export class MLCEngine implements MLCEngineInterface { return this.getPipeline().getMessage(); } + /** + * Set MLCEngine logging output level + * + * @param logLevel The new log level + */ + setLogLevel(logLevel: LogLevel) { + log.setLevel(logLevel); + } + /** * Get a new Conversation object based on the chat completion request. * diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts index 451677e2..1c575253 100644 --- a/src/extension_service_worker.ts +++ b/src/extension_service_worker.ts @@ -140,10 +140,10 @@ export async function CreateServiceWorkerMLCEngine( engineConfig?: MLCEngineConfig, keepAliveMs = 10000, ): Promise { - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( - keepAliveMs, - engineConfig?.logLevel, - ); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(keepAliveMs); + if (engineConfig?.logLevel) { + serviceWorkerMLCEngine.setLogLevel(engineConfig.logLevel); + } serviceWorkerMLCEngine.setInitProgressCallback( engineConfig?.initProgressCallback, ); @@ -192,10 +192,10 @@ class PortAdapter implements ChatWorker { export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { port: chrome.runtime.Port; - constructor(keepAliveMs = 10000, logLevel: LogLevel = "WARN") { + constructor(keepAliveMs = 10000) { const port = chrome.runtime.connect({ name: "web_llm_service_worker" }); const chatWorker = new PortAdapter(port); - super(chatWorker, logLevel); + super(chatWorker); this.port = port; setInterval(() => { this.keepAlive(); diff --git a/src/index.ts b/src/index.ts index 9f0387bd..0b6672f1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,7 @@ export { InitProgressReport, MLCEngineInterface, LogitProcessor, + LogLevel, } from "./types"; export { MLCEngine, CreateMLCEngine } from "./engine"; diff --git a/src/service_worker.ts b/src/service_worker.ts index 8eea3117..eba3e564 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -206,11 +206,10 @@ export async function CreateServiceWorkerMLCEngine( "Please refresh the page to retry initializing the service worker.", ); } - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( - serviceWorker, - undefined, - engineConfig?.logLevel, - ); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(serviceWorker); + if (engineConfig?.logLevel) { + serviceWorkerMLCEngine.setLogLevel(engineConfig.logLevel); + } serviceWorkerMLCEngine.setInitProgressCallback( engineConfig?.initProgressCallback, ); @@ -228,15 +227,11 @@ export async function CreateServiceWorkerMLCEngine( export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { missedHeatbeat = 0; - constructor( - worker: IServiceWorker, - keepAliveMs = 10000, - logLevel: LogLevel = "WARN", - ) { + constructor(worker: IServiceWorker, keepAliveMs = 10000) { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); } - super(new ServiceWorker(worker), logLevel); + super(new ServiceWorker(worker)); const onmessage = this.onmessage.bind(this); (navigator.serviceWorker as ServiceWorkerContainer).addEventListener( diff --git a/src/types.ts b/src/types.ts index d9cca24c..910c08e7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -194,6 +194,21 @@ export interface MLCEngineInterface { inputIds: Array, isPrefill: boolean, ): Promise; + + /** + * Set MLCEngine logging output level + * + * @param logLevel The new log level + */ + setLogLevel(logLevel: LogLevel): void; } -export type LogLevel = "TRACE" | "DEBUG" | "INFO" | "WARN" | "ERROR" | "SILENT"; +export const LOG_LEVELS = { + TRACE: 0, + DEBUG: 1, + INFO: 2, + WARN: 3, + ERROR: 4, + SILENT: 5, +}; +export type LogLevel = keyof typeof LOG_LEVELS; diff --git a/src/web_worker.ts b/src/web_worker.ts index 67fecf24..2ff79f95 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -350,8 +350,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { >(); private pendingPromise = new Map void>(); - constructor(worker: ChatWorker, logLevel: LogLevel = "WARN") { - log.setLevel(logLevel); + constructor(worker: ChatWorker) { this.worker = worker; worker.onmessage = (event: any) => { this.onmessage.bind(this)(event); @@ -627,4 +626,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { } } } + + setLogLevel(logLevel: LogLevel) { + log.setLevel(logLevel); + } }