From fde1777f8541fb352076cec1f82a3ae530dd2894 Mon Sep 17 00:00:00 2001 From: dstoc <539597+dstoc@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:04:51 +1000 Subject: [PATCH] [Fix] Support fetching images when using worker engine (#574) The previous implementation relied on the HTMLImageElement constuctor which is not available in worker contexts. --- examples/vision-model/src/vision_model.ts | 31 +++++++++++++++------- src/support.ts | 32 ++++++++--------------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/examples/vision-model/src/vision_model.ts b/examples/vision-model/src/vision_model.ts index 31629c40..27228c0a 100644 --- a/examples/vision-model/src/vision_model.ts +++ b/examples/vision-model/src/vision_model.ts @@ -9,6 +9,8 @@ function setLabel(id: string, text: string) { label.innerText = text; } +const USE_WEB_WORKER = false; + const proxyUrl = "https://cors-anywhere.herokuapp.com/"; const url_https_street = "https://www.ilankelman.org/stopsigns/australia.jpg"; const url_https_tree = "https://www.ilankelman.org/sunset.jpg"; @@ -23,16 +25,25 @@ async function main() { setLabel("init-label", report.text); }; const selectedModel = "Phi-3.5-vision-instruct-q4f16_1-MLC"; - const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( - selectedModel, - { - initProgressCallback: initProgressCallback, - logLevel: "INFO", // specify the log level - }, - { - context_window_size: 6144, - }, - ); + + const engineConfig: webllm.MLCEngineConfig = { + initProgressCallback: initProgressCallback, + logLevel: "INFO", // specify the log level + }; + const chatOpts = { + context_window_size: 6144, + }; + + const engine: webllm.MLCEngineInterface = USE_WEB_WORKER + ? await webllm.CreateWebWorkerMLCEngine( + new Worker(new URL("./worker.ts", import.meta.url), { + type: "module", + }), + selectedModel, + engineConfig, + chatOpts, + ) + : await webllm.CreateMLCEngine(selectedModel, engineConfig, chatOpts); // 1. Single image input (with choices) const messages: webllm.ChatCompletionMessageParam[] = [ diff --git a/src/support.ts b/src/support.ts index 3b342b8f..a2ae0d04 100644 --- a/src/support.ts +++ b/src/support.ts @@ -411,28 +411,18 @@ export const IMAGE_EMBED_SIZE = 1921; /** * Given a url, get the image data. The url can either start with `http` or `data:image`. */ -export function getImageDataFromURL(url: string): Promise { - return new Promise((resolve, reject) => { - // Converts img to any, and later `as CanvasImageSource`, otherwise build complains - const img: any = new Image(); - img.crossOrigin = "anonymous"; // Important for CORS - img.onload = () => { - const canvas = document.createElement("canvas"); - const ctx = canvas.getContext("2d"); - if (!ctx) { - reject(new Error("Could not get 2d context")); - return; - } - canvas.width = img.width; - canvas.height = img.height; - ctx.drawImage(img as CanvasImageSource, 0, 0); +export async function getImageDataFromURL(url: string): Promise { + const response = await fetch(url, { mode: "cors" }); + const img = await createImageBitmap(await response.blob()); + const canvas = new OffscreenCanvas(img.width, img.height); + const ctx = canvas.getContext("2d"); + if (!ctx) { + throw new Error("Could not get 2d context"); + } + ctx.drawImage(img, 0, 0); - const imageData = ctx.getImageData(0, 0, img.width, img.height); - resolve(imageData); - }; - img.onerror = () => reject(new Error("Failed to load image")); - img.src = url; - }); + const imageData = ctx.getImageData(0, 0, img.width, img.height); + return imageData; } /**