diff --git a/src/model-manager.test.ts b/src/model-manager.test.ts index 1b009c7..351c615 100644 --- a/src/model-manager.test.ts +++ b/src/model-manager.test.ts @@ -6,12 +6,12 @@ const TINY_MODEL = const SPLIT_MODEL = 'https://huggingface.co/ngxson/tinyllama_split_test/resolve/main/stories15M-q8_0-00001-of-00003.gguf'; -test('parseModelUrl handles single model URL', () => { +test.sequential('parseModelUrl handles single model URL', () => { const urls = ModelManager.parseModelUrl(TINY_MODEL); expect(urls).toEqual([TINY_MODEL]); }); -test('parseModelUrl handles array of URLs', () => { +test.sequential('parseModelUrl handles array of URLs', () => { const urls = ModelManager.parseModelUrl(SPLIT_MODEL); expect(urls.length).toBe(3); expect(urls[0]).toMatch(/-00001-of-00003\.gguf$/); @@ -19,7 +19,7 @@ test('parseModelUrl handles array of URLs', () => { expect(urls[2]).toMatch(/-00003-of-00003\.gguf$/); }); -test('download split model', async () => { +test.sequential('download split model', async () => { const manager = new ModelManager(); const model = await manager.downloadModel(SPLIT_MODEL); expect(model.files.length).toBe(3); @@ -33,22 +33,35 @@ test('download split model', async () => { expect(model.files[2].size).toBe(5773312); }); -test('download invalid model URL', async () => { +test.sequential( + 'interrupt download split model (partial files downloaded)', + async () => { + return; // skip on CI, only run locally with a slow connection + const manager = new ModelManager(); + await manager.clear(); + const controller = new AbortController(); + const downloadPromise = manager.downloadModel(SPLIT_MODEL, { + signal: controller.signal, + progressCallback: ({ loaded, total }) => { + const progress = loaded / total; + if (progress > 0.8) { + controller.abort(); + } + }, + }); + await expect(downloadPromise).rejects.toThrow('aborted'); + expect((await manager.getModels()).length).toBe(0); + expect((await manager.getModels({ includeInvalid: true })).length).toBe(1); + } +); + +test.sequential('download invalid model URL', async () => { const manager = new ModelManager(); const invalidUrl = 'https://invalid.example.com/model.gguf'; await expect(manager.downloadModel(invalidUrl)).rejects.toThrow(); }); -test('clear model manager', async () => { - const manager = new ModelManager(); - const model = await manager.downloadModel(TINY_MODEL); - expect(model).toBeDefined(); - expect((await manager.getModels()).length).toBeGreaterThan(0); - await manager.clear(); - expect((await manager.getModels()).length).toBe(0); -}); - -test('download with abort signal', async () => { +test.sequential('download with abort signal', async () => { const manager = new ModelManager(); await manager.clear(); const controller = new AbortController(); @@ -61,7 +74,7 @@ test('download with abort signal', async () => { expect((await manager.getModels()).length).toBe(0); }); -test('download with progress callback', async () => { +test.sequential('download with progress callback', async () => { const manager = new ModelManager(); await manager.clear(); @@ -83,25 +96,25 @@ test('download with progress callback', async () => { expect(model.size).toBeGreaterThan(0); }); -test('model validation status for new model', async () => { +test.sequential('model validation status for new model', async () => { const manager = new ModelManager(); const model = new Model(manager, TINY_MODEL); const status = await model.validate(); expect(status).toBe(ModelValidationStatus.INVALID); }); -test('downloadModel throws on invalid URL', async () => { +test.sequential('downloadModel throws on invalid URL', async () => { const manager = new ModelManager(); await expect(manager.downloadModel('invalid.txt')).rejects.toThrow(); }); -test('model size calculation', async () => { +test.sequential('model size calculation', async () => { const manager = new ModelManager(); const model = await manager.downloadModel(TINY_MODEL); expect(model.size).toBe(1185376); }); -test('remove model from cache', async () => { +test.sequential('remove model from cache', async () => { const manager = new ModelManager(); await manager.clear(); @@ -125,3 +138,12 @@ test('remove model from cache', async () => { const models = await manager.getModels(); expect(models.find((m) => m.url === TINY_MODEL)).toBeUndefined(); }); + +test.sequential('clear model manager', async () => { + const manager = new ModelManager(); + const model = await manager.downloadModel(TINY_MODEL); + expect(model).toBeDefined(); + expect((await manager.getModels()).length).toBeGreaterThan(0); + await manager.clear(); + expect((await manager.getModels()).length).toBe(0); +}); diff --git a/src/model-manager.ts b/src/model-manager.ts index 75e4985..60bde8c 100644 --- a/src/model-manager.ts +++ b/src/model-manager.ts @@ -118,19 +118,16 @@ export class Model { * - The model is deleted from the cache * - The model files are missing (or the download is interrupted) */ - async validate(): Promise { - // TODO: valid number of shards + validate(): ModelValidationStatus { + const nbShards = ModelManager.parseModelUrl(this.url).length; if (this.size === -1) { return ModelValidationStatus.DELETED; } - if (this.size < 16) { + if (this.size < 16 || this.files.length !== nbShards) { return ModelValidationStatus.INVALID; } for (const file of this.files) { - const metadata = await this.modelManager.cacheManager.getMetadata( - file.name - ); - if (!metadata || metadata.originalSize !== file.size) { + if (!file.metadata || file.metadata.originalSize !== file.size) { return ModelValidationStatus.INVALID; } } @@ -259,9 +256,9 @@ export class ModelManager { /** * Get all models in the cache */ - async getModels(): Promise { + async getModels(opts: { includeInvalid?: boolean } = {}): Promise { const cachedFiles = await this.cacheManager.list(); - const models: Model[] = []; + let models: Model[] = []; for (const file of cachedFiles) { const shards = ModelManager.parseModelUrl(file.metadata.originalURL); const isFirstShard = @@ -270,6 +267,11 @@ export class ModelManager { models.push(new Model(this, file.metadata.originalURL, cachedFiles)); } } + if (!opts.includeInvalid) { + models = models.filter( + (m) => m.validate() === ModelValidationStatus.VALID + ); + } return models; } diff --git a/src/wllama.test.ts b/src/wllama.test.ts index cdd2105..4810458 100644 --- a/src/wllama.test.ts +++ b/src/wllama.test.ts @@ -14,7 +14,7 @@ const SPLIT_MODEL = const EMBD_MODEL = TINY_MODEL; // for better speed -test('loads single model file', async () => { +test.sequential('loads single model file', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL, { @@ -33,7 +33,7 @@ test('loads single model file', async () => { await wllama.exit(); }); -test('loads single model file from HF', async () => { +test.sequential('loads single model file from HF', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromHF( @@ -49,7 +49,7 @@ test('loads single model file from HF', async () => { await wllama.exit(); }); -test('loads single thread model', async () => { +test.sequential('loads single thread model', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL, { @@ -66,7 +66,7 @@ test('loads single thread model', async () => { await wllama.exit(); }); -test('loads model with progress callback', async () => { +test.sequential('loads model with progress callback', async () => { const wllama = new Wllama(CONFIG_PATHS); let progressCalled = false; @@ -88,7 +88,7 @@ test('loads model with progress callback', async () => { await wllama.exit(); }); -test('loads split model files', async () => { +test.sequential('loads split model files', async () => { const wllama = new Wllama(CONFIG_PATHS, { parallelDownloads: 5, }); @@ -101,7 +101,7 @@ test('loads split model files', async () => { await wllama.exit(); }); -test('tokenizes and detokenizes text', async () => { +test.sequential('tokenizes and detokenizes text', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL, { @@ -122,7 +122,7 @@ test('tokenizes and detokenizes text', async () => { await wllama.exit(); }); -test('generates completion', async () => { +test.sequential('generates completion', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL, { @@ -151,7 +151,7 @@ test('generates completion', async () => { await wllama.exit(); }); -test('gets logits', async () => { +test.sequential('gets logits', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL, { @@ -170,7 +170,7 @@ test('gets logits', async () => { await wllama.exit(); }); -test('generates embeddings', async () => { +test.sequential('generates embeddings', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(EMBD_MODEL, { @@ -208,7 +208,7 @@ test('generates embeddings', async () => { await wllama.exit(); }); -test('allowOffline', async () => { +test.sequential('allowOffline', async () => { const wllama = new Wllama(CONFIG_PATHS, { allowOffline: true, }); @@ -224,10 +224,12 @@ test('allowOffline', async () => { } catch (e) { window.fetch = origFetch; throw e; + } finally { + window.fetch = origFetch; } }); -test('cleans up resources', async () => { +test.sequential('cleans up resources', async () => { const wllama = new Wllama(CONFIG_PATHS); await wllama.loadModelFromUrl(TINY_MODEL); expect(wllama.isModelLoaded()).toBe(true);