From c464a9372cf7e1f6271c0b7993026e0795d6eb97 Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 1 Dec 2024 22:34:02 +0900 Subject: [PATCH 01/20] add one hot with axis and values function --- crates/burn-tensor/src/tensor/api/int.rs | 59 ++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..6aff95ecd9 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -129,4 +129,63 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Create a one-hot encoded tensor with configurable `on_value`, `off_value`, and `axis`. + /// + /// # Arguments + /// + /// * `depth` - The number of classes for one-hot encoding. + /// * `on_value` - The value to use for the "on" positions. + /// * `off_value` - The value to use for the "off" positions. + /// * `axis` - The axis along which to perform one-hot encoding. + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// fn example() { + /// let device = B::Device::default(); + /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 0]], &device); + /// let indices: Tensor = Tensor::from_ints([[0, 2], [1, -1]], &device); + /// // One-hot encoding + /// let result = indices.one_hot_with_axis(3, 5, 0, -1); + /// assert_eq!(expected.to_data(), result.to_data()); + /// } + /// ``` + pub fn one_hot_with_axis( + self, + depth: usize, + on_value: i32, + off_value: i32, + axis: isize, + ) -> Tensor { + let indices_shape = self.dims(); + let rank = indices_shape.len(); + let actual_axis = if axis < 0 { + (rank as isize + axis + 1) as usize + } else { + axis as usize + }; + + assert!( + actual_axis <= rank, + "Axis {} out of bounds for tensor with rank {}", + actual_axis, + rank + ); + + // Create the shape for the result tensor + let mut result_shape = indices_shape.to_vec(); + result_shape.insert(actual_axis, depth); + + // Prepare the on_value and off_value tensors + let device = &self.device(); + let on_tensor = Tensor::full(result_shape.clone(), on_value, device); + let off_tensor = Tensor::full(result_shape.clone(), off_value, device); + + // Broadcast indices to the appropriate shape for scattering + let indices_expanded = self.unsqueeze_dim(actual_axis); + + // Create a zero tensor and scatter on_value + off_tensor.scatter(actual_axis, indices_expanded, on_tensor) + } } From e09e4107d033a45d98ab2057a8ba964080665819 Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 11 Dec 2024 22:24:09 +0900 Subject: [PATCH 02/20] update one hot multidimentional function --- crates/burn-tensor/src/tensor/api/int.rs | 65 +++++++++++---------- crates/burn-tensor/src/tests/ops/one_hot.rs | 12 ++++ 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 6aff95ecd9..b51a3e04ff 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -144,48 +144,53 @@ where /// use burn_tensor::{Int, Tensor}; /// fn example() { /// let device = B::Device::default(); - /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 0]], &device); + /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 5]], &device); /// let indices: Tensor = Tensor::from_ints([[0, 2], [1, -1]], &device); /// // One-hot encoding - /// let result = indices.one_hot_with_axis(3, 5, 0, -1); + /// let result = indices.one_hot_with_axis_and_values(3, 5, 0, -1); /// assert_eq!(expected.to_data(), result.to_data()); /// } /// ``` - pub fn one_hot_with_axis( + pub fn one_hot_with_axis_and_values( self, depth: usize, - on_value: i32, - off_value: i32, - axis: isize, + on_value: i64, + off_value: i64, + axis: i64, ) -> Tensor { - let indices_shape = self.dims(); - let rank = indices_shape.len(); - let actual_axis = if axis < 0 { - (rank as isize + axis + 1) as usize - } else { - axis as usize - }; - assert!( - actual_axis <= rank, - "Axis {} out of bounds for tensor with rank {}", - actual_axis, - rank + let mut shape = self.shape().dims::().to_vec(); + let rank = self.dims().len(); + let axis = if axis < 0 { + axis + rank as i64 + 1 // Convert negative axis to positive index + } else { + axis + }; + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + shape.insert(axis as usize, depth); + let condition1 = self.clone().greater_elem(-1 * depth as i64).int(); + let condition2 = self.clone().lower_elem(depth as i64).int(); + let valid_mask = condition1.mul(condition2).bool().bool_not(); + let adjusted_indices = self + .clone() + .mask_fill(self.clone().lower_elem(0), depth as i64) + .add( + self + .clone() + .mask_fill(self.clone().greater_elem(0), 0), ); - // Create the shape for the result tensor - let mut result_shape = indices_shape.to_vec(); - result_shape.insert(actual_axis, depth); - - // Prepare the on_value and off_value tensors - let device = &self.device(); - let on_tensor = Tensor::full(result_shape.clone(), on_value, device); - let off_tensor = Tensor::full(result_shape.clone(), off_value, device); + let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); + let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); - // Broadcast indices to the appropriate shape for scattering - let indices_expanded = self.unsqueeze_dim(actual_axis); + let output= Tensor::full(shape.clone(), off_value, &self.device()); + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &self.device()); + let scatter_off_values = Tensor::full(indices_unsqueezed.shape(), -off_value, &self.device()); - // Create a zero tensor and scatter on_value - off_tensor.scatter(actual_axis, indices_expanded, on_tensor) + output + .scatter(axis as usize, indices_unsqueezed.clone(), scatter_on_values) + .scatter(axis as usize, indices_unsqueezed, scatter_off_values) } } diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 310399119f..fb4976cfb2 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -71,4 +71,16 @@ mod tests { let index_tensor = TestTensorInt::<1>::arange(0..3, &device); let one_hot_tensor = index_tensor.one_hot(1); } + + #[test] + #[should_panic] + fn int_one_hot_with() { + let device = Default::default(); + let index_tensor = TestTensorInt::<1>::arange(0..3, &device); + let expected = TestTensorInt::eye(5, &device).into_data(); + + let one_hot_tensor = index_tensor.one_hot_with_axis_and_values(3, 5, 0, -1); + + one_hot_tensor.into_data().assert_eq(&expected, false); + } } From 47d952901edf6f464535af14def2e5636fc75061 Mon Sep 17 00:00:00 2001 From: tiruka Date: Fri, 13 Dec 2024 15:14:35 +0900 Subject: [PATCH 03/20] implementing on numeric.rs --- crates/burn-tensor/src/tensor/api/numeric.rs | 60 ++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 1f04aeaac8..9b35e12af3 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2030,7 +2030,67 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one-hot encoded tensor with configurable `on_value`, `off_value`, and `axis`. + /// + /// # Arguments + /// + /// * `depth` - The number of classes for one-hot encoding. + /// * `on_value` - The value to use for the "on" positions. + /// * `off_value` - The value to use for the "off" positions. + /// * `axis` - The axis along which to perform one-hot encoding. + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// fn example() { + /// let device = B::Device::default(); + /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 5]], &device); + /// let indices: Tensor = Tensor::from_ints([[0, 2], [1, -1]], &device); + /// // One-hot encoding + /// let result = indices.one_hot_with_axis_and_values(3, 5, 0, -1); + /// assert_eq!(expected.to_data(), result.to_data()); + /// } + /// ``` + pub fn one_hot_with_axis_and_values2( + self, + depth: usize, + on_value: K::Elem, + off_value: K::Elem, + axis: i64, + ) -> Tensor + { + let mut shape = self.shape().dims().to_vec(); + let rank = self.dims().len(); + let axis = if axis < 0 { + axis + rank as i64 + 1 // Convert negative axis to positive index + } else { + axis + }; + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + shape.insert(axis as usize, depth); + let condition1 = self.clone().greater_elem(-1 * depth as i64).int(); + let condition2 = self.clone().lower_elem(depth as i64).int(); + let valid_mask = condition1.mul(condition2).bool().bool_not(); + let adjusted_indices = self + .clone() + .mask_fill(self.clone().lower_elem(0), depth as i64) + .add( + self + .clone() + .mask_fill(self.clone().greater_elem(0), 0), + ); + + let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); + let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); + + let output= Tensor::full(shape.clone(), off_value, &self.device()); + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + output.scatter(axis as usize, indices_unsqueezed.clone(), scatter_on_values); + } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// /// # Returns From 07f069f8e2b7e6e109fda7649b517948bb56130f Mon Sep 17 00:00:00 2001 From: tiruka Date: Fri, 13 Dec 2024 23:11:31 +0900 Subject: [PATCH 04/20] update one hot method in numeric --- crates/burn-tensor/src/tensor/api/numeric.rs | 73 ++++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 9b35e12af3..52ec09cadc 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2041,14 +2041,13 @@ where /// # Example /// ```rust /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// fn example() { + /// use burn_tensor::{Int, Tensor, Float}; + /// fn example>>() { /// let device = B::Device::default(); - /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 5]], &device); - /// let indices: Tensor = Tensor::from_ints([[0, 2], [1, -1]], &device); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding - /// let result = indices.one_hot_with_axis_and_values(3, 5, 0, -1); - /// assert_eq!(expected.to_data(), result.to_data()); + /// let tensor = indices.one_hot_with_axis_and_values2(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); /// } /// ``` pub fn one_hot_with_axis_and_values2( @@ -2059,37 +2058,37 @@ where axis: i64, ) -> Tensor { - let mut shape = self.shape().dims().to_vec(); - let rank = self.dims().len(); - let axis = if axis < 0 { - axis + rank as i64 + 1 // Convert negative axis to positive index - } else { - axis - }; - if axis < 0 || axis > rank as i64 { - panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); - } - shape.insert(axis as usize, depth); - let condition1 = self.clone().greater_elem(-1 * depth as i64).int(); - let condition2 = self.clone().lower_elem(depth as i64).int(); - let valid_mask = condition1.mul(condition2).bool().bool_not(); - let adjusted_indices = self - .clone() - .mask_fill(self.clone().lower_elem(0), depth as i64) - .add( - self - .clone() - .mask_fill(self.clone().greater_elem(0), 0), - ); - - let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); - let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); - - let output= Tensor::full(shape.clone(), off_value, &self.device()); - let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); - output.scatter(axis as usize, indices_unsqueezed.clone(), scatter_on_values); - + let mut shape = self.shape().dims::().to_vec(); + let rank = self.dims().len(); + let axis = if axis < 0 { + axis + rank as i64 + 1 // Convert negative axis to positive index + } else { + axis + }; + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + let device = self.device(); + let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); + shape.insert(axis as usize, depth); + let condition1 = indices.clone().greater_elem(-1 * depth as i64).int(); + let condition2 = indices.clone().lower_elem(depth as i64).int(); + let valid_mask = condition1.mul(condition2).bool().bool_not(); + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), depth as i64) + .add( + indices + .clone() + .mask_fill(self.clone().greater_elem(0), 0), + ); + + let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); + let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); + let output= Tensor::full(shape.clone(), off_value, &device); + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// From f34e9bdec5dc28c36b02fcffd3d485ce0833e48e Mon Sep 17 00:00:00 2001 From: tiruka Date: Sat, 14 Dec 2024 20:27:42 +0900 Subject: [PATCH 05/20] update one hot function to deal with additional dims add one hot test --- crates/burn-tensor/src/tensor/api/numeric.rs | 6 ++-- crates/burn-tensor/src/tests/ops/one_hot.rs | 33 ++++++++++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 52ec09cadc..9065c53184 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2050,13 +2050,13 @@ where /// println!("{tensor}"); /// } /// ``` - pub fn one_hot_with_axis_and_values2( + pub fn one_hot_with_axis_and_values2( self, depth: usize, on_value: K::Elem, off_value: K::Elem, axis: i64, - ) -> Tensor + ) -> Tensor { let mut shape = self.shape().dims::().to_vec(); let rank = self.dims().len(); @@ -2084,7 +2084,7 @@ where ); let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); - let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); + let indices_unsqueezed: Tensor = valid_indices.unsqueeze_dim(axis as usize); let output= Tensor::full(shape.clone(), off_value, &device); let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index fb4976cfb2..ff6b488823 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -1,8 +1,13 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { use super::*; - use burn_tensor::{Int, TensorData}; - + use burn_tensor::{ + Int, TensorData, + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Numeric, Shape, Tensor, + }; #[test] fn float_should_support_one_hot() { let device = Default::default(); @@ -83,4 +88,28 @@ mod tests { one_hot_tensor.into_data().assert_eq(&expected, false); } + + #[test] + fn one_hot_with_axis_and_values_test_1() { + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]])); + + let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values2(3, FloatType::new(5.0), FloatType::new(0.0), -1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + + #[test] + fn one_hot_with_axis_and_values_test_2() { + let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); + let expected = TensorData::from(as_type!(FloatType:[ + [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values2(10, FloatType::new(3.0), FloatType::new(1.0), 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } } From 3e665676afbab1f430f17ece6b6601c3e87d8abf Mon Sep 17 00:00:00 2001 From: tiruka Date: Sat, 14 Dec 2024 21:29:23 +0900 Subject: [PATCH 06/20] added tests for one hot --- crates/burn-tensor/src/tensor/api/int.rs | 66 +------------------- crates/burn-tensor/src/tensor/api/numeric.rs | 4 +- crates/burn-tensor/src/tests/ops/one_hot.rs | 48 +++++++++----- 3 files changed, 36 insertions(+), 82 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index b51a3e04ff..c7b521721d 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -128,69 +128,5 @@ where device: &B::Device, ) -> Tensor { cartesian_grid::(shape, device) - } - - /// Create a one-hot encoded tensor with configurable `on_value`, `off_value`, and `axis`. - /// - /// # Arguments - /// - /// * `depth` - The number of classes for one-hot encoding. - /// * `on_value` - The value to use for the "on" positions. - /// * `off_value` - The value to use for the "off" positions. - /// * `axis` - The axis along which to perform one-hot encoding. - /// # Example - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// fn example() { - /// let device = B::Device::default(); - /// let expected: Tensor = Tensor::from_ints([[5, 0, 0], [0, 0, 5], [0, 5, 0], [0, 0, 5]], &device); - /// let indices: Tensor = Tensor::from_ints([[0, 2], [1, -1]], &device); - /// // One-hot encoding - /// let result = indices.one_hot_with_axis_and_values(3, 5, 0, -1); - /// assert_eq!(expected.to_data(), result.to_data()); - /// } - /// ``` - pub fn one_hot_with_axis_and_values( - self, - depth: usize, - on_value: i64, - off_value: i64, - axis: i64, - ) -> Tensor { - - let mut shape = self.shape().dims::().to_vec(); - let rank = self.dims().len(); - let axis = if axis < 0 { - axis + rank as i64 + 1 // Convert negative axis to positive index - } else { - axis - }; - if axis < 0 || axis > rank as i64 { - panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); - } - shape.insert(axis as usize, depth); - let condition1 = self.clone().greater_elem(-1 * depth as i64).int(); - let condition2 = self.clone().lower_elem(depth as i64).int(); - let valid_mask = condition1.mul(condition2).bool().bool_not(); - let adjusted_indices = self - .clone() - .mask_fill(self.clone().lower_elem(0), depth as i64) - .add( - self - .clone() - .mask_fill(self.clone().greater_elem(0), 0), - ); - - let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); - let indices_unsqueezed = valid_indices.unsqueeze_dim(axis as usize); - - let output= Tensor::full(shape.clone(), off_value, &self.device()); - let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &self.device()); - let scatter_off_values = Tensor::full(indices_unsqueezed.shape(), -off_value, &self.device()); - - output - .scatter(axis as usize, indices_unsqueezed.clone(), scatter_on_values) - .scatter(axis as usize, indices_unsqueezed, scatter_off_values) - } + } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 9065c53184..7d403c8d03 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2046,11 +2046,11 @@ where /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding - /// let tensor = indices.one_hot_with_axis_and_values2(3, 5.0.into(), 0.0.into(), -1); + /// let tensor = indices.one_hot_with_axis_and_values(3, 5.0.into(), 0.0.into(), -1); /// println!("{tensor}"); /// } /// ``` - pub fn one_hot_with_axis_and_values2( + pub fn one_hot_with_axis_and_values( self, depth: usize, on_value: K::Elem, diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index ff6b488823..e2509f34b0 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -2,11 +2,11 @@ mod tests { use super::*; use burn_tensor::{ - Int, TensorData, + Int, Float, TensorData, as_type, backend::Backend, tests::{Float as _, Int as _}, - Numeric, Shape, Tensor, + Numeric, Shape, Tensor }; #[test] fn float_should_support_one_hot() { @@ -78,29 +78,47 @@ mod tests { } #[test] - #[should_panic] - fn int_one_hot_with() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let expected = TestTensorInt::eye(5, &device).into_data(); - - let one_hot_tensor = index_tensor.one_hot_with_axis_and_values(3, 5, 0, -1); + fn one_hot_with_axis_and_values_when_having_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2 ,4]]); + let expected = TensorData::from(as_type!(IntType:[ + [[1, 1], + [3, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 3]], + [[1, 1], + [1, 1], + [3, 1], + [1, 1], + [1, 3], + [1, 1], + [1, 1], + [1, 1], + [1, 1], + [1, 1]]])); - one_hot_tensor.into_data().assert_eq(&expected, false); + let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(10, 3, 1, 1); + println!("{:?}", one_hot_tensor); + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_with_axis_and_values_test_1() { - let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + fn one_hot_with_axis_and_values_when_having_negative_axis_and_indices() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); let expected = TensorData::from(as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]])); - let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values2(3, FloatType::new(5.0), FloatType::new(0.0), -1); + let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(3, FloatType::new(5.0), FloatType::new(0.0), -1); one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_with_axis_and_values_test_2() { + fn one_hot_with_axis_and_values_when_having_negative_indices() { let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); let expected = TensorData::from(as_type!(FloatType:[ [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -108,7 +126,7 @@ mod tests { [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ])); - let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values2(10, FloatType::new(3.0), FloatType::new(1.0), 1); + let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(10, FloatType::new(3.0), FloatType::new(1.0), 1); one_hot_tensor.into_data().assert_eq(&expected, true); } From f16980066cff8781900a310486576f50d3cb3f9d Mon Sep 17 00:00:00 2001 From: tiruka Date: Sat, 14 Dec 2024 23:35:08 +0900 Subject: [PATCH 07/20] modify function name modify format add tests --- crates/burn-tensor/src/tensor/api/int.rs | 2 +- crates/burn-tensor/src/tensor/api/numeric.rs | 24 ++++++------- crates/burn-tensor/src/tests/ops/one_hot.rs | 36 +++++++++++++------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index c7b521721d..08bdab0fe7 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -128,5 +128,5 @@ where device: &B::Device, ) -> Tensor { cartesian_grid::(shape, device) - } + } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 7d403c8d03..95558ac714 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2046,18 +2046,21 @@ where /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding - /// let tensor = indices.one_hot_with_axis_and_values(3, 5.0.into(), 0.0.into(), -1); + /// let tensor = indices.one_hot_plus(3, 5.0.into(), 0.0.into(), -1); /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] /// } /// ``` - pub fn one_hot_with_axis_and_values( + pub fn one_hot_plus( self, depth: usize, on_value: K::Elem, off_value: K::Elem, axis: i64, - ) -> Tensor - { + ) -> Tensor { let mut shape = self.shape().dims::().to_vec(); let rank = self.dims().len(); let axis = if axis < 0 { @@ -2069,7 +2072,8 @@ where panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); } let device = self.device(); - let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); shape.insert(axis as usize, depth); let condition1 = indices.clone().greater_elem(-1 * depth as i64).int(); let condition2 = indices.clone().lower_elem(depth as i64).int(); @@ -2077,17 +2081,13 @@ where let adjusted_indices = indices .clone() .mask_fill(self.clone().lower_elem(0), depth as i64) - .add( - indices - .clone() - .mask_fill(self.clone().greater_elem(0), 0), - ); + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); let indices_unsqueezed: Tensor = valid_indices.unsqueeze_dim(axis as usize); - let output= Tensor::full(shape.clone(), off_value, &device); + let output = Tensor::full(shape.clone(), off_value, &device); let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index e2509f34b0..debd25f6b2 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -2,11 +2,10 @@ mod tests { use super::*; use burn_tensor::{ - Int, Float, TensorData, as_type, backend::Backend, tests::{Float as _, Int as _}, - Numeric, Shape, Tensor + Float, Int, Numeric, Shape, Tensor, TensorData, }; #[test] fn float_should_support_one_hot() { @@ -78,8 +77,8 @@ mod tests { } #[test] - fn one_hot_with_axis_and_values_when_having_positive_axis_and_indices() { - let tensor = TestTensorInt::<2>::from([[1, 9], [2 ,4]]); + fn one_hot_plus_with_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); let expected = TensorData::from(as_type!(IntType:[ [[1, 1], [3, 1], @@ -101,24 +100,27 @@ mod tests { [1, 1], [1, 1], [1, 1]]])); - - let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(10, 3, 1, 1); - println!("{:?}", one_hot_tensor); + + let one_hot_tensor: Tensor = tensor.one_hot_plus(10, 3, 1, 1); + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_with_axis_and_values_when_having_negative_axis_and_indices() { + fn one_hot_plus_with_negative_axis_and_indices() { let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); - let expected = TensorData::from(as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]])); + let expected = TensorData::from( + as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]]), + ); - let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(3, FloatType::new(5.0), FloatType::new(0.0), -1); + let one_hot_tensor: Tensor = + tensor.one_hot_plus(3, FloatType::new(5.0), FloatType::new(0.0), -1); one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_with_axis_and_values_when_having_negative_indices() { + fn one_hot_plus_with_negative_indices() { let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); let expected = TensorData::from(as_type!(FloatType:[ [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -126,8 +128,18 @@ mod tests { [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ])); - let one_hot_tensor: Tensor = tensor.one_hot_with_axis_and_values(10, FloatType::new(3.0), FloatType::new(1.0), 1); + let one_hot_tensor: Tensor = + tensor.one_hot_plus(10, FloatType::new(3.0), FloatType::new(1.0), 1); one_hot_tensor.into_data().assert_eq(&expected, true); } + + #[should_panic] + #[test] + fn one_hot_plus_should_panic_when_axis_out_range_of_rank() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + + let one_hot_tensor: Tensor = + tensor.one_hot_plus(3, FloatType::new(5.0), FloatType::new(0.0), 3); + } } From 6304674a3c7ccb439e4c7c5328ba2ee222d91e31 Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 15 Dec 2024 00:09:28 +0900 Subject: [PATCH 08/20] modify to respond to difference between Tensor type and values type --- crates/burn-tensor/src/tensor/api/numeric.rs | 8 ++++---- crates/burn-tensor/src/tests/ops/one_hot.rs | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 95558ac714..997c338a27 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2054,13 +2054,13 @@ where /// // [0.0, 0.0, 5.0]]] /// } /// ``` - pub fn one_hot_plus( + pub fn one_hot_plus, const D2: usize>( self, depth: usize, - on_value: K::Elem, - off_value: K::Elem, + on_value: K2::Elem, + off_value: K2::Elem, axis: i64, - ) -> Tensor { + ) -> Tensor { let mut shape = self.shape().dims::().to_vec(); let rank = self.dims().len(); let axis = if axis < 0 { diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index debd25f6b2..d679dce416 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -108,12 +108,12 @@ mod tests { #[test] fn one_hot_plus_with_negative_axis_and_indices() { - let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]); let expected = TensorData::from( as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]]), ); - let one_hot_tensor: Tensor = + let one_hot_tensor: Tensor = tensor.one_hot_plus(3, FloatType::new(5.0), FloatType::new(0.0), -1); one_hot_tensor.into_data().assert_eq(&expected, true); @@ -129,7 +129,7 @@ mod tests { ])); let one_hot_tensor: Tensor = - tensor.one_hot_plus(10, FloatType::new(3.0), FloatType::new(1.0), 1); + tensor.one_hot_plus(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, true); } From b31f1e0bdfd5f809e9b001159a96c7aa03c0c6fe Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 15 Dec 2024 00:50:00 +0900 Subject: [PATCH 09/20] fix clippy point out and doc test --- crates/burn-tensor/src/tensor/api/numeric.rs | 6 +++--- crates/burn-tensor/src/tests/ops/one_hot.rs | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 997c338a27..7a0b8efb7f 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2041,12 +2041,12 @@ where /// # Example /// ```rust /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor, Float}; + /// use burn_tensor::{Tensor, Float}; /// fn example>>() { /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding - /// let tensor = indices.one_hot_plus(3, 5.0.into(), 0.0.into(), -1); + /// let tensor:Tensor = indices.one_hot_plus(3, 5.0.into(), 0.0.into(), -1); /// println!("{tensor}"); /// // [[[5.0, 0.0, 0.0], /// // [0.0, 0.0, 5.0]], @@ -2075,7 +2075,7 @@ where let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); shape.insert(axis as usize, depth); - let condition1 = indices.clone().greater_elem(-1 * depth as i64).int(); + let condition1 = indices.clone().greater_elem(-(depth as i64)).int(); let condition2 = indices.clone().lower_elem(depth as i64).int(); let valid_mask = condition1.mul(condition2).bool().bool_not(); let adjusted_indices = indices diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index d679dce416..27a70cbe8b 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -128,8 +128,7 @@ mod tests { [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ])); - let one_hot_tensor: Tensor = - tensor.one_hot_plus(10, 3.0, 1.0, 1); + let one_hot_tensor: Tensor = tensor.one_hot_plus(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, true); } From 628f584d57b69190340197906f2db5072ae123f8 Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 15 Dec 2024 16:05:22 +0900 Subject: [PATCH 10/20] do refactoring modify comments --- crates/burn-tensor/src/tensor/api/numeric.rs | 56 +++++++++++++++----- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 7a0b8efb7f..6d398d8f98 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2030,14 +2030,19 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } - /// Create a one-hot encoded tensor with configurable `on_value`, `off_value`, and `axis`. + /// Create a one-hot encoded tensor with configurable `depth`, `on_value`, `off_value`, and `axis` including high-ranked tensors. /// /// # Arguments /// - /// * `depth` - The number of classes for one-hot encoding. - /// * `on_value` - The value to use for the "on" positions. - /// * `off_value` - The value to use for the "off" positions. - /// * `axis` - The axis along which to perform one-hot encoding. + /// * `depth`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; @@ -2061,35 +2066,60 @@ where off_value: K2::Elem, axis: i64, ) -> Tensor { + // Initialize shape from the current tensor dimensions and prepare for modification let mut shape = self.shape().dims::().to_vec(); + let device = self.device(); let rank = self.dims().len(); + + // Adjust negative axis to a positive index let axis = if axis < 0 { - axis + rank as i64 + 1 // Convert negative axis to positive index + axis + rank as i64 + 1 } else { axis }; + + // Ensure axis is within valid range if axis < 0 || axis > rank as i64 { panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); } - let device = self.device(); + + // Convert the input tensor to integer indices let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); + + // Insert the new dimension for the one-hot representation shape.insert(axis as usize, depth); - let condition1 = indices.clone().greater_elem(-(depth as i64)).int(); - let condition2 = indices.clone().lower_elem(depth as i64).int(); - let valid_mask = condition1.mul(condition2).bool().bool_not(); + + // Create masks for valid index range [-depth, depth-1] + let above_minimum = indices.clone().greater_elem(-(depth as i64)).int(); + let below_maximum = indices.clone().lower_elem(depth as i64).int(); + + // Combine conditions to identify invalid indices + let invalid_mask = above_minimum.mul(below_maximum).bool().bool_not(); + + // Adjust indices to valid range and handle invalid indices let adjusted_indices = indices .clone() - .mask_fill(self.clone().lower_elem(0), depth as i64) - .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); + .mask_fill(self.clone().lower_elem(0), depth as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + + // Replace invalid indices with the off_value + let valid_indices = adjusted_indices.mask_fill(invalid_mask, off_value); - let valid_indices = adjusted_indices.mask_fill(valid_mask, off_value); + // Unsqueeze the indices tensor along the specified axis let indices_unsqueezed: Tensor = valid_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) } + /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// /// # Returns From 422954bd00b1f91be8abfdb44af0a7dcea5b77ac Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 15 Dec 2024 20:40:25 +0900 Subject: [PATCH 11/20] update burn book to publish one hot plus method --- burn-book/src/building-blocks/tensor.md | 1 + 1 file changed, 1 insertion(+) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fb429ffd0f..6ef80a81d0 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -157,6 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | `tensor.not_equal(other)` | `x != y` | | `tensor.permute(axes)` | `tensor.permute(axes)` | +| `tensor.one_hot_plus(depth, on_value, off_value, axis)` | N/A | | `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | | `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | | `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | From 8c8ba36cee7e7bad53dfdeadbdddb4ffd839995c Mon Sep 17 00:00:00 2001 From: tiruka Date: Sun, 22 Dec 2024 20:02:37 +0900 Subject: [PATCH 12/20] modify one_hot_plus to one_hot_fill and args names --- crates/burn-tensor/src/tensor/api/numeric.rs | 21 ++++++++++---------- crates/burn-tensor/src/tests/ops/one_hot.rs | 16 +++++++-------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 6d398d8f98..471b74c0f7 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2030,11 +2030,12 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } - /// Create a one-hot encoded tensor with configurable `depth`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. /// /// # Arguments /// - /// * `depth`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. /// * `on_value`: The value to assign for active positions (corresponding to indices). /// * `off_value`: The value to assign for inactive positions. /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. @@ -2051,7 +2052,7 @@ where /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding - /// let tensor:Tensor = indices.one_hot_plus(3, 5.0.into(), 0.0.into(), -1); + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); /// println!("{tensor}"); /// // [[[5.0, 0.0, 0.0], /// // [0.0, 0.0, 5.0]], @@ -2059,9 +2060,9 @@ where /// // [0.0, 0.0, 5.0]]] /// } /// ``` - pub fn one_hot_plus, const D2: usize>( + pub fn one_hot_fill, const D2: usize>( self, - depth: usize, + num_classes: usize, on_value: K2::Elem, off_value: K2::Elem, axis: i64, @@ -2088,11 +2089,11 @@ where Tensor::from_data(self.to_data().convert::(), &device); // Insert the new dimension for the one-hot representation - shape.insert(axis as usize, depth); + shape.insert(axis as usize, num_classes); - // Create masks for valid index range [-depth, depth-1] - let above_minimum = indices.clone().greater_elem(-(depth as i64)).int(); - let below_maximum = indices.clone().lower_elem(depth as i64).int(); + // Create masks for valid index range [-num_classes, num_classes-1] + let above_minimum = indices.clone().greater_elem(-(num_classes as i64)).int(); + let below_maximum = indices.clone().lower_elem(num_classes as i64).int(); // Combine conditions to identify invalid indices let invalid_mask = above_minimum.mul(below_maximum).bool().bool_not(); @@ -2100,7 +2101,7 @@ where // Adjust indices to valid range and handle invalid indices let adjusted_indices = indices .clone() - .mask_fill(self.clone().lower_elem(0), depth as i64) // Handle negative indices + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices // Replace invalid indices with the off_value diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 27a70cbe8b..1ae64aed08 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -77,7 +77,7 @@ mod tests { } #[test] - fn one_hot_plus_with_positive_axis_and_indices() { + fn one_hot_fill_with_positive_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); let expected = TensorData::from(as_type!(IntType:[ [[1, 1], @@ -101,26 +101,26 @@ mod tests { [1, 1], [1, 1]]])); - let one_hot_tensor: Tensor = tensor.one_hot_plus(10, 3, 1, 1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3, 1, 1); one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_plus_with_negative_axis_and_indices() { + fn one_hot_fill_with_negative_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]); let expected = TensorData::from( as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]]), ); let one_hot_tensor: Tensor = - tensor.one_hot_plus(3, FloatType::new(5.0), FloatType::new(0.0), -1); + tensor.one_hot_fill(3, FloatType::new(5.0), FloatType::new(0.0), -1); one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] - fn one_hot_plus_with_negative_indices() { + fn one_hot_fill_with_negative_indices() { let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); let expected = TensorData::from(as_type!(FloatType:[ [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], @@ -128,17 +128,17 @@ mod tests { [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ])); - let one_hot_tensor: Tensor = tensor.one_hot_plus(10, 3.0, 1.0, 1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, true); } #[should_panic] #[test] - fn one_hot_plus_should_panic_when_axis_out_range_of_rank() { + fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); let one_hot_tensor: Tensor = - tensor.one_hot_plus(3, FloatType::new(5.0), FloatType::new(0.0), 3); + tensor.one_hot_fill(3, FloatType::new(5.0), FloatType::new(0.0), 3); } } From ae5164755cccde48d11396cdd9bc4c4aa7d38bc9 Mon Sep 17 00:00:00 2001 From: tiruka Date: Tue, 24 Dec 2024 21:13:18 +0900 Subject: [PATCH 13/20] modify one_hot function in int impl and float impl modify one_hot tests --- crates/burn-tensor/src/tensor/api/check.rs | 20 +--- crates/burn-tensor/src/tensor/api/float.rs | 36 +++--- crates/burn-tensor/src/tensor/api/int.rs | 57 +++++----- crates/burn-tensor/src/tensor/api/numeric.rs | 12 +- crates/burn-tensor/src/tests/ops/one_hot.rs | 114 ++++++++----------- 5 files changed, 105 insertions(+), 134 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d4ab13faf4..8834e09f27 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, Int, Shape, Tensor}; +use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -447,22 +447,8 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - - pub(crate) fn one_hot_tensor( - index_tensor: Tensor, + pub(crate) fn one_hot_tensor>( + index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6f59f6e88..a40bbef827 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,11 +1,9 @@ -use alloc::vec::Vec; -use core::convert::TryInto; - use crate::check::TensorCheck; +use crate::ops::FloatElem; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; @@ -182,25 +180,25 @@ where /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// - /// fn example() { + /// fn example() where ::FloatElem: From{ /// let device = Default::default(); - /// let one_hot = Tensor::::one_hot(2, 10, &device); + /// let indices: Tensor = Tensor::from_ints([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] /// } /// ``` - pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) + pub fn one_hot(self, num_classes: usize) -> Tensor + where + FloatElem: From, + { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill( + num_classes, + B::FloatElem::from(1.0), + B::FloatElem::from(0.0), + -1, + ) } /// Applies the matrix multiplication operation. diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..67aaad8bd5 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,6 @@ use crate::check; use crate::check::TensorCheck; +use crate::ops::IntElem; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -29,34 +30,6 @@ where pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - let [num_samples] = self.dims(); - let indices = self.unsqueeze_dim(1); - let values = indices.ones_like(); - Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values) - } } impl Tensor @@ -129,4 +102,32 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Create a one hot tensor from an index tensor. + /// + /// # Arguments + /// + /// * `num_classes` - The number of classes to use in encoding. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Int}; + /// + /// fn example() where ::IntElem: From{ + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor + where + IntElem: From, + { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, B::IntElem::from(1), B::IntElem::from(0), -1) + } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 471b74c0f7..6cdececf3c 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2092,11 +2092,17 @@ where shape.insert(axis as usize, num_classes); // Create masks for valid index range [-num_classes, num_classes-1] - let above_minimum = indices.clone().greater_elem(-(num_classes as i64)).int(); - let below_maximum = indices.clone().lower_elem(num_classes as i64).int(); + let lower_bound = indices + .clone() + .lower_equal_elem(-(num_classes as i64)) + .int(); + let upper_bound = indices + .clone() + .greater_equal_elem(num_classes as i64 - 1) + .int(); // Combine conditions to identify invalid indices - let invalid_mask = above_minimum.mul(below_maximum).bool().bool_not(); + let invalid_mask = lower_bound.mul(upper_bound).bool(); // Adjust indices to valid range and handle invalid indices let adjusted_indices = indices diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 1ae64aed08..d1333bb1a5 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -7,99 +7,62 @@ mod tests { tests::{Float as _, Int as _}, Float, Int, Numeric, Shape, Tensor, TensorData, }; + #[test] fn float_should_support_one_hot() { - let device = Default::default(); - - let tensor = TestTensor::<1>::one_hot(0, 5, &device); - let expected = TensorData::from([1., 0., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 5, &device); - let expected = TensorData::from([0., 1., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(4, 5, &device); - let expected = TensorData::from([0., 0., 0., 0., 1.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 2, &device); - let expected = TensorData::from([0., 1.]); - tensor.into_data().assert_eq(&expected, false); + let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]); + one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(1, 1, &device); + let tensor = TestTensor::<1>::from([5.0]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(0, 0, &device); + let tensor = TestTensor::<1>::from([0.0]); + let result: Tensor = tensor.one_hot(0); } #[test] fn int_should_support_one_hot() { - let device = Default::default(); - - let index_tensor = TestTensorInt::<1>::arange(0..5, &device); - let one_hot_tensor = index_tensor.one_hot(5); - let expected = TestTensorInt::eye(5, &device).into_data(); + let tensor = TestTensorInt::<1>::from([0, 1, 4]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..6, &device); - let one_hot_tensor = index_tensor.one_hot(5); + let tensor = TestTensorInt::<1>::from([5]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(0); - } - - #[test] - #[should_panic] - fn int_one_hot_should_panic_when_number_of_classes_is_1() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(1); + let tensor = TestTensorInt::<1>::from([2]); + let result: Tensor = tensor.one_hot(0); } #[test] fn one_hot_fill_with_positive_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); - let expected = TensorData::from(as_type!(IntType:[ - [[1, 1], - [3, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 3]], - [[1, 1], - [1, 1], - [3, 1], - [1, 1], - [1, 3], - [1, 1], - [1, 1], - [1, 1], - [1, 1], - [1, 1]]])); + let expected = TensorData::from(as_type!(IntType: [ + [[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]], + [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] + ])); let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3, 1, 1); @@ -109,9 +72,10 @@ mod tests { #[test] fn one_hot_fill_with_negative_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]); - let expected = TensorData::from( - as_type!(FloatType: [[[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]]), - ); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + ])); let one_hot_tensor: Tensor = tensor.one_hot_fill(3, FloatType::new(5.0), FloatType::new(0.0), -1); @@ -122,10 +86,10 @@ mod tests { #[test] fn one_hot_fill_with_negative_indices() { let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); - let expected = TensorData::from(as_type!(FloatType:[ + let expected = TensorData::from(as_type!(FloatType: [ [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ])); let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); @@ -139,6 +103,22 @@ mod tests { let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); let one_hot_tensor: Tensor = - tensor.one_hot_fill(3, FloatType::new(5.0), FloatType::new(0.0), 3); + tensor.one_hot_fill(2, FloatType::new(5.0), FloatType::new(0.0), 3); + } + + #[test] + fn one_hot_fill_should_panic_when_index_exceeds_number_of_classes() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, 1.0]]); + + let one_hot_tensor: Tensor = + tensor.one_hot_fill(1, FloatType::new(5.0), FloatType::new(0.0), -1); + } + + #[test] + fn one_hot_fill_should_panic_when_number_of_classes_is_zero() { + let tensor = TestTensor::<2>::from([[1.0, 2.0], [1.0, 1.0]]); + + let one_hot_tensor: Tensor = + tensor.one_hot_fill(0, FloatType::new(5.0), FloatType::new(0.0), -1); } } From 1f5007842ce9d5ebf3fbc47dadfb27d96396d905 Mon Sep 17 00:00:00 2001 From: tiruka Date: Tue, 24 Dec 2024 22:47:41 +0900 Subject: [PATCH 14/20] modify numeric to clear logic --- crates/burn-tensor/src/tensor/api/numeric.rs | 26 ++++---------------- crates/burn-tensor/src/tests/ops/one_hot.rs | 16 ------------ 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 6cdececf3c..6abbfc7385 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2083,38 +2083,22 @@ where if axis < 0 || axis > rank as i64 { panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); } - // Convert the input tensor to integer indices let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); - // Insert the new dimension for the one-hot representation shape.insert(axis as usize, num_classes); - - // Create masks for valid index range [-num_classes, num_classes-1] - let lower_bound = indices - .clone() - .lower_equal_elem(-(num_classes as i64)) - .int(); - let upper_bound = indices - .clone() - .greater_equal_elem(num_classes as i64 - 1) - .int(); - - // Combine conditions to identify invalid indices - let invalid_mask = lower_bound.mul(upper_bound).bool(); - // Adjust indices to valid range and handle invalid indices let adjusted_indices = indices .clone() .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices - - // Replace invalid indices with the off_value - let valid_indices = adjusted_indices.mask_fill(invalid_mask, off_value); - + check!(TensorCheck::one_hot_tensor( + adjusted_indices.clone(), + num_classes + )); // Unsqueeze the indices tensor along the specified axis - let indices_unsqueezed: Tensor = valid_indices.unsqueeze_dim(axis as usize); + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); // Initialize the output tensor with the off_value let output = Tensor::full(shape.clone(), off_value, &device); diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index d1333bb1a5..0ce6973b33 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -105,20 +105,4 @@ mod tests { let one_hot_tensor: Tensor = tensor.one_hot_fill(2, FloatType::new(5.0), FloatType::new(0.0), 3); } - - #[test] - fn one_hot_fill_should_panic_when_index_exceeds_number_of_classes() { - let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, 1.0]]); - - let one_hot_tensor: Tensor = - tensor.one_hot_fill(1, FloatType::new(5.0), FloatType::new(0.0), -1); - } - - #[test] - fn one_hot_fill_should_panic_when_number_of_classes_is_zero() { - let tensor = TestTensor::<2>::from([[1.0, 2.0], [1.0, 1.0]]); - - let one_hot_tensor: Tensor = - tensor.one_hot_fill(0, FloatType::new(5.0), FloatType::new(0.0), -1); - } } From 1909afd230e015df2b6f59eeecf355a51135a88a Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 25 Dec 2024 09:44:21 +0900 Subject: [PATCH 15/20] modify miscs due to validation, linnter and formatter --- crates/burn-tensor/src/tensor/api/float.rs | 17 ++++------------- crates/burn-tensor/src/tensor/api/int.rs | 10 +++------- crates/burn-tensor/src/tensor/api/numeric.rs | 4 ++-- crates/burn-tensor/src/tests/ops/one_hot.rs | 8 +++----- 4 files changed, 12 insertions(+), 27 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a40bbef827..cacff1e8c0 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,5 +1,4 @@ use crate::check::TensorCheck; -use crate::ops::FloatElem; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; @@ -180,25 +179,17 @@ where /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// - /// fn example() where ::FloatElem: From{ + /// fn example(){ /// let device = Default::default(); - /// let indices: Tensor = Tensor::from_ints([0.0, 1.0, 2.0, 3.0], &device); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); /// let one_hot: Tensor = indices.one_hot(4); /// println!("{}", one_hot.to_data()); /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] /// } /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor - where - FloatElem: From, - { + pub fn one_hot(self, num_classes: usize) -> Tensor { check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - self.one_hot_fill( - num_classes, - B::FloatElem::from(1.0), - B::FloatElem::from(0.0), - -1, - ) + self.one_hot_fill(num_classes, 1.0, 0.0, -1) } /// Applies the matrix multiplication operation. diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 67aaad8bd5..4bd6179f13 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,6 +1,5 @@ use crate::check; use crate::check::TensorCheck; -use crate::ops::IntElem; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -115,7 +114,7 @@ where /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Int}; /// - /// fn example() where ::IntElem: From{ + /// fn example(){ /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); /// let one_hot: Tensor = indices.one_hot(4); @@ -123,11 +122,8 @@ where /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] /// } /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor - where - IntElem: From, - { + pub fn one_hot(self, num_classes: usize) -> Tensor { check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - self.one_hot_fill(num_classes, B::IntElem::from(1), B::IntElem::from(0), -1) + self.one_hot_fill(num_classes, 1.0, 0.0, -1) } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 6abbfc7385..a687ad431e 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2063,8 +2063,8 @@ where pub fn one_hot_fill, const D2: usize>( self, num_classes: usize, - on_value: K2::Elem, - off_value: K2::Elem, + on_value: f32, + off_value: f32, axis: i64, ) -> Tensor { // Initialize shape from the current tensor dimensions and prepare for modification diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 0ce6973b33..37a8b49e4d 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -64,7 +64,7 @@ mod tests { [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] ])); - let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3, 1, 1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, true); } @@ -77,8 +77,7 @@ mod tests { [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] ])); - let one_hot_tensor: Tensor = - tensor.one_hot_fill(3, FloatType::new(5.0), FloatType::new(0.0), -1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); one_hot_tensor.into_data().assert_eq(&expected, true); } @@ -102,7 +101,6 @@ mod tests { fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); - let one_hot_tensor: Tensor = - tensor.one_hot_fill(2, FloatType::new(5.0), FloatType::new(0.0), 3); + let one_hot_tensor: Tensor = tensor.one_hot_fill(2, 5.0, 0.0, 3); } } From ac060d13fbac51f40f7476c72bb19d0eb10195da Mon Sep 17 00:00:00 2001 From: tiruka Date: Thu, 26 Dec 2024 11:39:08 +0900 Subject: [PATCH 16/20] modify documents for tensor api --- burn-book/src/building-blocks/tensor.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 6ef80a81d0..7108bd9d82 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -157,7 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | `tensor.not_equal(other)` | `x != y` | | `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.one_hot_plus(depth, on_value, off_value, axis)` | N/A | +| `tensor.one_hot_fill(depth, on_value, off_value, axis)` | N/A | | `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | | `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | | `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | @@ -259,7 +259,7 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | --------------------------------------------- | ---------------------------------- | -| `Tensor::one_hot(index, num_classes, device)` | N/A | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | @@ -297,7 +297,7 @@ Those operations are only available for `Int` tensors. | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | N/A | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | ### Bool Operations From efb19a086dda248777f008788747108f64088673 Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 8 Jan 2025 20:47:19 +0900 Subject: [PATCH 17/20] modify codes to follow review comments --- burn-book/src/building-blocks/tensor.md | 5 +-- crates/burn-tensor/src/tensor/api/check.rs | 14 +++++++ crates/burn-tensor/src/tensor/api/float.rs | 43 ++++++++++++++++---- crates/burn-tensor/src/tensor/api/int.rs | 27 ------------ crates/burn-tensor/src/tensor/api/numeric.rs | 20 +++++++++ crates/burn-tensor/src/tests/ops/one_hot.rs | 15 +++++-- 6 files changed, 82 insertions(+), 42 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 7108bd9d82..8a7c01bbc9 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -157,7 +157,6 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | `tensor.not_equal(other)` | `x != y` | | `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.one_hot_fill(depth, on_value, off_value, axis)` | N/A | | `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | | `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | | `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | @@ -229,6 +228,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | +| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | @@ -259,7 +260,6 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | --------------------------------------------- | ---------------------------------- | -| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | @@ -297,7 +297,6 @@ Those operations are only available for `Int` tensors. | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | ### Bool Operations diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 8834e09f27..c56cc9f12d 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -447,6 +447,20 @@ impl TensorCheck { check } + pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { + let mut check = Self::Ok; + if index >= num_classes { + check = check.register( + "One Hot", + TensorError::new(format!( + "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", + )), + ); + } + + check + } + pub(crate) fn one_hot_tensor>( index_tensor: Tensor, num_classes: usize, diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index cacff1e8c0..c5f337b99e 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -2,10 +2,12 @@ use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, TensorData}; +use crate::tensor::{Distribution, Shape, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; +use alloc::vec::Vec; +use core::convert::TryInto; impl Tensor where @@ -170,7 +172,6 @@ where &self.device(), ))) } - /// Create a one hot tensor. /// /// # Example @@ -179,17 +180,41 @@ where /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// - /// fn example(){ + /// fn example() { /// let device = Default::default(); - /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); - /// let one_hot: Tensor = indices.one_hot(4); + /// let one_hot = Tensor::::_one_hot(2, 10, &device); /// println!("{}", one_hot.to_data()); - /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] /// } /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - self.one_hot_fill(num_classes, 1.0, 0.0, -1) + /// This is equivalent to the code below using the new `one_hot` function. + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([2.0], &device); + /// let one_hot: Tensor = indices.one_hot::<2>(10).flatten::<1>(0, 1); + /// println!("{}", one_hot.to_data()); + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// } + #[deprecated( + since = "0.16.0", + note = "`Tensor::one_hot(...)` will be removed in the future, please use the new `tensor.one_hot(...)` method instead" + )] + pub fn _one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { + check!(TensorCheck::one_hot_index(index, num_classes)); + + let mut dims = [1; D]; + dims[D - 1] = num_classes; + let shape = Shape::new(dims); + let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); + let tensor = Tensor::zeros(shape, device); + let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); + ranges[D - 1] = index..index + 1; + + tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) } /// Applies the matrix multiplication operation. diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 4bd6179f13..e882a107c7 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,3 @@ -use crate::check; -use crate::check::TensorCheck; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -101,29 +99,4 @@ where ) -> Tensor { cartesian_grid::(shape, device) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example(){ - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot: Tensor = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - self.one_hot_fill(num_classes, 1.0, 0.0, -1) - } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index a687ad431e..2965c4733d 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2030,6 +2030,26 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. /// diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 37a8b49e4d..b1f5634eac 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -20,6 +20,15 @@ mod tests { one_hot_tensor.into_data().assert_eq(&expected, false); } + #[test] + fn float_should_support_one_hot_index() { + let tensor = TestTensor::<1>::from([2.0]); + let one_hot_tensor: Tensor = + tensor.one_hot::<2>(10).flatten::<1>(0, 1); + let expected = TensorData::from([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + one_hot_tensor.into_data().assert_eq(&expected, false); + } + #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { @@ -37,7 +46,7 @@ mod tests { #[test] fn int_should_support_one_hot() { let tensor = TestTensorInt::<1>::from([0, 1, 4]); - let one_hot_tensor: Tensor = tensor.one_hot(5); + let one_hot_tensor: Tensor = tensor.one_hot(5).int(); let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } @@ -46,14 +55,14 @@ mod tests { #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { let tensor = TestTensorInt::<1>::from([5]); - let result: Tensor = tensor.one_hot(5); + let result: Tensor = tensor.one_hot(5).int(); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { let tensor = TestTensorInt::<1>::from([2]); - let result: Tensor = tensor.one_hot(0); + let result: Tensor = tensor.one_hot(0).int(); } #[test] From 6aa9648e35728fd25573e713f4409310b8c83b2b Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 15 Jan 2025 21:13:41 +0900 Subject: [PATCH 18/20] modify codes to follow reviews --- crates/burn-tensor/src/tensor/api/check.rs | 25 +++++----- crates/burn-tensor/src/tensor/api/float.rs | 48 +------------------- crates/burn-tensor/src/tensor/api/numeric.rs | 15 +++--- crates/burn-tensor/src/tests/ops/one_hot.rs | 19 ++++---- 4 files changed, 27 insertions(+), 80 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index c56cc9f12d..60cc841b5b 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -447,20 +447,6 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - pub(crate) fn one_hot_tensor>( index_tensor: Tensor, num_classes: usize, @@ -487,6 +473,17 @@ impl TensorCheck { check } + pub(crate) fn one_hot_tensor_rank() -> Self { + let mut check = Self::Ok; + if D + 1 != D2 { + check = check.register( + "One Hot", + TensorError::new("Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1") + ); + } + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index c5f337b99e..b50d0d0596 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -2,12 +2,10 @@ use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; -use alloc::vec::Vec; -use core::convert::TryInto; impl Tensor where @@ -172,50 +170,6 @@ where &self.device(), ))) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let device = Default::default(); - /// let one_hot = Tensor::::_one_hot(2, 10, &device); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - /// This is equivalent to the code below using the new `one_hot` function. - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example(){ - /// let device = Default::default(); - /// let indices: Tensor = Tensor::from_floats([2.0], &device); - /// let one_hot: Tensor = indices.one_hot::<2>(10).flatten::<1>(0, 1); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - #[deprecated( - since = "0.16.0", - note = "`Tensor::one_hot(...)` will be removed in the future, please use the new `tensor.one_hot(...)` method instead" - )] - pub fn _one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) - } /// Applies the matrix multiplication operation. /// diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 2965c4733d..29b2335e39 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2041,12 +2041,12 @@ where /// fn example(){ /// let device = Default::default(); /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); - /// let one_hot: Tensor = indices.one_hot(4); + /// let one_hot: Tensor = indices.one_hot(4); /// println!("{}", one_hot.to_data()); /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] /// } /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { + pub fn one_hot(self, num_classes: usize) -> Tensor { check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); self.one_hot_fill(num_classes, 1.0, 0.0, -1) } @@ -2080,13 +2080,14 @@ where /// // [0.0, 0.0, 5.0]]] /// } /// ``` - pub fn one_hot_fill, const D2: usize>( + pub fn one_hot_fill( self, num_classes: usize, on_value: f32, off_value: f32, axis: i64, - ) -> Tensor { + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); // Initialize shape from the current tensor dimensions and prepare for modification let mut shape = self.shape().dims::().to_vec(); let device = self.device(); @@ -2113,11 +2114,7 @@ where .clone() .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices - check!(TensorCheck::one_hot_tensor( - adjusted_indices.clone(), - num_classes - )); - // Unsqueeze the indices tensor along the specified axis + // Unsqueeze the indices tensor along the specified axis let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); // Initialize the output tensor with the off_value diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index b1f5634eac..34a8ece7ee 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -23,9 +23,8 @@ mod tests { #[test] fn float_should_support_one_hot_index() { let tensor = TestTensor::<1>::from([2.0]); - let one_hot_tensor: Tensor = - tensor.one_hot::<2>(10).flatten::<1>(0, 1); - let expected = TensorData::from([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + let one_hot_tensor: Tensor = tensor.one_hot::<2>(10); + let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); one_hot_tensor.into_data().assert_eq(&expected, false); } @@ -46,7 +45,7 @@ mod tests { #[test] fn int_should_support_one_hot() { let tensor = TestTensorInt::<1>::from([0, 1, 4]); - let one_hot_tensor: Tensor = tensor.one_hot(5).int(); + let one_hot_tensor: Tensor = tensor.one_hot(5); let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } @@ -55,14 +54,14 @@ mod tests { #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { let tensor = TestTensorInt::<1>::from([5]); - let result: Tensor = tensor.one_hot(5).int(); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { let tensor = TestTensorInt::<1>::from([2]); - let result: Tensor = tensor.one_hot(0).int(); + let result: Tensor = tensor.one_hot(0); } #[test] @@ -81,12 +80,12 @@ mod tests { #[test] fn one_hot_fill_with_negative_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]); - let expected = TensorData::from(as_type!(FloatType: [ - [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], - [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + let expected = TensorData::from(as_type!(IntType: [ + [[5, 0, 0], [0, 0, 5]], + [[0, 5, 0], [0, 0, 5]] ])); - let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); one_hot_tensor.into_data().assert_eq(&expected, true); } From 20a3a31d1c155606dd096695bb5d77cba262d8d8 Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 15 Jan 2025 21:26:51 +0900 Subject: [PATCH 19/20] modify tests to follow reviews comments --- crates/burn-tensor/src/tests/ops/one_hot.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 34a8ece7ee..24e8f24b38 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -79,13 +79,13 @@ mod tests { #[test] fn one_hot_fill_with_negative_axis_and_indices() { - let tensor = TestTensorInt::<2>::from([[0, 2], [1, -1]]); - let expected = TensorData::from(as_type!(IntType: [ - [[5, 0, 0], [0, 0, 5]], - [[0, 5, 0], [0, 0, 5]] + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] ])); - let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); one_hot_tensor.into_data().assert_eq(&expected, true); } From 4bffb3070724f7841d6fc07b34b9e94c0fd7c91c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 15 Jan 2025 10:53:25 -0500 Subject: [PATCH 20/20] Improve check message --- crates/burn-tensor/src/tensor/api/check.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 60cc841b5b..8a6fb2ad78 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -478,7 +478,10 @@ impl TensorCheck { if D + 1 != D2 { check = check.register( "One Hot", - TensorError::new("Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1") + TensorError::new( + "The one-hot tensor rank must correspond to the rank of the tensor + 1", + ) + .details(format!("Expected D2={}, got {D2}", D + 1)), ); } check