Skip to content

Commit

Permalink
add tests for ModelManager
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Dec 1, 2024
1 parent 9f8cdb3 commit 658fafd
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 20 deletions.
5 changes: 4 additions & 1 deletion src/cache-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ class CacheManager {

async download(url: string, options: DownloadOptions = {}): Promise<void> {
const worker = createWorker(OPFS_UTILS_WORKER_CODE);
let aborted = false;
if (options.signal) {
aborted = options.signal.aborted;
const mSignal = options.signal;
mSignal.addEventListener('abort', () => {
aborted = true;
worker.postMessage({ action: 'download-abort' });
});
delete options.signal;
Expand All @@ -90,7 +93,7 @@ class CacheManager {
url,
filename,
metadataFileName,
options: { headers: options.headers },
options: { headers: options.headers, aborted },
});
worker.onmessage = (e: MessageEvent<any>) => {
if (e.data.ok) {
Expand Down
79 changes: 79 additions & 0 deletions src/model-manager.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { test, expect } from 'vitest';
import { ModelManager, Model, ModelValidationStatus } from './model-manager';

const TINY_MODEL =
'https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf';
const SPLIT_MODEL =
'https://huggingface.co/ngxson/tinyllama_split_test/resolve/main/stories15M-q8_0-00001-of-00003.gguf';

test('parseModelUrl handles single model URL', () => {
const urls = ModelManager.parseModelUrl(TINY_MODEL);
expect(urls).toEqual([TINY_MODEL]);
});

test('parseModelUrl handles array of URLs', () => {
const urls = ModelManager.parseModelUrl(SPLIT_MODEL);
expect(urls.length).toBe(3);
expect(urls[0]).toMatch(/-00001-of-00003\.gguf$/);
expect(urls[1]).toMatch(/-00002-of-00003\.gguf$/);
expect(urls[2]).toMatch(/-00003-of-00003\.gguf$/);
});

test('download split model', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(SPLIT_MODEL);
expect(model.files.length).toBe(3);
// check names
expect(model.files[0].metadata.originalURL).toMatch(/-00001-of-00003\.gguf$/);
expect(model.files[1].metadata.originalURL).toMatch(/-00002-of-00003\.gguf$/);
expect(model.files[2].metadata.originalURL).toMatch(/-00003-of-00003\.gguf$/);
// check sizes
expect(model.files[0].size).toBe(10517152);
expect(model.files[1].size).toBe(10381216);
expect(model.files[2].size).toBe(5773312);
});

test('download invalid model URL', async () => {
const manager = new ModelManager();
const invalidUrl = 'https://invalid.example.com/model.gguf';
await expect(manager.downloadModel(invalidUrl)).rejects.toThrow();
});

test('clear model manager', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(TINY_MODEL);
expect(model).toBeDefined();
expect((await manager.getModels()).length).toBeGreaterThan(0);
await manager.clear();
expect((await manager.getModels()).length).toBe(0);
});

test('download with abort signal', async () => {
const manager = new ModelManager();
await manager.clear();
const controller = new AbortController();
const downloadPromise = manager.downloadModel(TINY_MODEL, {
signal: controller.signal,
});
setTimeout(() => controller.abort(), 10);
await downloadPromise.catch(console.error);
await expect(downloadPromise).rejects.toThrow('aborted');
});

test('model validation status for new model', async () => {
const manager = new ModelManager();
const model = new Model(manager, TINY_MODEL);
const status = await model.validate();
expect(status).toBe(ModelValidationStatus.INVALID);
});

test('downloadModel throws on invalid URL', async () => {
const manager = new ModelManager();
await expect(manager.downloadModel('invalid.txt')).rejects.toThrow();
});

test('model size calculation', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(TINY_MODEL);
expect(model.size).toBe(1185376);
});
15 changes: 14 additions & 1 deletion src/model-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ export class Model {
}
allFiles.push(file);
}
allFiles.sort((a, b) =>
a.metadata.originalURL.localeCompare(b.metadata.originalURL)
);
return allFiles;
}

Expand All @@ -187,7 +190,7 @@ export class ModelManager {
public params: ModelManagerParams;
public logger: WllamaLogger;

constructor(params: ModelManagerParams) {
constructor(params: ModelManagerParams = {}) {
this.cacheManager = params.cacheManager || new CacheManager();
this.params = params;
this.logger = params.logger || console;
Expand Down Expand Up @@ -263,6 +266,9 @@ export class ModelManager {
return model;
}

/**
* Get a model from the cache or download it if it's not available.
*/
async getModelOrDownload(
url: string,
options: DownloadOptions = {}
Expand All @@ -275,4 +281,11 @@ export class ModelManager {
}
return this.downloadModel(url, options);
}

/**
* Remove all models from the cache
*/
async clear(): Promise<void> {
await this.cacheManager.clear();
}
}
11 changes: 4 additions & 7 deletions src/wllama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ test('cleans up resources', async () => {
await wllama.loadModelFromUrl(TINY_MODEL);
expect(wllama.isModelLoaded()).toBe(true);
await wllama.exit();
let exception = null;
try {
await wllama.tokenize('test');
} catch (e) {
exception = e;
}
expect(exception).toBeDefined();
await expect(wllama.tokenize('test')).rejects.toThrow();

// Double check that the model is really unloaded
expect(wllama.isModelLoaded()).toBe(false);
});
16 changes: 6 additions & 10 deletions src/workers-code/generated.ts

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/workers-code/opfs-utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ onmessage = async (e) => {
assertNonNull(filename);
assertNonNull(metadataFileName);
assertNonNull(options);
assertNonNull(options.aborted);
abortController = new AbortController();
if (options.aborted) abortController.abort();
const response = await fetch(url, {
...options,
signal: abortController.signal,
Expand Down Expand Up @@ -134,6 +136,6 @@ onmessage = async (e) => {

throw new Error('OPFS Worker: Invalid action', e.data);
} catch (err) {
return resErr({ err });
return resErr(err);
}
};

0 comments on commit 658fafd

Please sign in to comment.