diff --git a/discojs/discojs-core/src/validation/validator.ts b/discojs/discojs-core/src/validation/validator.ts index 1d5181ede..d6e51e871 100644 --- a/discojs/discojs-core/src/validation/validator.ts +++ b/discojs/discojs-core/src/validation/validator.ts @@ -41,36 +41,36 @@ export class Validator { let hits = 0 await data.preprocess().dataset.batch(batchSize) - .forEachAsync((e) => { - if (typeof e === 'object' && 'xs' in e && 'ys' in e) { - const xs = e.xs as tf.Tensor + .forEachAsync((e) => { + if (typeof e === 'object' && 'xs' in e && 'ys' in e) { + const xs = e.xs as tf.Tensor - const ys = this.getLabel(e.ys as tf.Tensor) - const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) + const ys = this.getLabel(e.ys as tf.Tensor) + const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) - const currentFeatures = xs.arraySync() + const currentFeatures = xs.arraySync() - if (Array.isArray(currentFeatures)) { - features = features.concat(currentFeatures) - } else { - throw new TypeError('features array is not correct') - } + if (Array.isArray(currentFeatures)) { + features = features.concat(currentFeatures) + } else { + throw new TypeError('features array is not correct') + } - groundTruth.push(...Array.from(ys)) - predictions.push(...Array.from(pred)) + groundTruth.push(...Array.from(ys)) + predictions.push(...Array.from(pred)) - this.size += xs.shape[0] + this.size += xs.shape[0] - hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size + hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size - // TODO: Confusion Matrix stats + // TODO: Confusion Matrix stats - const currentAccuracy = hits / this.size - this.graphInformant.updateAccuracy(currentAccuracy) - } else { - throw new Error('missing feature/label in dataset') - } - }) + const currentAccuracy = hits / this.size + this.graphInformant.updateAccuracy(currentAccuracy) + } else { + throw new Error('missing feature/label in dataset') + } + }) this.logger.success(`Obtained validation accuracy of ${this.accuracy}`) this.logger.success(`Visited ${this.visitedSamples} samples`) @@ -105,20 +105,20 @@ export class Validator { await data.preprocess().dataset.batch(batchSize) .forEachAsync(e => { - const xs = e as tf.Tensor - const currentFeatures = xs.arraySync() - - if (Array.isArray(currentFeatures)) { - features = features.concat(currentFeatures) - } else { - throw new TypeError('features array is not correct') - } + const xs = e as tf.Tensor + const currentFeatures = xs.arraySync() - const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) - predictions.push(...Array.from(pred)) + if (Array.isArray(currentFeatures)) { + features = features.concat(currentFeatures) + } else { + throw new TypeError('features array is not correct') + } + + const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) + predictions.push(...Array.from(pred)) }) - return (List(features).zip(List(predictions)) as List<[Features, number]>) + return List(features).zip(List(predictions)) .map(([f, p]) => ({ features: f, pred: p })) .toArray() }