-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize LLaMA for inference #513
Conversation
Benchmarks: this PR gives +44% to inference speedModel: Stable Beluga 2 (70B) main @ a2484b3:
optimize_llama @ f332b0e:
|
f332b0e
to
57762b4
Compare
When testing with TinyLlama for some unrelated thing I caught error |
Similarly to #500, this PR aims to speed up Llama models by making the following optimizations compared to the original Transformers implementation:
Additionally, this PR introduces a
petals.utils.cuda_graphs.make_inference_graphed_callable
function that converts any inference-mode callable into its CUDA graph version. This is meant to serve as an alternative fortorch.cuda.make_graphed_callables
that does not attempt to build a graph for the backward pass: inference is called in inference_mode, so the original function fails (that's why the Falcon PR used custom graph tracing as well)