From 2c9af0f500c04383aa7220ab2c9220a608f75cbf Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Mon, 29 Jul 2024 08:17:55 -0400 Subject: [PATCH] [Runtime] Allow aborting fetchNDArray through AbortSignal (#17208) [Runtime] Allow aborting fetchNDArray --- web/src/artifact_cache.ts | 11 ++++++----- web/src/runtime.ts | 13 +++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index f833df1be523..9690ed3320b9 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -58,13 +58,14 @@ export interface ArtifactCacheTemplate { * * @param url: The url to the data to be cached. * @param storetype: Only applies to `ArtifactIndexedDBCache`. Since `indexedDB` stores the actual + * @param signal: An optional AbortSignal to abort data retrival * data rather than a request, we specify `storagetype`. There are two options: * 1. "json": IndexedDB stores `fetch(url).json()` * 2. "arraybuffer": IndexedDB stores `fetch(url).arrayBuffer()` * * @note This is an async function. */ - addToCache(url: string, storetype?: string): Promise; + addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise; /** * check if cache has all keys in Cache @@ -126,8 +127,8 @@ export class ArtifactCache implements ArtifactCacheTemplate { } // eslint-disable-next-line @typescript-eslint/no-unused-vars - async addToCache(url: string, storetype?: string) { - const request = new Request(url); + async addToCache(url: string, storetype?: string, signal?: AbortSignal) { + const request = new Request(url, signal ? { signal } : undefined); if (this.cache === undefined) { this.cache = await caches.open(this.scope); } @@ -282,7 +283,7 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { }); } - async addToCache(url: string, storetype?: string): Promise { + async addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise { await this.initDB(); // await the initDB process // If already cached, nothing to do const isInDB = await this.isUrlInDB(url); @@ -290,7 +291,7 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { return; } try { - const response = await fetch(url); + const response = await fetch(url, signal ? { signal } : undefined); if (!response.ok) { throw new Error('Network response was not ok'); } diff --git a/web/src/runtime.ts b/web/src/runtime.ts index fd7bcc6ab23b..d71c98e7d1bc 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1444,13 +1444,15 @@ export class Instance implements Disposable { * @param device The device to be fetched to. * @param cacheScope The scope identifier of the cache * @param cacheType The type of the cache: "cache" or "indexedDB" + * @param signal An optional AbortSignal to abort the fetch * @returns The meta data */ async fetchNDArrayCache( ndarrayCacheUrl: string, device: DLDevice, cacheScope = "tvmjs", - cacheType = "cache" + cacheType = "cache", + signal?: AbortSignal, ): Promise { let artifactCache: ArtifactCacheTemplate; if (cacheType === undefined || cacheType.toLowerCase() === "cache") { @@ -1465,7 +1467,8 @@ export class Instance implements Disposable { const list = await artifactCache.fetchWithCache(jsonUrl, "json"); await this.fetchNDArrayCacheInternal( ndarrayCacheUrl, - list["records"] as Array, device, artifactCache); + list["records"] as Array, device, artifactCache, + signal); this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; } @@ -1477,12 +1480,14 @@ export class Instance implements Disposable { * @param list The list of array data. * @param device The device to store the data to. * @param artifactCache The artifact cache + * @param signal An optional AbortSignal to abort the fetch */ private async fetchNDArrayCacheInternal( ndarrayCacheUrl: string, list: Array, device: DLDevice, - artifactCache: ArtifactCacheTemplate + artifactCache: ArtifactCacheTemplate, + signal?: AbortSignal, ) { const perf = compact.getPerformance(); const tstart = perf.now(); @@ -1537,7 +1542,7 @@ export class Instance implements Disposable { const shard = list[i]; const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; try { - await artifactCache.addToCache(dataUrl, "arraybuffer"); + await artifactCache.addToCache(dataUrl, "arraybuffer", signal); } catch (err) { this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); throw err;