From 58f916cff11b7a377b99828bcbc06c2eea3e0e3a Mon Sep 17 00:00:00 2001 From: tharvik Date: Fri, 16 Aug 2024 23:49:25 +0200 Subject: [PATCH] discojs/validation: simplify --- discojs/src/dataset/data/image_data.ts | 4 +- discojs/src/validation/validator.ts | 254 +++++----- server/tests/validator.spec.ts | 100 ++-- .../src/components/containers/TableLayout.vue | 30 +- .../src/components/testing/PredictSteps.vue | 327 +++++++------ webapp/src/components/testing/TestSteps.vue | 440 ++++++++++-------- webapp/src/components/testing/Testing.vue | 47 +- 7 files changed, 601 insertions(+), 601 deletions(-) diff --git a/discojs/src/dataset/data/image_data.ts b/discojs/src/dataset/data/image_data.ts index d6d7a4849..e4c771a02 100644 --- a/discojs/src/dataset/data/image_data.ts +++ b/discojs/src/dataset/data/image_data.ts @@ -32,8 +32,8 @@ export class ImageData extends Data { } let shape - if ('xs' in sample && 'ys' in sample) { - shape = (sample as { xs: tf.Tensor, ys: number[] }).xs.shape + if ('xs' in sample) { + shape = (sample as { xs: tf.Tensor }).xs.shape } else { shape = (sample as tf.Tensor3D).shape } diff --git a/discojs/src/validation/validator.ts b/discojs/src/validation/validator.ts index fd373e085..437cbf73a 100644 --- a/discojs/src/validation/validator.ts +++ b/discojs/src/validation/validator.ts @@ -1,149 +1,165 @@ -import { List } from 'immutable' -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, Task, - Logger, - client as clients, - data, - Memory, - ModelSource, TypedDataset, TypedLabeledDataset, } from "../index.js"; -import { datasetToData, labeledDatasetToData } from "../dataset/data/helpers.js"; +import { + datasetToData, + labeledDatasetToData, +} from "../dataset/data/helpers.js"; + +function intoTFDataset( + iter: AsyncIterable, +): tf.data.Dataset { + // @ts-expect-error generator + return tf.data.generator(async function* () { + yield* iter; + }); +} export class Validator { - private size = 0 - private _confusionMatrix: number[][] | undefined - private rollingAccuracy = 0 + readonly #model: Model; - constructor ( + constructor( public readonly task: Task, - public readonly logger: Logger, - private readonly memory: Memory, - private readonly source?: ModelSource, - private readonly client?: clients.Client + model: Model, ) { - if (source === undefined && client === undefined) { - throw new Error('To initialize a Validator, either or both a source and client need to be specified') - } + this.#model = model; } - private async getLabel(ys: tf.Tensor): Promise { - // Binary classification - if (ys.shape[1] == 1) { - const threshold = tf.scalar(0.5) - const binaryTensor = ys.greaterEqual(threshold) - const binaryArray = await binaryTensor.data() - tf.dispose([binaryTensor, threshold]) - return binaryArray - // Multi-class classification - } else { - const yIdxTensor = ys.argMax(-1) - const yIdx = await yIdxTensor.data() - tf.dispose([yIdxTensor]) - return yIdx + /** infer every line of the dataset and check that it is as labeled */ + async *test(dataset: TypedLabeledDataset): AsyncGenerator { + const preprocessed = ( + await labeledDatasetToData(this.task, dataset) + ).preprocess(); + const batched = preprocessed.batch().dataset; + + const iterator = await tf.data + .zip<[tf.Tensor1D | tf.Tensor2D, number]>([ + preprocessed.dataset.map((t) => { + if ( + typeof t !== "object" || + !("ys" in t) || + !(t.ys instanceof tf.Tensor) || + !(t.ys.rank === 1 || t.ys.rank === 2) + ) + throw new Error("unexpected preprocessed dataset"); + if ("xs" in t) tf.dispose(t.xs); + return t.ys; + }), + intoTFDataset(this.#inferOnBatchedData(batched)), + ]) + .iterator(); + for ( + let iter = await iterator.next(); + iter.done !== true; + iter = await iterator.next() + ) { + const zipped = iter.value; + + const label = await getLabel(zipped[0]); + tf.dispose(zipped[0]); + const infered = zipped[1]; + + yield label === infered; } - // Multi-label classification is not supported } - // test assumes data comes with labels while predict doesn't - async *test(dataset: data.Data | TypedLabeledDataset): - AsyncGenerator, void> { - if (Array.isArray(dataset)) - dataset = await labeledDatasetToData(this.task, dataset) - const batched = dataset - .preprocess() - .dataset.batch(this.task.trainingInformation.batchSize); - - const model = await this.getModel() - let hits = 0 - const iterator = await batched.iterator() - let next = await iterator.next() - while (next.done !== true) { - const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } - const ysLabel = await this.getLabel(ys) - const yPredTensor = await model.predict(xs) - const pred = await this.getLabel(yPredTensor) - const currentFeatures = await xs.array() - this.size += ysLabel.length - hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size - this.rollingAccuracy = hits / this.size - tf.dispose([xs, ys, yPredTensor]) - - yield (List(ysLabel).zip(List(pred), List(currentFeatures))) - .map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f })) - .toArray() - - next = await iterator.next() - } - - this.logger.success(`Obtained validation accuracy of ${this.accuracy}`) - this.logger.success(`Visited ${this.visitedSamples} samples`) + /** use the model to predict every line of the dataset */ + async *infer(dataset: TypedDataset): AsyncGenerator { + const data = await datasetToData(this.task, dataset); + + const batched = data.preprocess().batch().dataset; + + yield* this.#inferOnBatchedData(batched); } - async *inference (dataset: data.Data | TypedDataset): AsyncGenerator, void> { - if (Array.isArray(dataset)) - dataset = await datasetToData(this.task, dataset) - const batched = dataset - .preprocess() - .dataset.batch(this.task.trainingInformation.batchSize); - - const model = await this.getModel() - const iterator = await batched.iterator() - let next = await iterator.next() - - while (next.done !== true) { - let xs: tf.Tensor2D - if (next.value instanceof tf.Tensor) { - xs = next.value as tf.Tensor2D - } else { - const tensors = (next.value as { xs: tf.Tensor2D, ys: tf.Tensor2D }) - xs = tensors['xs'] - tf.dispose([tensors['ys']]) + async *#inferOnBatchedData( + batched: tf.data.Dataset, + ): AsyncGenerator { + const iterator = await batched.iterator(); + for ( + let iter = await iterator.next(); + iter.done !== true; + iter = await iterator.next() + ) { + const row = iter.value; + if ( + typeof row !== "object" || + !("xs" in row) || + !(row.xs instanceof tf.Tensor) + ) + throw new Error("unexpected shape of dataset"); + + const prediction = await this.#model.predict(row.xs); + tf.dispose(row); + let predictions: number[]; + switch (prediction.rank) { + case 2: + case 3: + predictions = await getLabels( + // cast as rank was just checked + prediction as tf.Tensor2D | tf.Tensor3D, + ); + prediction.dispose(); + break; + default: + throw new Error("unexpected batched prediction shape"); } - const currentFeatures = await xs.array() - const yPredTensor = await model.predict(xs) - const pred = await this.getLabel(yPredTensor) - this.size += pred.length - if (!Array.isArray(currentFeatures)) { - throw new TypeError('Data format is incorrect') - } - tf.dispose([xs, yPredTensor]) - - yield List(currentFeatures).zip(List(pred)) - .map(([f, p]) => ({ features: f, pred: p })) - .toArray() - - next = await iterator.next() + prediction.dispose(); + + for (const prediction of predictions) yield prediction; } - - this.logger.success(`Visited ${this.visitedSamples} samples`) } +} - async getModel (): Promise { - if (this.source !== undefined && await this.memory.contains(this.source)) { - return await this.memory.getModel(this.source) - } +async function getLabels(ys: tf.Tensor2D | tf.Tensor3D): Promise { + // cast as unstack drop a dimension and tfjs doesn't type correctly + return Promise.all( + tf.unstack(ys).map((y) => { + const ret = getLabel(y as tf.Tensor1D | tf.Tensor2D); + y.dispose(); + return ret; + }), + ); +} - if (this.client !== undefined) { - return await this.client.getLatestModel() - } +async function getLabel(ys: tf.Tensor1D | tf.Tensor2D): Promise { + switch (ys.rank) { + case 1: { + if (ys.shape[0] == 1) { + // Binary classification + const threshold = tf.scalar(0.5); + const binaryTensor = ys.greaterEqual(threshold); - throw new Error('Could not load the model') - } + const binaryArray = await binaryTensor.data(); + tf.dispose([binaryTensor, threshold]); - get accuracy (): number { - return this.rollingAccuracy - } + return binaryArray[0]; + } - get visitedSamples (): number { - return this.size - } + // Multi-class classification + const indexTensor = ys.argMax(); - get confusionMatrix (): number[][] | undefined { - return this._confusionMatrix + const indexArray = await indexTensor.data(); + tf.dispose([indexTensor]); + + return indexArray[0]; + + // Multi-label classification is not supported + } + case 2: { + // it's LLM, we only extract the next token + const firstToken = tf.tidy(() => ys.gather([0]).squeeze().argMax()); + const raw = await firstToken.data(); + firstToken.dispose(); + + return raw[0]; + } + default: + throw new Error("unexpected tensor rank"); } } diff --git a/server/tests/validator.spec.ts b/server/tests/validator.spec.ts index 89366416f..8dcd862cd 100644 --- a/server/tests/validator.spec.ts +++ b/server/tests/validator.spec.ts @@ -1,35 +1,12 @@ import { expect } from "chai"; import { Repeat } from "immutable"; -import type * as http from "node:http"; - -import { - Validator, - ConsoleLogger, - EmptyMemory, - client as clients, - aggregator as aggregators, - defaultTasks, -} from "@epfml/discojs"; -import { loadCSV, loadImagesInDir } from "@epfml/discojs-node"; -import { Server } from "../src/index.js"; +import { Validator, defaultTasks } from "@epfml/discojs"; +import { loadCSV, loadImagesInDir } from "@epfml/discojs-node"; describe("validator", function () { - this.timeout("10s"); - - let server: http.Server; - let url: URL; - beforeEach(async () => { - [server, url] = await Server.of( - defaultTasks.simpleFace, - defaultTasks.lusCovid, - defaultTasks.titanic, - ).then((s) => s.serve()); - }); - afterEach(() => server?.close()); - it("can read and predict randomly on simple_face", async () => { - const task = defaultTasks.simpleFace.getTask(); + const provider = defaultTasks.simpleFace; const [adult, child] = [ (await loadImagesInDir("../datasets/simple_face/adult")).zip( @@ -41,48 +18,44 @@ describe("validator", function () { ]; const dataset = adult.chain(child); - // Init a validator instance - const meanAggregator = aggregators.getAggregator(task, { scheme: "local" }); - const client = new clients.Local(url, task, meanAggregator); const validator = new Validator( - task, - new ConsoleLogger(), - new EmptyMemory(), - undefined, - client, + provider.getTask(), + await provider.getModel(), ); - for await (const _ of validator.test(["image", dataset])); + let hits = 0; + let size = 0; + for await (const correct of validator.test(["image", dataset])) { + if (correct) hits++; + size++; + } - expect(validator.visitedSamples).to.equal(await dataset.size()); - expect(validator.accuracy).to.be.greaterThan(0.3); - }).timeout("5s"); + expect(hits / size).to.be.greaterThan(0.3); + }); it("can read and predict randomly on titanic", async () => { - const task = defaultTasks.titanic.getTask(); + const provider = defaultTasks.titanic; const dataset = loadCSV("../datasets/titanic_train.csv"); - const meanAggregator = aggregators.getAggregator(task, { scheme: "local" }); - const client = new clients.Local(url, task, meanAggregator); const validator = new Validator( - task, - new ConsoleLogger(), - new EmptyMemory(), - undefined, - client, + provider.getTask(), + await provider.getModel(), ); - for await (const _ of validator.test(["tabular", dataset])); + let hits = 0; + let size = 0; + for await (const correct of validator.test(["tabular", dataset])) { + if (correct) hits++; + size++; + } - expect(validator.visitedSamples).to.equal(await dataset.size()); - expect(validator.accuracy).to.be.greaterThan(0.3); - }).timeout("1s"); + expect(hits / size).to.be.greaterThan(0.3); + }); it("can read and predict randomly on lus_covid", async () => { - const task = defaultTasks.lusCovid.getTask(); + const provider = defaultTasks.lusCovid; - // Load the data const [positive, negative] = [ (await loadImagesInDir("../datasets/lus_covid/COVID+")).zip( Repeat("COVID-Positive"), @@ -93,21 +66,18 @@ describe("validator", function () { ]; const dataset = positive.chain(negative); - // Initialize a validator instance - const meanAggregator = aggregators.getAggregator(task, { scheme: "local" }); - const client = new clients.Local(url, task, meanAggregator); const validator = new Validator( - task, - new ConsoleLogger(), - new EmptyMemory(), - undefined, - client, + provider.getTask(), + await provider.getModel(), ); - // Assert random initialization metrics - for await (const _ of validator.test(["image", dataset])); + let hits = 0; + let size = 0; + for await (const correct of validator.test(["image", dataset])) { + if (correct) hits++; + size++; + } - expect(validator.visitedSamples).to.equal(await dataset.size()); - expect(validator.accuracy).to.be.greaterThan(0.3); - }).timeout("1s"); + expect(hits / size).to.be.greaterThan(0.3); + }); }); diff --git a/webapp/src/components/containers/TableLayout.vue b/webapp/src/components/containers/TableLayout.vue index b6dcb700e..19001de18 100644 --- a/webapp/src/components/containers/TableLayout.vue +++ b/webapp/src/components/containers/TableLayout.vue @@ -12,18 +12,9 @@ - - - + + + {{ value }} @@ -31,14 +22,11 @@ - diff --git a/webapp/src/components/testing/PredictSteps.vue b/webapp/src/components/testing/PredictSteps.vue index aa3ff2c32..a4e1aeb7f 100644 --- a/webapp/src/components/testing/PredictSteps.vue +++ b/webapp/src/components/testing/PredictSteps.vue @@ -39,12 +39,12 @@ -
+
By clicking the button below, you will be able to predict using the selected model with chosen dataset of yours.
- predict + predict
@@ -62,69 +62,56 @@
- {{ visitedSamples }} + + {{ predictions?.results.size ?? 0 }} +  samples visited
-
+
download as csv -