Skip to content

✨[Feature] sampling within the model #3459

Open
@jjh42

Description

@jjh42

Exporting a model that uses torch.Categorical().sample to sample from the logits.

I currently have a (fixed length) loop within a torch.compile graph that includes sampling from the logits to choose an output and feeding that in as the next input (a standard auto-regressive model).

I see the examples in this repo of gpt2 etc all use greedy sampling (i.e. they're not stochastic) and trying to export my model gives an error

raise UnsupportedOperatorException(
torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::multinomial not currently supported!

Is there any workaround or is sampling not currently possible in tensorrt? I know you can sample outside the model but in my case it is much better encapsulated to have the sampling inside the model.

This can be consider a feature request to support multinomial in torch_tensorrt I guess.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions