diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index c641209046..4da93e0dfe 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -564,7 +564,7 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); - shape[dim] = times; + shape[dim] *= times; let out = tensor.client.tensor_uninitialized(shape); let desc = RepeatOperationDescription { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index e282f11abd..a24f7045b5 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -1620,7 +1620,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); - shape[dim] = times; + shape[dim] *= times; let out = tensor.client.tensor_uninitialized(shape); let desc = RepeatOperationDescription { diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 2e691707bf..c990d3f25c 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1665,7 +1665,7 @@ impl IntTensorOps for Fusion { let stream = tensor.stream; let mut shape = tensor.shape.clone(); - shape[dim] = times; + shape[dim] *= times; let out = tensor.client.tensor_uninitialized(shape); let desc = RepeatOperationDescription { diff --git a/crates/burn-jit/src/kernel/index/repeat.rs b/crates/burn-jit/src/kernel/index/repeat.rs index 3c4dc0a52d..047578287f 100644 --- a/crates/burn-jit/src/kernel/index/repeat.rs +++ b/crates/burn-jit/src/kernel/index/repeat.rs @@ -38,19 +38,21 @@ impl RepeatComputeShader { let stride_input = scope.create_local(Elem::UInt); let stride_output = scope.create_local(Elem::UInt); - let shape_output = scope.create_local(Elem::UInt); + let shape = scope.create_local(Elem::UInt); for i in 0..self.rank { + gpu!(scope, stride_input = stride(input, i)); + gpu!(scope, stride_output = stride(output, i)); if i != self.dim { - gpu!(scope, stride_input = stride(input, i)); - gpu!(scope, stride_output = stride(output, i)); - gpu!(scope, shape_output = shape(output, i)); - - gpu!(scope, offset_local = id / stride_output); - gpu!(scope, offset_local = offset_local % shape_output); - gpu!(scope, offset_local = offset_local * stride_input); - gpu!(scope, offset_input += offset_local); + gpu!(scope, shape = shape(output, i)); + } else { + gpu!(scope, shape = shape(input, i)); } + + gpu!(scope, offset_local = id / stride_output); + gpu!(scope, offset_local = offset_local % shape); + gpu!(scope, offset_local = offset_local * stride_input); + gpu!(scope, offset_input += offset_local); } let result = scope.create_local(input.item()); @@ -108,12 +110,9 @@ pub(crate) fn repeat( times: usize, ) -> JitTensor { let mut shape = input.shape.clone(); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } // Create output handle - shape.dims[dim] = times; + shape.dims[dim] *= times; let num_elems_output = shape.num_elements(); let handle = input .client diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 2e8cfc588e..8fc5060047 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -564,10 +564,6 @@ where } /// Repeat the tensor along the given dimension. - /// - /// # Panics - /// - /// If the selected dimension more than one item. pub fn repeat(self, dim: usize, times: usize) -> Self { Self::new(K::repeat(self.primitive, dim, times)) } diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index cd4e988181..b302d3e44c 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -1,4 +1,7 @@ -use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor}; +use super::{ + cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor, + IntTensor, +}; use crate::{ backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor, }; @@ -174,28 +177,12 @@ pub trait BoolTensorOps { dim: usize, times: usize, ) -> BoolTensor { - let mut shape = Self::bool_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; - - let mut i = 0; - let ranges_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor)); - for i in 0..times { - let mut ranges = ranges_select_all.clone(); - ranges[dim] = i..i + 1; - tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone()); - } - - tensor_output + repeat_with_slice_assign::( + Tensor::::from_primitive(tensor), + dim, + times, + ) + .into_primitive() } /// Concatenates the tensors along the given dimension. diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index bbbba372a8..8f3fc56d51 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1,4 +1,5 @@ use super::cat::cat_with_slice_assign; +use super::repeat::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use crate::Tensor; use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int}; @@ -270,28 +271,12 @@ pub trait IntTensorOps { dim: usize, times: usize, ) -> IntTensor { - let mut shape = Self::int_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone()); - } - - tensor_output + repeat_with_slice_assign::( + Tensor::::from_primitive(tensor), + dim, + times, + ) + .into_primitive() } /// Concatenates the given tensors along the given dimension. diff --git a/crates/burn-tensor/src/tensor/ops/modules/cat.rs b/crates/burn-tensor/src/tensor/ops/modules/cat.rs index ad0627927a..0dbc526ddc 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/cat.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/cat.rs @@ -19,10 +19,8 @@ pub(crate) fn cat_with_slice_assign let mut i = 0; let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; i += 1; - start..end + 0..shape.dims[i - 1] }); let mut output_index = 0; diff --git a/crates/burn-tensor/src/tensor/ops/modules/mod.rs b/crates/burn-tensor/src/tensor/ops/modules/mod.rs index 695683a122..33b9576753 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/mod.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/mod.rs @@ -3,6 +3,8 @@ pub mod conv; /// Module with cat operation pub(crate) mod cat; +/// Module with repeat operation +pub(crate) mod repeat; /// Module with unfold operations. pub(crate) mod unfold; diff --git a/crates/burn-tensor/src/tensor/ops/modules/repeat.rs b/crates/burn-tensor/src/tensor/ops/modules/repeat.rs new file mode 100644 index 0000000000..1837cff938 --- /dev/null +++ b/crates/burn-tensor/src/tensor/ops/modules/repeat.rs @@ -0,0 +1,36 @@ +use crate::{backend::Backend, BasicOps, Tensor, TensorKind}; + +pub(crate) fn repeat_with_slice_assign< + B: Backend, + const D: usize, + K: TensorKind + BasicOps, +>( + tensor: Tensor, + dim: usize, + times: usize, +) -> Tensor { + let mut shape = tensor.shape(); + let device = tensor.device(); + + let original_dim_length = shape.dims[dim]; + shape.dims[dim] *= times; + + let mut tensor_output = Tensor::empty(shape.clone(), &device); + + let mut i = 0; + let indices_select_all = [0; D].map(|_| { + i += 1; + 0..shape.dims[i - 1] + }); + + let mut output_index = 0; + for _ in 0..times { + let mut indices = indices_select_all.clone(); + indices[dim] = output_index..output_index + original_dim_length; + output_index += original_dim_length; + + tensor_output = tensor_output.slice_assign(indices, tensor.clone()); + } + + tensor_output +} diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 292cc525c0..27b28b3070 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -1,4 +1,5 @@ use super::cat::cat_with_slice_assign; +use super::repeat::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor}; use crate::backend::BackendBridge; use crate::Tensor; @@ -193,28 +194,8 @@ pub trait FloatTensorOps { dim: usize, times: usize, ) -> FloatTensor { - let mut shape = B::float_shape(&tensor); - if shape.dims[dim] != 1 { - panic!("Can only repeat dimension with dim=1"); - } - shape.dims[dim] = times; - - let mut i = 0; - let indices_select_all = [0; D].map(|_| { - let start = 0; - let end = shape.dims[i]; - i += 1; - start..end - }); - - let mut tensor_output = B::float_empty(shape, &B::float_device(&tensor)); - for i in 0..times { - let mut indices = indices_select_all.clone(); - indices[dim] = i..i + 1; - tensor_output = B::float_slice_assign(tensor_output, indices, tensor.clone()); - } - - tensor_output + repeat_with_slice_assign::(Tensor::::from_primitive(tensor), dim, times) + .into_primitive() } /// Adds two tensors together. diff --git a/crates/burn-tensor/src/tests/ops/repeat.rs b/crates/burn-tensor/src/tests/ops/repeat.rs index 46c159c885..710a29bc92 100644 --- a/crates/burn-tensor/src/tests/ops/repeat.rs +++ b/crates/burn-tensor/src/tests/ops/repeat.rs @@ -45,4 +45,66 @@ mod tests { let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_float_repeat_on_dims_larger_than_1() { + let data = Data::from([ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let data_actual = tensor.repeat(2, 2).into_data(); + + let data_expected = Data::from([ + [[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]], + [[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]], + [[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]], + [[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]], + ]); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_repeat_on_dims_larger_than_1() { + let data = Data::from([ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[9, 10], [11, 12]], + [[13, 14], [15, 16]], + ]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let data_actual = tensor.repeat(2, 3).into_data(); + + let data_expected = Data::from([ + [[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]], + [[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]], + [[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]], + [[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]], + ]); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_bool_repeat_on_dims_larger_than_1() { + let data = Data::from([ + [[false, true], [true, false]], + [[true, true], [false, false]], + ]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let data_actual = tensor.repeat(1, 2).into_data(); + + let data_expected = Data::from([ + [[false, true], [true, false], [false, true], [true, false]], + [[true, true], [false, false], [true, true], [false, false]], + ]); + + assert_eq!(data_expected, data_actual); + } }