Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] nn.BatchNorm3d is not compiling (+ can't disable fast_partitioner) #3168

Open
orioninthesky98 opened this issue Sep 20, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@orioninthesky98
Copy link

orioninthesky98 commented Sep 20, 2024

Bug Description

I can't compile this model and the error seems to be caused by nn.BatchNorm3d

To Reproduce

Steps to reproduce the behavior:

  1. Init the model after importing the classes in this gist: https://gist.github.com/orioninthesky98/8a2012f555b7bd4ce50398ff2a1c9291
conv_block = ConvBlock(
    in_channels=1,
    out_channels=16,
    kernel_size=[1, 1, 3],
    pool_ksize=[1, 1, 2],
)
conv_block = conv_block.to("cuda")
conv_block.eval()
  1. Compile it with this command
batch_size = 128
network_input_shape = [1, 1, 1, 32]

placeholder_batch = torch.rand((batch_size,) + tuple(network_input_shape))
placeholder_batch = placeholder_batch.to("cuda")

compiled_model = trt.compile(
    conv_block,
    inputs=[placeholder_batch],
    enabled_precisions={torch.float32},
    optimization_level=5,  # max is 5, compilation takes longer but gives the best speedup
    debug=True,  # very verbose, only turn on if needed
    use_fast_partitioner=True,  #  cant disable, results in error when exporting
    dynamic=False,
    disable_tf32=True,  # reduce precision errors at the expense of small slowdown
)

Get the below error (here for full trace: https://gist.github.com/orioninthesky98/9e51a9e83232aa3cac64ce68fe0e512b)

DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /cc2_norm/_native_batch_norm_legit_no_training_1 (kind: aten._native_batch_norm_legit_no_training.default, args: ('[CONVOLUTION]-[aten_ops.convolution.default]-[/cc2_conv/convolution_1]_output <tensorrt.ITensor [shape=(81), dtype=DataType.FLOAT]>', '<torch.Tensor as np.ndarray [shape=(16,), dtype=float32]>', '<torch.Tensor as np.ndarray [shape=(16,), dtype=float32]>', '<torch.Tensor as np.ndarray [shape=(16,), dtype=float32]>', '<torch.Tensor as np.ndarray [shape=(16,), dtype=float32]>', 0.1, 1e-05))
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/cc2_conv/convolution_1].)

File ~/.conda/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py:111, in aten_ops_batch_norm_legit_no_training(ctx, target, args, kwargs, name)
     99 @dynamo_tensorrt_converter(
    100     torch.ops.aten._native_batch_norm_legit_no_training.default,
    101     capability_validator=one_user_validator,
   (...)
    109     name: str,
    110 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
--> 111     return impl.normalization.batch_norm(
    112         ctx,
    113         target,
    114         SourceIR.ATEN,
    115         name,
    116         input=args[0],
    117         weight=args[1],
    118         bias=args[2],
    119         running_mean=args[3],
    120         running_var=args[4],
    121         training=False,
    122         momentum=args[5],
    123         eps=args[6],
    124         cudnn_enabled=False,
    125         return_mean_rstd=(
    126             target == torch.ops.aten._native_batch_norm_legit_no_training.default
    127         ),
    128     )

File ~/.conda/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py:65, in batch_norm(ctx, target, source_ir, name, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, return_mean_rstd)
     63 # For BatchNorm1d, reshape 1d to 2d
     64 output_shape = input.shape
---> 65 if len(input.shape) < 4:
     66     assert (
     67         len(get_dynamic_dims(input.shape)) <= 1
     68     ), "BatchNorm1D with more than one dynamic dims is not currently supported."
     69     new_shape = (
     70         (input.shape[0], input.shape[1], 1, 1)
     71         if len(input.shape) == 2
     72         else (input.shape[0], input.shape[1], input.shape[2], 1)
     73     )

ValueError: __len__() should return >= 0

While executing %_native_batch_norm_legit_no_training_1 : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%convolution_1, %cc2_norm_weight, %cc2_norm_bias, %cc2_norm_running_mean, %cc2_norm_running_var, 0.1, 1e-05), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f982e80c6b0>: ((128, 1, 1, 1, 32), torch.float32, False, (32, 32, 32, 32, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f982e8175f0>: ((128, 1, 1, 1, 32), torch.float32, False, (32, 32, 32, 32, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f982e528eb0>: ((128, 16, 1, 1, 32), torch.float32, False, (512, 32, 32, 32, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f982e938c70>: ((128, 16, 1, 1, 32), torch.float32, False, (512, 32, 32, 32, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f982e7ce770>: ((128, 16, 1, 1, 32), torch.float32, False, (512, 32, 32, 32, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f982e96ef30>: ((128, 16, 1, 1, 32), torch.float32, False, (512, 32, 32, 32, 1), torch.contiguous_format, False, {})}})
Original traceback:
    from loguru import logger
    return forward_call(*args, **kwargs)
    h = self.norm(h)

Expected behavior

The model should compile successfully.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4.0
  • PyTorch Version (e.g. 1.0): '2.4.1+cu121'
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux, "Ubuntu 22.04.4 LTS"
  • How you installed PyTorch (conda, pip, libtorch, source): conda + pip
  • Build command you used (if compiling from source): N/A
  • Are you using local sources or building from archives: N/A
  • Python version: 3.10
  • CUDA version: 12.5
  • GPU models and configuration: 1 x H100
  • Any other relevant information: N/A

Additional context

update:
I have found a temporary hotfix by replacing BatchNorm3d with BatchNorm1d. This works for some of our use-cases but we still have many models which do need the full 3d BatchNorm, so a fix would still be greatly appreciated.

  1. replace this line with self.norm = ( nn.BatchNorm1d(in_channels) if self.is_full_preact else nn.BatchNorm1d(out_channels) )
    https://gist.github.com/orioninthesky98/8a2012f555b7bd4ce50398ff2a1c9291#file-model-py-L74
  2. replace this line https://gist.github.com/orioninthesky98/8a2012f555b7bd4ce50398ff2a1c9291#file-model-py-L108-L109 with:
                if self.norm_type == "sn":
                    pass
                elif self.norm_type == "bn":
                    # hotfix for tensorrt compile: squeeze extra 2 dims: (N,C,D,H,W) -> (N,C,H,W) -> (N,C,W)
                    h = h.squeeze(2).squeeze(2)
                    h = self.norm(h)
                    h = h.unsqueeze(2).unsqueeze(2)
                else:
                    h = self.norm(h)

log: https://gist.github.com/orioninthesky98/96612bfd59e35344182de44d9a303aa7

related bug:
if I try to set use_fast_partitioner=False, the model actually compiles fine, but I get this error at the very end and the script crashes, https://gist.github.com/orioninthesky98/a784c361ebbdfa9000564b3f8a1ac1c0) somebody already filed this bug: #3157

@orioninthesky98 orioninthesky98 added the bug Something isn't working label Sep 20, 2024
@orioninthesky98
Copy link
Author

any updates on this?

@lanluo-nvidia lanluo-nvidia self-assigned this Sep 30, 2024
@lanluo-nvidia
Copy link
Collaborator

@orioninthesky98
I have tried the example in the current latest main and our upcoming 2.5.0 release, both are working as expected.
I think the batchnorm3d bug has been fixed.

Also in terms of the use_fast_partitioner=False, bug: #3157 PR is raised and will be merged into main and 2.5.0 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants