-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
601 additions
and
601 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,149 +1,165 @@ | ||
import { List } from 'immutable' | ||
import * as tf from '@tensorflow/tfjs' | ||
import * as tf from "@tensorflow/tfjs"; | ||
|
||
import type { | ||
Model, | ||
Task, | ||
Logger, | ||
client as clients, | ||
data, | ||
Memory, | ||
ModelSource, | ||
TypedDataset, | ||
TypedLabeledDataset, | ||
} from "../index.js"; | ||
import { datasetToData, labeledDatasetToData } from "../dataset/data/helpers.js"; | ||
import { | ||
datasetToData, | ||
labeledDatasetToData, | ||
} from "../dataset/data/helpers.js"; | ||
|
||
function intoTFDataset<T extends tf.TensorContainer>( | ||
iter: AsyncIterable<T>, | ||
): tf.data.Dataset<T> { | ||
// @ts-expect-error generator | ||
return tf.data.generator(async function* () { | ||
yield* iter; | ||
}); | ||
} | ||
|
||
export class Validator { | ||
private size = 0 | ||
private _confusionMatrix: number[][] | undefined | ||
private rollingAccuracy = 0 | ||
readonly #model: Model; | ||
|
||
constructor ( | ||
constructor( | ||
public readonly task: Task, | ||
public readonly logger: Logger, | ||
private readonly memory: Memory, | ||
private readonly source?: ModelSource, | ||
private readonly client?: clients.Client | ||
model: Model, | ||
) { | ||
if (source === undefined && client === undefined) { | ||
throw new Error('To initialize a Validator, either or both a source and client need to be specified') | ||
} | ||
this.#model = model; | ||
} | ||
|
||
private async getLabel(ys: tf.Tensor): Promise<Float32Array | Int32Array | Uint8Array> { | ||
// Binary classification | ||
if (ys.shape[1] == 1) { | ||
const threshold = tf.scalar(0.5) | ||
const binaryTensor = ys.greaterEqual(threshold) | ||
const binaryArray = await binaryTensor.data() | ||
tf.dispose([binaryTensor, threshold]) | ||
return binaryArray | ||
// Multi-class classification | ||
} else { | ||
const yIdxTensor = ys.argMax(-1) | ||
const yIdx = await yIdxTensor.data() | ||
tf.dispose([yIdxTensor]) | ||
return yIdx | ||
/** infer every line of the dataset and check that it is as labeled */ | ||
async *test(dataset: TypedLabeledDataset): AsyncGenerator<boolean> { | ||
const preprocessed = ( | ||
await labeledDatasetToData(this.task, dataset) | ||
).preprocess(); | ||
const batched = preprocessed.batch().dataset; | ||
|
||
const iterator = await tf.data | ||
.zip<[tf.Tensor1D | tf.Tensor2D, number]>([ | ||
preprocessed.dataset.map((t) => { | ||
if ( | ||
typeof t !== "object" || | ||
!("ys" in t) || | ||
!(t.ys instanceof tf.Tensor) || | ||
!(t.ys.rank === 1 || t.ys.rank === 2) | ||
) | ||
throw new Error("unexpected preprocessed dataset"); | ||
if ("xs" in t) tf.dispose(t.xs); | ||
return t.ys; | ||
}), | ||
intoTFDataset(this.#inferOnBatchedData(batched)), | ||
]) | ||
.iterator(); | ||
for ( | ||
let iter = await iterator.next(); | ||
iter.done !== true; | ||
iter = await iterator.next() | ||
) { | ||
const zipped = iter.value; | ||
|
||
const label = await getLabel(zipped[0]); | ||
tf.dispose(zipped[0]); | ||
const infered = zipped[1]; | ||
|
||
yield label === infered; | ||
} | ||
// Multi-label classification is not supported | ||
} | ||
|
||
// test assumes data comes with labels while predict doesn't | ||
async *test(dataset: data.Data | TypedLabeledDataset): | ||
AsyncGenerator<Array<{ groundTruth: number, pred: number, features: number[] }>, void> { | ||
if (Array.isArray(dataset)) | ||
dataset = await labeledDatasetToData(this.task, dataset) | ||
const batched = dataset | ||
.preprocess() | ||
.dataset.batch(this.task.trainingInformation.batchSize); | ||
|
||
const model = await this.getModel() | ||
let hits = 0 | ||
const iterator = await batched.iterator() | ||
let next = await iterator.next() | ||
while (next.done !== true) { | ||
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } | ||
const ysLabel = await this.getLabel(ys) | ||
const yPredTensor = await model.predict(xs) | ||
const pred = await this.getLabel(yPredTensor) | ||
const currentFeatures = await xs.array() | ||
this.size += ysLabel.length | ||
hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size | ||
this.rollingAccuracy = hits / this.size | ||
tf.dispose([xs, ys, yPredTensor]) | ||
|
||
yield (List(ysLabel).zip(List(pred), List(currentFeatures))) | ||
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f })) | ||
.toArray() | ||
|
||
next = await iterator.next() | ||
} | ||
|
||
this.logger.success(`Obtained validation accuracy of ${this.accuracy}`) | ||
this.logger.success(`Visited ${this.visitedSamples} samples`) | ||
/** use the model to predict every line of the dataset */ | ||
async *infer(dataset: TypedDataset): AsyncGenerator<number, void> { | ||
const data = await datasetToData(this.task, dataset); | ||
|
||
const batched = data.preprocess().batch().dataset; | ||
|
||
yield* this.#inferOnBatchedData(batched); | ||
} | ||
|
||
async *inference (dataset: data.Data | TypedDataset): AsyncGenerator<Array<{ features: number[], pred: number }>, void> { | ||
if (Array.isArray(dataset)) | ||
dataset = await datasetToData(this.task, dataset) | ||
const batched = dataset | ||
.preprocess() | ||
.dataset.batch(this.task.trainingInformation.batchSize); | ||
|
||
const model = await this.getModel() | ||
const iterator = await batched.iterator() | ||
let next = await iterator.next() | ||
|
||
while (next.done !== true) { | ||
let xs: tf.Tensor2D | ||
if (next.value instanceof tf.Tensor) { | ||
xs = next.value as tf.Tensor2D | ||
} else { | ||
const tensors = (next.value as { xs: tf.Tensor2D, ys: tf.Tensor2D }) | ||
xs = tensors['xs'] | ||
tf.dispose([tensors['ys']]) | ||
async *#inferOnBatchedData( | ||
batched: tf.data.Dataset<tf.TensorContainer>, | ||
): AsyncGenerator<number, void> { | ||
const iterator = await batched.iterator(); | ||
for ( | ||
let iter = await iterator.next(); | ||
iter.done !== true; | ||
iter = await iterator.next() | ||
) { | ||
const row = iter.value; | ||
if ( | ||
typeof row !== "object" || | ||
!("xs" in row) || | ||
!(row.xs instanceof tf.Tensor) | ||
) | ||
throw new Error("unexpected shape of dataset"); | ||
|
||
const prediction = await this.#model.predict(row.xs); | ||
tf.dispose(row); | ||
let predictions: number[]; | ||
switch (prediction.rank) { | ||
case 2: | ||
case 3: | ||
predictions = await getLabels( | ||
// cast as rank was just checked | ||
prediction as tf.Tensor2D | tf.Tensor3D, | ||
); | ||
prediction.dispose(); | ||
break; | ||
default: | ||
throw new Error("unexpected batched prediction shape"); | ||
} | ||
const currentFeatures = await xs.array() | ||
const yPredTensor = await model.predict(xs) | ||
const pred = await this.getLabel(yPredTensor) | ||
this.size += pred.length | ||
if (!Array.isArray(currentFeatures)) { | ||
throw new TypeError('Data format is incorrect') | ||
} | ||
tf.dispose([xs, yPredTensor]) | ||
|
||
yield List(currentFeatures).zip(List(pred)) | ||
.map(([f, p]) => ({ features: f, pred: p })) | ||
.toArray() | ||
|
||
next = await iterator.next() | ||
prediction.dispose(); | ||
|
||
for (const prediction of predictions) yield prediction; | ||
} | ||
|
||
this.logger.success(`Visited ${this.visitedSamples} samples`) | ||
} | ||
} | ||
|
||
async getModel (): Promise<Model> { | ||
if (this.source !== undefined && await this.memory.contains(this.source)) { | ||
return await this.memory.getModel(this.source) | ||
} | ||
async function getLabels(ys: tf.Tensor2D | tf.Tensor3D): Promise<number[]> { | ||
// cast as unstack drop a dimension and tfjs doesn't type correctly | ||
return Promise.all( | ||
tf.unstack(ys).map((y) => { | ||
const ret = getLabel(y as tf.Tensor1D | tf.Tensor2D); | ||
y.dispose(); | ||
return ret; | ||
}), | ||
); | ||
} | ||
|
||
if (this.client !== undefined) { | ||
return await this.client.getLatestModel() | ||
} | ||
async function getLabel(ys: tf.Tensor1D | tf.Tensor2D): Promise<number> { | ||
switch (ys.rank) { | ||
case 1: { | ||
if (ys.shape[0] == 1) { | ||
// Binary classification | ||
const threshold = tf.scalar(0.5); | ||
const binaryTensor = ys.greaterEqual(threshold); | ||
|
||
throw new Error('Could not load the model') | ||
} | ||
const binaryArray = await binaryTensor.data(); | ||
tf.dispose([binaryTensor, threshold]); | ||
|
||
get accuracy (): number { | ||
return this.rollingAccuracy | ||
} | ||
return binaryArray[0]; | ||
} | ||
|
||
get visitedSamples (): number { | ||
return this.size | ||
} | ||
// Multi-class classification | ||
const indexTensor = ys.argMax(); | ||
|
||
get confusionMatrix (): number[][] | undefined { | ||
return this._confusionMatrix | ||
const indexArray = await indexTensor.data(); | ||
tf.dispose([indexTensor]); | ||
|
||
return indexArray[0]; | ||
|
||
// Multi-label classification is not supported | ||
} | ||
case 2: { | ||
// it's LLM, we only extract the next token | ||
const firstToken = tf.tidy(() => ys.gather([0]).squeeze().argMax()); | ||
const raw = await firstToken.data(); | ||
firstToken.dispose(); | ||
|
||
return raw[0]; | ||
} | ||
default: | ||
throw new Error("unexpected tensor rank"); | ||
} | ||
} |
Oops, something went wrong.