Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and rework GPT-TF.js #807

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch ../server/dist --watch . --exec npm run",
"start": "npm run build && node dist/cli.js",
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"train_gpt": "npm run build && node dist/train_gpt.js",
"build": "tsc",
"lint": "npx eslint .",
"test": ": nothing"
Expand Down
53 changes: 17 additions & 36 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { parse } from "ts-command-line-args";
import * as tf from "@tensorflow/tfjs"
import { AutoTokenizer } from "@xenova/transformers";

import { fetchTasks, models, async_iterator, defaultTasks, processing } from "@epfml/discojs";
import { fetchTasks, models, async_iterator, defaultTasks } from "@epfml/discojs";
import { loadModelFromDisk, loadText } from '@epfml/discojs-node'

import { Server } from "server";
Expand Down Expand Up @@ -79,43 +79,24 @@ async function main(args: Required<CLIArguments>): Promise<void> {
maxIter: iterationsPerEpoch,
blockSize: contextLength,
lr: 0.0001,
vocabSize: 50258 // default wikitext task uses the gpt2 tokenizer with vocabSize 50258
vocabSize: 50257 // default wikitext task uses the gpt2 tokenizer with vocabSize 50257
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
}

// Load the dataset after setting the Task batch size and max sequence length
// to make sure the dataset is batched and tokenized correctly
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')

const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
// TODO will be easier when preproccessing is redone
const preprocessedDataset = intoTFGenerator(
dataset
.map((line) =>
processing.tokenizeAndLeftPad(line, tokenizer, maxLength),
)
.batch(batchSize)
.map((batch) =>
tf.tidy(() => ({
xs: tf.tensor2d(
batch.map((tokens) => tokens.slice(0, -1)).toArray(),
),
ys: tf.stack(
batch
.map(
(tokens) =>
tf.oneHot(
tokens.slice(1),
tokenizer.model.vocab.length + 1,
) as tf.Tensor2D,
)
.toArray(),
) as tf.Tensor3D,
})),
),
);

const dataset = loadText(
'../datasets/wikitext/wiki.train.tokens',
tokenizer, config.blockSize, batchSize
)
// TODO will be easier when preprocessing is redone
const preprocessedDataset = intoTFGenerator(dataset).map((tokens: number[]) => {
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length)
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32')
return {xs, ys}
}).batch(batchSize) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>

// Init and train the model
const model = new models.GPT(config)
console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`)
Expand All @@ -139,18 +120,18 @@ async function main(args: Required<CLIArguments>): Promise<void> {

// Benchmark parameters
const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,'
const nbNewTokens = 200
const maxNewTokens = 200
const iterations = 10
console.log("Generating", nbNewTokens, "new tokens")
console.log("Generating", maxNewTokens, "new tokens")

let inferenceTime = 0
for (let i = 0; i < iterations; i++) {
const timeStart = performance.now()
const _ = await model.generate(prompt, tokenizer, nbNewTokens)
const _ = await model.generate(prompt, tokenizer, { maxNewTokens })
inferenceTime += performance.now() - timeStart
}
// Overall average includes tokenization, token sampling and de-tokenization
console.log(`Inference time: ${(inferenceTime/ nbNewTokens / iterations).toFixed(2)} ms/token`)
console.log(`Inference time: ${(inferenceTime/ maxNewTokens / iterations).toFixed(2)} ms/token`)
}
await new Promise((resolve, reject) => {
server.once('close', resolve)
Expand Down
54 changes: 54 additions & 0 deletions cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import * as tf from "@tensorflow/tfjs-node"
import { AutoTokenizer } from "@xenova/transformers";
import { models } from "@epfml/discojs";
import { loadText } from '@epfml/discojs-node'

function intoTFGenerator<T extends tf.TensorContainer>(
iter: AsyncIterable<T>,
): tf.data.Dataset<T> {
// @ts-expect-error generator
return tf.data.generator(async function* () {
yield* iter;
});
}

async function main(): Promise<void> {

const config: models.GPTConfig = {
modelType: 'gpt-nano',
lr: 0.01,
maxIter: 10,
evaluateEvery:50,
maxEvalBatches: 10,
blockSize: 16,
vocabSize: 50257,
debug: false
}

const batchSize = 8
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
const textDataset = loadText(
"../datasets/wikitext/wiki.train.tokens",
tokenizer, config.blockSize, batchSize
)

const tokenDataset = intoTFGenerator(textDataset).map((tokens: number[]) => {
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length)
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32')
return {xs, ys}
}).batch(batchSize) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>

const model = new models.GPT(config)
for (let i = 0; i < 6; i++) {
console.log(`Epoch ${i}`)
for await (const logs of model.train(tokenDataset, undefined)) {
console.log(logs)
}
}

const generation = await model.generate("First", tokenizer, { maxNewTokens: 10, doSample: false, topk: 5, temperature:0.1 })
console.log(generation)
}

// You can run this example with "npm run run_gpt" from this folder
main().catch(console.error)
78 changes: 75 additions & 3 deletions discojs-node/src/loaders.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import * as fs from "node:fs/promises";
import { withFile } from "tmp-promise";
import { describe, it } from "mocha";
import { expect } from "chai";
import { models } from "@epfml/discojs";
import { AutoTokenizer } from "@xenova/transformers";

import {
loadCSV,
Expand Down Expand Up @@ -50,13 +52,83 @@ describe("image directory parser", () => {
});

describe("text parser", () => {

it("parses basic file", async () => {
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also test with "Xenova/gpt2", but our code should work with any tokenizer; can you switch it to another one?

await withFile(async ({ path }) => {
await fs.writeFile(path, ["a", "b", "c"].join("\n"));
const text = ["a", "b", "c"].join("\n")
await fs.writeFile(path, text);
// set block size to 4 to get 1 sequence of 4 tokens + 1 label token
const parsed = loadText(path, tokenizer, 4, 1);
expect(await parsed.size()).to.equal(1);
const next = await parsed[Symbol.asyncIterator]().next()
expect(next.done).to.be.false;

const tokens = next.value as number[]
Comment on lines +64 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to check for the whole dataset content, with Array.fromAsync (as locally redone in discojs/src/dataset/dataset.spec.ts), that's more readable and ensure that it doesn't generate more values that expected

const expectedTokens = models.tokenize(tokenizer, text, {
padding: false,
truncation: false,
return_tensor: false
})
Comment on lines +68 to +72
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, better to test for the actual tokens values. it is not really an issue but it assumes that models.tokenize will always work as expected. also, I built test codebases with a bunch of logic testing like this and while it looks correct, it breaks quite easily.

expect(tokens).to.deep.equal(expectedTokens);
});
});

it("yields the correct block size", async () => {
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
await withFile(async ({ path }) => {
const text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
const expectedTokens = models.tokenize(tokenizer, text, {
padding: false,
truncation: false,
return_tensor: false
})
await fs.writeFile(path, text);

const parsed = loadText(path);
// set block size to 4 to get 1 sequence of 4 tokens + 1 label token
// so we expect 5 tokens per read
const blockSize = 4
const parsed = loadText(path, tokenizer, blockSize, 1);
// expect the number of sequences to be the total number of tokens divided by blockSize
// we use floor because the last incomplete sequence is dropped
expect(await parsed.size()).to.equal(Math.floor(expectedTokens.length / blockSize));

let i = 0
for await (const tokens of parsed) {
// each sequence should have length blockSize + 1 (for the label)
expect(tokens).to.deep.equal(expectedTokens.slice(i, i + blockSize + 1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this tests more that advertisez, you can simplify it by only checking for the size of each block.

// but the window should move by blockSize only
i += blockSize
}
})
});

it("reads multiple chunks", async () => {
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
await withFile(async ({ path }) => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider moving the definitions and the actions not neede the file outside of withFile (also for the previous ones). you can actually return a value out of withFile if wanted.

const text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec sed risus maximus, ultricies ex sed, dictum elit. Curabitur faucibus egestas enim et auctor. Quisque vel dignissim turpis. Curabitur justo tellus, elementum sit amet erat eget, auctor ornare nisi. Nunc tortor odio, ultrices id leo vitae, euismod congue ex. Curabitur arcu leo, sagittis quis felis nec, imperdiet aliquet tellus. Integer a mollis nulla. Quisque pulvinar lectus eget nisi pharetra, non molestie magna ullamcorper. Sed porttitor diam non blandit molestie. Duis tristique arcu ut efficitur efficitur. Fusce et ullamcorper tortor. Pellentesque a accumsan lacus, nec mollis risus. Nunc quis eros a orci ultricies cursus. Maecenas sodales ipsum a magna malesuada efficitur. Maecenas at sapien blandit, egestas nisi eu, mollis elit."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please split it into multiple lines using […].join(" "), that's more readable.

const expectedTokens = models.tokenize(tokenizer, text, {
padding: false,
truncation: false,
return_tensor: false
})
await fs.writeFile(path, text);

expect(await parsed.size()).to.equal(3);
// set block size to 4 to get 1 sequence of 4 tokens + 1 label token
// so we expect 5 tokens per read
const blockSize = 4
const parsed = loadText(path, tokenizer, blockSize, 1, 1); // set the min chunk size allowed to 1 bit
// expect the number of sequences to be the total number of tokens divided by blockSize
// we use floor because the last incomplete sequence is dropped
expect(await parsed.size()).to.equal(Math.floor(expectedTokens.length / blockSize));

let i = 0
for await (const tokens of parsed) {
// each sequence should have length blockSize + 1 (for the label)
expect(tokens).to.deep.equal(expectedTokens.slice(i, i + blockSize + 1));
// but the window should move by blockSize only
i += blockSize
}
Comment on lines +117 to +131
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated from previous test. consider a common function.

});
});
});
100 changes: 92 additions & 8 deletions discojs-node/src/loaders/text.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,98 @@
import * as fs from "node:fs/promises";
import * as readline from "node:readline/promises";
import createDebug from "debug";
import { createReadStream } from 'node:fs';

import { Dataset, Text } from "@epfml/discojs";
import { PreTrainedTokenizer } from '@xenova/transformers';
import { Dataset, Text, models } from "@epfml/discojs";

export function load(path: string): Dataset<Text> {
const debug = createDebug("discojs-node:loaders:text");

/**
* Returns a Dataset that streams and tokenizes text to yield tokenized sequences
* one at a time. Each sequence has size `blockSize` + 1, where the first `blockSize`
* tokens are the input and the last token is the label. The following sequence
* starts with the last token of the previous sequence (so the previous label is now the
* first input token).
* In other words, the dataset yields sequences of size `blockSize` + 1 but with an overlap
* of 1 token between each sequence.
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
*
* @param path path to the text file to read
* @param tokenizer the tokenizer to use, should match the model that will be trained
* @param blockSize the context length, the maximum number of tokens of input sequences
* @param batchSize default to 1, the number of input sequences (of `blockSize` tokens) in each batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be kinda duplicated from blockSize

* The batch size is only used to configure the chunk size of the file stream such that each chunk is
* big enough to contain at least one batch.
* @param minChunkSize default to 16KiB, the minimum size of each chunk in bits
* @returns a dataset of tokenized input and label sequences
*/
export function load(path: string, tokenizer: PreTrainedTokenizer,
blockSize: number, batchSize: number = 1, minChunkSize = 16384): Dataset<Text> {
return new Dataset(async function* () {
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" });
if (batchSize < 1 || blockSize < 1 || minChunkSize < 1)
throw new Error("batchSize, blockSize and minChunkSize must be positive integers");
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
// we want each chunk to be at least bigger than the block size (each chunk corresponds to a block)
// (or event bigger than batch size * block size so that each chunk corresponds to a batch)
const chunkTokenSize = batchSize * (blockSize + 1) // + 1 for the next word label ys
// We read 8*8 = 8 bytes per expected token to ensure we have enough tokens
// For reference, the GPT-2 tokenizer encodes 3 to 4 bytes per token on average
const chunkBitSize = Math.max(minChunkSize, chunkTokenSize * 8 * 8);
debug("Setting the chunk size to %o bits", chunkBitSize)
// Create a stream to read the text file chunk by chunk
const stream = createReadStream(path, {
encoding: "utf8",
highWaterMark: chunkBitSize
});

// `readline` is a bit overkill but seems standard
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line
yield* readline.createInterface({ input, crlfDelay: Infinity });
// iterate over the chunks
let endOfPreviousChunk = ""
let iteration = 0
for await (const chunk of stream) {
if (typeof chunk !== 'string') throw new Error('Expected file stream to yield string')
debug("Reading chunk of size %o", chunk.length)
// tokenize the whole chunk at once
// Concatenate with potential leftovers from the previous chunk
const tokens = models.tokenize(tokenizer, endOfPreviousChunk + chunk, {
padding: false,
truncation: false,
return_tensor: false,
})
if (tokens.length < blockSize + 1) {
// throw if it happens on the 1st iteration
if (iteration === 0)
throw new Error(`the chunk (${tokens.length} tokens) is too small ` +
`to get a sequence of length blockSize (${blockSize + 1} tokens). ` +
`Either the text file or the chunk size (${chunkBitSize} bits) is too small.`);
// if this isn't the first iteration we simply skip
// as we expect the last chunk to be potentially smaller than the block size
debug("chunk smaller than block size, loading next chunk")
continue
}
debug("batch per chunk: %o", tokens.length / (batchSize * blockSize))
let currentPosition = 0;
// yield one block of tokens at a time
while (currentPosition + blockSize + 1 <= tokens.length) {
yield tokens.slice(currentPosition, currentPosition + blockSize + 1);
currentPosition += blockSize; // don't add + 1 here
}
// keep the last tokens for the next chunk
// if this was the last one the remaining tokens are discarded
if (currentPosition < tokens.length) {
// We actually need to decode the tokens to get the leftover text
// instead of simply keeping the remaining tokens.
// this is because the tokens may be different once prepended to the next chunk
// e.g. if the remaining text is ". A" and the next chunk starts with "nother"
// the tokenization will be different than if we simply concatenate the remaining tokens
endOfPreviousChunk = tokenizer.decode(
tokens.slice(currentPosition),
{ skip_special_tokens: true }
)
debug("End of chunk, remaining text: '%s'", endOfPreviousChunk)
} else {
// Note that the difference between tokenizing and then concatenating
// vs concatenating and then tokenizing can happen if their is no
// remaining text. We consider this difference negligible
endOfPreviousChunk = "";
}
Comment on lines +53 to +94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated in discojs-web/loaders/text. that hints to me that it shouldn't happen in the loader but applied after.
the issue at hand is that lines where outputted by the previous version. I think we can change it to output characters (single letter string) instead. that also would drop the blockSize, batchSize & minChunkSize argument which aren't really relevant for reading text (separation of concerns and all that)

in the newly merged processing PR (#781), it is much simpler to combine such transformation, I think that smth like

loadText($path).batch($blockSize).map((block) => tokenize(block, $tokenizer))

with tokenize updated to accept block/List<string> instead, and maybe drop the padding (but what would be the behavior at the end of the file?)

iteration++;
}
});
}
Loading