-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added llm-example/ containing an example for running disco on the wik…
…itext task. Provided downloading dataset scripts for wikitext and tiny-shakespeare (may be supported later) and documentation in the README
- Loading branch information
1 parent
a0be7af
commit 6c00b3c
Showing
13 changed files
with
1,064 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"extends": "standard-with-typescript", | ||
"parserOptions": { | ||
"project": "./tsconfig.eslint.json" | ||
}, | ||
"env": { | ||
"node": true, | ||
"es6": true, | ||
"mocha": true | ||
}, | ||
"rules": { | ||
// TODO need strict | ||
"@typescript-eslint/strict-boolean-expressions": "off" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# ================================= | ||
|
||
### MOVE THE LLM-EXAMPLE/ FOLDER TO THE ROOT PROJECT FOLDER IN ORDER FOR THIS TO WORK | ||
|
||
# ================================= | ||
|
||
# Prerequisites | ||
|
||
- nvm: https://github.com/nvm-sh/nvm#installing-and-updating | ||
- Bun: https://github.com/oven-sh/bun | ||
|
||
# Installation | ||
|
||
### Required | ||
|
||
The following is always needed | ||
|
||
```sh | ||
# install dependencies in discojs-core | ||
cd discojs/discojs-core/ | ||
bun install | ||
|
||
# install dependencies in discojs-node | ||
cd ../discojs-node/ | ||
bun install | ||
|
||
# install dependencies in server | ||
cd ../../server/ | ||
bun install | ||
|
||
# install dependencies in experiment and download + preprocess dataset | ||
cd ../experiment | ||
bun install | ||
|
||
# Dataset installation, feel free to install only one dataset or both, you can later choose which one to preprocess / train on | ||
./install-wikitext.sh # Installs wikitext-103 dataset | ||
./install-shakespeare.sh # Installs tiny-shakespeare dataset | ||
bun preprocess.ts [wikitext-103 | tiny-shakespeare] | ||
``` | ||
|
||
### Running on Node | ||
|
||
```sh | ||
cd experiment/ | ||
bun main.ts [wikitext-103 | tiny-shakespeare] | ||
``` | ||
|
||
### Running on a browser | ||
|
||
```sh | ||
# install dependencies in discojs-web | ||
cd ../discojs/discojs-web/ | ||
bun install | ||
|
||
# install dependencies for the browser server and run it | ||
cd ../../browser/server | ||
bun install | ||
bun run dev | ||
|
||
# [in a separate terminal] install dependencies for the browser server and run it | ||
cd ../client | ||
nvm use 18 # Node version 18.x or later is required for NextJS | ||
bun install | ||
bun run dev | ||
|
||
# Navigate to http://localhost:3000 on your browser of choice and click on "train" | ||
# If you would like to use WebGPU then firefox won't work, please run the following command to run chrome with WebGPU enabled | ||
# (I advise to run this command in a separate terminal tab as well because you will have logs even in detach mode) | ||
google-chrome --enable-unsafe-webgpu --enable-features=Vulkan,UseSkiaRenderer & | ||
# Or from the browser/client/ directory | ||
./chrome-webgpu.sh # equivalent to command above | ||
``` | ||
|
||
# Running tests | ||
|
||
To run tests, you first need to follow the "Required" section of the Installation instructions. | ||
|
||
### Testing on Node | ||
|
||
```sh | ||
# Follow the instructions "# Running on node" before proceeding | ||
cd discojs/discojs-node/ | ||
bun --bun test text-loader.spec.ts | ||
``` | ||
|
||
### Testing on a "simulated" browser | ||
|
||
```sh | ||
# Since the following will test the web version, | ||
# the websocket server needs to be running. | ||
# Follow the 2 first steps of the installation instructions "# Running on a browser" | ||
# before proceeding | ||
cd browser/server/ | ||
bun --bun socket.ts | ||
|
||
# In a new terminal tab | ||
cd discojs/discojs-web/ | ||
bun --bun test text_loader.spec.ts | ||
``` | ||
|
||
# Benchmarks | ||
|
||
## Text Loader | ||
|
||
Benchmarking of the text loader is done via iterating 1000 times over the dataset and taking the average time is ms. The vocabulary size is set to 50257. We vary the batch and block sizes and report the results here. | ||
Tests run on an AMD Ryzen 6 7600 CPU. | ||
|
||
### Node | ||
|
||
```py | ||
# (batch, block) = time / iter | ||
- (4, 64) = 1.481 ms | ||
- (4, 128) = 2.564 ms | ||
- (4, 256) = 2.213 ms | ||
- (4, 512) = 3.284 ms | ||
- (16, 64) = 1.912 ms | ||
- (16, 128) = 3.323 ms | ||
- (16, 256) = 6.499 ms | ||
- (16, 512) = 12.131 ms | ||
- (32, 64) = 3.299 ms | ||
- (32, 128) = 6.579 ms | ||
- (32, 256) = 12.325 ms | ||
- (32, 512) = 23.752 ms | ||
``` | ||
|
||
### Web (simulated) | ||
|
||
```py | ||
# (batch, block) = time / iter | ||
- (4, 64) = 1.617 ms | ||
- (4, 128) = 2.725 ms | ||
- (4, 256) = 2.162 ms | ||
- (4, 512) = 3.603 ms | ||
- (16, 64) = 2.120 ms | ||
- (16, 128) = 3.751 ms | ||
- (16, 256) = 6.796 ms | ||
- (16, 512) = 12.837 ms | ||
- (32, 64) = 3.598 ms | ||
- (32, 128) = 6.883 ms | ||
- (32, 256) = 12.718 ms | ||
- (32, 512) = 25.475 ms | ||
``` | ||
|
||
### Web (actual browser) | ||
|
||
## Training on GPT | ||
|
||
TODO: put wandb url | ||
|
||
### Node | ||
|
||
### Web (actual browser) | ||
|
||
# TODO | ||
|
||
1. Benchmark all | ||
2. Try new dataset | ||
3. Try new model | ||
4. Investigate if node v18 is required everywhere now | ||
|
||
# Future work | ||
|
||
1. Disco support for various backends (for WebGPU especially) using `tf.setBackend`, and benchmark on them | ||
2. Support for dedicated tfjs model, which allows custom training loop, e.g. `GPTModel extends Model`. This is partially implemented but not fully (issues in Trainer / TrainerBuilder?) | ||
3. Refactor Task, add generic types | ||
4. QloRA in disco core or at least for GPT |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import { Config, Models as Model } from './tfjs-types' | ||
|
||
export const configModels = { | ||
gpt2: { | ||
nLayer: 12, | ||
nHead: 12, | ||
nEmbd: 768, | ||
vocabSize: 50257, | ||
blockSize: 1024, | ||
}, | ||
'gpt2-medium': { | ||
nLayer: 24, | ||
nHead: 16, | ||
nEmbd: 1024, | ||
vocabSize: 50257, | ||
blockSize: 1024, | ||
}, | ||
'gpt2-large': { | ||
nLayer: 36, | ||
nHead: 20, | ||
nEmbd: 1280, | ||
vocabSize: 50257, | ||
blockSize: 1024, | ||
}, | ||
'gpt2-xl': { | ||
nLayer: 48, | ||
nHead: 25, | ||
nEmbd: 1600, | ||
vocabSize: 50257, | ||
blockSize: 1024, | ||
}, | ||
'gpt-mini': { nLayer: 6, nHead: 6, nEmbd: 192 }, | ||
'gpt-micro': { nLayer: 4, nHead: 4, nEmbd: 128 }, | ||
'gpt-nano': { nLayer: 3, nHead: 3, nEmbd: 48 }, | ||
} as const | ||
|
||
const modelType: Model = 'gpt-nano' | ||
const model = configModels[modelType] | ||
const dataset = 'wikitext-103' | ||
const batchSize = 8 | ||
const blockSize = 128 // = sequence length | ||
const lr = 0.001 | ||
const maxIter = 10 | ||
|
||
const baseConfig = { | ||
debug: false, | ||
verbose: false, | ||
|
||
modelType, | ||
...model, | ||
|
||
dataset, | ||
batchSize, | ||
blockSize, | ||
lr, | ||
maxIter, | ||
shuffle: NaN, | ||
weightDecay: false, // If set, wasm backend won't work because of the custom AdamW optimizer | ||
optimizer: 'adamw', | ||
gradClip: 1, | ||
scheduler: null, | ||
embdDrop: 0.2, | ||
bias: true, | ||
numWorkers: 0, | ||
vocabSize: 50257, | ||
|
||
wandbProject: 'disco-gpt-benchmark', | ||
|
||
evalFreq: 25, | ||
evalSeqPrefix: 'none', | ||
maxEvalBatches: 24, | ||
} as const | ||
|
||
const config: Config = { | ||
...baseConfig, | ||
residDrop: baseConfig.embdDrop, | ||
wandbName: `${modelType}_${dataset}_bs=${batchSize}_seq=${blockSize}_lr=${lr}_iter=${maxIter}`, | ||
} as const | ||
|
||
export default config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import { readdir } from 'fs/promises' | ||
import path from 'path' | ||
import { dataset, Task, node } from '@epfml/discojs-node' | ||
import { TOKENIZED_FILE_EXTENSION } from './preprocess' | ||
|
||
async function getDatasetSource( | ||
root: string, | ||
splits: (keyof dataset.loader.TextSource)[] | ||
): Promise<dataset.loader.TextSource> { | ||
console.log('Preprocessed dataset located at:', root) | ||
const files = await readdir(root) | ||
return Object.fromEntries( | ||
splits.map((split) => { | ||
const splitFiles = files.filter( | ||
(f) => f.endsWith(TOKENIZED_FILE_EXTENSION) && f.includes(split) | ||
) | ||
|
||
console.log( | ||
'Found', | ||
splitFiles.length, | ||
'files in dataset for the', | ||
split, | ||
'split.' | ||
) | ||
|
||
const splitFilesPath = splitFiles.map((f) => path.join(root, f)) | ||
return [split, splitFilesPath] | ||
}) | ||
) as dataset.loader.TextSource | ||
} | ||
|
||
export async function loadData( | ||
task: Task, | ||
name: string, | ||
config?: Partial<dataset.loader.TextConfig> | ||
): Promise<dataset.DataSplit> { | ||
// TODO: Make this even more generic so that it works for any dataset / any task | ||
// 1) move getDatasetSource to core so that the web version can use it as well | ||
/* @ts-ignore - for import.meta.dir */ | ||
const root = path.join(import.meta.dir, 'datasets', name) | ||
const source = await getDatasetSource(root, ['train', 'validation']) | ||
return await new node.dataset.loader.NodeTextLoader(task).loadAll( | ||
source, | ||
config | ||
) | ||
} | ||
|
||
import fs from 'fs' | ||
import Rand from 'rand-seed' | ||
|
||
const rand = new Rand('1234') | ||
|
||
function shuffle<T, U>(array: T[], arrayTwo: U[]): void { | ||
for (let i = array.length - 1; i > 0; i--) { | ||
const j = Math.floor(rand.next() * (i + 1)) | ||
const temp = array[i] | ||
array[i] = array[j] | ||
array[j] = temp | ||
|
||
const tempTwo = arrayTwo[i] | ||
arrayTwo[i] = arrayTwo[j] | ||
arrayTwo[j] = tempTwo | ||
} | ||
} | ||
|
||
function filesFromFolder(dir: string, folder: string): string[] { | ||
const f = fs.readdirSync(dir + folder) | ||
return f.map((file) => dir + folder + '/' + file) | ||
} | ||
|
||
export async function loadDataFace(task: Task): Promise<dataset.DataSplit> { | ||
const dir = '../example_training_data/simple_face/' | ||
const youngFolders = ['child'] | ||
const oldFolders = ['adult'] | ||
|
||
const youngFiles = youngFolders.flatMap((folder) => { | ||
return filesFromFolder(dir, folder) | ||
}) | ||
|
||
const oldFiles = oldFolders.flatMap((folder) => { | ||
return filesFromFolder(dir, folder) | ||
}) | ||
|
||
const filesPerFolder = [youngFiles, oldFiles] | ||
|
||
const labels = filesPerFolder.flatMap((files, index) => | ||
Array(files.length).fill(index) | ||
) | ||
const files = filesPerFolder.flat() | ||
|
||
shuffle(files, labels) | ||
|
||
return await new node.dataset.loader.NodeImageLoader(task).loadAll(files, { | ||
labels: labels, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mkdir -p ./datasets/tinyshakespeare | ||
curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt --output ./datasets/tiny-shakespeare/train | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
mkdir -p ./datasets/wikitext | ||
mkdir -p ./datasets/wikitext-103 | ||
curl https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip --output ./datasets/wikitext/raw.zip | ||
unzip ./datasets/wikitext/raw.zip -d ./datasets/wikitext-103 | ||
mv ./datasets/wikitext-103/wikitext-103-raw/wiki.train.raw ./datasets/wikitext-103/train | ||
mv ./datasets/wikitext-103/wikitext-103-raw/wiki.test.raw ./datasets/wikitext-103/test | ||
mv ./datasets/wikitext-103/wikitext-103-raw/wiki.valid.raw ./datasets/wikitext-103/validation | ||
rmdir ./datasets/wikitext-103/wikitext-103-raw/ |
Oops, something went wrong.