You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I'm trying to add new ONNX models tothe onnx-tests crate that contain constant int tensors. When trying to load them for testing, the generated code doesn't compile.
error[E0308]: mismatched types
--> /opt/workspace/caches/cargo/debug/build/onnx-tests-a831f24a8189216e/out/model/constant_tensor_i32.rs:37:65
|
37 | let constant1: burn::module::Param<Tensor<B, 2, Int>> = burn::nn::Initializer::Zeros
| ________________________--------------------------------------___^
| | |
| | expected due to this
38 | | .init([2, 2], device)
39 | | .set_require_grad(false);
| |____________________________________^ expected `Param<Tensor<B, 2, Int>>`, found `Param<Tensor<_, _>>`
|
= note: expected struct `Param<Tensor<B, burn::tensor::Int, 2>>`
found struct `Param<Tensor<_, burn::tensor::Float, _>>`
To Reproduce
In this branch, I generated a couple ONNX models with constant int tensors in them (constant_tensor_i32.onnx and constant_tensor_i64.onnx) from constant_tensor.py. I can run cargo -p burn-import over them fine, but when trying to load them for ONNX tests it results in the above error at compile time.
The models are currently commented out in that branch so that they aren't loaded during ONNX tests. Uncommenting these lines then trying to run the tests (cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml) gives the error
I'm pretty sure this affects const bool tensors also but don't have an easy way to reproduce.
Expected behavior
There shouldn't be a compile error - the generated code should compile.
Screenshots
In Netron the models have the expected constant tensors.
Desktop (please complete the following information):
OS: macOS 15.2
Version: based off of 8a89293 , ConstantNode functionality in the linked test branch shouldn't be different from that commit
Additional context
The generated code for constant tensors looks something like:
let constant1: burn::module::Param<Tensor<B,2,Int>> = burn::nn::Initializer::Zeros.init([2,2], device).set_require_grad(false);
I think .set_require_grad may be coercing to a Float tensor? I've been trying to implement conversion of the ONNX OneHot op, there ends up being a constant int tensor [0, 1] in the generated model. I was able to get things compiling by overriding the ConstantNode codegen differently depending on whether the tensor was Int/Float/Bool here - https://github.com/jameshiew/burn/pull/1/files#diff-53910fea1e19653a4ee950445bb349d29c1ecacaabdf33e4f230a2f423b1e54eR124-R134 - so that the generated code looks like:
let constant2: burn::module::Param<Tensor<B,1,Int>> = burn::nn::zeros_int([2], device)
should be equivalent to
let constant2: burn::module::Param<Tensor<B,1,Int>> = Tensor::<B,1,Int>::zeros([2], device)
The text was updated successfully, but these errors were encountered:
Describe the bug
I'm trying to add new ONNX models tothe
onnx-tests
crate that contain constant int tensors. When trying to load them for testing, the generated code doesn't compile.To Reproduce
In this branch, I generated a couple ONNX models with constant int tensors in them (
constant_tensor_i32.onnx
andconstant_tensor_i64.onnx
) fromconstant_tensor.py
. I can runcargo -p burn-import
over them fine, but when trying to load them for ONNX tests it results in the above error at compile time.The models are currently commented out in that branch so that they aren't loaded during ONNX tests. Uncommenting these lines then trying to run the tests (
cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml
) gives the errortest_onnx.rs
build.rs
I'm pretty sure this affects const bool tensors also but don't have an easy way to reproduce.
Expected behavior
There shouldn't be a compile error - the generated code should compile.
Screenshots
In Netron the models have the expected constant tensors.
Desktop (please complete the following information):
OS: macOS 15.2
Version: based off of 8a89293 , ConstantNode functionality in the linked test branch shouldn't be different from that commit
Additional context
The generated code for constant tensors looks something like:
I think
.set_require_grad
may be coercing to a Float tensor? I've been trying to implement conversion of the ONNX OneHot op, there ends up being a constant int tensor[0, 1]
in the generated model. I was able to get things compiling by overriding theConstantNode
codegen differently depending on whether the tensor was Int/Float/Bool here - https://github.com/jameshiew/burn/pull/1/files#diff-53910fea1e19653a4ee950445bb349d29c1ecacaabdf33e4f230a2f423b1e54eR124-R134 - so that the generated code looks like:should be equivalent to
The text was updated successfully, but these errors were encountered: