Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] torchdata integration #1929

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

andrewkho
Copy link

@andrewkho andrewkho commented Oct 30, 2024

Note! This requires torchdata nightly to be installed to work correctly.

Test multi-dataset training command:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 compile=True dataloader.num_workers=4

Test multi-dataset command with dataloader_only mode:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 profile_mode=dataloader_only compile=True dataloader.num_workers=4

Benchmarking on 8xA100

for 200 steps, full model training, batch_size: 4, gradient_accumulation_steps: 1
TWFB: Sum of Time Waiting For Batch (max across all ranks) divided by sum of step times
Single Datasets are run with OCRVQA
Multi-Dataset was run with:
ocrvqa, docvqa, dvqa, tabmwp with equal weighting for all datasets.

Trial TWFB % Sum DL times sum step times
0 Single dataset (Baseline) 0.06279 14.596 232.454
1 Single dataset (TorchData threads) 0.00726 1.64746 226.780
2 Multi dataset (threads) 0.00514 2.4945 485.388
3 Multi dataset Streaming 0.01712, 8.5445 499.0716

Multi-dataset runs much slower, guessing because one of the datasets (dvqa?) requires more padding than ocrvqa.

Launch commands:
0: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=false max_steps_per_epoch=200 compile=True
1: (same as 2 but with 3 of the datasets removed in the config)
2: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True
3: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True dataloader.streaming=true

Please have a look at the code set up, and how this composability can help with streaming datasets, and multi-dataset mixing. You can think of this as approaches to replace torch.utils.data.DataLoader, while introducing more flexible parallelism schemes, eg instead of just one multiprocess worker setting, you could do multi-threading, pipeline parallelism, etc. It also enables more powerful composability IMO.

I have done some single-device benchmarking on a machine with A100 40gb, both with and without the model, performance is on par or better than standard dataloader.

TODO: fill in below

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1929

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 4 Cancelled Jobs

As of commit b1b2ab6 with merge base 4fb2464 (image):

NEW FAILURES - The following jobs have failed:

  • Recipe Tests / recipe_test (3.11) (gh)
    ERROR: THESE PACKAGES DO NOT MATCH THE HASHES FROM THE REQUIREMENTS FILE. If you have updated the package versions, please update the hashes. Otherwise, examine the package contents carefully; someone may have tampered with them.
  • Unit Test / unit_tests (3.9) (gh)
    ERROR: Could not install packages due to an OSError: HTTPSConnectionPool(host='files.pythonhosted.org', port=443): Max retries exceeded with url: /packages/ed/a5/33cf000137545a08b0a3a6ea76c8ccbd87917f78bb5d737f9f56f3b11ef6/datasets-3.1.0-py3-none-any.whl.metadata (Caused by ResponseError('too many 502 error responses'))

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 30, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this prototype, this is helpful to see! The way I'd imagine we'd expose the torchdata dataloader would be from a builder function with a few knobs exposed with reasonable defaults:

def build_dataloader(ds, num_workers, pin_memory, prefetch, in_memory, parallel_method, ...):
	# Put together all the nodes here

# In config
dataloader:
  num_workers:
  ...

For a power user, what might they want to tune to optimize performance for their hardware and model setup?

It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?

node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
# Cut here for Streaming/Iterable dataset instead =====
node = _Mapper(node, map_fn=ds._prepare_sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this where the transform would get parallelized?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see

_Mapper = partial(
    ParallelMapper,
    num_workers=num_workers,
    method=parallel_method,
    in_order=True,
)

)
# Map style set up =======
node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this mean we can keep our own Dataset abstractions? What if we went with entirely Iterable datasets?

Copy link
Author

@andrewkho andrewkho Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, but this was just a way to get it working quickly, we can also refactor this into separate function. For IterableDataset, you could wrap it in IterableWrapper and then call everything underneath here, ie

node = IterableWrapper(my_iterable_dataset)
node = _Mapper(node, map_fn=ds._prepare_sample)
...

batch = next(dl_iter)
dl_dt = time.perf_counter() - dl_t0
idx += 1
except StopIteration:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not create a context manager that handles this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No good reason, hacky code is hacky :) Just wanted to see some rough numbers

# Map style set up =======
node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
# Cut here for Streaming/Iterable dataset instead =====
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tried this out on an HF dataset with streaming = True yet? (I assume it won't work out of the box yet?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see update :D almost out-of-the-box

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to wrap with IterableWrapper if the underlying dataset is IterableDataset? Also in our case, the underlying dataset would be a HF dataset class

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapper will make it conform to BaseNode's API so we can hold a pointer to iterator, as well as unify the state management. Subclasses need to define .iterator() instead of .__iter__().

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR! I like the IterableWrapper/Batcher/etc APIs, they look pretty clean (and agree with @RdoubleA's suggestion about exposing this in builders/configs). How will this work when training on multiple datasets? Will we just apply the same set of APIs to each sampler individually?

@andrewkho
Copy link
Author

Thanks for the comments y'all, I updated this with a streaming example. Test with:

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device num_workers=2 pin_memory=False use_torchdata=True parallel_method=thread max_steps_per_epoch=50 dataset.streaming=True

@andrewkho
Copy link
Author

@RdoubleA

For a power user, what might they want to tune to optimize performance for their hardware and model setup?

Good question, at the minimum it'd be some global worker-setting, but one option that may be worth supporting is allowing power users to define their entire pipelines entirely in config, not sure if you think this is a bad idea. eg by default, use a builder with a few knobs, but also allow the entire dataloader definition to be composable. Somewhat similar to how you enable users to pass a list of datasets right now. Something similar could be done for mixing. We'd need to be thoughtful for what the useful small-atomic units woudl be, and figure out syntax sugar.

It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?

In terms of optimizing, one thing we're looking at is tf.data's autotune approach which will automatically update prefetch buffers and workers. This is not implemented, but hoping to do something along these lines which will help with the too-many-tunable-parameters problem.

@@ -235,3 +243,73 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")
return ds


def the_cauldron_dataset_torchdata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of this could be pushed into the SFT class so that we can just reuse it for any new datasets?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of it is probably re-usable, I think it should be it's own class or at least a builder, maybe like _sft.py: class SFTDatasetNode or something less terrible sounding

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually 90% of this builder func could probably live in _sft.py as it doesn't have anything to do with the cauldron

@@ -0,0 +1,1119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Author

@andrewkho andrewkho Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied and modified from lora_finetune_distributed

@@ -0,0 +1,131 @@
# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied and modified from 11B_lora

@@ -117,7 +128,17 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._data[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:

class PrepareSample:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring this into it's own Callable class that can be used by both the current torch.utils.Dataset and the torchdata.nodes

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey Andrew, thank you so much for this PR! This is such a nice feature to have.

I did a first pass. I understand that its still a draft, but thought of making comments so it could save you some time for when the PR is closer to being ready.

Comment on lines 116 to 117
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use our utility instead?

def get_world_size_and_rank() -> Tuple[int, int]:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be removed

global_streaming=streaming,
global_num_workers=num_workers,
),
prefetch_factor=8,
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add it to setup_data as a default or exposed in the config? Or is this parameter not commonly changed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR: this is now exposed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, why isn't this showing the new version? Let me make sure I've pushed


return ds

def _setup_data_td(
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require a bit of a refactoring, and probably not the main point of this pr, but it would be nice to have _setup_dataset and _setup_dataloader as two different methods. It should be easier to read and maintain

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split this into setting up individual datasets (_setup_one_dataset) and the global dataloader/mixer set up.


# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this was already in the setup_data, but we usually try to fail fast and catch errors like this in the init.

Comment on lines 113 to 114
if load_dataset_kwargs.get("streaming", False):
self._data = split_dataset_by_node(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if we can decouple dataloader from dataset, it will be easier to maintain/work with it. For example, can we do something like:

MyDatasetDistributed = split_dataset_by_node(MyDataset)

def split_dataset_by_node(...):
	assert hasattr(MyDataset, self._data)

or maybe SFTDataset can have getter

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed actually

Comment on lines 690 to 696
if len(cfg_datasets) == 1:
node = next(iter(datasets.values())) # TODO: multi dataset
else:
node = MultiDatasetWeightedSampler(
source_nodes=datasets,
weights=weights,
)
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: I understand that this is torchdata api, but i wonder if 'len(cfg_datasets) == 1' check should be inside of MultiDatasetWeightedSampler. I.e. do we need this if check?


log.info("TorchData nodes are initialized")

return node
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'node' feels a bit weird, since we do: self._dataloader = self._setup_dataloader(...)

Should we rename it to dataloader?


# TODO: add multi-dataset mixer
if num_workers == 0:
_Mapper = Mapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit torn about if getting mapper, sampler, pin memory, etc should be an utility shared across all recipes, or if it should be exposed. No strong opinion, just thinking outloud

if pin_memory:
node = PinMemory(node)
if num_workers > 0:
node = Prefetcher(node, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this PreFetcher=2 different than the previous prefect_factor=8?

Comment on lines 645 to 647
All data related setup happens here. Currently this recipe only supports
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess this needs to be updated

prefetch_factor: 2
seed: null

multi_datasets:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally this is too verbose to parse for me, and even in the recipe there are just too many nested dictionaries. Ideally, I would like to achieve this type of UI in the config for datasets:

datasets:
  - _component_: torchtune.datasets...
    weight: 1.0
    subset: ...
  - _component_: torchtune.datasets...
    weight: 1.0
    ...

or something similar so all I have to do is specify the dataset I want and the weight. As it is I have multi_datasets -> datasets -> dataset just to specify the dataset builder. Maybe this is very ideal, but other libraries such as Axolotl are able to do this.

I am aware there's a few challenges to having this:

  1. MultiDatasetSampler requires passing in dictionaries for datasets and weights
  2. weight is not a valid argument for instantiating dataset components

I'm wondering if there's a way we can do this all for the user in a builder. For example:

datasets: ListConfig
for cfg_dataset in datasets:
	weights[k] = cfg_dataset.pop("weight")
	dataset[k] = Prefetcher(config.instantiate(cfg_dataset), prefetch_factor)
dataloader = get_multi_datasets(datasets, weights, cfg_dataloader)

stop_criterion imo should be moved to the dataloader config

)
weights, datasets = {}, {}
cfg_datasets = cfg_multi_datasets.datasets
for k, cfg_and_weight in cfg_datasets.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this logic needs to be in the builder. we do not want to expose torchdata internals in each recipe. I am totally okay with creating a new file in torchtune/data that contains functions that set up torchdata dataloaders and nodes.

Another ideal that I'm curious if we can achieve, can we unify the UX for multi datasets and single datasets? i.e., if we had a get_dataloader method, you can pass in a single dataset or a multi dataset and the call is the same in the recipe regardless of what the user specifies in the config

),
)

return node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this returns a node, not a dataloader right? Are users still able to access the underlying hugging face data?



@requires_torchdata
def SFTDatasetNode( # noqa[N802]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we're moving away from the class-based dataset abstraction and more of a function that returns the node configured with user parameters.

Curious if it would be better UX to just abandon the SFTDataset class (after the full migration) and keep Transform classes for each kinda of dataset (Instruct, Chat, MM, SFT, Preference) which is passed into a generic node builder

dataset = dataset.shuffle(seed=seed)
node = IterableWrapper(dataset)
else:
sampler = DistributedSampler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, which of these nodes are specific to a single dataset vs global for the dataloader?

streaming: bool = False,
shuffle: bool = False,
seed: int = 0,
num_workers: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is specific params such as seed, shuffle, num_workers, etc for each individual dataset a valid use case? My understanding was you can specify this globally at the dataloader level


# Get global settings
shuffle = cfg_dataloader.shuffle
parallel_method = cfg_dataloader.get("parallel_method", "thread")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, I'm quite confused on which parameters belong in the "dataset" abstraction and which belong in the "dataloader" abstraction. As it is, it seems you are using these in both. I would prefer to make this distinction very clear, unless I am missing something you may need to configure per dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants