Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize LLaMA for inference #513

Merged
merged 4 commits into from
Nov 14, 2023
Merged

Optimize LLaMA for inference #513

merged 4 commits into from
Nov 14, 2023

Conversation

mryab
Copy link
Member

@mryab mryab commented Sep 18, 2023

Similarly to #500, this PR aims to speed up Llama models by making the following optimizations compared to the original Transformers implementation:

  • Position indices are not generated, because rotary embeddings do not need them (only the lengths of the prefix and the number of encoded tokens)
  • All operations before and after the attention layer (i.e., RMS normalization and MLP) are fused within a single CUDA graph
  • Similarly, the rotary PE function is fused as a CUDA graph

Additionally, this PR introduces a petals.utils.cuda_graphs.make_inference_graphed_callable function that converts any inference-mode callable into its CUDA graph version. This is meant to serve as an alternative for torch.cuda.make_graphed_callables that does not attempt to build a graph for the backward pass: inference is called in inference_mode, so the original function fails (that's why the Falcon PR used custom graph tracing as well)

@borzunov borzunov changed the title [WIP] Optimize LLaMa for inference [WIP] Optimize Llama for inference Sep 19, 2023
@borzunov
Copy link
Collaborator

borzunov commented Sep 20, 2023

Benchmarks: this PR gives +44% to inference speed

Model: Stable Beluga 2 (70B)
GPU: A6000 Ada

main @ a2484b3:

Sep 20 03:08:50.845 [INFO] Inference throughput: 750.6 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)                      
Sep 20 03:09:04.064 [INFO] Forward pass throughput: 48486.8 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

optimize_llama @ f332b0e:

Sep 20 03:10:13.415 [INFO] Inference throughput: 1078.9 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 20 03:10:26.583 [INFO] Forward pass throughput: 48003.5 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

@mryab mryab requested a review from borzunov September 20, 2023 12:01
@mryab mryab marked this pull request as ready for review September 20, 2023 12:13
@mryab mryab changed the title [WIP] Optimize Llama for inference Optimize Llama for inference Sep 20, 2023
tests/test_optimized_layers.py Outdated Show resolved Hide resolved
tests/test_optimized_layers.py Outdated Show resolved Hide resolved
@mryab mryab changed the title Optimize Llama for inference Optimize LLaMA for inference Oct 8, 2023
@justheuristic justheuristic merged commit 03cbe90 into main Nov 14, 2023
11 checks passed
@justheuristic justheuristic deleted the optimize_llama branch November 14, 2023 17:14
@poedator
Copy link
Collaborator

When testing with TinyLlama for some unrelated thing I caught error Caught too many indices for tensor of dimension 2
It happened in this line cos = cos[:, :, kv_seq_len - q_len :]
https://github.com/bigscience-workshop/petals/pull/513/files#diff-492af4f870c9613ff6b5fce973ddd1d75bf135b30f40a7cb83f455c4f0e72ea6R87
Env: Tranformers 4.35.2
ref to test run https://github.com/bigscience-workshop/petals/actions/runs/6950529337/job/18910867509?pr=545 - see line 2755
@mryab ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants