Skip to content

Commit

Permalink
fix model validation, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Dec 1, 2024
1 parent b78f2df commit e6bc698
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 39 deletions.
60 changes: 41 additions & 19 deletions src/model-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ const TINY_MODEL =
const SPLIT_MODEL =
'https://huggingface.co/ngxson/tinyllama_split_test/resolve/main/stories15M-q8_0-00001-of-00003.gguf';

test('parseModelUrl handles single model URL', () => {
test.sequential('parseModelUrl handles single model URL', () => {
const urls = ModelManager.parseModelUrl(TINY_MODEL);
expect(urls).toEqual([TINY_MODEL]);
});

test('parseModelUrl handles array of URLs', () => {
test.sequential('parseModelUrl handles array of URLs', () => {
const urls = ModelManager.parseModelUrl(SPLIT_MODEL);
expect(urls.length).toBe(3);
expect(urls[0]).toMatch(/-00001-of-00003\.gguf$/);
expect(urls[1]).toMatch(/-00002-of-00003\.gguf$/);
expect(urls[2]).toMatch(/-00003-of-00003\.gguf$/);
});

test('download split model', async () => {
test.sequential('download split model', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(SPLIT_MODEL);
expect(model.files.length).toBe(3);
Expand All @@ -33,22 +33,35 @@ test('download split model', async () => {
expect(model.files[2].size).toBe(5773312);
});

test('download invalid model URL', async () => {
test.sequential(
'interrupt download split model (partial files downloaded)',
async () => {
return; // skip on CI, only run locally with a slow connection
const manager = new ModelManager();
await manager.clear();
const controller = new AbortController();
const downloadPromise = manager.downloadModel(SPLIT_MODEL, {
signal: controller.signal,
progressCallback: ({ loaded, total }) => {
const progress = loaded / total;
if (progress > 0.8) {
controller.abort();
}
},
});
await expect(downloadPromise).rejects.toThrow('aborted');
expect((await manager.getModels()).length).toBe(0);
expect((await manager.getModels({ includeInvalid: true })).length).toBe(1);
}
);

test.sequential('download invalid model URL', async () => {
const manager = new ModelManager();
const invalidUrl = 'https://invalid.example.com/model.gguf';
await expect(manager.downloadModel(invalidUrl)).rejects.toThrow();
});

test('clear model manager', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(TINY_MODEL);
expect(model).toBeDefined();
expect((await manager.getModels()).length).toBeGreaterThan(0);
await manager.clear();
expect((await manager.getModels()).length).toBe(0);
});

test('download with abort signal', async () => {
test.sequential('download with abort signal', async () => {
const manager = new ModelManager();
await manager.clear();
const controller = new AbortController();
Expand All @@ -61,7 +74,7 @@ test('download with abort signal', async () => {
expect((await manager.getModels()).length).toBe(0);
});

test('download with progress callback', async () => {
test.sequential('download with progress callback', async () => {
const manager = new ModelManager();
await manager.clear();

Expand All @@ -83,25 +96,25 @@ test('download with progress callback', async () => {
expect(model.size).toBeGreaterThan(0);
});

test('model validation status for new model', async () => {
test.sequential('model validation status for new model', async () => {
const manager = new ModelManager();
const model = new Model(manager, TINY_MODEL);
const status = await model.validate();
expect(status).toBe(ModelValidationStatus.INVALID);
});

test('downloadModel throws on invalid URL', async () => {
test.sequential('downloadModel throws on invalid URL', async () => {
const manager = new ModelManager();
await expect(manager.downloadModel('invalid.txt')).rejects.toThrow();
});

test('model size calculation', async () => {
test.sequential('model size calculation', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(TINY_MODEL);
expect(model.size).toBe(1185376);
});

test('remove model from cache', async () => {
test.sequential('remove model from cache', async () => {
const manager = new ModelManager();
await manager.clear();

Expand All @@ -125,3 +138,12 @@ test('remove model from cache', async () => {
const models = await manager.getModels();
expect(models.find((m) => m.url === TINY_MODEL)).toBeUndefined();
});

test.sequential('clear model manager', async () => {
const manager = new ModelManager();
const model = await manager.downloadModel(TINY_MODEL);
expect(model).toBeDefined();
expect((await manager.getModels()).length).toBeGreaterThan(0);
await manager.clear();
expect((await manager.getModels()).length).toBe(0);
});
20 changes: 11 additions & 9 deletions src/model-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,16 @@ export class Model {
* - The model is deleted from the cache
* - The model files are missing (or the download is interrupted)
*/
async validate(): Promise<ModelValidationStatus> {
// TODO: valid number of shards
validate(): ModelValidationStatus {
const nbShards = ModelManager.parseModelUrl(this.url).length;
if (this.size === -1) {
return ModelValidationStatus.DELETED;
}
if (this.size < 16) {
if (this.size < 16 || this.files.length !== nbShards) {
return ModelValidationStatus.INVALID;
}
for (const file of this.files) {
const metadata = await this.modelManager.cacheManager.getMetadata(
file.name
);
if (!metadata || metadata.originalSize !== file.size) {
if (!file.metadata || file.metadata.originalSize !== file.size) {
return ModelValidationStatus.INVALID;
}
}
Expand Down Expand Up @@ -259,9 +256,9 @@ export class ModelManager {
/**
* Get all models in the cache
*/
async getModels(): Promise<Model[]> {
async getModels(opts: { includeInvalid?: boolean } = {}): Promise<Model[]> {
const cachedFiles = await this.cacheManager.list();
const models: Model[] = [];
let models: Model[] = [];
for (const file of cachedFiles) {
const shards = ModelManager.parseModelUrl(file.metadata.originalURL);
const isFirstShard =
Expand All @@ -270,6 +267,11 @@ export class ModelManager {
models.push(new Model(this, file.metadata.originalURL, cachedFiles));
}
}
if (!opts.includeInvalid) {
models = models.filter(
(m) => m.validate() === ModelValidationStatus.VALID
);
}
return models;
}

Expand Down
24 changes: 13 additions & 11 deletions src/wllama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const SPLIT_MODEL =

const EMBD_MODEL = TINY_MODEL; // for better speed

test('loads single model file', async () => {
test.sequential('loads single model file', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
Expand All @@ -33,7 +33,7 @@ test('loads single model file', async () => {
await wllama.exit();
});

test('loads single model file from HF', async () => {
test.sequential('loads single model file from HF', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromHF(
Expand All @@ -49,7 +49,7 @@ test('loads single model file from HF', async () => {
await wllama.exit();
});

test('loads single thread model', async () => {
test.sequential('loads single thread model', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
Expand All @@ -66,7 +66,7 @@ test('loads single thread model', async () => {
await wllama.exit();
});

test('loads model with progress callback', async () => {
test.sequential('loads model with progress callback', async () => {
const wllama = new Wllama(CONFIG_PATHS);

let progressCalled = false;
Expand All @@ -88,7 +88,7 @@ test('loads model with progress callback', async () => {
await wllama.exit();
});

test('loads split model files', async () => {
test.sequential('loads split model files', async () => {
const wllama = new Wllama(CONFIG_PATHS, {
parallelDownloads: 5,
});
Expand All @@ -101,7 +101,7 @@ test('loads split model files', async () => {
await wllama.exit();
});

test('tokenizes and detokenizes text', async () => {
test.sequential('tokenizes and detokenizes text', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
Expand All @@ -122,7 +122,7 @@ test('tokenizes and detokenizes text', async () => {
await wllama.exit();
});

test('generates completion', async () => {
test.sequential('generates completion', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
Expand Down Expand Up @@ -151,7 +151,7 @@ test('generates completion', async () => {
await wllama.exit();
});

test('gets logits', async () => {
test.sequential('gets logits', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(TINY_MODEL, {
Expand All @@ -170,7 +170,7 @@ test('gets logits', async () => {
await wllama.exit();
});

test('generates embeddings', async () => {
test.sequential('generates embeddings', async () => {
const wllama = new Wllama(CONFIG_PATHS);

await wllama.loadModelFromUrl(EMBD_MODEL, {
Expand Down Expand Up @@ -208,7 +208,7 @@ test('generates embeddings', async () => {
await wllama.exit();
});

test('allowOffline', async () => {
test.sequential('allowOffline', async () => {
const wllama = new Wllama(CONFIG_PATHS, {
allowOffline: true,
});
Expand All @@ -224,10 +224,12 @@ test('allowOffline', async () => {
} catch (e) {
window.fetch = origFetch;
throw e;
} finally {
window.fetch = origFetch;
}
});

test('cleans up resources', async () => {
test.sequential('cleans up resources', async () => {
const wllama = new Wllama(CONFIG_PATHS);
await wllama.loadModelFromUrl(TINY_MODEL);
expect(wllama.isModelLoaded()).toBe(true);
Expand Down

0 comments on commit e6bc698

Please sign in to comment.