Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input. #79

Open
Cyril9227 opened this issue Feb 14, 2024 · 10 comments

Comments

@Cyril9227
Copy link

Hi everyone,

Thanks for the awesome work. I've been trying to export the pytorch model to ONNX for inference with torch.onnx.export but it yields this error : OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input.

Unfortunately, It seems 5D grid_sample is still unsupported by onnx / torch. Is there any alternative available ? Or any advice to make the model work with ONNX ?

Thanks

@SergeySandler
Copy link

SergeySandler commented Feb 16, 2024

@Cyril9227, torch.onnx.export() fails for me too. It seems like the cause is described in pytorch/pytorch#100790 that will be addressed through pytorch/pytorch#114801 (ONNX opset 20 support).

In the meantime I was trying to convert to ONNX through Haiku (JAX) -> TensorFlow ->ONNX, using https://dm-haiku.readthedocs.io/en/latest/notebooks/jax2tf.html as a tutorial for Haiku -> TF:

import functools
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tqdm import tqdm
import tree
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils
import tensorflow as tf
import sonnet as snt

checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']
params_vars = tf.nest.map_structure(tf.Variable, params)

def build_online_model_init(frames, query_points):
  """Initialize query features for the query points."""
  model = tapir_model.TAPIR(use_causal_conv=True, bilinear_interp_with_depthwise_conv=False) 

  feature_grids = model.get_feature_grids(frames, is_training=False)
  query_features = model.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
  )
  return query_features

init_tf = hk.transform(build_online_model_init) 

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = params   
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, None, x), enable_xla=False)
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)

net = JaxModule(params_vars,  init_tf.apply)

# frames: [num_frames, height, width, 3], query_points: [num_points, 3] where 3 for the tuple (t, y, x)
@tf.function(autograph=False, input_signature=[{"frames" : tf.TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32), 
                                                "query_points": tf.TensorSpec(shape=(20,3), dtype=tf.float32)}]) 
def forward(x):
  return net(x)

to_save = tf.Module()
to_save.forward = forward
to_save.params = list(net.variables)
tf.saved_model.save(to_save, "TapirInit")  

but it fails with TypeError: build_online_model_init() missing 1 required positional argument: 'query_points'. Similar with build_online_model_predict(). Maybe the input_signature() is incorrect in tf.function(), but I cannot figure out how to fix it.
Have you tried the TF path?

Since tf2onnx only supports ONNX opset up to 18, the TF SavedModel to ONNX conversion is likely to have the same problem as with PyTorch :(

@saikiran321
Copy link

saikiran321 commented Feb 17, 2024

@Cyril9227 I have posted a solution here https://github.com/pytorch/pytorch/issues/100790. See if that works for you

@SergeySandler
Copy link

@saikiran321, the solution you have posted does not produce the unsupported ONNX error related to opset 20 support.
Instead, torch.onnx.export fails with ValueError: only one element tensors can be converted to Python scalars.
A docker file and a Python code to reproduce the result are in the zip file attached torch2onnx.zip.
Do you know what could be the cause for this error? Thank you.

@cdoersch
Copy link
Collaborator

cdoersch commented Feb 21, 2024

I'm no expert on ONNX, but if the problem is a 5D gather operation, then I suspect the source of the problem is extracting query features. It's possible to rewrite the vmap using a 4D gather; it wastes computation, but it's probably relatively small compared to the rest of the model. Try setting parallelize_query_extraction to True when contstructing the tapir model; it should produce exactly the same result given the same checkpoint, but hopefully it will avoid the 5D gather.

As a bit of an explanation, when extracting the query feature, we get a [t,y,x] coordinate and use bilinear interpolation to extract a feature from that location. The parallelize_query_extraction version instead extracts the feature at [y,x] from every frame (using a vmapped 4D gather), and then multiplies the resulting tensor by a 1-hot t vector to discard every query feature except the one on frame t.

Of course, this is only implemented the jax version; you'd have to re-implement the same algorithm in the torch model to export from torch.

@zmtttt
Copy link

zmtttt commented May 28, 2024

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

@larrygoyeau
Copy link

Hi

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

Hi! Same error, did you succeed to solve this?

@ibaiGorordo
Copy link

ibaiGorordo commented Aug 10, 2024

I modified the torch model for the case of t=1 and reduced all the 5D to 4D, among other changes: https://github.com/ibaiGorordo/Tapir-Pytorch-Inference

I also added a script to export the model but it is very slow when running in onnxruntime compared to Pytorch (RTX4080): ~700 ms without refinement and ~20s with 4 iterations (1000 points 640x640)

@SergeySandler
Copy link

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

@ibaiGorordo
Copy link

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

I added the inference time calculation on the onnx_export.py script.

CPU is faster:
tapir_onnx_cpu

Than CUDA:
tapir_onnx_cuda

The slow part seems to be with the convolutions in the pips mixer block

@SergeySandler
Copy link

SergeySandler commented Aug 13, 2024

@ibaiGorordo, I have reproduced tapir.onnx and it is three times slower than Pytorch with CUDA device.
My results on Windows: PyTorch inference takes around 0.1 sec on CUDA, 3 sec on CPU; ONNX - 0.3 sec with DmlExecutionProvider, 3 sec with CPUExecutionProvider.

There are a couple of hints for Windows that might be useful, especially if your results with ONNX are worse than with CPU:

  1. Do not forget to add device_id:your_card_ID (that is 0 in my case) in
    predictor = onnxruntime.InferenceSession(f'{output_dir}/tapir.onnx', providers = ['DmlExecutionProvider'], provider_options=[{'device_id':0}]) , otherwise it might use integreated Intel graphics card instead of NVIDIA card,
  2. Without pip install onnxruntime-directml DmlExecutionProvider is not available in Windows.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants