Skip to content

Commit

Permalink
Allow easier use of custom datasets (meta-llama#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
chauhang authored Sep 8, 2023
2 parents df77625 + 26b9b7d commit d172d88
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ If you want to dive right into single or multi GPU fine-tuning, run the examples
All the parameters in the examples and recipes below need to be further tuned to have desired results based on the model, method, data and task at hand.

**Note:**
* To change the dataset in the commands below pass the `dataset` arg. Current options for dataset are `grammar_dataset`, `alpaca_dataset`and `samsum_dataset`. A description of the datasets and how to add custom datasets can be found in [Dataset.md](./docs/Dataset.md). For `grammar_dataset`, `alpaca_dataset` please make sure you use the suggested instructions from [here](./docs/single_gpu.md#how-to-run-with-different-datasets) to set them up.
* To change the dataset in the commands below pass the `dataset` arg. Current options for integrated dataset are `grammar_dataset`, `alpaca_dataset`and `samsum_dataset`. A description of how to use your own dataset and how to add custom datasets can be found in [Dataset.md](./docs/Dataset.md#using-custom-datasets). For `grammar_dataset`, `alpaca_dataset` please make sure you use the suggested instructions from [here](./docs/single_gpu.md#how-to-run-with-different-datasets) to set them up.

* Default dataset and other LORA config has been set to `samsum_dataset`.

Expand Down
33 changes: 29 additions & 4 deletions docs/Dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,35 @@ The provided fine tuning script allows you to select between three datasets by p
* [alpaca_dataset](https://github.com/tatsu-lab/stanford_alpaca) provides 52K instruction-response pairs as generated by `text-davinci-003`.
* [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.

## Adding custom datasets

The list of available datasets can easily be extended with custom datasets by following these instructions.

## Using custom datasets

The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
To use a custom dataset there are two possible ways.
The first provides a function returning the dataset in a .py file which can be given to the command line tool.
This does not involve changing the source code of llama-recipes.
The second way is targeting contributions which extend llama-recipes as it involves changing the source code.

### Training on custom data
To supply a custom dataset you need to provide a single .py file which contains a function with the following signature:
```@python
def get_custom_dataset(dataset_config, tokenizer, split: str):
```
For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](examples/custom_dataset.py).
The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
The split signals wether to return the training or validation dataset.
The default function name is `get_custom_dataset` but this can be changes as described below.

In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
```
python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
```
To change the function name that is used in the .py you can append the name following a `:` like this:
```
python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py:get_foo" [TRAINING PARAMETERS]
```
This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.

### Adding new dataset
Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.

Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.
Expand Down
6 changes: 5 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ For more in depth information on inference including inference safety checks and

**Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.

**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.

## Train on custom dataset
To show how to train a model on a custom dataset we provide an example to generate a custom dataset in [custom_dataset.py](./custom_dataset.py).
The usage of the custom dataset is further described in the datasets [README](../docs/Dataset.md#training-on-custom-data).
33 changes: 33 additions & 0 deletions examples/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# For dataset details visit: https://huggingface.co/datasets/samsum

import datasets

from llama_recipes.datasets.utils import Concatenator

def get_custom_dataset(dataset_config, tokenizer, split):
dataset = datasets.load_dataset("samsum", split=split)

prompt = (
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
)

def apply_prompt_template(sample):
return {
"text": prompt.format(
dialog=sample["dialogue"],
summary=sample["summary"],
eos_token=tokenizer.eos_token,
)
}

dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))

dataset = dataset.map(
lambda sample: tokenizer(sample["text"]),
batched=True,
remove_columns=list(dataset.features),
).map(Concatenator(), batched=True)
return dataset
10 changes: 9 additions & 1 deletion src/llama_recipes/configs/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ class alpaca_dataset:
dataset: str = "alpaca_dataset"
train_split: str = "train"
test_split: str = "val"
data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
data_path: str = "src/llama_recipes/datasets/alpaca_data.json"


@dataclass
class custom_dataset:
dataset: str = "custom_dataset"
file: str = "examples/custom_dataset.py"
train_split: str = "train"
test_split: str = "validation"
8 changes: 5 additions & 3 deletions src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def generate_peft_config(train_config, kwargs):

assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"

config = configs[names.index(train_config.peft_method)]
config = configs[names.index(train_config.peft_method)]()

update_config(config, **kwargs)
params = {k.name: getattr(config, k.name) for k in fields(config)}
peft_config = peft_configs[names.index(train_config.peft_method)](**params)
Expand All @@ -52,10 +53,11 @@ def generate_peft_config(train_config, kwargs):

def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())

assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"

dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()

update_config(dataset_config, **kwargs)

return dataset_config
38 changes: 38 additions & 0 deletions src/llama_recipes/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import importlib
from functools import partial
from pathlib import Path

import torch

Expand All @@ -12,10 +14,46 @@
)


def load_module_from_py_file(py_file: str) -> object:
"""
This method loads a module from a py file which is not in the Python path
"""
module_name = Path(py_file).name
loader = importlib.machinery.SourceFileLoader(module_name, py_file)
spec = importlib.util.spec_from_loader(module_name, loader)
module = importlib.util.module_from_spec(spec)

loader.exec_module(module)

return module


def get_custom_dataset(dataset_config, tokenizer, split: str):
if ":" in dataset_config.file:
module_path, func_name = dataset_config.file.split(":")
else:
module_path, func_name = dataset_config.file, "get_custom_dataset"

if not module_path.endswith(".py"):
raise ValueError(f"Dataset file {module_path} is not a .py file.")

module_path = Path(module_path)
if not module_path.is_file():
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")

module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_config, tokenizer, split)
except AttributeError as e:
print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
raise e


DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
"grammar_dataset": get_grammar_dataset,
"samsum_dataset": get_samsum_dataset,
"custom_dataset": get_custom_dataset,
}


Expand Down
58 changes: 58 additions & 0 deletions tests/datasets/test_custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import pytest
from unittest.mock import patch


@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
from llama_recipes.finetuning import main

tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})

kwargs = {
"dataset": "custom_dataset",
"custom_dataset.file": "examples/custom_dataset.py",
"batch_size_training": 1,
"use_peft": False,
}

main(**kwargs)

assert train.call_count == 1

args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]

VAL_SAMPLES = 818
TRAIN_SAMPLES = 14732
CONCAT_SIZE = 2048

assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
assert len(eval_dataloader) == VAL_SAMPLES


@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
from llama_recipes.finetuning import main

tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})

kwargs = {
"dataset": "custom_dataset",
"custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
"batch_size_training": 1,
"use_peft": False,
}
with pytest.raises(AttributeError):
main(**kwargs)
37 changes: 37 additions & 0 deletions tests/datasets/test_samsum_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from unittest.mock import patch


@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
from llama_recipes.finetuning import main

tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})


kwargs = {
"batch_size_training": 1,
"use_peft": False,
"dataset": "samsum_dataset",
}

main(**kwargs)

assert train.call_count == 1

args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]

VAL_SAMPLES = 818
TRAIN_SAMPLES = 14732
CONCAT_SIZE = 2048
assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
assert len(eval_dataloader) == VAL_SAMPLES

0 comments on commit d172d88

Please sign in to comment.