diff --git a/crates/burn-tensor/src/tests/ops/cast.rs b/crates/burn-tensor/src/tests/ops/cast.rs index bc45518c38..3f760d6766 100644 --- a/crates/burn-tensor/src/tests/ops/cast.rs +++ b/crates/burn-tensor/src/tests/ops/cast.rs @@ -46,6 +46,7 @@ mod tests { let output = tensor.cast(DType::F32); assert_eq!(output.dtype(), DType::F32); - output.into_data().assert_approx_eq(&data, 5); + // Use precision 2 for parametrized tests in f16 and bf16 + output.into_data().assert_approx_eq(&data, 2); } }