Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generated code for test ONNX model containing a constant int tensor doesn't compile #2625

Open
jameshiew opened this issue Dec 17, 2024 · 0 comments
Labels
bug Something isn't working onnx

Comments

@jameshiew
Copy link

jameshiew commented Dec 17, 2024

Describe the bug
I'm trying to add new ONNX models tothe onnx-tests crate that contain constant int tensors. When trying to load them for testing, the generated code doesn't compile.

error[E0308]: mismatched types
  --> /opt/workspace/caches/cargo/debug/build/onnx-tests-a831f24a8189216e/out/model/constant_tensor_i32.rs:37:65
   |
37 |           let constant1: burn::module::Param<Tensor<B, 2, Int>> = burn::nn::Initializer::Zeros
   |  ________________________--------------------------------------___^
   | |                        |
   | |                        expected due to this
38 | |             .init([2, 2], device)
39 | |             .set_require_grad(false);
   | |____________________________________^ expected `Param<Tensor<B, 2, Int>>`, found `Param<Tensor<_, _>>`
   |
   = note: expected struct `Param<Tensor<B, burn::tensor::Int, 2>>`
              found struct `Param<Tensor<_, burn::tensor::Float, _>>`

To Reproduce
In this branch, I generated a couple ONNX models with constant int tensors in them (constant_tensor_i32.onnx and constant_tensor_i64.onnx) from constant_tensor.py. I can run cargo -p burn-import over them fine, but when trying to load them for ONNX tests it results in the above error at compile time.

The models are currently commented out in that branch so that they aren't loaded during ONNX tests. Uncommenting these lines then trying to run the tests (cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml) gives the error

I'm pretty sure this affects const bool tensors also but don't have an easy way to reproduce.

Expected behavior
There shouldn't be a compile error - the generated code should compile.

Screenshots
In Netron the models have the expected constant tensors.

Screenshot 2024-12-17 at 18 55 20 Screenshot 2024-12-17 at 18 55 47

Desktop (please complete the following information):
OS: macOS 15.2
Version: based off of 8a89293 , ConstantNode functionality in the linked test branch shouldn't be different from that commit

Additional context
The generated code for constant tensors looks something like:

        let constant1: burn::module::Param<Tensor<B, 2, Int>> = burn::nn::Initializer::Zeros
            .init([2, 2], device)
            .set_require_grad(false);

I think .set_require_grad may be coercing to a Float tensor? I've been trying to implement conversion of the ONNX OneHot op, there ends up being a constant int tensor [0, 1] in the generated model. I was able to get things compiling by overriding the ConstantNode codegen differently depending on whether the tensor was Int/Float/Bool here - https://github.com/jameshiew/burn/pull/1/files#diff-53910fea1e19653a4ee950445bb349d29c1ecacaabdf33e4f230a2f423b1e54eR124-R134 - so that the generated code looks like:

let constant2: burn::module::Param<Tensor<B, 1, Int>> = burn::nn::zeros_int([2], device)

should be equivalent to

let constant2: burn::module::Param<Tensor<B, 1, Int>> = Tensor::<B, 1, Int>::zeros([2], device)
@laggui laggui added bug Something isn't working onnx labels Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working onnx
Projects
None yet
Development

No branches or pull requests

2 participants