Skip to content

Commit

Permalink
Fix (#2269)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 10, 2024
1 parent d3fbdea commit 58ce502
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 23 deletions.
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ tch = "0.15.0"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1b2eeeabfdd6f111f8bac7d4c4d00357d023e15b" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7a86f9a86e376fedb09f096f2b548e501a130883" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7a86f9a86e376fedb09f096f2b548e501a130883" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-autodiff/src/tests/flip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), false); // 1x2x2
grad_2.to_data().assert_eq(
.into_data()
.assert_approx_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), 3); // 1x2x2
grad_2.into_data().assert_approx_eq(
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
false,
3,
); // 1x2x3
}
}
8 changes: 4 additions & 4 deletions crates/burn-autodiff/src/tests/permute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), false); // 1x2x2
grad_2.to_data().assert_eq(
.into_data()
.assert_approx_eq(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), 3); // 1x2x2
grad_2.into_data().assert_approx_eq(
&TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]),
false,
3,
); // 1x3x2
}
}
10 changes: 5 additions & 5 deletions crates/burn-jit/src/kernel/index/select_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
value: &Tensor<F>,
dim: &u32,
) {
let dim2 = *dim;
let dim = *dim;
let mut offset_tensor = 0u32;
let mut offset_value = 0u32;
let mut num_elems = 1u32;

// Calculate offsets and num_elems
for i in 0..tensor.rank() {
if i != dim2 {
if i != dim {
let shape_tensor = tensor.shape(i);

num_elems *= shape_tensor;
Expand All @@ -32,11 +32,11 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
return;
}

let strides_tensor_dim = tensor.stride(dim2);
let strides_value_dim = value.stride(dim2);
let strides_tensor_dim = tensor.stride(dim);
let strides_value_dim = value.stride(dim);

// Main operation
for i in 0..value.shape(dim2) {
for i in 0..value.shape(dim) {
let index_tensor = u32::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
let index_value = i * strides_value_dim + offset_value;

Expand Down

0 comments on commit 58ce502

Please sign in to comment.