From 55411371f900cf6381e265cfb0d28d0d5644748f Mon Sep 17 00:00:00 2001 From: tharvik Date: Wed, 28 Aug 2024 11:16:45 +0200 Subject: [PATCH] *: put Memory into webapp --- discojs-web/src/index.ts | 1 - discojs-web/src/memory/index.ts | 1 - discojs-web/src/memory/memory.ts | 193 --------------- discojs/src/index.ts | 1 - discojs/src/memory/base.ts | 134 ---------- discojs/src/memory/empty.ts | 58 ----- discojs/src/memory/index.ts | 2 - discojs/src/training/disco.ts | 22 +- package-lock.json | 38 ++- webapp/cypress/e2e/library.cy.ts | 100 -------- webapp/cypress/e2e/testing.cy.ts | 9 +- webapp/cypress/e2e/training.cy.ts | 2 + webapp/package.json | 1 + webapp/src/assets/svg/CrossIcon.vue | 29 --- webapp/src/assets/svg/FileIcon.vue | 25 -- webapp/src/components/App.vue | 8 - .../components/progress_bars/TestingBar.vue | 4 +- .../src/components/sidebar/ModelLibrary.vue | 232 ------------------ webapp/src/components/sidebar/SideBar.vue | 87 ------- webapp/src/components/testing/Testing.vue | 122 ++++----- .../testing/__tests__/Testing.spec.ts | 97 ++++++++ webapp/src/components/training/Finished.vue | 112 ++++----- .../src/components/training/ModelCaching.vue | 156 ------------ webapp/src/components/training/Trainer.vue | 16 +- webapp/src/components/training/Training.vue | 10 +- webapp/src/main.ts | 3 +- webapp/src/store/__tests__/models.spec.ts | 27 ++ webapp/src/store/memory.ts | 75 ------ webapp/src/store/models.ts | 108 ++++++++ webapp/src/store/validation.ts | 4 +- 30 files changed, 403 insertions(+), 1274 deletions(-) delete mode 100644 discojs-web/src/memory/index.ts delete mode 100644 discojs-web/src/memory/memory.ts delete mode 100644 discojs/src/memory/base.ts delete mode 100644 discojs/src/memory/empty.ts delete mode 100644 discojs/src/memory/index.ts delete mode 100644 webapp/cypress/e2e/library.cy.ts delete mode 100644 webapp/src/assets/svg/CrossIcon.vue delete mode 100644 webapp/src/assets/svg/FileIcon.vue delete mode 100644 webapp/src/components/sidebar/ModelLibrary.vue create mode 100644 webapp/src/components/testing/__tests__/Testing.spec.ts delete mode 100644 webapp/src/components/training/ModelCaching.vue create mode 100644 webapp/src/store/__tests__/models.spec.ts delete mode 100644 webapp/src/store/memory.ts create mode 100644 webapp/src/store/models.ts diff --git a/discojs-web/src/index.ts b/discojs-web/src/index.ts index 71a7bc24c..a8082e1ed 100644 --- a/discojs-web/src/index.ts +++ b/discojs-web/src/index.ts @@ -1,2 +1 @@ export * from "./loaders/index.js"; -export * from "./memory/index.js"; diff --git a/discojs-web/src/memory/index.ts b/discojs-web/src/memory/index.ts deleted file mode 100644 index bad0ecca4..000000000 --- a/discojs-web/src/memory/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { IndexedDB } from './memory.js' diff --git a/discojs-web/src/memory/memory.ts b/discojs-web/src/memory/memory.ts deleted file mode 100644 index 030956c9b..000000000 --- a/discojs-web/src/memory/memory.ts +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Helper functions used to load and save TFJS models from IndexedDB. The - * working model is the model currently being trained for a task. Saved models - * are models that were explicitly saved to IndexedDB. The two working/ and saved/ - * folders are invisible to the user. The user only interacts with the saved/ - * folder via the model library. The working/ folder is only used by the backend. - * The working model is loaded from IndexedDB for training (model.fit) only. - */ -import { Map } from 'immutable' -import createDebug from "debug" -import * as tf from '@tensorflow/tfjs' - -import type { Model, ModelInfo, ModelSource } from '@epfml/discojs' -import { Memory, models } from '@epfml/discojs' - -const debug = createDebug('discojs-web:memory') - -export class IndexedDB extends Memory { - override getModelMemoryPath (source: ModelSource): string { - if (typeof source === 'string') { - return source - } - const version = source.version ?? 0 - return `indexeddb://${source.type}/${source.tensorBackend}/${source.taskID}/${source.name}@${version}` - } - - override getModelInfo (source: ModelSource): ModelInfo { - if (typeof source !== 'string') { - return source - } - const [type, tensorBackend, taskID, fullName] = source.split('/').splice(2) - - if (type !== 'working' && type !== 'saved') { - throw Error("Unknown memory model type") - } - - const [name, versionSuffix] = fullName.split('@') - const version = versionSuffix === undefined ? 0 : Number(versionSuffix) - if (tensorBackend !== 'tfjs' && tensorBackend !== 'gpt') { - throw Error("Unknown tensor backend") - } - return { type, taskID, name, version, tensorBackend } - } - - async getModelMetadata (source: ModelSource): Promise { - const models = await tf.io.listModels() - return models[this.getModelMemoryPath(source)] - } - - async contains (source: ModelSource): Promise { - return await this.getModelMetadata(source) !== undefined - } - - override async getModel(source: ModelSource): Promise { - const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source)) - - const tensorBackend = this.getModelInfo(source).tensorBackend - switch (tensorBackend) { - case 'tfjs': - return new models.TFJS(layersModel) - case 'gpt': - return new models.GPT(undefined, layersModel) - default: { - const _: never = tensorBackend - throw new Error('should never happen') - } - } - } - - async deleteModel (source: ModelSource): Promise { - await tf.io.removeModel(this.getModelMemoryPath(source)) - } - - async loadModel(source: ModelSource): Promise { - const src = this.getModelInfo(source) - if (src.type === 'working') { - // Model is already loaded - return - } - await tf.io.copyModel( - this.getModelMemoryPath(src), - this.getModelMemoryPath({ ...src, type: 'working', version: 0 }) - ) - } - - /** - * Saves the working model to the source. - * @param source the destination - * @param model the model - */ - override async updateWorkingModel (source: ModelSource, model: Model): Promise { - const src: ModelInfo = this.getModelInfo(source) - if (src.type !== 'working') { - throw new Error('expected working type model') - } - // Enforce version 0 to always keep a single working model at a time - const modelInfo = { ...src, type: 'working' as const, version: 0 } - let includeOptimizer; - if (model instanceof models.TFJS) { - modelInfo['tensorBackend'] = 'tfjs' - includeOptimizer = true - } else if (model instanceof models.GPT) { - modelInfo['tensorBackend'] = 'gpt' - includeOptimizer = false // true raises an error - } else { - debug('unknown working model type %o', model) - throw new Error(`unknown model type while updating working model`) - } - const indexedDBURL = this.getModelMemoryPath(modelInfo) - await model.extract().save(indexedDBURL, { includeOptimizer }) - } - - /** - * Creates a saved copy of the working model corresponding to the source. - * @param source the source - */ - async saveWorkingModel (source: ModelSource): Promise { - const src: ModelInfo = this.getModelInfo(source) - if (src.type !== 'working') { - throw new Error('expected working type model') - } - const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: 'saved' })) - await tf.io.copyModel( - this.getModelMemoryPath({ ...src, type: 'working' }), - dst - ) - return dst - } - - override async saveModel (source: ModelSource, model: Model): Promise { - const src: ModelInfo = this.getModelInfo(source) - if (src.type !== 'saved') { - throw new Error('expected saved type model') - } - - const modelInfo = await this.duplicateSource({ ...src, type: 'saved' }) - let includeOptimizer; - if (model instanceof models.TFJS) { - modelInfo['tensorBackend'] = 'tfjs' - includeOptimizer = true - } else if (model instanceof models.GPT) { - modelInfo['tensorBackend'] = 'gpt' - includeOptimizer = false // true raises an error - } else { - debug('unknown saved model type %o', model) - throw new Error('unknown model type while saving model') - } - const indexedDBURL = this.getModelMemoryPath(modelInfo) - await model.extract().save(indexedDBURL, { includeOptimizer }) - return indexedDBURL - } - - /** - * Downloads the model corresponding to the source. - * @param source the source - */ - async downloadModel (source: ModelSource): Promise { - const src: ModelInfo = this.getModelInfo(source) - await tf.io.copyModel( - this.getModelMemoryPath(source), - `downloads://${src.taskID}_${src.name}` - ) - } - - async latestDuplicate (source: ModelSource): Promise { - if (typeof source !== 'string') { - source = this.getModelMemoryPath({ ...source, version: 0 }) - } - // perform a single memory read - const paths = Map(await tf.io.listModels()) - if (!paths.has(source)) { - return undefined - } - - const latest = Map(paths).keySeq().toList() - .map((p) => this.getModelInfo(p).version).max() - if (latest === undefined) { - return 0 - } - return latest - } - - async duplicateSource (source: ModelSource): Promise { - const latestDuplicate = await this.latestDuplicate(source) - source = this.getModelInfo(source) - - if (latestDuplicate === undefined) { - return source - } - - return { ...source, version: latestDuplicate + 1 } - } -} diff --git a/discojs/src/index.ts b/discojs/src/index.ts index 37c5f08f3..80c27a0b2 100644 --- a/discojs/src/index.ts +++ b/discojs/src/index.ts @@ -10,7 +10,6 @@ export * as aggregator from './aggregator/index.js' export { WeightsContainer, aggregation } from './weights/index.js' export { Logger, ConsoleLogger } from './logging/index.js' -export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js' export { Disco, RoundLogs, RoundStatus } from './training/index.js' export { Validator } from './validation/index.js' diff --git a/discojs/src/memory/base.ts b/discojs/src/memory/base.ts deleted file mode 100644 index ae8176628..000000000 --- a/discojs/src/memory/base.ts +++ /dev/null @@ -1,134 +0,0 @@ -// only used browser-side -// TODO: replace IO type - -import type { Model, TaskID } from '../index.js' - -/** - * Type of models stored in memory. Stored models can either be a model currently - * being trained ("working model") or a regular model saved in memory ("saved model"). - * There can only be a single working model for a given task. - */ -type StoredModelType = 'saved' | 'working' - -/** - * Model information which uniquely identifies a model in memory. - */ -export interface ModelInfo { - // The model's type: "working" or "saved" model. - type: StoredModelType - // The model's version, to allow for multiple saved models of a same task without - // causing id conflicts - version?: number - // The model's corresponding task - taskID: TaskID - // The model's name - name: string - // Tensor framework used by the model - tensorBackend: 'gpt'|'tfjs' // onnx in the future -} - -/** - * A model source uniquely identifies a model stored in memory. - * It can be in the form of either a model info object or an ID - * (one-to-one mapping between the two) - */ -export type ModelSource = ModelInfo | string - -/** - * Represents a model memory system, providing functions to fetch, save, delete and update models. - * Stored models can either be a model currently being trained ("working model") or a regular model - * saved in memory ("saved model"). There can only be a single working model for a given task. - */ -export abstract class Memory { - /** - * Fetches the model identified by the given model source. - * @param source The model source - * @returns The model - */ - abstract getModel (source: ModelSource): Promise - - /** - * Removes the model identified by the given model source from memory. - * @param source The model source - * @returns The model - */ - abstract deleteModel (source: ModelSource): Promise - - /** - * Replaces the corresponding working model with the saved model identified by the given model source. - * @param source The model source - */ - abstract loadModel (source: ModelSource): Promise - - /** - * Fetches metadata for the model identified by the given model source. - * If the model does not exist in memory, returns undefined. - * @param source The model source - * @returns The model metadata or undefined - */ - abstract getModelMetadata (source: ModelSource): Promise - - /** - * Replaces the working model identified by the given source with the newly provided model. - * @param source The model source - * @param model The new model - */ - abstract updateWorkingModel (source: ModelSource, model: Model): Promise - - /** - * Creates a saved model copy from the working model identified by the given model source. - * Returns the saved model's path. - * @param source The model source - * @returns The saved model's path - */ - abstract saveWorkingModel (source: ModelSource): Promise - - /** - * Saves the newly provided model to the given model source. - * Returns the saved model's path - * @param source The model source - * @param model The new model - * @returns The saved model's path - */ - abstract saveModel (source: ModelSource, model: Model): Promise - - /** - * Moves the model identified by the model source to a file system. This is platform-dependent. - * @param source The model source - */ - abstract downloadModel (source: ModelSource): Promise - - /** - * Checks whether the model memory contains the model identified by the given source. - * @param source The model source - * @returns True if the memory contains the model, false otherwise - */ - abstract contains (source: ModelSource): Promise - - /** - * Computes the path in memory corresponding to the given model source, be it a path or model information. - * This is used to easily switch between model path and information, which are both unique model identifiers - * with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given - * model source. - * @param source The model source - * @returns The model path - */ - abstract getModelMemoryPath (source: ModelSource): string | undefined - - /** - * Computes the model information corresponding to the given model source, be it a path or model information. - * This is used to easily switch between model path and information, which are both unique model identifiers - * with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred - * from the given model source. - * @param source The model source - * @returns The model information - */ - abstract getModelInfo (source: ModelSource): ModelInfo | undefined - - /** - * Computes the lowest version a model source can have without conflicting with model versions currently in memory. - * @param source The model source - * @returns The duplicated model source - */ - abstract duplicateSource (source: ModelSource): Promise -} diff --git a/discojs/src/memory/empty.ts b/discojs/src/memory/empty.ts deleted file mode 100644 index c601d799b..000000000 --- a/discojs/src/memory/empty.ts +++ /dev/null @@ -1,58 +0,0 @@ -import type { Model } from '../index.js' - -import type { ModelInfo } from './base.js' -import { Memory } from './base.js' - -/** - * Represents an empty model memory. - */ -export class Empty extends Memory { - getModelMetadata (): Promise { - return Promise.resolve(undefined) - } - - contains (): Promise { - return Promise.resolve(false) - } - - getModel (): Promise { - return Promise.reject(new Error('empty')) - } - - loadModel (): Promise { - return Promise.reject(new Error('empty')) - } - - updateWorkingModel (): Promise { - // nothing to do - return Promise.resolve() - } - - saveWorkingModel (): Promise { - return Promise.resolve(undefined) - } - - saveModel (): Promise { - return Promise.resolve(undefined) - } - - async deleteModel (): Promise { - // nothing to do - } - - downloadModel (): Promise { - return Promise.reject(new Error('empty')) - } - - getModelMemoryPath (): string { - throw new Error('empty') - } - - getModelInfo (): ModelInfo { - throw new Error('empty') - } - - duplicateSource (): Promise { - return Promise.resolve(undefined) - } -} diff --git a/discojs/src/memory/index.ts b/discojs/src/memory/index.ts deleted file mode 100644 index 942673a1e..000000000 --- a/discojs/src/memory/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { Empty } from './empty.js' -export { Memory, type ModelInfo, type ModelSource } from './base.js' diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index fc0e17aa0..5dcb86367 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -4,9 +4,7 @@ import { BatchLogs, ConsoleLogger, EpochLogs, - EmptyMemory, Logger, - Memory, Task, TrainingInformation, } from "../index.js"; @@ -22,7 +20,6 @@ import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js"; interface DiscoConfig { scheme: TrainingInformation["scheme"]; logger: Logger; - memory: Memory; } export type RoundStatus = @@ -33,14 +30,13 @@ export type RoundStatus = /** * Top-level class handling distributed training from a client's perspective. It is meant to be - * a convenient object providing a reduced yet complete API that wraps model training, - * communication with nodes, logs and model memory. + * a convenient object providing a reduced yet complete API that wraps model training and + * communication with nodes. */ export class Disco extends EventEmitter<{'status': RoundStatus}>{ public readonly trainer: Trainer; readonly #client: clients.Client; readonly #logger: Logger; - readonly #memory: Memory; readonly #task: Task; /** @@ -55,10 +51,9 @@ export class Disco extends EventEmitter<{'status': RoundStatus}>{ config: Partial ) { super() - const { scheme, logger, memory } = { + const { scheme, logger } = { scheme: task.trainingInformation.scheme, logger: new ConsoleLogger(), - memory: new EmptyMemory(), ...config, }; @@ -80,7 +75,6 @@ export class Disco extends EventEmitter<{'status': RoundStatus}>{ this.#logger = logger; this.#client = client; - this.#memory = memory; this.#task = task; this.trainer = new Trainer(task, client) // Simply propagate the training status events emitted by the client @@ -171,16 +165,6 @@ export class Disco extends EventEmitter<{'status': RoundStatus}>{ return await returnedRoundLogs; }.bind(this)(); - - await this.#memory.updateWorkingModel( - { - type: "working", - taskID: this.#task.id, - name: this.#task.trainingInformation.modelID, - tensorBackend: this.#task.trainingInformation.tensorBackend, - }, - this.trainer.model, - ); } this.#logger.success("Training finished"); } diff --git a/package-lock.json b/package-lock.json index 5dc37984b..f070a7bde 100644 --- a/package-lock.json +++ b/package-lock.json @@ -148,11 +148,12 @@ } }, "node_modules/@babel/parser": { - "version": "7.25.3", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.25.3.tgz", - "integrity": "sha512-iLTJKDbJ4hMvFPgQwwsVoxtHyWpKKPBrxkANrSYewDPaPpT5py5yeVkgPIJ7XYXhndxJpaA3PyALSXQ7u8e/Dw==", + "version": "7.25.4", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.25.4.tgz", + "integrity": "sha512-nq+eWrOgdtu3jG5Os4TQP3x3cLA8hR8TvJNjD8vnPa20WGycimcparWnLK4jJhElTK6SDyuJo1weMKO/5LpmLA==", + "license": "MIT", "dependencies": { - "@babel/types": "^7.25.2" + "@babel/types": "^7.25.4" }, "bin": { "parser": "bin/babel-parser.js" @@ -162,9 +163,10 @@ } }, "node_modules/@babel/types": { - "version": "7.25.2", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.25.2.tgz", - "integrity": "sha512-YTnYtra7W9e6/oAZEHj0bJehPRUlLH9/fbpT5LfB0NhQXyALCRkRs3zH9v07IYhkgpqX6Z78FnuccZr/l4Fs4Q==", + "version": "7.25.4", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.25.4.tgz", + "integrity": "sha512-zQ1ijeeCXVEh+aNL0RlmkPkG8HUiDcU2pzQQFjtbntgAczRASFzj4H+6+bV+dy1ntKR14I/DypeuRG1uma98iQ==", + "license": "MIT", "dependencies": { "@babel/helper-string-parser": "^7.24.8", "@babel/helper-validator-identifier": "^7.24.7", @@ -7873,9 +7875,9 @@ "license": "BSD-3-Clause" }, "node_modules/ignore": { - "version": "5.3.1", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.1.tgz", - "integrity": "sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==", + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", "dev": true, "license": "MIT", "engines": { @@ -10382,9 +10384,9 @@ } }, "node_modules/pkg-types": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/pkg-types/-/pkg-types-1.1.3.tgz", - "integrity": "sha512-+JrgthZG6m3ckicaOB74TwQ+tBWsFl3qVQg7mN8ulwSOElJ7gBhKzj2VkCPnZ4NlF6kEquYU+RIYNVAvzd54UA==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/pkg-types/-/pkg-types-1.2.0.tgz", + "integrity": "sha512-+ifYuSSqOQ8CqP4MbZA5hDpb97n3E8SVWdJe+Wms9kj745lmd3b7EZJiqvmLwAlmRfjrI7Hi5z3kdBJ93lFNPA==", "dev": true, "license": "MIT", "dependencies": { @@ -14113,6 +14115,7 @@ "d3": "7", "immutable": "4", "pinia": "2", + "pinia-plugin-persistedstate": "3", "vee-validate": "4", "vue": "3", "vue-router": "4", @@ -14348,6 +14351,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "webapp/node_modules/pinia-plugin-persistedstate": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/pinia-plugin-persistedstate/-/pinia-plugin-persistedstate-3.2.1.tgz", + "integrity": "sha512-MK++8LRUsGF7r45PjBFES82ISnPzyO6IZx3CH5vyPseFLZCk1g2kgx6l/nW8pEBKxxd4do0P6bJw+mUSZIEZUQ==", + "license": "MIT", + "peerDependencies": { + "pinia": "^2.0.0" + } + }, "webapp/node_modules/signal-exit": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", diff --git a/webapp/cypress/e2e/library.cy.ts b/webapp/cypress/e2e/library.cy.ts deleted file mode 100644 index be99f11e4..000000000 --- a/webapp/cypress/e2e/library.cy.ts +++ /dev/null @@ -1,100 +0,0 @@ -import { expect } from "chai"; -import * as tf from "@tensorflow/tfjs"; - -import { defaultTasks } from "@epfml/discojs"; - -import { setupServerWith } from "../support/e2e.ts"; - -// we can't test via component stubs as it requires IndexedDB - -describe("model library", () => { - /** Ensure that downloaded model is TFJS and trainable */ - function expectDownloadOfTFJSModelIsTrainable(modelName: string): void { - const folder = Cypress.config("downloadsFolder"); - cy.readFile(`${folder}/titanic_titanic-model.json`).then((content) => - cy.intercept( - { hostname: "downloads", pathname: `/${modelName}.json` }, - content, - ), - ); - cy.readFile(`${folder}/${modelName}.weights.bin`).then((content) => - cy.intercept( - { - hostname: "downloads", - pathname: `/${modelName}.weights.bin`, - }, - content, - ), - ); - - cy.wrap({ loadModel: tf.loadLayersModel }) - .invoke("loadModel", `http://downloads/${modelName}.json`) - .then((promise) => promise) - .then((model) => { - const [input, output] = [model.input, model.output]; - if (Array.isArray(input) || Array.isArray(output)) - throw new Error("only support single input & output"); - - return model.fit( - tf.ones(input.shape.map((s) => (s === null ? 1 : s))), - tf.ones(output.shape.map((s) => (s === null ? 1 : s))), - { - epochs: 3, - }, - ); - }) - .should((history) => expect(history.epoch).to.have.lengthOf(3)); - } - - it("allows downloading of server models", () => { - setupServerWith(defaultTasks.titanic); - - cy.visit("/#/evaluate"); - cy.contains("button", "download").click(); - - cy.get("#model-library-btn").click(); - cy.contains("titanic-model").get('button[title="Download"]').click(); - - expectDownloadOfTFJSModelIsTrainable("titanic_titanic-model"); - }); - - it("store trained model", () => { - setupServerWith({ - getTask() { - const task = defaultTasks.titanic.getTask(); - task.trainingInformation.epochs = - task.trainingInformation.roundDuration = 3; - return task; - }, - getModel: defaultTasks.titanic.getModel, - }); - - // Wait for tasks to load - cy.visit("/#/list").contains("button", "participate", { timeout: 5000 }); - - cy.visit("/#/titanic"); - cy.contains("button", "next").click(); - - cy.contains("label", "select CSV").selectFile( - "../datasets/titanic_train.csv", - ); - cy.contains("button", "next").click(); - - cy.contains("button", "locally").click(); - cy.contains("button", "start training").click(); - cy.contains("h6", "epochs") - .next({ timeout: 10_000 }) - .should("have.text", "3 / 3"); - cy.contains("button", "next").click(); - - // TODO do not save by default, only via "save model" button - - // TODO be reactive - cy.visit("/#/evaluate"); // force refresh - - cy.get("#model-library-btn").click(); - cy.contains("titanic-model").get('button[title="Download"]').click(); - - expectDownloadOfTFJSModelIsTrainable("titanic_titanic-model"); - }); -}); diff --git a/webapp/cypress/e2e/testing.cy.ts b/webapp/cypress/e2e/testing.cy.ts index 8710c3c11..0647d0675 100644 --- a/webapp/cypress/e2e/testing.cy.ts +++ b/webapp/cypress/e2e/testing.cy.ts @@ -7,8 +7,7 @@ it("can test titanic", () => { cy.visit("/#/evaluate"); cy.contains("button", "download").click(); - - cy.contains("titanic-model").parents().contains("button", "test").click(); + cy.contains("button", "test").click(); cy.contains("label", "select CSV").selectFile( "../datasets/titanic_train.csv", @@ -29,8 +28,7 @@ it("can test lus_covid", () => { cy.visit("/#/evaluate"); cy.contains("button", "download").click(); - - cy.contains("lus-covid-model").parents().contains("button", "test").click(); + cy.contains("button", "test").click(); cy.task("readdir", "../datasets/lus_covid/COVID+/").then((files) => cy.contains("label", "select images").selectFile(files), @@ -51,8 +49,7 @@ it("can start and stop testing of wikitext", () => { cy.visit("/#/evaluate"); cy.contains("button", "download").click(); - - cy.contains("llm-raw-model").parents().contains("button", "test").click(); + cy.contains("button", "test").click(); cy.contains("label", "select text").selectFile( "../datasets/wikitext/wiki.test.tokens", diff --git a/webapp/cypress/e2e/training.cy.ts b/webapp/cypress/e2e/training.cy.ts index de11abf6b..350c763ed 100644 --- a/webapp/cypress/e2e/training.cy.ts +++ b/webapp/cypress/e2e/training.cy.ts @@ -44,6 +44,8 @@ describe("training page", () => { cy.contains("button", "next").click(); cy.contains("button", "test model").click(); + + cy.contains("Titanic Prediction"); }); it("can start and stop training of lus_covid", () => { diff --git a/webapp/package.json b/webapp/package.json index b8b4fd1bd..45ddc9a71 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -18,6 +18,7 @@ "d3": "7", "immutable": "4", "pinia": "2", + "pinia-plugin-persistedstate": "3", "vee-validate": "4", "vue": "3", "vue-router": "4", diff --git a/webapp/src/assets/svg/CrossIcon.vue b/webapp/src/assets/svg/CrossIcon.vue deleted file mode 100644 index 344930b85..000000000 --- a/webapp/src/assets/svg/CrossIcon.vue +++ /dev/null @@ -1,29 +0,0 @@ - - diff --git a/webapp/src/assets/svg/FileIcon.vue b/webapp/src/assets/svg/FileIcon.vue deleted file mode 100644 index 705365e3a..000000000 --- a/webapp/src/assets/svg/FileIcon.vue +++ /dev/null @@ -1,25 +0,0 @@ - - diff --git a/webapp/src/components/App.vue b/webapp/src/components/App.vue index be3aecbd4..35b4e4fcd 100644 --- a/webapp/src/components/App.vue +++ b/webapp/src/components/App.vue @@ -51,11 +51,9 @@ diff --git a/webapp/src/components/progress_bars/TestingBar.vue b/webapp/src/components/progress_bars/TestingBar.vue index c8dd16185..73e0e157e 100644 --- a/webapp/src/components/progress_bars/TestingBar.vue +++ b/webapp/src/components/progress_bars/TestingBar.vue @@ -50,7 +50,6 @@ diff --git a/webapp/src/components/sidebar/SideBar.vue b/webapp/src/components/sidebar/SideBar.vue index 252da62ab..ef28645ac 100644 --- a/webapp/src/components/sidebar/SideBar.vue +++ b/webapp/src/components/sidebar/SideBar.vue @@ -59,14 +59,6 @@ > - - - - - - -
- - -
diff --git a/webapp/src/components/training/ModelCaching.vue b/webapp/src/components/training/ModelCaching.vue deleted file mode 100644 index a2e8dba11..000000000 --- a/webapp/src/components/training/ModelCaching.vue +++ /dev/null @@ -1,156 +0,0 @@ - - - - diff --git a/webapp/src/components/training/Trainer.vue b/webapp/src/components/training/Trainer.vue index 9582a1a1d..e323bcd1c 100644 --- a/webapp/src/components/training/Trainer.vue +++ b/webapp/src/components/training/Trainer.vue @@ -1,9 +1,5 @@ @@ -41,6 +46,7 @@ import { useRouter, useRoute } from "vue-router"; import type { Dataset, + Model, Tabular, TaskID, Text, @@ -128,4 +134,6 @@ const dataset = computed(() => { return undefined; }); + +const trainedModel = ref(); diff --git a/webapp/src/main.ts b/webapp/src/main.ts index d5bfd0acd..b3427317a 100644 --- a/webapp/src/main.ts +++ b/webapp/src/main.ts @@ -1,6 +1,7 @@ import createDebug from "debug"; import { createApp } from 'vue' import { createPinia } from 'pinia' +import piniaPersited from "pinia-plugin-persistedstate"; import * as tf from "@tensorflow/tfjs"; import App from '@/components/App.vue' @@ -36,7 +37,7 @@ app.config.errorHandler = (err, instance, info) => { } app - .use(createPinia()) + .use(createPinia().use(piniaPersited)) .use(router) .use(VueTippy, { diff --git a/webapp/src/store/__tests__/models.spec.ts b/webapp/src/store/__tests__/models.spec.ts new file mode 100644 index 000000000..c04347892 --- /dev/null +++ b/webapp/src/store/__tests__/models.spec.ts @@ -0,0 +1,27 @@ +import { createPinia, setActivePinia } from "pinia"; +import piniaPersited from "pinia-plugin-persistedstate"; +import { beforeEach, expect, it } from "vitest"; + +import { models as discoModels } from "@epfml/discojs"; + +import { useModelsStore } from "../models"; +import { createApp } from "vue"; + +const app = createApp({}); +beforeEach(() => { + const pinia = setActivePinia(createPinia().use(piniaPersited)); + app.use(pinia); + setActivePinia(pinia); +}); + +it("persists", async () => { + const models = useModelsStore(); + + const id = await models.add("task id", new discoModels.GPT()); + + models.$persist(); + models.$hydrate(); + + expect(models.infos.size).to.equal(1); + expect(await models.get(id)).to.be.an.instanceof(discoModels.GPT); +}); diff --git a/webapp/src/store/memory.ts b/webapp/src/store/memory.ts deleted file mode 100644 index e05fec31c..000000000 --- a/webapp/src/store/memory.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { defineStore } from 'pinia' -import { ref, shallowRef } from 'vue' - -import { Map } from 'immutable' -import * as tf from '@tensorflow/tfjs' - -import type { ModelInfo } from '@epfml/discojs' - -export interface ModelMetadata extends ModelInfo { - date: string - hours: string - fileSize: number -} - -export const useMemoryStore = defineStore('memory', () => { - const models = shallowRef>(Map()) - const useIndexedDB = ref(true) - - function setIndexedDB (use: boolean) { - useIndexedDB.value = use - } - - function addModel (id: string, metadata: ModelMetadata) { - models.value = models.value.set(id, metadata) - } - - function deleteModel (id: string) { - models.value = models.value.delete(id) - } - - async function initModels () { - const models = await tf.io.listModels() - for (const path in models) { - const [location, _, directory, tensorBackend, task, fullName] = path.split('/') - if (location !== 'indexeddb:') { - continue - } - const [name, version] = fullName.split('@') - const metadata = models[path] - const date = new Date(metadata.dateSaved) - const zeroPad = (number: number) => String(number).padStart(2, '0') - const dateSaved = [ - date.getDate(), - date.getMonth() + 1, - date.getFullYear() - ] - .map(zeroPad) - .join('/') - const hourSaved = [date.getHours(), date.getMinutes()] - .map(zeroPad) - .join('h') - const size = [ - metadata.modelTopologyBytes, - metadata.weightSpecsBytes, - metadata.weightDataBytes, - ].reduce((acc: number, v) => acc + (v === undefined ? 0 : v), 0) - - if (tensorBackend !== 'tfjs' && tensorBackend !== 'gpt') { - throw new Error("Tensor backend unrecognized: " + tensorBackend) - } - addModel(path, { - name, - tensorBackend, - taskID: task, - type: directory !== 'working' ? 'saved' : 'working', - date: dateSaved, - hours: hourSaved, - fileSize: Math.round(size / 1024), - version: version !== undefined ? Number(version) : 0 - }) - } - } - - return { models, useIndexedDB, initModels, addModel, deleteModel, setIndexedDB } -}) diff --git a/webapp/src/store/models.ts b/webapp/src/store/models.ts new file mode 100644 index 000000000..d9e454167 --- /dev/null +++ b/webapp/src/store/models.ts @@ -0,0 +1,108 @@ +import { Map, Range } from "immutable"; +import type { StateTree } from "pinia"; +import { defineStore } from "pinia"; +import { computed, ref, toRaw } from "vue"; + +import type { Model } from "@epfml/discojs"; +import { serialization } from "@epfml/discojs"; + +export type ModelID = number; + +interface Infos { + taskID: string; + dateSaved: Date; +} +type IDToModel = Map; + +type Serialized = Array< + [ModelID, { taskID: string; dateSaved: number; encoded: string }] +>; + +export const useModelsStore = defineStore( + "models", + () => { + const idToModel = ref(Map()); + + const infos = computed(() => + idToModel.value.map(({ taskID, dateSaved, encoded }) => ({ + taskID, + dateSaved, + // approx assuming that `encoded` size far outweigh other fields + // encoding as base16 which has a 50% efficiency + storageSize: encoded.length * 2, + })), + ); + + async function get(id: ModelID): Promise { + const infos = idToModel.value.get(id); + if (infos === undefined) return undefined; + + return await serialization.model.decode(toRaw(infos.encoded)); + } + + async function add(taskID: string, model: Model): Promise { + const dateSaved = new Date(); + const id = dateSaved.getTime(); + + idToModel.value = idToModel.value.set(id, { + taskID, + dateSaved, + encoded: await serialization.model.encode(model), + }); + + return id; + } + + function remove(id: ModelID): void { + idToModel.value = idToModel.value.delete(id); + } + + return { + idToModel, + infos, + get, + add, + remove, + }; + }, + { + persist: { + // only `ref` is `idToModel`, only serializing that + serializer: { + serialize(state: StateTree): string { + return JSON.stringify( + (state.idToModel as IDToModel) + .map(({ taskID, dateSaved, encoded }) => ({ + taskID, + dateSaved: dateSaved.getTime(), + // Uint8Array is very inefficiently encoded in JSON, Window.{atob,btoa} is broken + // using hex encoding (base16) + encoded: [...encoded] + .flatMap((b) => [ + (b >> 4).toString(16), + (b & 0xf).toString(16), + ]) + .join(""), + })) + .toArray() satisfies Serialized, + ); + }, + deserialize(raw: string): { idToModel: IDToModel } { + return { + idToModel: Map((JSON.parse(raw) as Serialized) ?? []).map( + ({ taskID, dateSaved, encoded }) => ({ + taskID, + dateSaved: new Date(dateSaved), + encoded: Uint8Array.from( + Range(0, encoded.length / 2).map((i) => + Number.parseInt(encoded.slice(i * 2, i * 2 + 2), 16), + ), + ), + }), + ), + }; + }, + }, + }, + }, +); diff --git a/webapp/src/store/validation.ts b/webapp/src/store/validation.ts index 172d84165..4d667f7e6 100644 --- a/webapp/src/store/validation.ts +++ b/webapp/src/store/validation.ts @@ -1,10 +1,12 @@ import { defineStore } from "pinia"; import { ref } from "vue"; +import type { ModelID } from "./models"; + export const useValidationStore = defineStore("validation", () => { const step = ref<0 | 1 | 2>(0); const mode = ref<"predict" | "test">(); - const modelID = ref(); + const modelID = ref(); return { step, mode, modelID }; });