Skip to content

Commit

Permalink
discojs: add & simplify some types
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Aug 9, 2024
1 parent b468a37 commit 9e9ca35
Show file tree
Hide file tree
Showing 17 changed files with 72 additions and 91 deletions.
7 changes: 6 additions & 1 deletion cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { parse } from "ts-command-line-args";
import type * as tf from "@tensorflow/tfjs"

import type { Task } from '@epfml/discojs'
import { fetchTasks, data, models, async_iterator, defaultTasks } from "@epfml/discojs";
Expand Down Expand Up @@ -74,7 +75,11 @@ async function main(args: Required<CLIArguments>): Promise<void> {
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
const dataset = await loadWikitextData(task)
const preprocessedDataset = dataset.train.preprocess().batch().dataset
const preprocessedDataset = dataset.train.preprocess().batch()
.dataset as tf.data.Dataset<{
xs: tf.Tensor2D;
ys: tf.Tensor3D;
}>;

// Init and train the model
const model = new models.GPT(config)
Expand Down
8 changes: 4 additions & 4 deletions discojs-web/src/memory/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import { Map } from 'immutable'
import * as tf from '@tensorflow/tfjs'

import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs'
import type { Model, ModelInfo, ModelSource } from '@epfml/discojs'
import { Memory, models } from '@epfml/discojs'

export class IndexedDB extends Memory {
override getModelMemoryPath (source: ModelSource): Path {
override getModelMemoryPath (source: ModelSource): string {
if (typeof source === 'string') {
return source
}
Expand Down Expand Up @@ -110,7 +110,7 @@ export class IndexedDB extends Memory {
* Creates a saved copy of the working model corresponding to the source.
* @param source the source
*/
async saveWorkingModel (source: ModelSource): Promise<Path> {
async saveWorkingModel (source: ModelSource): Promise<string> {
const src: ModelInfo = this.getModelInfo(source)
if (src.type !== 'working') {
throw new Error('expected working type model')
Expand All @@ -123,7 +123,7 @@ export class IndexedDB extends Memory {
return dst
}

override async saveModel (source: ModelSource, model: Model): Promise<Path> {
override async saveModel (source: ModelSource, model: Model): Promise<string> {
const src: ModelInfo = this.getModelInfo(source)
if (src.type !== 'saved') {
throw new Error('expected saved type model')
Expand Down
9 changes: 5 additions & 4 deletions discojs/src/dataset/data/preprocessing/text_preprocessing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ interface TokenizedEntry extends tf.TensorContainerObject {
*/
const leftPadding: PreprocessingFunction = {
type: TextPreprocessing.LeftPadding,
apply: async (x: Promise<tf.TensorContainer>, task: Task): Promise<tf.TensorContainer> => {
apply: async (x: Promise<tf.TensorContainer>, task: Task): Promise<{ xs: tf.Tensor1D, ys: tf.Tensor2D }> => {
if (x === undefined || !Array.isArray(x) || x.length == 0 || typeof(x[0] !== 'number')) {
new Error("The leftPadding preprocessing expects a non empty 1D array of number")
}
Expand All @@ -44,7 +44,7 @@ const leftPadding: PreprocessingFunction = {
const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length as number
const maxLengthPlusLabel = maxLength + 1

let fixedLengthTokens = tf.tensor(tokens, undefined, 'int32') // cast tokens from float to int for gpt-tfjs
let fixedLengthTokens = tf.tensor1d(tokens, 'int32') // cast tokens from float to int for gpt-tfjs
if (fixedLengthTokens.size > maxLengthPlusLabel) { // Should never happen because tokenization truncates inputs
throw Error("There are more tokens than expected after tokenization and truncation")
} else if (fixedLengthTokens.size < maxLengthPlusLabel) { // Pad inputs to fixed length
Expand All @@ -54,7 +54,8 @@ const leftPadding: PreprocessingFunction = {
// if tokens.size == maxLengthPlusLabel we can leave it as it is

// ys is a one-hot encoding of the next token (i.e. xs shifted by one)
const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1)
// cast because oneHot isn't size-typing its return value
const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1) as tf.Tensor2D
// remove the extra token now that ys is created
const xs = fixedLengthTokens.slice([0], maxLength)
return { xs, ys }
Expand All @@ -71,7 +72,7 @@ interface TokenizerOutput {
*/
const tokenize: PreprocessingFunction = {
type: TextPreprocessing.Tokenize,
apply: async (x: Promise<tf.TensorContainer>, task: Task): Promise<tf.TensorContainer> => {
apply: async (x: Promise<tf.TensorContainer>, task: Task): Promise<{ tokens: number[] }> => {
if (typeof x !== 'string') {
new Error("The tokenize preprocessing expects a string as input")
}
Expand Down
3 changes: 1 addition & 2 deletions discojs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export * as aggregator from './aggregator/index.js'

export { WeightsContainer, aggregation } from './weights/index.js'
export { Logger, ConsoleLogger } from './logging/index.js'
export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js'
export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js'
export { Disco, RoundLogs } from './training/index.js'
export { Validator } from './validation/index.js'

Expand All @@ -18,5 +18,4 @@ export * as models from './models/index.js'
export * from './task/index.js'
export * as defaultTasks from './default_tasks/index.js'

export * from './types.js'
export * as async_iterator from "./utils/async_iterator.js"
15 changes: 5 additions & 10 deletions discojs/src/memory/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@

import type { Model, TaskID } from '../index.js'

/**
* Model path which uniquely identifies a model in memory.
*/
export type Path = string

/**
* Type of models stored in memory. Stored models can either be a model currently
* being trained ("working model") or a regular model saved in memory ("saved model").
Expand All @@ -34,10 +29,10 @@ export interface ModelInfo {

/**
* A model source uniquely identifies a model stored in memory.
* It can be in the form of either a model info object or a Path string
* It can be in the form of either a model info object or an ID
* (one-to-one mapping between the two)
*/
export type ModelSource = ModelInfo | Path
export type ModelSource = ModelInfo | string

/**
* Represents a model memory system, providing functions to fetch, save, delete and update models.
Expand Down Expand Up @@ -86,7 +81,7 @@ export abstract class Memory {
* @param source The model source
* @returns The saved model's path
*/
abstract saveWorkingModel (source: ModelSource): Promise<Path | undefined>
abstract saveWorkingModel (source: ModelSource): Promise<string | undefined>

/**
* Saves the newly provided model to the given model source.
Expand All @@ -95,7 +90,7 @@ export abstract class Memory {
* @param model The new model
* @returns The saved model's path
*/
abstract saveModel (source: ModelSource, model: Model): Promise<Path | undefined>
abstract saveModel (source: ModelSource, model: Model): Promise<string | undefined>

/**
* Moves the model identified by the model source to a file system. This is platform-dependent.
Expand All @@ -118,7 +113,7 @@ export abstract class Memory {
* @param source The model source
* @returns The model path
*/
abstract getModelMemoryPath (source: ModelSource): Path | undefined
abstract getModelMemoryPath (source: ModelSource): string | undefined

/**
* Computes the model information corresponding to the given model source, be it a path or model information.
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/memory/empty.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { Model } from '../index.js'

import type { ModelInfo, Path } from './base.js'
import type { ModelInfo } from './base.js'
import { Memory } from './base.js'

/**
Expand Down Expand Up @@ -44,7 +44,7 @@ export class Empty extends Memory {
return Promise.reject(new Error('empty'))
}

getModelMemoryPath (): Path {
getModelMemoryPath (): string {
throw new Error('empty')
}

Expand Down
2 changes: 1 addition & 1 deletion discojs/src/memory/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export { Empty } from './empty.js'
export { Memory, type ModelInfo, type Path, type ModelSource } from './base.js'
export { Memory, type ModelInfo, type ModelSource } from './base.js'
2 changes: 1 addition & 1 deletion discojs/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ describe('gpt-tfjs', function() {
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length + 1)
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32')
return {xs, ys}
}).repeat().batch(64)
}).repeat().batch(64) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>

const model = new GPT(config)
for (let i = 0; i < 5; i++)
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ export class GPT extends Model {
* @param tracker
*/
override async *train(
trainingData: Dataset,
validationData?: Dataset,
trainingData: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>,
validationData?: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>,
): AsyncGenerator<BatchLogs, EpochLogs> {
this.model.compile();

Expand Down
11 changes: 0 additions & 11 deletions discojs/src/types.ts

This file was deleted.

12 changes: 6 additions & 6 deletions discojs/src/validation/validator.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { List } from 'immutable'
import * as tf from '@tensorflow/tfjs'

import type { data, Model, Task, Logger, client as clients, Memory, ModelSource, Features } from '../index.js'
import type { data, Model, Task, Logger, client as clients, Memory, ModelSource } from '../index.js'

export class Validator {
private size = 0
Expand Down Expand Up @@ -40,7 +40,7 @@ export class Validator {

// test assumes data comes with labels while predict doesn't
async *test(data: data.Data):
AsyncGenerator<Array<{ groundTruth: number, pred: number, features: Features }>, void> {
AsyncGenerator<Array<{ groundTruth: number, pred: number, features: number[] }>, void> {
const batchSize = this.task.trainingInformation?.batchSize
if (batchSize === undefined) {
throw new TypeError('Batch size is undefined')
Expand All @@ -60,7 +60,7 @@ export class Validator {
this.rollingAccuracy = hits / this.size
tf.dispose([xs, ys, yPredTensor])

yield (List(ysLabel).zip(List(pred), List(currentFeatures)) as List<[number, number, Features]>)
yield (List(ysLabel).zip(List(pred), List(currentFeatures)))
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
.toArray()

Expand All @@ -71,7 +71,7 @@ export class Validator {
this.logger.success(`Visited ${this.visitedSamples} samples`)
}

async *inference (data: data.Data): AsyncGenerator<Array<{ features: Features, pred: number }>, void> {
async *inference (data: data.Data): AsyncGenerator<Array<{ features: number[], pred: number }>, void> {
const batchSize = this.task.trainingInformation?.batchSize
if (batchSize === undefined) {
throw new TypeError('Batch size is undefined')
Expand All @@ -82,7 +82,7 @@ export class Validator {
let next = await iterator.next()

while (next.done !== true) {
let xs: tf.Tensor
let xs: tf.Tensor2D
if (next.value instanceof tf.Tensor) {
xs = next.value as tf.Tensor2D
} else {
Expand All @@ -99,7 +99,7 @@ export class Validator {
}
tf.dispose([xs, yPredTensor])

yield List(currentFeatures as number[]).zip(List(pred))
yield List(currentFeatures).zip(List(pred))
.map(([f, p]) => ({ features: f, pred: p }))
.toArray()

Expand Down
4 changes: 2 additions & 2 deletions server/src/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import fs from 'node:fs/promises'
import tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-node'

import { Task, Path, Digest, TaskProvider, isTask } from '@epfml/discojs'
import { Task, Digest, TaskProvider, isTask } from '@epfml/discojs'
import { Model, models, serialization } from '@epfml/discojs'

const debug = createDebug("server:tasks");
Expand Down Expand Up @@ -63,7 +63,7 @@ export class TasksAndModels {
return model
}

private async checkDigest (digest: Digest, modelPath: Path): Promise<void> {
private async checkDigest (digest: Digest, modelPath: string): Promise<void> {
const hash = createHash(digest.algorithm)
const modelConfigRaw = await fs.readFile(`${modelPath}/model.json`)

Expand Down
33 changes: 14 additions & 19 deletions webapp/src/components/sidebar/ModelLibrary.vue
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
</span>
<div v-if="memoryStore.useIndexedDB" class="space-y-4">
<button
v-for="[path, metadata] in memoryStore.models"
:key="path"
v-for="[id, metadata] in memoryStore.models"
:key="id"
class="flex items-center justify-between px-4 py-2 space-x-4 outline outline-1 outline-slate-300 rounded-md transition-colors duration-200 text-slate-600 hover:text-slate-800 hover:outline-slate-800 focus:outline-none focus:ring-1 focus:ring-slate-800"
>
<div class="cursor-pointer w-2/3" @click="openTesting(path)">
<div class="cursor-pointer w-2/3" @click="openTesting(id)">
<span>
{{ metadata.name.slice(0, 16) }}
<span
Expand All @@ -48,7 +48,7 @@
<ModelButton
event="delete-model"
hover="Delete"
@delete-model="deleteModel(path)"
@delete-model="deleteModel(id)"
>
<Bin2Icon />
</ModelButton>
Expand All @@ -57,7 +57,7 @@
<ModelButton
event="download-model"
hover="Download"
@download-model="downloadModel(path)"
@download-model="memory.downloadModel(id)"
>
<Download2Icon />
</ModelButton>
Expand All @@ -66,7 +66,7 @@
<ModelButton
event="load-model"
hover="Load for next training"
@load-model="loadModel(path)"
@load-model="loadModel(id)"
>
<LoadIcon />
</ModelButton>
Expand Down Expand Up @@ -126,7 +126,6 @@ import { List } from "immutable";
import { onMounted, ref } from "vue";
import { useRouter } from "vue-router";
import type { Path } from "@epfml/discojs";
import { EmptyMemory } from "@epfml/discojs";
import { IndexedDB } from "@epfml/discojs-web";
Expand Down Expand Up @@ -172,19 +171,19 @@ function switchToEvaluate(): void {
router.push({ path: "/evaluate" });
emit("close-panel");
}
function openTesting(path: Path) {
validationStore.setModel(path);
function openTesting(modelID: string) {
validationStore.setModel(modelID);
router.push({ path: "/evaluate" });
}
async function loadModel(path: Path) {
const modelInfo = memory.getModelInfo(path)
async function loadModel(modelID: string) {
const modelInfo = memory.getModelInfo(modelID)
if (modelInfo === undefined) {
throw new Error('not such model')
}
if (modelInfo.type !== 'working') {
try {
await memory.loadModel(path)
await memory.loadModel(modelID)
await memoryStore.initModels()
} catch (e) {
let msg = 'unable to load model'
Expand All @@ -199,14 +198,10 @@ function openTesting(path: Path) {
}
}
async function downloadModel(path: Path) {
await memory.downloadModel(path);
}
async function deleteModel(path: Path): Promise<void> {
async function deleteModel(modelID: string): Promise<void> {
try {
await memoryStore.deleteModel(path);
await memory.deleteModel(path);
await memoryStore.deleteModel(modelID);
await memory.deleteModel(modelID);
toaster.success("Successfully deleted the model");
} catch (e) {
let msg = "unable to delete model";
Expand Down
Loading

0 comments on commit 9e9ca35

Please sign in to comment.