Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple GPU low performance #1734

Open
jetstudio-io opened this issue Oct 1, 2024 · 3 comments
Open

Multiple GPU low performance #1734

jetstudio-io opened this issue Oct 1, 2024 · 3 comments
Labels
question Further information is requested

Comments

@jetstudio-io
Copy link

Hello,
I have an issue with multiple GPU performance.

  • I use the recipe lora_finetune_single_device with the config mini_lora_single_device.yaml on 6000ADA, I got ~5it/s
  • I use the recipe lora_finetune_distributed with the config mini_lora.yaml on 2 x 6000ADA, I got 1.5s/it
    The dataset that I used to fine-tune is HuggingFaceFW/fineweb-edu-score-2
    How can I improve the performance in multiple GPU?
@RdoubleA
Copy link
Contributor

RdoubleA commented Oct 1, 2024

Hi @jetstudio-io, thanks for the question. The it/s or sec/it metric is not a great indicator of performance here. Instead, I would check the logs for tokens per second to do a better comparison. For example:

$ cat /tmp/full-llama3.2-finetune/log_1727815865.txt
Step 1 | loss:2.7667150497436523 lr:2e-05 tokens_per_second_per_gpu:766.5005627443386

Or you can see it over time if you log with WandB and set log_memory_stats=True in your launch command.

Many factors can impact raw seconds/iteration, especially gradient accumulation, but it is not necessarily indicative of training convergence speed. That being said, there are still other ways to improve performance. You can check our documentation page on memory/perf features you can enable to get some ideas (cc @felipemello1): https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html.

A very direct way to improve throughput is to enable packing in your dataset. If you are using the torchtune dataset builder functions, you can simply pass packed=True in your config or launch command.

@felipemello1
Copy link
Contributor

felipemello1 commented Oct 1, 2024

Like @RdoubleA said, the configh as "gradient_accumulation_steps: 16", which means that one step is actually 16.

Maybe try the following:

tune run lora_finetune_single_device --config phi3/mini_lora_single_device \
compile=True \
dataset.packed=True \
tokenizer.max_seq_len=2048 \
bsz=4 \
gradient_accumulation_steps=2 \
enable_activation_checkpointing = False \
log_every_n_steps=1 \
metric_logger=torchtune.training.metric_logging.WandBLogger \
log_peak_memory_stats=True

If you are running out of memory, set enable_activation_checkpointing=True
Otherwise, increase bsz

You can see your memory in weights and biases website

Also use torchtune/pytorch nightlies for maximum performance: https://github.com/pytorch/torchtune#install-nightly-release

@jetstudio-io
Copy link
Author

Thanks for yours advice, I'll try to test the token/s

@joecummings joecummings added the question Further information is requested label Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants