Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Encountered bug when using Torch-TensorRT #3135

Open
Jeevi10 opened this issue Aug 30, 2024 · 4 comments
Open

🐛 [Bug] Encountered bug when using Torch-TensorRT #3135

Jeevi10 opened this issue Aug 30, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@Jeevi10
Copy link

Jeevi10 commented Aug 30, 2024

Bug Description

https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/mutable_torchtrt_module_example.py
I replaced hugging face whisper model instead of diffusion model

To Reproduce

import numpy as np
import torch
import torch_tensorrt as torch_trt
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
#from datasets import load_dataset
from peft import PeftModel
#import torchvision.models as models

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

%%

Initialize the Mutable Torch TensorRT Module with settings.

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16

model_id = "distil-whisper/distil-large-v3"

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refitable": True,
}

base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, 
torch_dtype=torch_dtype, 
low_cpu_mem_usage=True,
)

base_model.eval().to(device)

#model = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(base_model, **settings)

Steps to reproduce the behavior:

1.Please run above code

Traceback (most recent call last):
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
return tx.inline_user_function_return(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step
self.exception_handler(e)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler
raise raised_exception
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
return super().call_function(tx, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
return super().call_function(tx, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 108, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step
self.exception_handler(e)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler
raise raised_exception
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function
return tx.inline_user_function_return(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step
self.exception_handler(e)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler
raise raised_exception
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
return super().call_function(tx, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
return super().call_function(tx, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 108, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step
self.exception_handler(e)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler
raise raised_exception
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1388, in RAISE_VARARGS
self._raise_exception_variable(inst)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1381, in _raise_exception_variable
raise exc.ObservedException(f"raised exception {val}")
torch._dynamo.exc.ObservedException: raised exception ExceptionVariable()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/jalagurajah/Desktop/ROME/My Documents/ASR_2024/torchrttest.py", line 42, in
mutable_module(*inputs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py", line 391, in call
return self.forward(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py", line 354, in forward
self.compile()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py", line 286, in compile
self.exp_program = torch.export.export(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/init.py", line 258, in export
return _export(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 1007, in wrapper
raise e
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 980, in wrapper
ep = fn(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/exported_program.py", line 97, in wrapper
return fn(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 1915, in _export
export_artifact = export_func( # type: ignore[operator]
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 1214, in _strict_export
return _strict_export_lower_to_aten_ir(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 1242, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/export/_trace.py", line 550, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
result_traced = opt_f(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1244, in call
return self._torchdynamo_orig_callable(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 516, in call
return _compile(
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 908, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 656, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 689, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 210, in _fn
return fn(*args, **kwargs)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 624, in transform
tracer.run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step
self.exception_handler(e)
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1495, in exception_handler
raise Unsupported("Observed exception")
torch._dynamo.exc.Unsupported: Observed exception

from user code:
File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py", line 1764, in forward
outputs = self.model(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Expected behavior

complie without problem

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version : 12.4
  • PyTorch Version : 2.5.0.dev20240830+cu124
  • CPU Architecture: x86
  • OS (e.g., Linux): ubuntu
  • How you installed PyTorch (conda, pip, libtorch, source): nightly install
  • Python version: 3.10
  • CUDA version: 12.4
  • GPU models and configuration: A100 80G
@Jeevi10 Jeevi10 added the bug Something isn't working label Aug 30, 2024
@narendasan
Copy link
Collaborator

Have you tried exporting whisper with torch.export? Does that work properly? Seems like right now that is the step that is failing

@Jeevi10
Copy link
Author

Jeevi10 commented Sep 5, 2024

Have you tried exporting whisper with torch.export? Does that work properly? Seems like right now that is the step that is failing

I was not able to successfully export whisper model

@narendasan
Copy link
Collaborator

narendasan commented Sep 5, 2024

For using the MutableModule, that is a prerequisite step (it'll either be done by you or by us). It might be worth opening an issue on pytorch/pytorch for this.

Alternatively you can also try torch.compile(..., backend="tensorrt") which is a bit more flexible.

@Jeevi10
Copy link
Author

Jeevi10 commented Sep 5, 2024

import torch
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
hf_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]

input_features = processor( audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt" ).input_features

with torch.no_grad(): print(hf_model.generate(input_features))

exported_model = torch.export.export(hf_model, args=(input_features,))
torch.export.save(exported_model, "model.pt")
pt_model = torch.export.load('model.pt')

with torch.no_grad(): print(pt_model.module().generate(input_features))

I recreated the error using this above code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants