Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init code #2

Merged
merged 19 commits into from
Sep 21, 2024
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,30 @@ install uv

```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync
uv sync --extra all
```

run your code using

```bash
uv run ...
```

## quick check

To check that everything is working you can do

```bash
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug.toml
```

## run test

You need a machine with a least two gpus to run the full test suite.

Some test must be run from the root directory.

```bash
uv run pytest
```

12 changes: 12 additions & 0 deletions configs/150M/3090.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name_model = "150M"
project = "debug_150m_zero_band"

[train]
micro_bs = 16 # change this base on the gpu
sharding_strategy = "NO_SHARD"

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
12 changes: 12 additions & 0 deletions configs/150M/A40.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name_model = "150M"
project = "debug_150m_zero_band"

[train]
micro_bs = 32 # change this base on the gpu
sharding_strategy = "NO_SHARD"

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
12 changes: 12 additions & 0 deletions configs/150M/H100.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name_model = "150M"
project = "debug_150m_zero_band"

[train]
micro_bs = 64 # change this base on the gpu
sharding_strategy = "NO_SHARD"

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
12 changes: 12 additions & 0 deletions configs/1B/H100.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name_model = "1B"
project = "debug_1B_zero_band"

[train]
micro_bs = 16
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
12 changes: 12 additions & 0 deletions configs/7B/H100.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name_model = "7B"
project = "debug_7B_zero_band"

[train]
micro_bs = 6
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 3840
warmup_steps = 1000
total_steps = 88_000
lr = 6e-4
13 changes: 13 additions & 0 deletions configs/debug.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name_model = "debugmodel"
project = "debug"

[train]
micro_bs = 8

[optim]
batch_size = 16
warmup_steps = 10
total_steps = 5000

[data]
fake_data = true
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ dependencies = [
"setuptools",
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/[email protected]"
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c",
"einops"
]

[project.optional-dependencies]
all = [
"wandb",
]

[build-system]
requires = ["hatchling"]
Expand Down
2 changes: 0 additions & 2 deletions src/zeroband/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
def hello() -> str:
return "Hello from zeroband!"
87 changes: 87 additions & 0 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from functools import partial
from typing import Any, Generator

import torch
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

TEST_VOCAB_SIZE = 1024

# TODO sami: make sure the init of the model is the same on all rank


class FakeTokenizedDataset(IterableDataset):
"""This is a dummy dataset that generates random sequences of length seq_len and vocab_size"""

def __init__(self, seq_len: int, vocab_size: int):
self.seq_len = seq_len
self.vocab_size = vocab_size
assert vocab_size > 3, "Vocab size must be greater than 3"

def __iter__(self) -> Generator[dict[str, Any], Any, None]:
while True:
input_ids = torch.randint(3, self.vocab_size, (self.seq_len,)).tolist()
yield {"input_ids": input_ids}


def collate_causal_mask(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100) -> callable:
"""collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length"""
return partial(_collate_fn_causal_mask, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index)


def _collate_fn_causal_mask(
samples: list[dict[str, torch.LongTensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100
) -> dict[str, torch.LongTensor]:
"""collate function for causal mask. Fill with padding tokens if sequence is shorter than max_seq_length.
input_ids and labels are both of size max_seq_length.
"""

assert samples[0].keys() == {"input_ids"}

batched = {"input_ids": [], "labels": []}

if max_seq_length > 0:
max_seq_length += 1 # this makes sure that the effective seqlen is correct

for sample in samples:
input_ids = torch.Tensor(sample["input_ids"]).long()

if len(input_ids) < max_seq_length:
input_ids = torch.cat([input_ids, torch.full((max_seq_length - len(input_ids),), pad_id)])
elif len(input_ids) > max_seq_length:
input_ids = input_ids[:max_seq_length]

batched["input_ids"].append(input_ids[1:])
batched["labels"].append(input_ids[:-1])

return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)}


def get_dataloader(
tokenizer, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int, fake_data: bool
) -> DataLoader:
if fake_data:
train_dataset = FakeTokenizedDataset(seq_length, TEST_VOCAB_SIZE)
else:
ds = load_dataset("allenai/c4", "en", streaming=True)

def tokenize_function(data):
outputs = tokenizer(data["text"], truncation=True, max_length=seq_length, padding="max_length")
return outputs

tokenized_datasets = ds.map(
tokenize_function, batched=True, remove_columns=["text", "timestamp", "url", "attention_mask"]
)["train"]
train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank)

data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100)

return DataLoader(
train_dataset,
collate_fn=data_collator,
batch_size=batch_size,
num_workers=num_workers,
)
Empty file added src/zeroband/models/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Llama 2 is licensed under the LLAMA 2 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from zeroband.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=8),
"150M": ModelArgs(dim=1024, n_layers=12, n_heads=16), # todo(sami): double check this
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"26B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"70B": ModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
),
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
),
"70B": ModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}


def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

if type_model == "llama2":
config = llama2_configs[name_model]
elif type_model == "llama3":
config = llama3_configs[name_model]
else:
raise ValueError(f"Model type {type_model} not supported")

config.vocab_size = vocab_size
return Transformer(config), config
Loading