-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
is fwd_kvcache compatible with torch.compile in 2.7.2post1 ? #1386
Comments
Sure, would love to see some PR fixing this |
@ani300 in case you know how to fix this. |
By the way: there is a slight speed regression for inference with kvcache between 2.5.9.post1 and 2.6.1 |
Can you send a short script to reproduce the speed regression? e.g. with this input, 2.5.9.post1 gets XXX seconds and 2.6.1 gets YYY seconds |
@vince62s I probably forgot to add the torch.compile() wrapping for this function when I did the rest. I can probably take a stab at it later in the week, as I'm wrapped up with a work deadline until Wednesday |
I am lazy so I did not recompile 2.6.1 which takes too long to compile but 2.6.1 and 2.7.2post1 are similar in speed.
With 2.5.9post1: 28.4653 sec That's 3% but my real world use case says 4% (maybe because I use Rotary cos/sin also) |
Getting this warning and then many subsequent recompiles because using dynamic shapes (and dynamic=True in torch.compile)
/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd_kvcache. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
The text was updated successfully, but these errors were encountered: