-
Notifications
You must be signed in to change notification settings - Fork 438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[draft] torchdata integration #1929
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1929
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 4 Cancelled JobsAs of commit b1b2ab6 with merge base 4fb2464 (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this prototype, this is helpful to see! The way I'd imagine we'd expose the torchdata dataloader would be from a builder function with a few knobs exposed with reasonable defaults:
def build_dataloader(ds, num_workers, pin_memory, prefetch, in_memory, parallel_method, ...):
# Put together all the nodes here
# In config
dataloader:
num_workers:
...
For a power user, what might they want to tune to optimize performance for their hardware and model setup?
It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) | ||
# Cut here for Streaming/Iterable dataset instead ===== | ||
node = _Mapper(node, map_fn=ds._prepare_sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this where the transform would get parallelized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see
_Mapper = partial(
ParallelMapper,
num_workers=num_workers,
method=parallel_method,
in_order=True,
)
) | ||
# Map style set up ======= | ||
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does this mean we can keep our own Dataset abstractions? What if we went with entirely Iterable datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically yes, but this was just a way to get it working quickly, we can also refactor this into separate function. For IterableDataset, you could wrap it in IterableWrapper and then call everything underneath here, ie
node = IterableWrapper(my_iterable_dataset)
node = _Mapper(node, map_fn=ds._prepare_sample)
...
batch = next(dl_iter) | ||
dl_dt = time.perf_counter() - dl_t0 | ||
idx += 1 | ||
except StopIteration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not create a context manager that handles this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No good reason, hacky code is hacky :) Just wanted to see some rough numbers
# Map style set up ======= | ||
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) | ||
# Cut here for Streaming/Iterable dataset instead ===== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we tried this out on an HF dataset with streaming = True yet? (I assume it won't work out of the box yet?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see update :D almost out-of-the-box
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to wrap with IterableWrapper
if the underlying dataset is IterableDataset? Also in our case, the underlying dataset would be a HF dataset class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapper will make it conform to BaseNode's API so we can hold a pointer to iterator, as well as unify the state management. Subclasses need to define .iterator()
instead of .__iter__()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for creating this PR! I like the IterableWrapper/Batcher/etc APIs, they look pretty clean (and agree with @RdoubleA's suggestion about exposing this in builders/configs). How will this work when training on multiple datasets? Will we just apply the same set of APIs to each sampler individually?
Thanks for the comments y'all, I updated this with a streaming example. Test with:
|
Good question, at the minimum it'd be some global worker-setting, but one option that may be worth supporting is allowing power users to define their entire pipelines entirely in config, not sure if you think this is a bad idea. eg by default, use a builder with a few knobs, but also allow the entire dataloader definition to be composable. Somewhat similar to how you enable users to pass a list of datasets right now. Something similar could be done for mixing. We'd need to be thoughtful for what the useful small-atomic units woudl be, and figure out syntax sugar.
In terms of optimizing, one thing we're looking at is tf.data's autotune approach which will automatically update prefetch buffers and workers. This is not implemented, but hoping to do something along these lines which will help with the too-many-tunable-parameters problem. |
@@ -235,3 +243,73 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | |||
if packed: | |||
raise ValueError("Multimodal datasets don't support packing yet.") | |||
return ds | |||
|
|||
|
|||
def the_cauldron_dataset_torchdata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of this could be pushed into the SFT class so that we can just reuse it for any new datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of it is probably re-usable, I think it should be it's own class or at least a builder, maybe like _sft.py: class SFTDatasetNode
or something less terrible sounding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually 90% of this builder func could probably live in _sft.py as it doesn't have anything to do with the cauldron
@@ -0,0 +1,1119 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied and modified from lora_finetune_distributed
@@ -0,0 +1,131 @@ | |||
# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied and modified from 11B_lora
@@ -117,7 +128,17 @@ def __getitem__(self, index: int) -> Dict[str, Any]: | |||
sample = self._data[index] | |||
return self._prepare_sample(sample) | |||
|
|||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: | |||
|
|||
class PrepareSample: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring this into it's own Callable class that can be used by both the current torch.utils.Dataset and the torchdata.nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey Andrew, thank you so much for this PR! This is such a nice feature to have.
I did a first pass. I understand that its still a draft, but thought of making comments so it could save you some time for when the PR is closer to being ready.
torchtune/datasets/_sft.py
Outdated
rank=int(os.environ.get("RANK", 0)), | ||
world_size=int(os.environ.get("WORLD_SIZE", 1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use our utility instead?
torchtune/torchtune/training/_distributed.py
Line 150 in 1eb7785
def get_world_size_and_rank() -> Tuple[int, int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be removed
global_streaming=streaming, | ||
global_num_workers=num_workers, | ||
), | ||
prefetch_factor=8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to add it to setup_data as a default or exposed in the config? Or is this parameter not commonly changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR: this is now exposed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, why isn't this showing the new version? Let me make sure I've pushed
|
||
return ds | ||
|
||
def _setup_data_td( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would require a bit of a refactoring, and probably not the main point of this pr, but it would be nice to have _setup_dataset and _setup_dataloader as two different methods. It should be easier to read and maintain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've split this into setting up individual datasets (_setup_one_dataset) and the global dataloader/mixer set up.
|
||
# Instantiate collate_fn | ||
if "left_pad_sequence" in collate_fn: | ||
raise RuntimeError("left_pad_sequence collator is only for inference.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this was already in the setup_data, but we usually try to fail fast and catch errors like this in the init.
torchtune/datasets/_sft.py
Outdated
if load_dataset_kwargs.get("streaming", False): | ||
self._data = split_dataset_by_node( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if we can decouple dataloader from dataset, it will be easier to maintain/work with it. For example, can we do something like:
MyDatasetDistributed = split_dataset_by_node(MyDataset)
def split_dataset_by_node(...):
assert hasattr(MyDataset, self._data)
or maybe SFTDataset can have getter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be removed actually
if len(cfg_datasets) == 1: | ||
node = next(iter(datasets.values())) # TODO: multi dataset | ||
else: | ||
node = MultiDatasetWeightedSampler( | ||
source_nodes=datasets, | ||
weights=weights, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: I understand that this is torchdata api, but i wonder if 'len(cfg_datasets) == 1' check should be inside of MultiDatasetWeightedSampler. I.e. do we need this if check?
|
||
log.info("TorchData nodes are initialized") | ||
|
||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'node' feels a bit weird, since we do: self._dataloader = self._setup_dataloader(...)
Should we rename it to dataloader?
|
||
# TODO: add multi-dataset mixer | ||
if num_workers == 0: | ||
_Mapper = Mapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit torn about if getting mapper, sampler, pin memory, etc should be an utility shared across all recipes, or if it should be exposed. No strong opinion, just thinking outloud
if pin_memory: | ||
node = PinMemory(node) | ||
if num_workers > 0: | ||
node = Prefetcher(node, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this PreFetcher=2 different than the previous prefect_factor=8?
All data related setup happens here. Currently this recipe only supports | ||
Map-style Datasets which fit into memory and an option for random shuffling. | ||
Samplers, iterable datasets, and streaming datasets are not supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess this needs to be updated
prefetch_factor: 2 | ||
seed: null | ||
|
||
multi_datasets: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally this is too verbose to parse for me, and even in the recipe there are just too many nested dictionaries. Ideally, I would like to achieve this type of UI in the config for datasets:
datasets:
- _component_: torchtune.datasets...
weight: 1.0
subset: ...
- _component_: torchtune.datasets...
weight: 1.0
...
or something similar so all I have to do is specify the dataset I want and the weight. As it is I have multi_datasets -> datasets -> dataset just to specify the dataset builder. Maybe this is very ideal, but other libraries such as Axolotl are able to do this.
I am aware there's a few challenges to having this:
- MultiDatasetSampler requires passing in dictionaries for datasets and weights
- weight is not a valid argument for instantiating dataset components
I'm wondering if there's a way we can do this all for the user in a builder. For example:
datasets: ListConfig
for cfg_dataset in datasets:
weights[k] = cfg_dataset.pop("weight")
dataset[k] = Prefetcher(config.instantiate(cfg_dataset), prefetch_factor)
dataloader = get_multi_datasets(datasets, weights, cfg_dataloader)
stop_criterion imo should be moved to the dataloader config
) | ||
weights, datasets = {}, {} | ||
cfg_datasets = cfg_multi_datasets.datasets | ||
for k, cfg_and_weight in cfg_datasets.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this logic needs to be in the builder. we do not want to expose torchdata internals in each recipe. I am totally okay with creating a new file in torchtune/data
that contains functions that set up torchdata dataloaders and nodes.
Another ideal that I'm curious if we can achieve, can we unify the UX for multi datasets and single datasets? i.e., if we had a get_dataloader
method, you can pass in a single dataset or a multi dataset and the call is the same in the recipe regardless of what the user specifies in the config
), | ||
) | ||
|
||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this returns a node, not a dataloader right? Are users still able to access the underlying hugging face data?
|
||
|
||
@requires_torchdata | ||
def SFTDatasetNode( # noqa[N802] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems we're moving away from the class-based dataset abstraction and more of a function that returns the node configured with user parameters.
Curious if it would be better UX to just abandon the SFTDataset class (after the full migration) and keep Transform classes for each kinda of dataset (Instruct, Chat, MM, SFT, Preference) which is passed into a generic node builder
dataset = dataset.shuffle(seed=seed) | ||
node = IterableWrapper(dataset) | ||
else: | ||
sampler = DistributedSampler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, which of these nodes are specific to a single dataset vs global for the dataloader?
streaming: bool = False, | ||
shuffle: bool = False, | ||
seed: int = 0, | ||
num_workers: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is specific params such as seed, shuffle, num_workers, etc for each individual dataset a valid use case? My understanding was you can specify this globally at the dataloader level
|
||
# Get global settings | ||
shuffle = cfg_dataloader.shuffle | ||
parallel_method = cfg_dataloader.get("parallel_method", "thread") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general, I'm quite confused on which parameters belong in the "dataset" abstraction and which belong in the "dataloader" abstraction. As it is, it seems you are using these in both. I would prefer to make this distinction very clear, unless I am missing something you may need to configure per dataset
Note! This requires torchdata nightly to be installed to work correctly.
Test multi-dataset training command:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 compile=True dataloader.num_workers=4
Test multi-dataset command with dataloader_only mode:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 profile_mode=dataloader_only compile=True dataloader.num_workers=4
Benchmarking on 8xA100
for 200 steps, full model training, batch_size: 4, gradient_accumulation_steps: 1
TWFB: Sum of Time Waiting For Batch (max across all ranks) divided by sum of step times
Single Datasets are run with OCRVQA
Multi-Dataset was run with:
ocrvqa, docvqa, dvqa, tabmwp with equal weighting for all datasets.
Multi-dataset runs much slower, guessing because one of the datasets (dvqa?) requires more padding than ocrvqa.
Launch commands:
0: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=false max_steps_per_epoch=200 compile=True
1: (same as 2 but with 3 of the datasets removed in the config)
2: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True
3: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True dataloader.streaming=true
Please have a look at the code set up, and how this composability can help with streaming datasets, and multi-dataset mixing. You can think of this as approaches to replace torch.utils.data.DataLoader, while introducing more flexible parallelism schemes, eg instead of just one multiprocess worker setting, you could do multi-threading, pipeline parallelism, etc. It also enables more powerful composability IMO.
I have done some single-device benchmarking on a machine with A100 40gb, both with and without the model, performance is on par or better than standard dataloader.
TODO: fill in below
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example