Skip to content

Commit

Permalink
discojs/validation: simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Aug 22, 2024
1 parent da17ab5 commit 58f916c
Show file tree
Hide file tree
Showing 7 changed files with 601 additions and 601 deletions.
4 changes: 2 additions & 2 deletions discojs/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ export class ImageData extends Data {
}

let shape
if ('xs' in sample && 'ys' in sample) {
shape = (sample as { xs: tf.Tensor, ys: number[] }).xs.shape
if ('xs' in sample) {
shape = (sample as { xs: tf.Tensor }).xs.shape
} else {
shape = (sample as tf.Tensor3D).shape
}
Expand Down
254 changes: 135 additions & 119 deletions discojs/src/validation/validator.ts
Original file line number Diff line number Diff line change
@@ -1,149 +1,165 @@
import { List } from 'immutable'
import * as tf from '@tensorflow/tfjs'
import * as tf from "@tensorflow/tfjs";

import type {
Model,
Task,
Logger,
client as clients,
data,
Memory,
ModelSource,
TypedDataset,
TypedLabeledDataset,
} from "../index.js";
import { datasetToData, labeledDatasetToData } from "../dataset/data/helpers.js";
import {
datasetToData,
labeledDatasetToData,
} from "../dataset/data/helpers.js";

function intoTFDataset<T extends tf.TensorContainer>(
iter: AsyncIterable<T>,
): tf.data.Dataset<T> {
// @ts-expect-error generator
return tf.data.generator(async function* () {
yield* iter;
});
}

export class Validator {
private size = 0
private _confusionMatrix: number[][] | undefined
private rollingAccuracy = 0
readonly #model: Model;

constructor (
constructor(
public readonly task: Task,
public readonly logger: Logger,
private readonly memory: Memory,
private readonly source?: ModelSource,
private readonly client?: clients.Client
model: Model,
) {
if (source === undefined && client === undefined) {
throw new Error('To initialize a Validator, either or both a source and client need to be specified')
}
this.#model = model;
}

private async getLabel(ys: tf.Tensor): Promise<Float32Array | Int32Array | Uint8Array> {
// Binary classification
if (ys.shape[1] == 1) {
const threshold = tf.scalar(0.5)
const binaryTensor = ys.greaterEqual(threshold)
const binaryArray = await binaryTensor.data()
tf.dispose([binaryTensor, threshold])
return binaryArray
// Multi-class classification
} else {
const yIdxTensor = ys.argMax(-1)
const yIdx = await yIdxTensor.data()
tf.dispose([yIdxTensor])
return yIdx
/** infer every line of the dataset and check that it is as labeled */
async *test(dataset: TypedLabeledDataset): AsyncGenerator<boolean> {
const preprocessed = (
await labeledDatasetToData(this.task, dataset)
).preprocess();
const batched = preprocessed.batch().dataset;

const iterator = await tf.data
.zip<[tf.Tensor1D | tf.Tensor2D, number]>([
preprocessed.dataset.map((t) => {
if (
typeof t !== "object" ||
!("ys" in t) ||
!(t.ys instanceof tf.Tensor) ||
!(t.ys.rank === 1 || t.ys.rank === 2)
)
throw new Error("unexpected preprocessed dataset");
if ("xs" in t) tf.dispose(t.xs);
return t.ys;
}),
intoTFDataset(this.#inferOnBatchedData(batched)),
])
.iterator();
for (
let iter = await iterator.next();
iter.done !== true;
iter = await iterator.next()
) {
const zipped = iter.value;

const label = await getLabel(zipped[0]);
tf.dispose(zipped[0]);
const infered = zipped[1];

yield label === infered;
}
// Multi-label classification is not supported
}

// test assumes data comes with labels while predict doesn't
async *test(dataset: data.Data | TypedLabeledDataset):
AsyncGenerator<Array<{ groundTruth: number, pred: number, features: number[] }>, void> {
if (Array.isArray(dataset))
dataset = await labeledDatasetToData(this.task, dataset)
const batched = dataset
.preprocess()
.dataset.batch(this.task.trainingInformation.batchSize);

const model = await this.getModel()
let hits = 0
const iterator = await batched.iterator()
let next = await iterator.next()
while (next.done !== true) {
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
const ysLabel = await this.getLabel(ys)
const yPredTensor = await model.predict(xs)
const pred = await this.getLabel(yPredTensor)
const currentFeatures = await xs.array()
this.size += ysLabel.length
hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size
this.rollingAccuracy = hits / this.size
tf.dispose([xs, ys, yPredTensor])

yield (List(ysLabel).zip(List(pred), List(currentFeatures)))
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
.toArray()

next = await iterator.next()
}

this.logger.success(`Obtained validation accuracy of ${this.accuracy}`)
this.logger.success(`Visited ${this.visitedSamples} samples`)
/** use the model to predict every line of the dataset */
async *infer(dataset: TypedDataset): AsyncGenerator<number, void> {
const data = await datasetToData(this.task, dataset);

const batched = data.preprocess().batch().dataset;

yield* this.#inferOnBatchedData(batched);
}

async *inference (dataset: data.Data | TypedDataset): AsyncGenerator<Array<{ features: number[], pred: number }>, void> {
if (Array.isArray(dataset))
dataset = await datasetToData(this.task, dataset)
const batched = dataset
.preprocess()
.dataset.batch(this.task.trainingInformation.batchSize);

const model = await this.getModel()
const iterator = await batched.iterator()
let next = await iterator.next()

while (next.done !== true) {
let xs: tf.Tensor2D
if (next.value instanceof tf.Tensor) {
xs = next.value as tf.Tensor2D
} else {
const tensors = (next.value as { xs: tf.Tensor2D, ys: tf.Tensor2D })
xs = tensors['xs']
tf.dispose([tensors['ys']])
async *#inferOnBatchedData(
batched: tf.data.Dataset<tf.TensorContainer>,
): AsyncGenerator<number, void> {
const iterator = await batched.iterator();
for (
let iter = await iterator.next();
iter.done !== true;
iter = await iterator.next()
) {
const row = iter.value;
if (
typeof row !== "object" ||
!("xs" in row) ||
!(row.xs instanceof tf.Tensor)
)
throw new Error("unexpected shape of dataset");

const prediction = await this.#model.predict(row.xs);
tf.dispose(row);
let predictions: number[];
switch (prediction.rank) {
case 2:
case 3:
predictions = await getLabels(
// cast as rank was just checked
prediction as tf.Tensor2D | tf.Tensor3D,
);
prediction.dispose();
break;
default:
throw new Error("unexpected batched prediction shape");
}
const currentFeatures = await xs.array()
const yPredTensor = await model.predict(xs)
const pred = await this.getLabel(yPredTensor)
this.size += pred.length
if (!Array.isArray(currentFeatures)) {
throw new TypeError('Data format is incorrect')
}
tf.dispose([xs, yPredTensor])

yield List(currentFeatures).zip(List(pred))
.map(([f, p]) => ({ features: f, pred: p }))
.toArray()

next = await iterator.next()
prediction.dispose();

for (const prediction of predictions) yield prediction;
}

this.logger.success(`Visited ${this.visitedSamples} samples`)
}
}

async getModel (): Promise<Model> {
if (this.source !== undefined && await this.memory.contains(this.source)) {
return await this.memory.getModel(this.source)
}
async function getLabels(ys: tf.Tensor2D | tf.Tensor3D): Promise<number[]> {
// cast as unstack drop a dimension and tfjs doesn't type correctly
return Promise.all(
tf.unstack(ys).map((y) => {
const ret = getLabel(y as tf.Tensor1D | tf.Tensor2D);
y.dispose();
return ret;
}),
);
}

if (this.client !== undefined) {
return await this.client.getLatestModel()
}
async function getLabel(ys: tf.Tensor1D | tf.Tensor2D): Promise<number> {
switch (ys.rank) {
case 1: {
if (ys.shape[0] == 1) {
// Binary classification
const threshold = tf.scalar(0.5);
const binaryTensor = ys.greaterEqual(threshold);

throw new Error('Could not load the model')
}
const binaryArray = await binaryTensor.data();
tf.dispose([binaryTensor, threshold]);

get accuracy (): number {
return this.rollingAccuracy
}
return binaryArray[0];
}

get visitedSamples (): number {
return this.size
}
// Multi-class classification
const indexTensor = ys.argMax();

get confusionMatrix (): number[][] | undefined {
return this._confusionMatrix
const indexArray = await indexTensor.data();
tf.dispose([indexTensor]);

return indexArray[0];

// Multi-label classification is not supported
}
case 2: {
// it's LLM, we only extract the next token
const firstToken = tf.tidy(() => ys.gather([0]).squeeze().argMax());
const raw = await firstToken.data();
firstToken.dispose();

return raw[0];
}
default:
throw new Error("unexpected tensor rank");
}
}
Loading

0 comments on commit 58f916c

Please sign in to comment.