Skip to content

✨[Feature] sampling within the model #3459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jjh42 opened this issue Apr 1, 2025 · 3 comments
Open

✨[Feature] sampling within the model #3459

jjh42 opened this issue Apr 1, 2025 · 3 comments
Assignees
Labels
feature request New feature or request

Comments

@jjh42
Copy link

jjh42 commented Apr 1, 2025

Exporting a model that uses torch.Categorical().sample to sample from the logits.

I currently have a (fixed length) loop within a torch.compile graph that includes sampling from the logits to choose an output and feeding that in as the next input (a standard auto-regressive model).

I see the examples in this repo of gpt2 etc all use greedy sampling (i.e. they're not stochastic) and trying to export my model gives an error

raise UnsupportedOperatorException(
torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::multinomial not currently supported!

Is there any workaround or is sampling not currently possible in tensorrt? I know you can sample outside the model but in my case it is much better encapsulated to have the sampling inside the model.

This can be consider a feature request to support multinomial in torch_tensorrt I guess.

@jjh42 jjh42 added the feature request New feature or request label Apr 1, 2025
@narendasan
Copy link
Collaborator

Can you provide a reproducer of this issue? The simplest way is to probably have the sampling in a PyTorch block since I'm not sure if TRT can handle it. What is odd here is that you are getting past capability partitioning.

You can try doing torch_exectued_ops=[torch.ops.aten.multinomial]

@jjh42
Copy link
Author

jjh42 commented Apr 2, 2025

thanks, will give a repro soon, I think its really anything with a Categorical.

If you use a torch_executed_op then you won't be able to run the model with the tensorrt C++ runtime?

@narendasan
Copy link
Collaborator

You can trace with torch.jit.trace and still use it with the libtorchtrt_runtime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants