From 60ee8e6b355d5973ccfa6b5fb11c0bf09665d0bd Mon Sep 17 00:00:00 2001 From: mcarthur Date: Thu, 11 Jul 2024 09:12:50 +0000 Subject: [PATCH 01/38] sparse --- Cargo.lock | 19 + crates/burn-core/Cargo.toml | 2 + crates/burn-core/src/backend.rs | 3 + crates/burn-sparse/Cargo.toml | 43 + crates/burn-sparse/src/backend/alias.rs | 4 + crates/burn-sparse/src/backend/api.rs | 33 + crates/burn-sparse/src/backend/kind.rs | 61 + crates/burn-sparse/src/backend/mod.rs | 8 + .../burn-sparse/src/backend/sparse_backend.rs | 110 ++ crates/burn-sparse/src/decorator/backend.rs | 42 + crates/burn-sparse/src/decorator/mod.rs | 10 + crates/burn-sparse/src/decorator/ops.rs | 1143 +++++++++++++++++ .../src/decorator/precision_bridge.rs | 37 + .../src/decorator/representation.rs | 21 + .../burn-sparse/src/decorator/sparse_coo.rs | 273 ++++ .../burn-sparse/src/decorator/sparse_csr.rs | 88 ++ crates/burn-sparse/src/lib.rs | 2 + crates/burn/Cargo.toml | 1 + 18 files changed, 1900 insertions(+) create mode 100644 crates/burn-sparse/Cargo.toml create mode 100644 crates/burn-sparse/src/backend/alias.rs create mode 100644 crates/burn-sparse/src/backend/api.rs create mode 100644 crates/burn-sparse/src/backend/kind.rs create mode 100644 crates/burn-sparse/src/backend/mod.rs create mode 100644 crates/burn-sparse/src/backend/sparse_backend.rs create mode 100644 crates/burn-sparse/src/decorator/backend.rs create mode 100644 crates/burn-sparse/src/decorator/mod.rs create mode 100644 crates/burn-sparse/src/decorator/ops.rs create mode 100644 crates/burn-sparse/src/decorator/precision_bridge.rs create mode 100644 crates/burn-sparse/src/decorator/representation.rs create mode 100644 crates/burn-sparse/src/decorator/sparse_coo.rs create mode 100644 crates/burn-sparse/src/decorator/sparse_csr.rs create mode 100644 crates/burn-sparse/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 6d452da6ec..b7da2bef64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -508,6 +508,7 @@ dependencies = [ "burn-dataset", "burn-derive", "burn-ndarray", + "burn-sparse", "burn-tch", "burn-tensor", "burn-wgpu", @@ -701,6 +702,24 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-sparse" +version = "0.14.0" +dependencies = [ + "burn-common", + "burn-tensor", + "derive-new", + "half", + "hashbrown 0.14.5", + "num-traits", + "proc-macro2", + "quote", + "rand", + "rand_distr", + "serde", + "syn 2.0.69", +] + [[package]] name = "burn-tch" version = "0.14.0" diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index a33fab2994..aa8429b944 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -68,6 +68,7 @@ vision = ["burn-dataset?/vision", "burn-common/network"] # Backend autodiff = ["burn-autodiff"] fusion = ["burn-wgpu?/fusion"] +sparse = ["burn-sparse"] ## Backend features metal = ["burn-candle?/metal"] @@ -111,6 +112,7 @@ burn-wgpu = { path = "../burn-wgpu", version = "0.14.0", optional = true, defaul burn-autodiff = { path = "../burn-autodiff", version = "0.14.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.14.0", optional = true } burn-candle = { path = "../burn-candle", version = "0.14.0", optional = true } +burn-sparse = { path = "../burn-sparse", version = "0.14.0", optional = true } derive-new = { workspace = true } log = { workspace = true, optional = true } diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index 34fbdf3370..a1a5813491 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -27,3 +27,6 @@ pub use burn_tch as libtorch; #[cfg(feature = "tch")] pub use burn_tch::LibTorch; + +#[cfg(feature = "sparse")] +pub use burn_sparse as sparse; diff --git a/crates/burn-sparse/Cargo.toml b/crates/burn-sparse/Cargo.toml new file mode 100644 index 0000000000..4b60b96254 --- /dev/null +++ b/crates/burn-sparse/Cargo.toml @@ -0,0 +1,43 @@ +[package] +authors = [] +categories = ["science", "no-std", "embedded", "wasm"] +description = "Sparse tensor crate that offers a default sparse backend wrapper around burn backends." +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "tensor", "sparse"] +license.workspace = true +name = "burn-sparse" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/burn-sparse" +version.workspace = true + +[features] +default = ["std"] +doc = ["default"] +experimental-named-tensor = [] +std = ["rand/std", "half/std", "num-traits/std"] +wasm-sync = [] + +[dependencies] +burn-common = { path = "../burn-common", version = "0.14.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.14.0" } + +proc-macro2 = { workspace = true } +quote = { workspace = true } +syn = { workspace = true } +derive-new = { workspace = true } +half = { workspace = true } +num-traits = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } # use instead of statrs because it supports no_std + +# The same implementation of HashMap in std but with no_std support (only needs alloc crate) +hashbrown = { workspace = true } # no_std compatible + +# Serialization +serde = { workspace = true } + +[dev-dependencies] +rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std + +[package.metadata.docs.rs] +features = ["doc"] diff --git a/crates/burn-sparse/src/backend/alias.rs b/crates/burn-sparse/src/backend/alias.rs new file mode 100644 index 0000000000..dca8d059ce --- /dev/null +++ b/crates/burn-sparse/src/backend/alias.rs @@ -0,0 +1,4 @@ +use crate::backend::SparseBackend; + +/// Sparse tensor primitive type used by the backend. +pub type SparseTensor = ::SparseTensorPrimitive; diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs new file mode 100644 index 0000000000..60e2023af3 --- /dev/null +++ b/crates/burn-sparse/src/backend/api.rs @@ -0,0 +1,33 @@ +use crate::backend::{Sparse, SparseBackend}; +use burn_tensor::{Int, Tensor, TensorPrimitive}; + +pub trait SparseTensor +where + B: SparseBackend, +{ + fn dense_int(self) -> Tensor; + fn spmm(self, rhs: Tensor) -> Tensor; + fn dense(self) -> Tensor; +} + +impl SparseTensor for Tensor +where + B: SparseBackend, +{ + fn dense(self) -> Tensor { + Tensor::new(TensorPrimitive::Float(B::sparse_to_dense( + self.into_primitive(), + ))) + } + + fn dense_int(self) -> Tensor { + self.dense().int() + } + + fn spmm(self, rhs: Tensor) -> Tensor { + Tensor::new(TensorPrimitive::Float(B::sparse_spmm( + self.into_primitive(), + rhs.into_primitive().tensor(), + ))) + } +} diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs new file mode 100644 index 0000000000..f0e54ac8ec --- /dev/null +++ b/crates/burn-sparse/src/backend/kind.rs @@ -0,0 +1,61 @@ +use std::{future::Future, ops::Range}; + +use crate::backend::SparseBackend; +use burn_tensor::{backend::Backend, BasicOps, Shape, TensorData, TensorKind}; + +/// A type-level representation of the kind of a sparse (float) tensor. +#[derive(Clone, Debug)] +pub struct Sparse; + +impl TensorKind for Sparse { + type Primitive = B::SparseTensorPrimitive; + fn name() -> &'static str { + "Sparse" + } +} + +impl BasicOps for Sparse { + type Elem = B::FloatElem; + + fn into_data_async( + tensor: Self::Primitive, + ) -> impl Future + Send { + B::sparse_into_data(tensor) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + B::sparse_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + B::sparse_to_device(tensor, device) + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> Self::Primitive { + B::sparse_from_data(data, device) + } + + fn shape(tensor: &Self::Primitive) -> Shape { + B::sparse_shape(tensor) + } + + fn empty( + shape: Shape, + device: &::Device, + ) -> Self::Primitive { + B::sparse_empty(shape, device) + } + + fn slice( + tensor: Self::Primitive, + ranges: [Range; D2], + ) -> Self::Primitive { + B::sparse_slice(tensor, ranges) + } +} diff --git a/crates/burn-sparse/src/backend/mod.rs b/crates/burn-sparse/src/backend/mod.rs new file mode 100644 index 0000000000..20c24353ed --- /dev/null +++ b/crates/burn-sparse/src/backend/mod.rs @@ -0,0 +1,8 @@ +mod alias; +mod api; +mod kind; +mod sparse_backend; + +pub use alias::*; +pub use kind::*; +pub use sparse_backend::*; diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs new file mode 100644 index 0000000000..09020dfcd7 --- /dev/null +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -0,0 +1,110 @@ +use crate::backend::SparseTensor; +use burn_tensor::{backend::Backend, Device, Shape, TensorData}; +use core::{future::Future, ops::Range}; + +pub trait SparseBackend: Backend { + type SparseTensorPrimitive: Clone + Send + 'static + core::fmt::Debug; + + fn sparse_empty( + shape: Shape, + device: &Device, + ) -> SparseTensor; + + fn sparse_to_sparse( + dense: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive; + + fn sparse_to_dense( + sparse: Self::SparseTensorPrimitive, + ) -> Self::FloatTensorPrimitive; + + fn sparse_spmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::FloatTensorPrimitive; + + fn sparse_sddmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive; + + /// Gets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The elements at the given indices. + fn sparse_slice( + tensor: SparseTensor, + indices: [Range; D2], + ) -> SparseTensor; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn sparse_device(tensor: &SparseTensor) -> Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn sparse_to_device( + tensor: SparseTensor, + device: &Device, + ) -> SparseTensor; + + /// Gets the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn sparse_shape(tensor: &SparseTensor) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn sparse_into_data( + tensor: SparseTensor, + ) -> impl Future + Send; + + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn sparse_from_data( + data: TensorData, + device: &Device, + ) -> SparseTensor; +} diff --git a/crates/burn-sparse/src/decorator/backend.rs b/crates/burn-sparse/src/decorator/backend.rs new file mode 100644 index 0000000000..7a0fbf0002 --- /dev/null +++ b/crates/burn-sparse/src/decorator/backend.rs @@ -0,0 +1,42 @@ +use crate::decorator::FullPrecisionBridge; +use crate::decorator::SparseRepresentation; +use burn_tensor::backend::Backend; +use core::marker::PhantomData; +use derive_new::new; + +/// Tensor backend that extends existing backends with sparse tensor support. +/// This backend abstracts over all backends, and so lacks the performance of a direct implementation. +/// Backends implementing SparseDecorator should be used directly where possible. +#[derive(new, Clone, Copy, Default, Debug)] +pub struct SparseDecorator { + _p: PhantomData, + _r: PhantomData, +} + +impl Backend for SparseDecorator { + type Device = B::Device; + + type FullPrecisionBridge = FullPrecisionBridge; + + type FloatTensorPrimitive = B::FloatTensorPrimitive; + + type FloatElem = B::FloatElem; + + type IntTensorPrimitive = B::IntTensorPrimitive; + + type IntElem = B::IntElem; + + type BoolTensorPrimitive = B::BoolTensorPrimitive; + + type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; + + fn name() -> String { + format!("SparseDecorator<{}>", B::name()) + } + + fn seed(seed: u64) { + B::seed(seed) + } +} + +impl SparseDecorator {} diff --git a/crates/burn-sparse/src/decorator/mod.rs b/crates/burn-sparse/src/decorator/mod.rs new file mode 100644 index 0000000000..ba2b258c05 --- /dev/null +++ b/crates/burn-sparse/src/decorator/mod.rs @@ -0,0 +1,10 @@ +mod backend; +mod ops; +mod precision_bridge; +mod representation; +mod sparse_coo; +mod sparse_csr; + +pub use backend::*; +pub use precision_bridge::*; +pub use representation::*; diff --git a/crates/burn-sparse/src/decorator/ops.rs b/crates/burn-sparse/src/decorator/ops.rs new file mode 100644 index 0000000000..1ec03ce3cd --- /dev/null +++ b/crates/burn-sparse/src/decorator/ops.rs @@ -0,0 +1,1143 @@ +use crate::decorator::SparseDecorator; +use crate::decorator::SparseRepresentation; +use burn_tensor::{ + backend::Backend, + ops::{ + ActivationOps, BoolTensor, BoolTensorOps, ConvOptions, ConvTransposeOptions, FloatTensor, + FloatTensorOps, IntElem, IntTensor, IntTensorOps, InterpolateOptions, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, QTensorOps, + }, + Device, Distribution, Shape, TensorData, +}; +use core::ops::Range; + +impl FloatTensorOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ + fn float_random( + shape: burn_tensor::Shape, + distribution: burn_tensor::Distribution, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::FloatTensor { + B::float_random(shape, distribution, device) + } + + fn float_shape( + tensor: &burn_tensor::ops::FloatTensor, + ) -> burn_tensor::Shape { + B::float_shape(tensor) + } + + fn float_device( + tensor: &burn_tensor::ops::FloatTensor, + ) -> burn_tensor::Device { + B::float_device(tensor) + } + + fn float_to_device( + tensor: burn_tensor::ops::FloatTensor, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::FloatTensor { + B::float_to_device(tensor, device) + } + + fn float_into_int( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::IntTensor { + B::float_into_int(tensor) + } + + fn float_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::FloatTensor { + B::float_empty(shape, device) + } + + fn float_add( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_add(lhs, rhs) + } + + fn float_add_scalar( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_add_scalar(lhs, rhs) + } + + fn float_sub( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_sub(lhs, rhs) + } + + fn float_sub_scalar( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_sub_scalar(lhs, rhs) + } + + fn float_mul( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_mul(lhs, rhs) + } + + fn float_mul_scalar( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_mul_scalar(lhs, rhs) + } + + fn float_div( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_div(lhs, rhs) + } + + fn float_div_scalar( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_div_scalar(lhs, rhs) + } + + fn float_remainder_scalar( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_remainder_scalar(lhs, rhs) + } + + fn float_matmul( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_matmul(lhs, rhs) + } + + fn float_recip( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_recip(tensor) + } + + fn float_swap_dims( + tensor: burn_tensor::ops::FloatTensor, + dim1: usize, + dim2: usize, + ) -> burn_tensor::ops::FloatTensor { + B::float_swap_dims(tensor, dim1, dim2) + } + + fn float_permute( + tensor: burn_tensor::ops::FloatTensor, + axes: [usize; D], + ) -> burn_tensor::ops::FloatTensor { + B::float_permute(tensor, axes) + } + + fn float_flip( + tensor: burn_tensor::ops::FloatTensor, + axes: &[usize], + ) -> burn_tensor::ops::FloatTensor { + B::float_flip(tensor, axes) + } + + fn float_reshape( + tensor: burn_tensor::ops::FloatTensor, + shape: burn_tensor::Shape, + ) -> burn_tensor::ops::FloatTensor { + B::float_reshape(tensor, shape) + } + + fn float_gather( + dim: usize, + tensor: burn_tensor::ops::FloatTensor, + indices: burn_tensor::ops::IntTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_gather(dim, tensor, indices) + } + + fn float_scatter( + dim: usize, + tensor: burn_tensor::ops::FloatTensor, + indices: burn_tensor::ops::IntTensor, + value: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_scatter(dim, tensor, indices, value) + } + + fn float_select( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_select(tensor, dim, indices) + } + + fn float_select_assign( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + value: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_select_assign(tensor, dim, indices, value) + } + + fn float_slice( + tensor: burn_tensor::ops::FloatTensor, + ranges: [core::ops::Range; D2], + ) -> burn_tensor::ops::FloatTensor { + B::float_slice(tensor, ranges) + } + + fn float_slice_assign( + tensor: burn_tensor::ops::FloatTensor, + ranges: [core::ops::Range; D2], + value: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_slice_assign(tensor, ranges, value) + } + + fn float_mask_where( + tensor: burn_tensor::ops::FloatTensor, + mask: burn_tensor::ops::BoolTensor, + value: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_mask_where(tensor, mask, value) + } + + fn float_mask_fill( + tensor: burn_tensor::ops::FloatTensor, + mask: burn_tensor::ops::BoolTensor, + value: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::FloatTensor { + B::float_mask_fill(tensor, mask, value) + } + + fn float_equal( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::BoolTensor { + B::float_equal(lhs, rhs) + } + + fn float_equal_elem( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::BoolTensor { + B::float_equal_elem(lhs, rhs) + } + + fn float_greater( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::BoolTensor { + B::float_greater(lhs, rhs) + } + + fn float_greater_elem( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::BoolTensor { + B::float_greater_elem(lhs, rhs) + } + + fn float_greater_equal( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::BoolTensor { + B::float_greater_equal(lhs, rhs) + } + + fn float_greater_equal_elem( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::BoolTensor { + B::float_greater_equal_elem(lhs, rhs) + } + + fn float_lower( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::BoolTensor { + B::float_lower(lhs, rhs) + } + + fn float_lower_elem( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::BoolTensor { + B::float_lower_elem(lhs, rhs) + } + + fn float_lower_equal( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::BoolTensor { + B::float_lower_equal(lhs, rhs) + } + + fn float_lower_equal_elem( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatElem, + ) -> burn_tensor::ops::BoolTensor { + B::float_lower_equal_elem(lhs, rhs) + } + + fn float_sum( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_sum(tensor) + } + + fn float_sum_dim( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + ) -> burn_tensor::ops::FloatTensor { + B::float_sum_dim(tensor, dim) + } + + fn float_mean_dim( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + ) -> burn_tensor::ops::FloatTensor { + B::float_mean_dim(tensor, dim) + } + + fn float_exp( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_exp(tensor) + } + + fn float_log( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_log(tensor) + } + + fn float_log1p( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_log1p(tensor) + } + + fn float_powf( + lhs: burn_tensor::ops::FloatTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_powf(lhs, rhs) + } + + fn float_powf_scalar( + tensor: burn_tensor::ops::FloatTensor, + value: f32, + ) -> burn_tensor::ops::FloatTensor { + B::float_powf_scalar(tensor, value) + } + + fn float_sqrt( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_sqrt(tensor) + } + + fn float_abs( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_abs(tensor) + } + + fn float_cos( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_cos(tensor) + } + + fn float_sin( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_sin(tensor) + } + + fn float_tanh( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_tanh(tensor) + } + + fn float_erf( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + B::float_erf(tensor) + } + + fn float_argmax( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + ) -> burn_tensor::ops::IntTensor { + B::float_argmax(tensor, dim) + } + + fn float_argmin( + tensor: burn_tensor::ops::FloatTensor, + dim: usize, + ) -> burn_tensor::ops::IntTensor { + B::float_argmin(tensor, dim) + } + + fn float_expand( + tensor: burn_tensor::ops::FloatTensor, + shape: burn_tensor::Shape, + ) -> burn_tensor::ops::FloatTensor { + B::float_expand(tensor, shape) + } + + fn float_into_data( + tensor: FloatTensor, D>, + ) -> impl std::future::Future + Send { + B::float_into_data(tensor) + } + + fn float_from_data( + data: TensorData, + device: &Device>, + ) -> FloatTensor, D> { + B::float_from_data(data, device) + } +} + +impl BoolTensorOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ + fn bool_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device>, + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_empty(shape, device) + } + + fn bool_shape( + tensor: &burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::Shape { + B::bool_shape(tensor) + } + + fn bool_into_int( + tensor: burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::ops::IntTensor, D> { + B::bool_into_int(tensor) + } + + fn bool_into_float( + tensor: burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::ops::FloatTensor, D> { + B::bool_into_float(tensor) + } + + fn bool_device( + tensor: &burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::Device> { + B::bool_device(tensor) + } + + fn bool_to_device( + tensor: burn_tensor::ops::BoolTensor, D>, + device: &burn_tensor::Device>, + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_to_device(tensor, device) + } + + fn bool_reshape( + tensor: burn_tensor::ops::BoolTensor, D1>, + shape: burn_tensor::Shape, + ) -> burn_tensor::ops::BoolTensor, D2> { + B::bool_reshape(tensor, shape) + } + + fn bool_slice( + tensor: burn_tensor::ops::BoolTensor, D1>, + ranges: [core::ops::Range; D2], + ) -> burn_tensor::ops::BoolTensor, D1> { + B::bool_slice(tensor, ranges) + } + + fn bool_slice_assign( + tensor: burn_tensor::ops::BoolTensor, D1>, + ranges: [core::ops::Range; D2], + value: burn_tensor::ops::BoolTensor, D1>, + ) -> burn_tensor::ops::BoolTensor, D1> { + B::bool_slice_assign(tensor, ranges, value) + } + + fn bool_equal( + lhs: burn_tensor::ops::BoolTensor, D>, + rhs: burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_equal(lhs, rhs) + } + + fn bool_not( + tensor: burn_tensor::ops::BoolTensor, D>, + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_not(tensor) + } + + fn bool_swap_dims( + tensor: burn_tensor::ops::BoolTensor, D>, + dim1: usize, + dim2: usize, + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_swap_dims(tensor, dim1, dim2) + } + + fn bool_permute( + tensor: burn_tensor::ops::BoolTensor, D>, + axes: [usize; D], + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_permute(tensor, axes) + } + + fn bool_flip( + tensor: burn_tensor::ops::BoolTensor, D>, + axes: &[usize], + ) -> burn_tensor::ops::BoolTensor, D> { + B::bool_flip(tensor, axes) + } + + fn bool_expand( + tensor: burn_tensor::ops::BoolTensor, D1>, + shape: burn_tensor::Shape, + ) -> burn_tensor::ops::BoolTensor, D2> { + B::bool_expand(tensor, shape) + } + + fn bool_into_data( + tensor: BoolTensor, D>, + ) -> impl std::future::Future + Send { + B::bool_into_data(tensor) + } + + fn bool_from_data( + data: TensorData, + device: &Device>, + ) -> BoolTensor, D> { + B::bool_from_data(data, device) + } +} + +impl IntTensorOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ + fn int_empty( + shape: Shape, + device: &Device>, + ) -> IntTensor, D> { + B::int_empty(shape, device) + } + + fn int_shape(tensor: &IntTensor, D>) -> Shape { + B::int_shape(tensor) + } + + fn int_device( + tensor: &IntTensor, D>, + ) -> Device> { + B::int_device(tensor) + } + + fn int_to_device( + tensor: IntTensor, D>, + device: &Device>, + ) -> IntTensor, D> { + B::int_to_device(tensor, device) + } + + fn int_reshape( + tensor: IntTensor, D1>, + shape: Shape, + ) -> IntTensor, D2> { + B::int_reshape(tensor, shape) + } + + fn int_slice( + tensor: IntTensor, D1>, + indices: [Range; D2], + ) -> IntTensor, D1> { + B::int_slice(tensor, indices) + } + + fn int_slice_assign( + tensor: IntTensor, D1>, + indices: [Range; D2], + value: IntTensor, D1>, + ) -> IntTensor, D1> { + B::int_slice_assign(tensor, indices, value) + } + + fn int_into_float( + tensor: IntTensor, D>, + ) -> FloatTensor, D> { + B::int_into_float(tensor) + } + + fn int_mask_where( + tensor: IntTensor, D>, + mask: BoolTensor, D>, + source: IntTensor, D>, + ) -> IntTensor, D> { + B::int_mask_where(tensor, mask, source) + } + + fn int_mask_fill( + tensor: IntTensor, D>, + mask: BoolTensor, D>, + value: IntElem>, + ) -> IntTensor, D> { + B::int_mask_fill(tensor, mask, value) + } + + fn int_gather( + dim: usize, + tensor: IntTensor, D>, + indices: IntTensor, D>, + ) -> IntTensor, D> { + B::int_gather(dim, tensor, indices) + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, D>, + indices: IntTensor, D>, + value: IntTensor, D>, + ) -> IntTensor, D> { + B::int_scatter(dim, tensor, indices, value) + } + + fn int_select( + tensor: IntTensor, D>, + dim: usize, + indices: IntTensor, 1>, + ) -> IntTensor, D> { + B::int_select(tensor, dim, indices) + } + + fn int_select_assign( + tensor: IntTensor, D>, + dim: usize, + indices: IntTensor, 1>, + value: IntTensor, D>, + ) -> IntTensor, D> { + B::int_select_assign(tensor, dim, indices, value) + } + + fn int_repeat( + tensor: IntTensor, D>, + dim: usize, + times: usize, + ) -> IntTensor, D> { + B::int_repeat(tensor, dim, times) + } + + fn int_cat( + tensors: Vec, D>>, + dim: usize, + ) -> IntTensor, D> { + B::int_cat(tensors, dim) + } + + fn int_equal( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> BoolTensor, D> { + B::int_equal(lhs, rhs) + } + + fn int_equal_elem( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> BoolTensor, D> { + B::int_equal_elem(lhs, rhs) + } + + fn int_greater( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> BoolTensor, D> { + B::int_greater(lhs, rhs) + } + + fn int_greater_elem( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> BoolTensor, D> { + B::int_greater_elem(lhs, rhs) + } + + fn int_greater_equal( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> BoolTensor, D> { + B::int_greater_equal(lhs, rhs) + } + + fn int_greater_equal_elem( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> BoolTensor, D> { + B::int_greater_equal_elem(lhs, rhs) + } + + fn int_lower( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> BoolTensor, D> { + B::int_lower(lhs, rhs) + } + + fn int_lower_elem( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> BoolTensor, D> { + B::int_lower_elem(lhs, rhs) + } + + fn int_lower_equal( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> BoolTensor, D> { + B::int_lower_equal(lhs, rhs) + } + + fn int_lower_equal_elem( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> BoolTensor, D> { + B::int_lower_equal_elem(lhs, rhs) + } + + fn int_sub( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> IntTensor, D> { + B::int_sub(lhs, rhs) + } + + fn int_sub_scalar( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> IntTensor, D> { + B::int_sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> IntTensor, D> { + B::int_mul(lhs, rhs) + } + + fn int_mul_scalar( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> IntTensor, D> { + B::int_mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> IntTensor, D> { + B::int_div(lhs, rhs) + } + + fn int_div_scalar( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> IntTensor, D> { + B::int_div_scalar(lhs, rhs) + } + + fn int_remainder_scalar( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> IntTensor, D> { + B::int_remainder_scalar(lhs, rhs) + } + + fn int_zeros( + shape: Shape, + device: &Device>, + ) -> IntTensor, D> { + B::int_zeros(shape, device) + } + + fn int_ones( + shape: Shape, + device: &Device>, + ) -> IntTensor, D> { + B::int_ones(shape, device) + } + + fn int_sum( + tensor: IntTensor, D>, + ) -> IntTensor, 1> { + B::int_sum(tensor) + } + + fn int_sum_dim( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_sum_dim(tensor, dim) + } + + fn int_prod( + tensor: IntTensor, D>, + ) -> IntTensor, 1> { + B::int_prod(tensor) + } + + fn int_prod_dim( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_prod_dim(tensor, dim) + } + + fn int_mean_dim( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_mean_dim(tensor, dim) + } + + fn int_argmax( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_argmax(tensor, dim) + } + + fn int_argmin( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_argmin(tensor, dim) + } + + fn int_max_dim( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_max_dim(tensor, dim) + } + + fn int_max_dim_with_indices( + tensor: IntTensor, D>, + dim: usize, + ) -> ( + IntTensor, D>, + IntTensor, D>, + ) { + B::int_max_dim_with_indices(tensor, dim) + } + + fn int_min_dim( + tensor: IntTensor, D>, + dim: usize, + ) -> IntTensor, D> { + B::int_min_dim(tensor, dim) + } + + fn int_min_dim_with_indices( + tensor: IntTensor, D>, + dim: usize, + ) -> ( + IntTensor, D>, + IntTensor, D>, + ) { + B::int_min_dim_with_indices(tensor, dim) + } + + fn int_abs( + tensor: IntTensor, D>, + ) -> IntTensor, D> { + B::int_abs(tensor) + } + + fn int_transpose( + tensor: IntTensor, D>, + ) -> IntTensor, D> { + B::int_transpose(tensor) + } + + fn int_swap_dims( + tensor: IntTensor, D>, + dim1: usize, + dim2: usize, + ) -> IntTensor, D> { + B::int_swap_dims(tensor, dim1, dim2) + } + + fn int_permute( + tensor: IntTensor, D>, + axes: [usize; D], + ) -> IntTensor, D> { + B::int_permute(tensor, axes) + } + + fn int_flip( + tensor: IntTensor, D>, + axes: &[usize], + ) -> IntTensor, D> { + B::int_flip(tensor, axes) + } + + fn int_narrow( + tensor: IntTensor, D>, + dim: usize, + start: usize, + length: usize, + ) -> IntTensor, D> { + B::int_narrow(tensor, dim, start, length) + } + + fn int_chunk( + tensor: IntTensor, D>, + chunks: usize, + dim: usize, + ) -> Vec, D>> { + B::int_chunk(tensor, chunks, dim) + } + + fn int_random( + shape: Shape, + distribution: Distribution, + device: &Device>, + ) -> IntTensor, D> { + B::int_random(shape, distribution, device) + } + + fn int_add( + lhs: IntTensor, D>, + rhs: IntTensor, D>, + ) -> IntTensor, D> { + B::int_add(lhs, rhs) + } + + fn int_add_scalar( + lhs: IntTensor, D>, + rhs: IntElem>, + ) -> IntTensor, D> { + B::int_add_scalar(lhs, rhs) + } + + fn int_expand( + tensor: IntTensor, D1>, + shape: Shape, + ) -> IntTensor, D2> { + B::int_expand(tensor, shape) + } + + fn int_into_data( + tensor: IntTensor, D>, + ) -> impl std::future::Future + Send { + B::int_into_data(tensor) + } + + fn int_from_data( + data: TensorData, + device: &Device>, + ) -> IntTensor, D> { + B::int_from_data(data, device) + } +} + +impl QTensorOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ + fn quantize( + tensor: FloatTensor, + strategy: &burn_tensor::QuantizationStrategy, + ) -> burn_tensor::ops::QuantizedTensor { + B::quantize(tensor, strategy) + } + + fn dequantize( + tensor: burn_tensor::ops::QuantizedTensor, + strategy: &burn_tensor::QuantizationStrategy, + ) -> FloatTensor { + B::dequantize(tensor, strategy) + } + + fn q_shape(tensor: &burn_tensor::ops::QuantizedTensor) -> Shape { + B::q_shape(tensor) + } + + fn q_device(tensor: &burn_tensor::ops::QuantizedTensor) -> Device { + B::q_device(tensor) + } +} + +impl ModuleOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ + fn conv2d( + x: FloatTensor, 4>, + weight: FloatTensor, 4>, + bias: Option, 1>>, + options: ConvOptions<2>, + ) -> FloatTensor, 4> { + B::conv2d(x, weight, bias, options) + } + + fn conv_transpose2d( + x: FloatTensor, 4>, + weight: FloatTensor, 4>, + bias: Option, 1>>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor, 4> { + B::conv_transpose2d(x, weight, bias, options) + } + + fn avg_pool2d( + x: FloatTensor, 4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor, 4> { + B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) + } + + fn avg_pool2d_backward( + x: FloatTensor, 4>, + grad: FloatTensor, 4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor, 4> { + B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) + } + + fn max_pool2d( + x: FloatTensor, 4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor, 4> { + B::max_pool2d(x, kernel_size, stride, padding, dilation) + } + + fn max_pool2d_with_indices( + x: FloatTensor, 4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices> { + let MaxPool2dWithIndices { output, indices } = + B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); + MaxPool2dWithIndices { output, indices } + } + + fn max_pool2d_with_indices_backward( + x: FloatTensor, 4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, 4>, + indices: IntTensor, 4>, + ) -> MaxPool2dBackward> { + let MaxPool2dBackward { x_grad } = B::max_pool2d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + output_grad, + indices, + ); + MaxPool2dBackward { x_grad } + } + + fn adaptive_avg_pool2d( + x: FloatTensor, 4>, + output_size: [usize; 2], + ) -> FloatTensor, 4> { + B::adaptive_avg_pool2d(x, output_size) + } + + fn adaptive_avg_pool2d_backward( + x: FloatTensor, 4>, + grad: FloatTensor, 4>, + ) -> FloatTensor, 4> { + B::adaptive_avg_pool2d_backward(x, grad) + } + + fn interpolate( + x: FloatTensor, 4>, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor, 4> { + B::interpolate(x, output_size, options) + } + + fn interpolate_backward( + x: FloatTensor, 4>, + grad: FloatTensor, 4>, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor, 4> { + B::interpolate_backward(x, grad, output_size, options) + } + + fn conv3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<3>, + ) -> FloatTensor { + B::conv3d(x, weight, bias, options) + } + + fn conv_transpose3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<3>, + ) -> FloatTensor { + B::conv_transpose3d(x, weight, bias, options) + } +} + +impl ActivationOps> for SparseDecorator +where + B: Backend, + R: SparseRepresentation, +{ +} diff --git a/crates/burn-sparse/src/decorator/precision_bridge.rs b/crates/burn-sparse/src/decorator/precision_bridge.rs new file mode 100644 index 0000000000..2f5a78cdc9 --- /dev/null +++ b/crates/burn-sparse/src/decorator/precision_bridge.rs @@ -0,0 +1,37 @@ +use core::marker::PhantomData; + +use burn_tensor::{ + backend::{Backend, BackendBridge}, + ops::FloatTensor, +}; + +use crate::decorator::SparseDecorator; +use crate::decorator::SparseRepresentation; + +#[derive(Debug)] +pub struct FullPrecisionBridge { + _p: PhantomData, +} + +impl BackendBridge> for FullPrecisionBridge +where + B: Backend, + R: SparseRepresentation, + Bridge: BackendBridge + 'static, +{ + type Target = SparseDecorator; + + fn into_target( + tensor: FloatTensor, D>, + device: Option>, + ) -> burn_tensor::ops::FloatTensor { + Bridge::into_target(tensor, device) + } + + fn from_target( + tensor: burn_tensor::ops::FloatTensor, + device: Option>>, + ) -> burn_tensor::ops::FloatTensor, D> { + Bridge::from_target(tensor, device) + } +} diff --git a/crates/burn-sparse/src/decorator/representation.rs b/crates/burn-sparse/src/decorator/representation.rs new file mode 100644 index 0000000000..81d3ce96c6 --- /dev/null +++ b/crates/burn-sparse/src/decorator/representation.rs @@ -0,0 +1,21 @@ +#[derive(Debug, Default, Clone)] +pub struct SparseCSR; + +#[derive(Debug, Default, Clone)] +pub struct SparseCOO; + +pub trait SparseRepresentation: Clone + Default + Send + Sync + 'static + core::fmt::Debug { + fn name() -> String; +} + +impl SparseRepresentation for SparseCOO { + fn name() -> String { + "SparseCOO".to_owned() + } +} + +impl SparseRepresentation for SparseCSR { + fn name() -> String { + "SparseCSR".to_owned() + } +} diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs new file mode 100644 index 0000000000..1d9d262918 --- /dev/null +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -0,0 +1,273 @@ +use crate::decorator::SparseCOO; +use crate::decorator::SparseDecorator; +use burn_tensor::{ + backend::Backend, ops::SparseTensor, sparse_backend::SparseBackend, ElementConversion, Float, + Int, Shape, Tensor, TensorData, TensorPrimitive, +}; + +#[derive(Clone, Debug)] +pub struct SparseCOOTensor { + pub coordinates: Tensor, + pub values: Tensor, + pub shape: Shape, +} + +impl SparseBackend for SparseDecorator +where + B: Backend, +{ + type SparseTensorPrimitive = SparseCOOTensor; + + fn sparse_to_sparse( + dense: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); + + let shape = dense.shape(); + + let significant = dense.clone().not_equal_elem(0.0); + + let coordinates = significant + .clone() + .nonzero() + .into_iter() + .map(|tensor| { + let length = tensor.shape().dims[0]; + let shape = Shape::new([1, length]); + tensor.reshape(shape) + }) + .collect(); + + let coordinates = Tensor::cat(coordinates, 0); + + let dense = dense.flatten(0, D - 1); + + let dims = significant.dims(); + let values = dense.gather( + 0, + significant + .flatten::<1>(0, dims.len() - 1) + .nonzero() + .remove(0), + ); + + Self::SparseTensorPrimitive { + coordinates, + values, + shape, + } + } + + fn sparse_to_dense( + sparse: Self::SparseTensorPrimitive, + ) -> Self::FloatTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + } = sparse; + + let num_nonzero = coordinates.shape().dims[1]; + let device = coordinates.device(); + + let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); + + let mut strides_data = [[1]; D]; + for i in (0..D - 1).rev() { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } + + let strides_data: TensorData = TensorData::from(strides_data); + + let strides: Tensor = Tensor::from_data(strides_data, &device); + + let coordinates = strides.mul(coordinates).sum_dim(0).flatten(0, 1); + + let dense = dense.select_assign(0, coordinates, values); + + dense.reshape(shape).into_primitive().tensor() + } + + fn sparse_spmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::FloatTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + } = lhs; + + let rhs: Tensor = Tensor::from_primitive(TensorPrimitive::Float(rhs)); + let rhs_shape = rhs.shape(); + let device = coordinates.device(); + let nnz = coordinates.shape().dims[1]; + + // Ensure they are of the correct shape to multiply + if shape.dims[D - 1] != rhs_shape.dims[D - 2] { + panic!("Invalid shape for matrix multiplication"); + } + + // Ensure batches are the same + if D > 2 && rhs_shape.dims[0..D - 2] != shape.dims[0..D - 2] { + panic!("Batches must be of the same shape"); + } + + let mut out_shape = shape.clone(); + out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; + + // Compute strides for the dense tensor to match the flattened shape + let mut strides_data = [1; D]; + for i in (0..D - 1).rev() { + strides_data[i] = strides_data[i + 1] * shape.dims[i + 1] as i32; + } + let strides: Tensor = + Tensor::::from_ints(strides_data, &device).unsqueeze_dim(1); + + let column_index = coordinates.clone().slice([D - 1..D, 0..nnz]); + + // the indices into the flat row vector at which the containing matrix starts + let matrix_starts: Tensor = if D > 2 { + coordinates + .clone() + .slice([0..D - 2, 0..nnz]) + .mul(strides.clone().slice([0..D - 2])) + .div_scalar((shape.dims[D - 1]) as i32) + .sum_dim(0) + } else { + Tensor::::zeros(column_index.shape(), &device) + }; + + let row_index = coordinates.slice([D - 2..D - 1, 0..nnz]); + + let gather_index = matrix_starts.clone() + column_index; + let scatter_index = matrix_starts + row_index; + + let gather_index = gather_index.transpose().repeat(1, rhs_shape.dims[D - 1]); + let scatter_index = scatter_index.transpose().repeat(1, rhs_shape.dims[D - 1]); + let values = values.unsqueeze_dim(1).repeat(1, rhs_shape.dims[D - 1]); + + // Flatten the rhs similarly into 2 dimensions + let rhs: Tensor = rhs.reshape([-1, rhs_shape.dims[D - 1] as i32]); + + // Do the matmul using gather/scatter + let output: Tensor = + Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); + let gathered = rhs.gather(0, gather_index); + + let multiplied = gathered.mul(values); + + let scattered = output.scatter(0, scatter_index, multiplied); + + scattered.reshape(out_shape).into_primitive().tensor() + } + + fn sparse_sddmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + todo!() + } + + fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { + tensor.values.device() + } + + fn sparse_to_device( + tensor: burn_tensor::ops::SparseTensor, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + SparseCOOTensor { + coordinates: tensor.coordinates.to_device(device), + values: tensor.values.to_device(device), + shape: tensor.shape, + } + } + + fn sparse_shape( + tensor: &Self::SparseTensorPrimitive, + ) -> burn_tensor::Shape { + tensor.shape.clone() + } + + fn sparse_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + SparseCOOTensor { + coordinates: Tensor::from_primitive(B::int_empty( + burn_tensor::Shape::new([0, 0]), + &device, + )), + values: Tensor::from_primitive(TensorPrimitive::Float(B::float_empty( + burn_tensor::Shape::new([0]), + &device, + ))), + shape, + } + } + + fn sparse_slice( + tensor: Self::SparseTensorPrimitive, + indices: [std::ops::Range; D2], + ) -> burn_tensor::ops::SparseTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; + + let device = coordinates.device(); + let number_nonzero = coordinates.shape().dims[1]; + + let mut mask: Tensor = Tensor::ones(Shape::new([number_nonzero]), &device); + + for (dim, bound) in indices.iter().enumerate() { + let coords = coordinates.clone().slice([dim..dim + 1, 0..number_nonzero]); + let coords = coords.reshape(Shape::new([number_nonzero])); + + let mask_lower = coords + .clone() + .lower_elem(B::IntElem::from_elem(bound.end)) + .int(); + + let mask_upper = coords + .clone() + .greater_equal_elem(B::IntElem::from_elem(bound.start)) + .int(); + + mask = mask.mul(mask_lower).mul(mask_upper); + } + + let nonzero = mask.not_equal_elem(B::IntElem::from_elem(0)).nonzero(); + + let indices_dim1 = nonzero + .get(0) + .cloned() + .expect("Expected dimension to exist"); + + let coordinates = coordinates.select(1, indices_dim1.clone()); + let values = values.select(0, indices_dim1); + + SparseCOOTensor { + coordinates, + values, + shape, + } + } + + fn sparse_from_data( + data: TensorData, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + let dense = B::float_from_data(data, &device); + Self::sparse_to_sparse(dense) + } + + fn sparse_into_data( + tensor: burn_tensor::ops::SparseTensor, + ) -> impl std::future::Future + Send { + // TODO this could be way better + B::float_into_data(Self::sparse_to_dense(tensor)) + } +} diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs new file mode 100644 index 0000000000..605d87169c --- /dev/null +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -0,0 +1,88 @@ +use crate::decorator::SparseCSR; +use crate::decorator::SparseDecorator; +use burn_tensor::{backend::Backend, sparse_backend::SparseBackend}; +use core::marker::PhantomData; + +#[derive(Debug, Default, Clone)] +pub struct SparseCSRTensor { + _b: PhantomData, +} + +impl SparseBackend for SparseDecorator +where + B: Backend, +{ + type SparseTensorPrimitive = SparseCSRTensor; + + fn sparse_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + todo!() + } + + fn sparse_to_sparse( + dense: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + todo!() + } + + fn sparse_to_dense( + sparse: Self::SparseTensorPrimitive, + ) -> Self::FloatTensorPrimitive { + todo!() + } + + fn sparse_spmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::FloatTensorPrimitive { + todo!() + } + + fn sparse_sddmm( + lhs: Self::SparseTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + todo!() + } + + fn sparse_slice( + tensor: burn_tensor::ops::SparseTensor, + indices: [std::ops::Range; D2], + ) -> burn_tensor::ops::SparseTensor { + todo!() + } + + fn sparse_device( + tensor: &burn_tensor::ops::SparseTensor, + ) -> burn_tensor::Device { + todo!() + } + + fn sparse_to_device( + tensor: burn_tensor::ops::SparseTensor, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + todo!() + } + + fn sparse_shape( + tensor: &burn_tensor::ops::SparseTensor, + ) -> burn_tensor::Shape { + todo!() + } + + fn sparse_into_data( + tensor: burn_tensor::ops::SparseTensor, + ) -> impl std::future::Future + Send { + async { todo!() } + } + + fn sparse_from_data( + data: burn_tensor::TensorData, + device: &burn_tensor::Device, + ) -> burn_tensor::ops::SparseTensor { + todo!() + } +} diff --git a/crates/burn-sparse/src/lib.rs b/crates/burn-sparse/src/lib.rs new file mode 100644 index 0000000000..9dfe4c9dce --- /dev/null +++ b/crates/burn-sparse/src/lib.rs @@ -0,0 +1,2 @@ +pub mod backend; +pub mod decorator; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 4352559e8f..12d303ad74 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -36,6 +36,7 @@ vision = ["burn-core/vision"] # Backends autodiff = ["burn-core/autodiff"] fusion = ["burn-core/fusion"] +sparse = ["burn-core/sparse"] ## Backend features candle-cuda = ["burn-core/candle-cuda"] From 24daafd610dc8da142f1a8eb4db2e941fdf30635 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Thu, 11 Jul 2024 09:20:24 +0000 Subject: [PATCH 02/38] Fixed errors from moving sparse --- crates/burn-sparse/src/backend/kind.rs | 94 +++++++++++++++++++ .../burn-sparse/src/decorator/sparse_coo.rs | 17 ++-- .../burn-sparse/src/decorator/sparse_csr.rs | 26 +++-- 3 files changed, 115 insertions(+), 22 deletions(-) diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index f0e54ac8ec..1352030c13 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -58,4 +58,98 @@ impl BasicOps for Sparse { ) -> Self::Primitive { B::sparse_slice(tensor, ranges) } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + todo!() + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + todo!() + } + + fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { + todo!() + } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + todo!() + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn repeat( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + todo!() + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + todo!() + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> burn_tensor::Tensor { + todo!() + } + + fn not_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> burn_tensor::Tensor { + todo!() + } + + fn any( + tensor: Self::Primitive, + ) -> burn_tensor::Tensor { + todo!() + } + + fn any_dim( + tensor: Self::Primitive, + dim: usize, + ) -> burn_tensor::Tensor { + todo!() + } + + fn all( + tensor: Self::Primitive, + ) -> burn_tensor::Tensor { + todo!() + } + + fn all_dim( + tensor: Self::Primitive, + dim: usize, + ) -> burn_tensor::Tensor { + todo!() + } + + fn expand( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + todo!() + } } diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 1d9d262918..7b46e6487b 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -1,8 +1,9 @@ +use crate::backend::SparseBackend; +use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; use burn_tensor::{ - backend::Backend, ops::SparseTensor, sparse_backend::SparseBackend, ElementConversion, Float, - Int, Shape, Tensor, TensorData, TensorPrimitive, + backend::Backend, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; #[derive(Clone, Debug)] @@ -174,9 +175,9 @@ where } fn sparse_to_device( - tensor: burn_tensor::ops::SparseTensor, + tensor: SparseTensor, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { SparseCOOTensor { coordinates: tensor.coordinates.to_device(device), values: tensor.values.to_device(device), @@ -193,7 +194,7 @@ where fn sparse_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { SparseCOOTensor { coordinates: Tensor::from_primitive(B::int_empty( burn_tensor::Shape::new([0, 0]), @@ -210,7 +211,7 @@ where fn sparse_slice( tensor: Self::SparseTensorPrimitive, indices: [std::ops::Range; D2], - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { let SparseCOOTensor { coordinates, values, @@ -259,13 +260,13 @@ where fn sparse_from_data( data: TensorData, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { let dense = B::float_from_data(data, &device); Self::sparse_to_sparse(dense) } fn sparse_into_data( - tensor: burn_tensor::ops::SparseTensor, + tensor: SparseTensor, ) -> impl std::future::Future + Send { // TODO this could be way better B::float_into_data(Self::sparse_to_dense(tensor)) diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index 605d87169c..4cd7fc2038 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -1,6 +1,8 @@ +use crate::backend::SparseBackend; +use crate::backend::SparseTensor; use crate::decorator::SparseCSR; use crate::decorator::SparseDecorator; -use burn_tensor::{backend::Backend, sparse_backend::SparseBackend}; +use burn_tensor::backend::Backend; use core::marker::PhantomData; #[derive(Debug, Default, Clone)] @@ -17,7 +19,7 @@ where fn sparse_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { todo!() } @@ -48,33 +50,29 @@ where } fn sparse_slice( - tensor: burn_tensor::ops::SparseTensor, + tensor: SparseTensor, indices: [std::ops::Range; D2], - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { todo!() } - fn sparse_device( - tensor: &burn_tensor::ops::SparseTensor, - ) -> burn_tensor::Device { + fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { todo!() } fn sparse_to_device( - tensor: burn_tensor::ops::SparseTensor, + tensor: SparseTensor, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { todo!() } - fn sparse_shape( - tensor: &burn_tensor::ops::SparseTensor, - ) -> burn_tensor::Shape { + fn sparse_shape(tensor: &SparseTensor) -> burn_tensor::Shape { todo!() } fn sparse_into_data( - tensor: burn_tensor::ops::SparseTensor, + tensor: SparseTensor, ) -> impl std::future::Future + Send { async { todo!() } } @@ -82,7 +80,7 @@ where fn sparse_from_data( data: burn_tensor::TensorData, device: &burn_tensor::Device, - ) -> burn_tensor::ops::SparseTensor { + ) -> SparseTensor { todo!() } } From 790ed5f0ea13e09a8d076799666526a2471ad3ec Mon Sep 17 00:00:00 2001 From: mcarthur Date: Thu, 11 Jul 2024 09:53:08 +0000 Subject: [PATCH 03/38] some better imports and fixes --- crates/burn-core/src/backend.rs | 2 +- crates/burn-core/src/tensor.rs | 5 +++++ crates/burn-sparse/src/backend/api.rs | 20 ++++++++++++++++++-- crates/burn-sparse/src/backend/mod.rs | 1 + 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index a1a5813491..1c05f79617 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -29,4 +29,4 @@ pub use burn_tch as libtorch; pub use burn_tch::LibTorch; #[cfg(feature = "sparse")] -pub use burn_sparse as sparse; +pub use burn_sparse::decorator as sparse; diff --git a/crates/burn-core/src/tensor.rs b/crates/burn-core/src/tensor.rs index 074606bb14..ecc858ebbe 100644 --- a/crates/burn-core/src/tensor.rs +++ b/crates/burn-core/src/tensor.rs @@ -1 +1,6 @@ pub use burn_tensor::*; + +#[cfg(feature = "sparse")] +pub mod sparse { + pub use burn_sparse::backend::*; +} diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index 60e2023af3..146f4bbce4 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -1,7 +1,14 @@ use crate::backend::{Sparse, SparseBackend}; use burn_tensor::{Int, Tensor, TensorPrimitive}; -pub trait SparseTensor +pub trait ToSparse +where + B: SparseBackend, +{ + fn into_sparse(self) -> Tensor; +} + +pub trait SparseTensorApi where B: SparseBackend, { @@ -10,7 +17,16 @@ where fn dense(self) -> Tensor; } -impl SparseTensor for Tensor +impl ToSparse for Tensor +where + B: SparseBackend, +{ + fn into_sparse(self) -> Tensor { + Tensor::new(B::sparse_to_sparse(self.into_primitive().tensor())) + } +} + +impl SparseTensorApi for Tensor where B: SparseBackend, { diff --git a/crates/burn-sparse/src/backend/mod.rs b/crates/burn-sparse/src/backend/mod.rs index 20c24353ed..741143a850 100644 --- a/crates/burn-sparse/src/backend/mod.rs +++ b/crates/burn-sparse/src/backend/mod.rs @@ -4,5 +4,6 @@ mod kind; mod sparse_backend; pub use alias::*; +pub use api::*; pub use kind::*; pub use sparse_backend::*; From 6371a0a0cc76dd6d8d6c6098c77ed858fd3250c5 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Thu, 11 Jul 2024 11:18:48 +0000 Subject: [PATCH 04/38] Full sparse backend trait, lots unfinished --- crates/burn-sparse/src/backend/kind.rs | 32 +++--- .../burn-sparse/src/backend/sparse_backend.rs | 71 +++++++++++- .../burn-sparse/src/decorator/sparse_coo.rs | 103 ++++++++++++++++++ .../burn-sparse/src/decorator/sparse_csr.rs | 103 ++++++++++++++++++ 4 files changed, 292 insertions(+), 17 deletions(-) diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index 1352030c13..67b1aba619 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -1,7 +1,7 @@ use std::{future::Future, ops::Range}; use crate::backend::SparseBackend; -use burn_tensor::{backend::Backend, BasicOps, Shape, TensorData, TensorKind}; +use burn_tensor::{backend::Backend, BasicOps, Shape, Tensor, TensorData, TensorKind}; /// A type-level representation of the kind of a sparse (float) tensor. #[derive(Clone, Debug)] @@ -63,11 +63,11 @@ impl BasicOps for Sparse { tensor: Self::Primitive, shape: Shape, ) -> Self::Primitive { - todo!() + B::sparse_reshape(tensor, shape) } fn transpose(tensor: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_transpose(tensor) } fn swap_dims( @@ -75,15 +75,15 @@ impl BasicOps for Sparse { dim1: usize, dim2: usize, ) -> Self::Primitive { - todo!() + B::sparse_swap_dims(tensor, dim1, dim2) } fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - todo!() + B::sparse_permute(tensor, &axes) } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - todo!() + B::sparse_flip(tensor, &axes) } fn slice_assign( @@ -91,7 +91,7 @@ impl BasicOps for Sparse { ranges: [Range; D2], value: Self::Primitive, ) -> Self::Primitive { - todo!() + B::sparse_slice_assign(tensor, ranges, value) } fn repeat( @@ -99,57 +99,57 @@ impl BasicOps for Sparse { dim: usize, times: usize, ) -> Self::Primitive { - todo!() + B::sparse_repeat(tensor, dim, times) } fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - todo!() + B::sparse_cat(vectors, dim) } fn equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_equal(lhs, rhs)) } fn not_equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_not_equal(lhs, rhs)) } fn any( tensor: Self::Primitive, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_any(tensor)) } fn any_dim( tensor: Self::Primitive, dim: usize, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_any_dim(tensor, dim)) } fn all( tensor: Self::Primitive, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_all(tensor)) } fn all_dim( tensor: Self::Primitive, dim: usize, ) -> burn_tensor::Tensor { - todo!() + Tensor::new(B::sparse_all_dim(tensor, dim)) } fn expand( tensor: Self::Primitive, shape: Shape, ) -> Self::Primitive { - todo!() + B::sparse_expand(tensor, shape) } } diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index 09020dfcd7..056de9b2de 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -1,5 +1,5 @@ use crate::backend::SparseTensor; -use burn_tensor::{backend::Backend, Device, Shape, TensorData}; +use burn_tensor::{backend::Backend, ops::BoolTensor, Device, Shape, TensorData}; use core::{future::Future, ops::Range}; pub trait SparseBackend: Backend { @@ -107,4 +107,73 @@ pub trait SparseBackend: Backend { data: TensorData, device: &Device, ) -> SparseTensor; + + fn sparse_reshape( + tensor: SparseTensor, + shape: Shape, + ) -> SparseTensor; + + fn sparse_transpose(tensor: SparseTensor) -> SparseTensor; + + fn sparse_swap_dims( + tensor: SparseTensor, + dim1: usize, + dim2: usize, + ) -> SparseTensor; + + fn sparse_permute( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor; + + fn sparse_flip( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor; + + fn sparse_slice_assign( + tensor: SparseTensor, + ranges: [Range; D2], + value: SparseTensor, + ) -> SparseTensor; + + fn sparse_repeat( + tensor: SparseTensor, + dim: usize, + times: usize, + ) -> SparseTensor; + + fn sparse_cat( + tensors: Vec>, + dim: usize, + ) -> SparseTensor; + + fn sparse_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_not_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_any(tensor: SparseTensor) -> BoolTensor; + + fn sparse_any_dim( + tensor: SparseTensor, + dim: usize, + ) -> BoolTensor; + + fn sparse_all(tensor: SparseTensor) -> BoolTensor; + + fn sparse_all_dim( + tensor: SparseTensor, + dim: usize, + ) -> BoolTensor; + + fn sparse_expand( + tensor: SparseTensor, + shape: Shape, + ) -> SparseTensor; } diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 7b46e6487b..d547399937 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -271,4 +271,107 @@ where // TODO this could be way better B::float_into_data(Self::sparse_to_dense(tensor)) } + + fn sparse_reshape( + tensor: SparseTensor, + shape: Shape, + ) -> SparseTensor { + todo!() + } + + fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_swap_dims( + tensor: SparseTensor, + dim1: usize, + dim2: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_permute( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor { + todo!() + } + + fn sparse_flip( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor { + todo!() + } + + fn sparse_slice_assign( + tensor: SparseTensor, + ranges: [std::ops::Range; D2], + value: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_repeat( + tensor: SparseTensor, + dim: usize, + times: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_cat( + tensors: Vec>, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_not_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_any( + tensor: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_any_dim( + tensor: SparseTensor, + dim: usize, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_all( + tensor: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_all_dim( + tensor: SparseTensor, + dim: usize, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_expand( + tensor: SparseTensor, + shape: Shape, + ) -> SparseTensor { + todo!() + } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index 4cd7fc2038..c0888eb46a 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -83,4 +83,107 @@ where ) -> SparseTensor { todo!() } + + fn sparse_reshape( + tensor: SparseTensor, + shape: burn_tensor::Shape, + ) -> SparseTensor { + todo!() + } + + fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_swap_dims( + tensor: SparseTensor, + dim1: usize, + dim2: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_permute( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor { + todo!() + } + + fn sparse_flip( + tensor: SparseTensor, + axes: &[usize], + ) -> SparseTensor { + todo!() + } + + fn sparse_slice_assign( + tensor: SparseTensor, + ranges: [std::ops::Range; D2], + value: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_repeat( + tensor: SparseTensor, + dim: usize, + times: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_cat( + tensors: Vec>, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_not_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_any( + tensor: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_any_dim( + tensor: SparseTensor, + dim: usize, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_all( + tensor: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_all_dim( + tensor: SparseTensor, + dim: usize, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_expand( + tensor: SparseTensor, + shape: burn_tensor::Shape, + ) -> SparseTensor { + todo!() + } } From 6f5486437db3fe4915da0e2f016857461c458ea7 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Thu, 11 Jul 2024 12:29:58 +0000 Subject: [PATCH 05/38] sparse_reshape op --- .../burn-sparse/src/decorator/sparse_coo.rs | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index d547399937..65dbbf949d 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -2,6 +2,7 @@ use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; +use burn_tensor::ops::FloatTensor; use burn_tensor::{ backend::Backend, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -274,9 +275,44 @@ where fn sparse_reshape( tensor: SparseTensor, - shape: Shape, + out_shape: Shape, ) -> SparseTensor { - todo!() + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; + + let device = coordinates.device(); + + // Flatten the coordinates: + let mut strides_data = [[1]; D1]; + for i in (0..D1 - 1).rev() { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } + let strides_data: TensorData = TensorData::from(strides_data); + let strides: Tensor = Tensor::from_data(strides_data, &device); + let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); + + // Convert the flattened coordinates to the new shape + let mut remaining_flat_coordinates = flat_coordinates.clone(); + let mut new_coordinates = Vec::with_capacity(D2); + + for &dim_size in out_shape.dims.iter().rev() { + let size = dim_size as i64; + let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); + new_coordinates.push(new_coord.clone()); + remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); + } + + new_coordinates.reverse(); + let new_coordinates = Tensor::stack(new_coordinates, 0); + + SparseCOOTensor { + coordinates: new_coordinates, + values, + shape: out_shape, + } } fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { From 741d6ccdbe6d75ee1ce9825394ab279f00b0a6c5 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 14 Jul 2024 03:53:31 +0000 Subject: [PATCH 06/38] permute and transpose working --- .../burn-sparse/src/decorator/sparse_coo.rs | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 65dbbf949d..27ce2e4318 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -269,7 +269,6 @@ where fn sparse_into_data( tensor: SparseTensor, ) -> impl std::future::Future + Send { - // TODO this could be way better B::float_into_data(Self::sparse_to_dense(tensor)) } @@ -316,7 +315,10 @@ where } fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { - todo!() + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(d - 1, d - 2); + Self::sparse_permute(tensor, &axes) } fn sparse_swap_dims( @@ -331,7 +333,24 @@ where tensor: SparseTensor, axes: &[usize], ) -> SparseTensor { - todo!() + let SparseCOOTensor { + coordinates, + values, + mut shape, + } = tensor; + + for (i, &j) in (0..D).zip(axes).filter(|(i, j)| i < j) { + shape.dims.swap(i, j); + } + + let axes = Tensor::from(axes); + let coordinates = coordinates.select(0, axes); + + SparseCOOTensor { + coordinates, + values, + shape, + } } fn sparse_flip( From 08143fb5d3194836f73bf32e450bbf7e71f2247b Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 14 Jul 2024 04:13:41 +0000 Subject: [PATCH 07/38] swap dims --- crates/burn-sparse/src/decorator/sparse_coo.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 27ce2e4318..ba51c619aa 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -326,7 +326,10 @@ where dim1: usize, dim2: usize, ) -> SparseTensor { - todo!() + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(dim1, dim2); + Self::sparse_permute(tensor, &axes) } fn sparse_permute( From 19de62123de8222c972dd875cdf27ee0f96f5f62 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 14 Jul 2024 05:08:22 +0000 Subject: [PATCH 08/38] sparse flip --- .../burn-sparse/src/decorator/sparse_coo.rs | 74 ++++++++++++++++++- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index ba51c619aa..41e8d4f7a3 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -2,9 +2,9 @@ use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; -use burn_tensor::ops::FloatTensor; use burn_tensor::{ - backend::Backend, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, + backend::Backend, Bool, ElementConversion, Float, Int, Shape, Tensor, TensorData, + TensorPrimitive, }; #[derive(Clone, Debug)] @@ -360,7 +360,40 @@ where tensor: SparseTensor, axes: &[usize], ) -> SparseTensor { - todo!() + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; + + let nnz = coordinates.shape().dims[1]; + let device = &coordinates.device(); + + let mut mask = [0; D]; + for &axis in axes { + mask[axis] = 1; + } + let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, device) + .unsqueeze_dim(1) + .repeat(1, nnz) + .bool(); + + let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, device) + .unsqueeze_dim(1) + .repeat(1, nnz) + .sub(coordinates.clone()) + .sub_scalar(1); + + // println!("mask: {}", mask); + // println!("flipped: {}", flipped); + + let coordinates = coordinates.mask_where(mask, flipped); + + SparseCOOTensor { + coordinates, + values, + shape, + } } fn sparse_slice_assign( @@ -368,6 +401,11 @@ where ranges: [std::ops::Range; D2], value: SparseTensor, ) -> SparseTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } @@ -376,6 +414,11 @@ where dim: usize, times: usize, ) -> SparseTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } @@ -403,6 +446,11 @@ where fn sparse_any( tensor: SparseTensor, ) -> burn_tensor::ops::BoolTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } @@ -410,12 +458,22 @@ where tensor: SparseTensor, dim: usize, ) -> burn_tensor::ops::BoolTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } fn sparse_all( tensor: SparseTensor, ) -> burn_tensor::ops::BoolTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } @@ -423,6 +481,11 @@ where tensor: SparseTensor, dim: usize, ) -> burn_tensor::ops::BoolTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } @@ -430,6 +493,11 @@ where tensor: SparseTensor, shape: Shape, ) -> SparseTensor { + let SparseCOOTensor { + coordinates, + values, + shape, + } = tensor; todo!() } } From c56c25d6efb32adc42350b72c1b23a69c56201a9 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 14 Jul 2024 08:14:54 +0000 Subject: [PATCH 09/38] any, all, any_dim, all_dim --- .../burn-sparse/src/decorator/sparse_coo.rs | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 41e8d4f7a3..89b5afc0b9 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -384,9 +384,6 @@ where .sub(coordinates.clone()) .sub_scalar(1); - // println!("mask: {}", mask); - // println!("flipped: {}", flipped); - let coordinates = coordinates.mask_where(mask, flipped); SparseCOOTensor { @@ -451,19 +448,15 @@ where values, shape, } = tensor; - todo!() + let any = coordinates.shape().dims[1] > 0; + Tensor::::from([any]).into_primitive() } fn sparse_any_dim( tensor: SparseTensor, dim: usize, ) -> burn_tensor::ops::BoolTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - } = tensor; - todo!() + panic!("any_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); } fn sparse_all( @@ -474,19 +467,15 @@ where values, shape, } = tensor; - todo!() + let all = shape.num_elements() == coordinates.shape().dims[1]; + Tensor::::from([all]).into_primitive() } fn sparse_all_dim( tensor: SparseTensor, dim: usize, ) -> burn_tensor::ops::BoolTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - } = tensor; - todo!() + panic!("all_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); } fn sparse_expand( From e75181dbf334bca561f7422e2c3af9603248cd52 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 14 Jul 2024 09:47:26 +0000 Subject: [PATCH 10/38] repeat --- .../burn-sparse/src/decorator/sparse_coo.rs | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 89b5afc0b9..4b62ff29b5 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -414,9 +414,34 @@ where let SparseCOOTensor { coordinates, values, - shape, + mut shape, } = tensor; - todo!() + + let device = coordinates.device(); + let nnz = coordinates.shape().dims[1]; + + let values = values.repeat(0, times); + + let coordinates_mask: Tensor = Tensor::zeros(coordinates.shape(), &device); + let ones: Tensor = Tensor::ones(Shape::new([1, nnz]), &device); + let coordinates_mask = coordinates_mask.slice_assign([dim..dim + 1, 0..nnz], ones); + let coordinates = Tensor::cat( + (0..times) + .map(|n| { + coordinates.clone() + + coordinates_mask.clone() * (n as i32) * (shape.dims[dim] as i32) + }) + .collect::>(), + 1, + ); + + shape.dims[dim] *= times; + + SparseCOOTensor { + coordinates, + values, + shape, + } } fn sparse_cat( From cfe706bf72eec058df75f2e5a49f9382f65dbbb1 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Tue, 16 Jul 2024 11:28:45 +0000 Subject: [PATCH 11/38] coalesce, somewhat broken --- crates/burn-sparse/src/backend/api.rs | 21 ++ .../burn-sparse/src/backend/sparse_backend.rs | 8 + .../burn-sparse/src/decorator/sparse_coo.rs | 268 +++++++++++++++--- .../burn-sparse/src/decorator/sparse_csr.rs | 14 + 4 files changed, 278 insertions(+), 33 deletions(-) diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index 146f4bbce4..b4ad362f19 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -1,6 +1,10 @@ use crate::backend::{Sparse, SparseBackend}; use burn_tensor::{Int, Tensor, TensorPrimitive}; +pub enum CoalesceReduction { + Sum, +} + pub trait ToSparse where B: SparseBackend, @@ -15,6 +19,9 @@ where fn dense_int(self) -> Tensor; fn spmm(self, rhs: Tensor) -> Tensor; fn dense(self) -> Tensor; + fn coalesce(self, reduce: CoalesceReduction) -> Tensor; + fn number_nonzero(self) -> usize; + fn density(self) -> usize; } impl ToSparse for Tensor @@ -46,4 +53,18 @@ where rhs.into_primitive().tensor(), ))) } + + fn coalesce(self, reduction: CoalesceReduction) -> Tensor { + match reduction { + CoalesceReduction::Sum => Tensor::new(B::sparse_coalesce_sum(self.into_primitive())), + } + } + + fn number_nonzero(self) -> usize { + B::sparse_nonzero(self.into_primitive()) + } + + fn density(self) -> usize { + B::sparse_density(self.into_primitive()) + } } diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index 056de9b2de..c89e40e29b 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -28,6 +28,14 @@ pub trait SparseBackend: Backend { rhs: Self::FloatTensorPrimitive, ) -> Self::SparseTensorPrimitive; + fn sparse_coalesce_sum( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive; + + fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize; + + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize; + /// Gets the element at the given indices. /// /// # Arguments diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 4b62ff29b5..58b59f73c6 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -2,16 +2,20 @@ use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; +use burn_tensor::ops::IntTensorOps; +use burn_tensor::Device; use burn_tensor::{ backend::Backend, Bool, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; +use half::vec; #[derive(Clone, Debug)] pub struct SparseCOOTensor { - pub coordinates: Tensor, - pub values: Tensor, + pub coordinates: Option>, + pub values: Option>, pub shape: Shape, + pub device: Device, } impl SparseBackend for SparseDecorator @@ -26,8 +30,12 @@ where let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); let shape = dense.shape(); + let device = dense.device(); let significant = dense.clone().not_equal_elem(0.0); + if !significant.clone().any().into_scalar() { + return Self::sparse_empty(dense.shape(), &device); + }; let coordinates = significant .clone() @@ -53,10 +61,14 @@ where .remove(0), ); + let coordinates = Some(coordinates); + let values = Some(values); + Self::SparseTensorPrimitive { coordinates, values, shape, + device, } } @@ -67,10 +79,16 @@ where coordinates, values, shape, + device, } = sparse; + let (Some(coordinates), Some(values)) = (coordinates, values) else { + return Tensor::::zeros(shape, &device) + .into_primitive() + .tensor(); + }; + let num_nonzero = coordinates.shape().dims[1]; - let device = coordinates.device(); let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); @@ -98,11 +116,21 @@ where coordinates, values, shape, + device, } = lhs; let rhs: Tensor = Tensor::from_primitive(TensorPrimitive::Float(rhs)); let rhs_shape = rhs.shape(); - let device = coordinates.device(); + let mut out_shape = shape.clone(); + out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return Tensor::::zeros(out_shape, &device) + .into_primitive() + .tensor(); + }; + let nnz = coordinates.shape().dims[1]; // Ensure they are of the correct shape to multiply @@ -115,9 +143,6 @@ where panic!("Batches must be of the same shape"); } - let mut out_shape = shape.clone(); - out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; - // Compute strides for the dense tensor to match the flattened shape let mut strides_data = [1; D]; for i in (0..D - 1).rev() { @@ -172,7 +197,7 @@ where } fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { - tensor.values.device() + tensor.device.clone() } fn sparse_to_device( @@ -180,9 +205,10 @@ where device: &burn_tensor::Device, ) -> SparseTensor { SparseCOOTensor { - coordinates: tensor.coordinates.to_device(device), - values: tensor.values.to_device(device), + coordinates: tensor.coordinates.map(|t| t.to_device(device)), + values: tensor.values.map(|t| t.to_device(device)), shape: tensor.shape, + device: device.clone(), } } @@ -197,29 +223,38 @@ where device: &burn_tensor::Device, ) -> SparseTensor { SparseCOOTensor { - coordinates: Tensor::from_primitive(B::int_empty( - burn_tensor::Shape::new([0, 0]), - &device, - )), - values: Tensor::from_primitive(TensorPrimitive::Float(B::float_empty( - burn_tensor::Shape::new([0]), - &device, - ))), + coordinates: None, + values: None, shape, + device: device.clone(), } } fn sparse_slice( tensor: Self::SparseTensorPrimitive, - indices: [std::ops::Range; D2], + indices: [core::ops::Range; D2], ) -> SparseTensor { let SparseCOOTensor { coordinates, values, shape, + device, } = tensor; - let device = coordinates.device(); + let indices: [core::ops::Range; D1] = + Vec::from(indices).try_into().expect("D1 should equal D2"); + let out_shape = Shape::new(indices.clone().map(|r| r.end)); + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape: out_shape, + device, + }; + }; + let number_nonzero = coordinates.shape().dims[1]; let mut mask: Tensor = Tensor::ones(Shape::new([number_nonzero]), &device); @@ -251,10 +286,14 @@ where let coordinates = coordinates.select(1, indices_dim1.clone()); let values = values.select(0, indices_dim1); + let coordinates = Some(coordinates); + let values = Some(values); + SparseCOOTensor { coordinates, values, shape, + device, } } @@ -268,7 +307,7 @@ where fn sparse_into_data( tensor: SparseTensor, - ) -> impl std::future::Future + Send { + ) -> impl core::future::Future + Send { B::float_into_data(Self::sparse_to_dense(tensor)) } @@ -280,9 +319,18 @@ where coordinates, values, shape, + device, } = tensor; - let device = coordinates.device(); + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape: out_shape, + device, + }; + }; // Flatten the coordinates: let mut strides_data = [[1]; D1]; @@ -305,12 +353,16 @@ where } new_coordinates.reverse(); - let new_coordinates = Tensor::stack(new_coordinates, 0); + let coordinates = Tensor::stack(new_coordinates, 0); + + let coordinates = Some(coordinates); + let values = Some(values); SparseCOOTensor { - coordinates: new_coordinates, + coordinates, values, shape: out_shape, + device, } } @@ -340,6 +392,7 @@ where coordinates, values, mut shape, + device, } = tensor; for (i, &j) in (0..D).zip(axes).filter(|(i, j)| i < j) { @@ -347,12 +400,13 @@ where } let axes = Tensor::from(axes); - let coordinates = coordinates.select(0, axes); + let coordinates = coordinates.map(|coordinates| coordinates.select(0, axes)); SparseCOOTensor { coordinates, values, shape, + device, } } @@ -364,21 +418,31 @@ where coordinates, values, shape, + device, } = tensor; + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + let nnz = coordinates.shape().dims[1]; - let device = &coordinates.device(); let mut mask = [0; D]; for &axis in axes { mask[axis] = 1; } - let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, device) + let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, &device) .unsqueeze_dim(1) .repeat(1, nnz) .bool(); - let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, device) + let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, &device) .unsqueeze_dim(1) .repeat(1, nnz) .sub(coordinates.clone()) @@ -386,22 +450,27 @@ where let coordinates = coordinates.mask_where(mask, flipped); + let coordinates = Some(coordinates); + let values = Some(values); + SparseCOOTensor { coordinates, values, shape, + device, } } fn sparse_slice_assign( tensor: SparseTensor, - ranges: [std::ops::Range; D2], + ranges: [core::ops::Range; D2], value: SparseTensor, ) -> SparseTensor { let SparseCOOTensor { coordinates, values, shape, + device, } = tensor; todo!() } @@ -415,8 +484,22 @@ where coordinates, values, mut shape, + device, } = tensor; + let mut out_shape = shape.clone(); + out_shape.dims[dim] *= times; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + let device = coordinates.device(); let nnz = coordinates.shape().dims[1]; @@ -435,12 +518,14 @@ where 1, ); - shape.dims[dim] *= times; + let coordinates = Some(coordinates); + let values = Some(values); SparseCOOTensor { coordinates, values, - shape, + shape: out_shape, + device, } } @@ -472,8 +557,9 @@ where coordinates, values, shape, + device, } = tensor; - let any = coordinates.shape().dims[1] > 0; + let any = !matches!(coordinates, None); Tensor::::from([any]).into_primitive() } @@ -491,8 +577,12 @@ where coordinates, values, shape, + device, } = tensor; - let all = shape.num_elements() == coordinates.shape().dims[1]; + let all = match coordinates { + Some(coordinates) => shape.num_elements() == coordinates.shape().dims[1], + None => false, + }; Tensor::::from([all]).into_primitive() } @@ -511,7 +601,119 @@ where coordinates, values, shape, + device, } = tensor; todo!() } + + fn sparse_coalesce_sum( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + if tensor.coordinates.as_ref().map(|c| c.shape().dims[1] <= 1) == Some(true) { + return tensor; + } + let original_shape = tensor.shape.clone(); + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = Self::sparse_reshape(tensor, Shape::new([original_shape.num_elements()])); + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return Self::sparse_reshape( + SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }, + original_shape, + ); + }; + + let nnz = coordinates.shape().dims[1]; + if nnz <= 1 { + // impossible to be uncoalesced + return SparseCOOTensor { + coordinates: Some(coordinates), + values: Some(values), + shape: original_shape, + device, + }; + } + + let (coordinates, indices) = coordinates.sort_with_indices(1); + let values = values.select(0, indices.squeeze(0)); + let range = Tensor::::arange(0..nnz as i64, &device).unsqueeze::<2>(); + + // Get the diff of coordinates, diff[i] = coordinates[i]-coordinates[i-1] + let left_slice = coordinates.clone().slice([0..1, 0..nnz - 1]); + let right_slice = coordinates.clone().slice([0..1, 1..nnz]); + let diff = right_slice - left_slice; + let ones = Tensor::::ones(Shape::new([1, 1]), &device); + let diff = Tensor::cat(vec![ones, diff], 1); + + // TODO this all would be way cleaner with cumsum/max, but that is waiting on a pull request as of writing + // this is technically O(nnz) but only in super rare and likely constructed cases + // lots of inspiration could be taken from pytorch_scatter for better implementations + let unique_mask = diff.not_equal_elem(0); + let unique_indices = unique_mask.clone().nonzero().remove(1); + let steps = Tensor::cat( + vec![unique_indices.clone(), Tensor::from_data([nnz], &device)], + 0, + ); + let unique = steps.shape().dims[0]; + let steps = steps + .clone() + .slice([1..unique]) + .sub(steps.slice([0..unique - 1])) + .max() + .sub_scalar(1) + .into_scalar() + .elem::(); + + let mut scatter_indices = range.mul(unique_mask.int()); + + for _ in 0..steps { + scatter_indices = scatter_indices + .clone() + .slice([0..1, 1..nnz]) + .max_pair(scatter_indices.slice([0..1, 0..nnz - 1])); + scatter_indices = Tensor::cat( + vec![Tensor::zeros(Shape::new([1, 1]), &device), scatter_indices], + 1, + ); + } + + // Scatter/Gather everything into place + let zeroed = Tensor::::zeros(Shape::new([nnz]), &device); + let values = zeroed.scatter(0, scatter_indices.squeeze(0), values); + let values = values.gather(0, unique_indices.clone()); + let coordinates = coordinates.gather(1, unique_indices.unsqueeze::<2>()); + + let coordinates = Some(coordinates); + let values = Some(values); + + // reshape back into the original shape and send it! + let out = SparseCOOTensor { + coordinates, + values, + shape, + device, + }; + + Self::sparse_reshape(out, original_shape) + } + + fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize { + match tensor.coordinates { + Some(coordinates) => coordinates.shape().dims[1], + None => 0, + } + } + + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize { + todo!() + } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index c0888eb46a..f03c33a0c7 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -186,4 +186,18 @@ where ) -> SparseTensor { todo!() } + + fn sparse_coalesce_sum( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + todo!() + } + + fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize { + todo!() + } + + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize { + todo!() + } } From 57f6dbb94277e309f4e20720e35bc30e8246e11e Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 17 Jul 2024 02:03:12 +0000 Subject: [PATCH 12/38] fixed coalesce --- crates/burn-sparse/src/decorator/sparse_coo.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 58b59f73c6..366b8f6b50 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -669,7 +669,7 @@ where .slice([1..unique]) .sub(steps.slice([0..unique - 1])) .max() - .sub_scalar(1) + // .sub_scalar(1) .into_scalar() .elem::(); From 9a8020801faf1961e4715a72ae9a3c98080f3fc5 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 17 Jul 2024 02:29:46 +0000 Subject: [PATCH 13/38] fixed slice --- crates/burn-sparse/src/decorator/sparse_coo.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 366b8f6b50..e4a51cc6fb 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -241,8 +241,9 @@ where device, } = tensor; - let indices: [core::ops::Range; D1] = - Vec::from(indices).try_into().expect("D1 should equal D2"); + let mut indices = Vec::from(indices); + indices.extend(shape.dims[indices.len()..D1].iter().map(|&l| 0..l)); + let indices: [core::ops::Range; D1] = indices.try_into().expect("D2 must be <= D1"); let out_shape = Shape::new(indices.clone().map(|r| r.end)); let (Some(coordinates), Some(values)) = (coordinates, values) else { @@ -292,7 +293,7 @@ where SparseCOOTensor { coordinates, values, - shape, + shape: out_shape, device, } } From 2e0abd28525d59ea1a613ec83a75e188a9068739 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 17 Jul 2024 03:12:21 +0000 Subject: [PATCH 14/38] sparse density --- crates/burn-sparse/src/backend/api.rs | 4 ++-- crates/burn-sparse/src/backend/sparse_backend.rs | 2 +- crates/burn-sparse/src/decorator/sparse_coo.rs | 9 +++++++-- crates/burn-sparse/src/decorator/sparse_csr.rs | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index b4ad362f19..8714f70166 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -21,7 +21,7 @@ where fn dense(self) -> Tensor; fn coalesce(self, reduce: CoalesceReduction) -> Tensor; fn number_nonzero(self) -> usize; - fn density(self) -> usize; + fn density(self) -> f32; } impl ToSparse for Tensor @@ -64,7 +64,7 @@ where B::sparse_nonzero(self.into_primitive()) } - fn density(self) -> usize { + fn density(self) -> f32 { B::sparse_density(self.into_primitive()) } } diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index c89e40e29b..e5d0b27e7a 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -34,7 +34,7 @@ pub trait SparseBackend: Backend { fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize; - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize; + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32; /// Gets the element at the given indices. /// diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index e4a51cc6fb..eed32a8280 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -714,7 +714,12 @@ where } } - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize { - todo!() + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32 { + match sparse.coordinates { + Some(coordinates) => { + coordinates.shape().dims[1] as f32 / sparse.shape.num_elements() as f32 + } + None => 0.0, + } } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index f03c33a0c7..a445eb06e3 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -197,7 +197,7 @@ where todo!() } - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> usize { + fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32 { todo!() } } From 60d8f67bb6f0b347ffe84d8591780ddb54505d00 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 17 Jul 2024 04:30:58 +0000 Subject: [PATCH 15/38] numeric for sparse tensors, and add --- crates/burn-sparse/src/backend/kind.rs | 362 +++++++++++++++++- .../burn-sparse/src/backend/sparse_backend.rs | 186 ++++++++- .../burn-sparse/src/decorator/sparse_coo.rs | 125 ++++++ .../burn-sparse/src/decorator/sparse_csr.rs | 85 ++++ 4 files changed, 756 insertions(+), 2 deletions(-) diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index 67b1aba619..b3ddea1756 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -1,7 +1,7 @@ use std::{future::Future, ops::Range}; use crate::backend::SparseBackend; -use burn_tensor::{backend::Backend, BasicOps, Shape, Tensor, TensorData, TensorKind}; +use burn_tensor::{backend::Backend, BasicOps, Numeric, Shape, Tensor, TensorData, TensorKind}; /// A type-level representation of the kind of a sparse (float) tensor. #[derive(Clone, Debug)] @@ -153,3 +153,363 @@ impl BasicOps for Sparse { B::sparse_expand(tensor, shape) } } + +impl Numeric for Sparse { + fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + B::sparse_add(lhs, rhs) + } + + fn add_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + B::sparse_add_scalar(lhs, rhs.elem()) + } + + fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn sub_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn div_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn remainder_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn mul_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn neg(tensor: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn sign(tensor: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn zeros( + shape: Shape, + device: &::Device, + ) -> Self::Primitive { + todo!() + } + + fn ones( + shape: Shape, + device: &::Device, + ) -> Self::Primitive { + todo!() + } + + fn full( + shape: Shape, + fill_value: E, + device: &::Device, + ) -> Self::Primitive { + todo!() + } + + fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { + todo!() + } + + fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + todo!() + } + + fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { + todo!() + } + + fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + todo!() + } + + fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { + todo!() + } + + fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + todo!() + } + + fn equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn not_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn greater( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + todo!() + } + + fn greater_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn greater_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + todo!() + } + + fn greater_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn lower( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + todo!() + } + + fn lower_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn lower_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor { + todo!() + } + + fn lower_equal_elem( + lhs: Self::Primitive, + rhs: Self::Elem, + ) -> Tensor { + todo!() + } + + fn mask_where( + tensor: Self::Primitive, + mask: Tensor, + source: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn mask_fill( + tensor: Self::Primitive, + mask: Tensor, + value: Self::Elem, + ) -> Self::Primitive { + todo!() + } + + fn gather( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + ) -> Self::Primitive { + todo!() + } + + fn scatter( + dim: usize, + tensor: Self::Primitive, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn select( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + ) -> Self::Primitive { + todo!() + } + + fn select_assign( + tensor: Self::Primitive, + dim: usize, + indices: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn argmax( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + todo!() + } + + fn argmin( + tensor: Self::Primitive, + dim: usize, + ) -> ::IntTensorPrimitive { + todo!() + } + + fn max(tensor: Self::Primitive) -> Self::Primitive<1> { + todo!() + } + + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + todo!() + } + + fn max_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + todo!() + } + + fn min(tensor: Self::Primitive) -> Self::Primitive<1> { + todo!() + } + + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + todo!() + } + + fn min_dim_with_indices( + tensor: Self::Primitive, + dim: usize, + ) -> (Self::Primitive, ::IntTensorPrimitive) { + todo!() + } + + fn clamp( + tensor: Self::Primitive, + min: Self::Elem, + max: Self::Elem, + ) -> Self::Primitive { + todo!() + } + + fn clamp_min( + tensor: Self::Primitive, + min: Self::Elem, + ) -> Self::Primitive { + todo!() + } + + fn clamp_max( + tensor: Self::Primitive, + max: Self::Elem, + ) -> Self::Primitive { + todo!() + } + + fn abs(tensor: Self::Primitive) -> Self::Primitive { + todo!() + } + + fn powf( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn powi( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Self::Primitive { + todo!() + } + + fn powf_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn powi_scalar( + lhs: Self::Primitive, + rhs: E, + ) -> Self::Primitive { + todo!() + } + + fn random( + shape: Shape, + distribution: burn_tensor::Distribution, + device: &::Device, + ) -> Self::Primitive { + todo!() + } + + fn sort( + tensor: Self::Primitive, + dim: usize, + descending: bool, + ) -> Self::Primitive { + todo!() + } + + fn sort_with_indices( + tensor: Self::Primitive, + dim: usize, + descending: bool, + ) -> ( + Self::Primitive, + >::Primitive, + ) { + todo!() + } + + fn argsort( + tensor: Self::Primitive, + dim: usize, + descending: bool, + ) -> >::Primitive { + todo!() + } +} diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index e5d0b27e7a..533c0c27fc 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -1,5 +1,9 @@ use crate::backend::SparseTensor; -use burn_tensor::{backend::Backend, ops::BoolTensor, Device, Shape, TensorData}; +use burn_tensor::{ + backend::Backend, + ops::{BoolTensor, FloatElem, FloatTensor}, + Device, Shape, TensorData, +}; use core::{future::Future, ops::Range}; pub trait SparseBackend: Backend { @@ -184,4 +188,184 @@ pub trait SparseBackend: Backend { tensor: SparseTensor, shape: Shape, ) -> SparseTensor; + + /// Adds two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn sparse_add( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + /// Adds a sparse and dense tensor together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn sparse_add_dense( + lhs: SparseTensor, + rhs: FloatTensor, + ) -> SparseTensor; + + /// Adds a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of adding the scalar to the tensor. + fn sparse_add_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn sparse_sub( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + /// Subtracts a dense from a sparse tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor (sparse). + /// * `rhs` - The right hand side tensor (dense). + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn sparse_sub_dense( + lhs: SparseTensor, + rhs: FloatTensor, + ) -> SparseTensor; + + /// Subtracts a scalar from a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of subtracting the scalar from the tensor. + fn sparse_sub_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + /// Multiplies two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together. + fn sparse_mul( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + /// Multiplies a sparse and dense tensor together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together. + fn sparse_mul_dense( + lhs: SparseTensor, + rhs: FloatTensor, + ) -> SparseTensor; + + /// Multiplies a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the scalar with the tensor. + fn sparse_mul_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + /// Divides two sparse tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn sparse_div( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + /// Divides a sparse and dense tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn sparse_div_dense( + lhs: SparseTensor, + rhs: FloatTensor, + ) -> SparseTensor; + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn sparse_div_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; } diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index eed32a8280..6393afea6f 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -2,6 +2,7 @@ use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; +use burn_tensor::ops::FloatElem; use burn_tensor::ops::IntTensorOps; use burn_tensor::Device; use burn_tensor::{ @@ -722,4 +723,128 @@ where None => 0.0, } } + + fn sparse_add( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + let SparseCOOTensor { + coordinates: lhs_coordinates, + values: lhs_values, + shape: lhs_shape, + device: lhs_device, + } = lhs; + let (Some(lhs_coordinates), Some(lhs_values)) = (lhs_coordinates, lhs_values) else { + return rhs; + }; + + let SparseCOOTensor { + coordinates: rhs_coordinates, + values: rhs_values, + shape: rhs_shape, + device: rhs_device, + } = rhs; + let (Some(rhs_coordinates), Some(rhs_values)) = (rhs_coordinates, rhs_values) else { + return SparseCOOTensor { + coordinates: Some(lhs_coordinates), + values: Some(lhs_values), + shape: lhs_shape, + device: lhs_device, + }; + }; + + assert_eq!(lhs_shape, rhs_shape); + assert_eq!(lhs_device, rhs_device); + + let coordinates = Some(Tensor::cat(vec![lhs_coordinates, rhs_coordinates], 1)); + let values = Some(Tensor::cat(vec![lhs_values, rhs_values], 0)); + let shape = lhs_shape; + let device = lhs_device; + + let result = SparseCOOTensor { + coordinates, + values, + shape, + device, + }; + + Self::sparse_coalesce_sum(result) + } + + fn sparse_add_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_add_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_div( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_div_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_div_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index a445eb06e3..66a44e6125 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -3,6 +3,7 @@ use crate::backend::SparseTensor; use crate::decorator::SparseCSR; use crate::decorator::SparseDecorator; use burn_tensor::backend::Backend; +use burn_tensor::ops::FloatElem; use core::marker::PhantomData; #[derive(Debug, Default, Clone)] @@ -200,4 +201,88 @@ where fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32 { todo!() } + + fn sparse_add( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_add_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_add_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sub_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_mul_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_div( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_div_dense( + lhs: SparseTensor, + rhs: burn_tensor::ops::FloatTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_div_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } } From 1d9f85649f366dfe92c8b284f28bcd89c58bd8b4 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sat, 20 Jul 2024 04:12:54 +0000 Subject: [PATCH 16/38] add, sub, mul, div and some refactors --- crates/burn-sparse/src/backend/api.rs | 16 + crates/burn-sparse/src/backend/kind.rs | 18 +- .../burn-sparse/src/backend/sparse_backend.rs | 8 +- .../burn-sparse/src/decorator/sparse_coo.rs | 283 +++++++++++------- .../burn-sparse/src/decorator/sparse_csr.rs | 9 +- 5 files changed, 202 insertions(+), 132 deletions(-) diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index 8714f70166..da7ab4159f 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -22,6 +22,8 @@ where fn coalesce(self, reduce: CoalesceReduction) -> Tensor; fn number_nonzero(self) -> usize; fn density(self) -> f32; + fn add_dense(self, rhs: Tensor) -> Tensor; + fn mul_dense(self, rhs: Tensor) -> Tensor; } impl ToSparse for Tensor @@ -67,4 +69,18 @@ where fn density(self) -> f32 { B::sparse_density(self.into_primitive()) } + + fn add_dense(self, rhs: Tensor) -> Tensor { + Tensor::new(TensorPrimitive::Float(B::sparse_add_dense( + self.into_primitive(), + rhs.into_primitive().tensor(), + ))) + } + + fn mul_dense(self, rhs: Tensor) -> Tensor { + Tensor::new(TensorPrimitive::Float(B::sparse_mul_dense( + self.into_primitive(), + rhs.into_primitive().tensor(), + ))) + } } diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index b3ddea1756..f0a78d4279 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -167,25 +167,25 @@ impl Numeric for Sparse { } fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_sub(lhs, rhs) } fn sub_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_sub_scalar(lhs, rhs.elem()) } fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_div(lhs, rhs) } fn div_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_div_scalar(lhs, rhs.elem()) } fn remainder_scalar( @@ -196,14 +196,14 @@ impl Numeric for Sparse { } fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_mul(lhs, rhs) } fn mul_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_mul_scalar(lhs, rhs.elem()) } fn neg(tensor: Self::Primitive) -> Self::Primitive { @@ -218,14 +218,14 @@ impl Numeric for Sparse { shape: Shape, device: &::Device, ) -> Self::Primitive { - todo!() + B::sparse_empty(shape, device) } fn ones( shape: Shape, device: &::Device, ) -> Self::Primitive { - todo!() + B::sparse_to_sparse(B::float_ones(shape, device)) } fn full( @@ -233,7 +233,7 @@ impl Numeric for Sparse { fill_value: E, device: &::Device, ) -> Self::Primitive { - todo!() + B::sparse_to_sparse(B::float_full(shape, fill_value.elem(), device)) } fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index 533c0c27fc..3de056a038 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -217,7 +217,7 @@ pub trait SparseBackend: Backend { fn sparse_add_dense( lhs: SparseTensor, rhs: FloatTensor, - ) -> SparseTensor; + ) -> FloatTensor; /// Adds a scalar to a tensor. /// @@ -262,7 +262,7 @@ pub trait SparseBackend: Backend { fn sparse_sub_dense( lhs: SparseTensor, rhs: FloatTensor, - ) -> SparseTensor; + ) -> FloatTensor; /// Subtracts a scalar from a tensor. /// @@ -307,7 +307,7 @@ pub trait SparseBackend: Backend { fn sparse_mul_dense( lhs: SparseTensor, rhs: FloatTensor, - ) -> SparseTensor; + ) -> FloatTensor; /// Multiplies a scalar to a tensor. /// @@ -352,7 +352,7 @@ pub trait SparseBackend: Backend { fn sparse_div_dense( lhs: SparseTensor, rhs: FloatTensor, - ) -> SparseTensor; + ) -> FloatTensor; /// Divides a tensor by a scalar. /// diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 6393afea6f..25f5b536a4 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -1,8 +1,12 @@ +use std::ops::Mul; + use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; use crate::decorator::SparseDecorator; use burn_tensor::ops::FloatElem; +use burn_tensor::ops::FloatTensor; +use burn_tensor::ops::FloatTensorOps; use burn_tensor::ops::IntTensorOps; use burn_tensor::Device; use burn_tensor::{ @@ -19,6 +23,43 @@ pub struct SparseCOOTensor { pub device: Device, } +fn flatten_coordinates( + coordinates: Tensor, + shape: Shape, + device: &Device, +) -> Tensor { + let mut strides_data = [[1]; D]; + for i in (0..D - 1).rev() { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } + let strides_data: TensorData = TensorData::from(strides_data); + let strides: Tensor = Tensor::from_data(strides_data, device); + let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); + + flat_coordinates.unsqueeze_dim(0) +} + +fn unflatten_coordinates( + flat_coordinates: Tensor, + new_shape: Shape, +) -> Tensor { + let flat_coordinates = flat_coordinates.squeeze::<1>(0); + let mut remaining_flat_coordinates = flat_coordinates.clone(); + let mut new_coordinates = Vec::with_capacity(D); + + for &dim_size in new_shape.dims.iter().rev() { + let size = dim_size as i64; + let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); + new_coordinates.push(new_coord.clone()); + remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); + } + + new_coordinates.reverse(); + let reshaped_coordinates = Tensor::stack(new_coordinates, 0); + + reshaped_coordinates +} + impl SparseBackend for SparseDecorator where B: Backend, @@ -89,22 +130,9 @@ where .tensor(); }; - let num_nonzero = coordinates.shape().dims[1]; - let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); - - let mut strides_data = [[1]; D]; - for i in (0..D - 1).rev() { - strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; - } - - let strides_data: TensorData = TensorData::from(strides_data); - - let strides: Tensor = Tensor::from_data(strides_data, &device); - - let coordinates = strides.mul(coordinates).sum_dim(0).flatten(0, 1); - - let dense = dense.select_assign(0, coordinates, values); + let flat_coordinates = flatten_coordinates(coordinates, shape.clone(), &device).squeeze(0); + let dense = dense.select_assign(0, flat_coordinates, values); dense.reshape(shape).into_primitive().tensor() } @@ -235,27 +263,27 @@ where tensor: Self::SparseTensorPrimitive, indices: [core::ops::Range; D2], ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = tensor; - let mut indices = Vec::from(indices); - indices.extend(shape.dims[indices.len()..D1].iter().map(|&l| 0..l)); + indices.extend(tensor.shape.dims[indices.len()..D1].iter().map(|&l| 0..l)); let indices: [core::ops::Range; D1] = indices.try_into().expect("D2 must be <= D1"); let out_shape = Shape::new(indices.clone().map(|r| r.end)); - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early + if tensor.coordinates.is_none() && tensor.values.is_none() { return SparseCOOTensor { coordinates: None, values: None, shape: out_shape, - device, + device: tensor.device, }; - }; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; let number_nonzero = coordinates.shape().dims[1]; @@ -314,55 +342,36 @@ where } fn sparse_reshape( - tensor: SparseTensor, + tensor: SparseCOOTensor, out_shape: Shape, - ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = tensor; - - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early + ) -> SparseCOOTensor { + if tensor.coordinates.is_none() && tensor.values.is_none() { return SparseCOOTensor { coordinates: None, values: None, shape: out_shape, - device, + device: tensor.device, }; - }; - - // Flatten the coordinates: - let mut strides_data = [[1]; D1]; - for i in (0..D1 - 1).rev() { - strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; - } - let strides_data: TensorData = TensorData::from(strides_data); - let strides: Tensor = Tensor::from_data(strides_data, &device); - let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); - - // Convert the flattened coordinates to the new shape - let mut remaining_flat_coordinates = flat_coordinates.clone(); - let mut new_coordinates = Vec::with_capacity(D2); - - for &dim_size in out_shape.dims.iter().rev() { - let size = dim_size as i64; - let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); - new_coordinates.push(new_coord.clone()); - remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); } - new_coordinates.reverse(); - let coordinates = Tensor::stack(new_coordinates, 0); + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let shape = tensor.shape; + let device = tensor.device; - let coordinates = Some(coordinates); - let values = Some(values); + // Flatten the coordinates + let flat_coordinates = flatten_coordinates(coordinates, shape, &device); + + // Unflatten the coordinates to the new shape + let new_coordinates = unflatten_coordinates(flat_coordinates, out_shape.clone()); SparseCOOTensor { - coordinates, - values, + coordinates: Some(new_coordinates), + values: Some(values), shape: out_shape, device, } @@ -615,36 +624,28 @@ where return tensor; } let original_shape = tensor.shape.clone(); - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = Self::sparse_reshape(tensor, Shape::new([original_shape.num_elements()])); - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early - return Self::sparse_reshape( - SparseCOOTensor { - coordinates: None, - values: None, - shape, - device, - }, - original_shape, - ); - }; - let nnz = coordinates.shape().dims[1]; - if nnz <= 1 { - // impossible to be uncoalesced + if tensor.coordinates.is_none() && tensor.values.is_none() { return SparseCOOTensor { - coordinates: Some(coordinates), - values: Some(values), + coordinates: None, + values: None, shape: original_shape, - device, + device: tensor.device, }; } + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let nnz = coordinates.shape().dims[1]; + + let coordinates = flatten_coordinates(coordinates, original_shape.clone(), &device); + let flat_shape = Shape::new([original_shape.num_elements()]); + let (coordinates, indices) = coordinates.sort_with_indices(1); let values = values.select(0, indices.squeeze(0)); let range = Tensor::::arange(0..nnz as i64, &device).unsqueeze::<2>(); @@ -693,19 +694,18 @@ where let values = zeroed.scatter(0, scatter_indices.squeeze(0), values); let values = values.gather(0, unique_indices.clone()); let coordinates = coordinates.gather(1, unique_indices.unsqueeze::<2>()); + let coordinates = unflatten_coordinates(coordinates, original_shape.clone()); let coordinates = Some(coordinates); let values = Some(values); // reshape back into the original shape and send it! - let out = SparseCOOTensor { + SparseCOOTensor { coordinates, values, - shape, + shape: original_shape, device, - }; - - Self::sparse_reshape(out, original_shape) + } } fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize { @@ -772,79 +772,132 @@ where } fn sparse_add_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _: SparseTensor, + _: FloatElem, ) -> SparseTensor { - todo!() + panic!("Cannot add scalar to sparse, only zero preserving operations are permitted"); } fn sparse_add_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { - todo!() + ) -> FloatTensor { + if lhs.shape != B::float_shape(&rhs) { + panic!("lhs and rhs must have the same shape for sparse_add_dense"); + } + + if lhs.coordinates.is_none() && lhs.values.is_none() { + return rhs; + } + + let coordinates = lhs + .coordinates + .expect("Mismatch between coordinates and values"); + let values = lhs.values.expect("Mismatch between coordinates and values"); + let device = lhs.device; + let shape = lhs.shape; + + let coordinates = flatten_coordinates(coordinates, shape.clone(), &device); + let dense = B::float_reshape(rhs, Shape::new([shape.num_elements()])); + + let dense = B::float_scatter( + 0, + dense, + coordinates.squeeze(0).into_primitive(), + values.into_primitive().tensor(), + ); + + B::float_reshape(dense, shape) } fn sparse_sub( lhs: SparseTensor, rhs: SparseTensor, ) -> SparseTensor { - todo!() + Self::sparse_add( + lhs, + Self::sparse_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), + ) } fn sparse_sub_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { - todo!() + ) -> FloatTensor { + Self::sparse_add_dense( + lhs, + B::float_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), + ) } fn sparse_sub_scalar( lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + panic!("Cannot add scalar to sparse, only zero preserving operations are permitted"); } fn sparse_mul( lhs: SparseTensor, rhs: SparseTensor, ) -> SparseTensor { - todo!() + panic!("sparse_mul is unsupported until scatter supports multiplication based reduction"); } fn sparse_mul_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { - todo!() + ) -> FloatTensor { + if lhs.shape != B::float_shape(&rhs) { + panic!("lhs and rhs must have the same shape for sparse_add_dense"); + } + + if lhs.coordinates.is_none() && lhs.values.is_none() { + return Self::float_zeros(lhs.shape, &lhs.device); + } + + // TODO: this could be optimized a little if/when scatter gets other reduction strategies + let lhs = Self::sparse_to_dense(lhs); + Self::float_mul(lhs, rhs) } fn sparse_mul_scalar( - lhs: SparseTensor, + mut lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + lhs.values = lhs.values.map(|values| values.mul_scalar(rhs)); + lhs } fn sparse_div( - lhs: SparseTensor, - rhs: SparseTensor, + _: SparseTensor, + _: SparseTensor, ) -> SparseTensor { - todo!() + panic!("sparse_div is unsupported until scatter supports multiplication based reduction"); } fn sparse_div_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { - todo!() + ) -> FloatTensor { + if lhs.shape != B::float_shape(&rhs) { + panic!("lhs and rhs must have the same shape for sparse_add_dense"); + } + + if lhs.coordinates.is_none() && lhs.values.is_none() { + return Self::float_zeros(lhs.shape, &lhs.device); + } + + // TODO: this could be optimized a little if/when scatter gets other reduction strategies + let lhs = Self::sparse_to_dense(lhs); + Self::float_div(lhs, rhs) } fn sparse_div_scalar( - lhs: SparseTensor, + mut lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); + lhs } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index 66a44e6125..c5d9561e92 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -4,6 +4,7 @@ use crate::decorator::SparseCSR; use crate::decorator::SparseDecorator; use burn_tensor::backend::Backend; use burn_tensor::ops::FloatElem; +use burn_tensor::ops::FloatTensor; use core::marker::PhantomData; #[derive(Debug, Default, Clone)] @@ -219,7 +220,7 @@ where fn sparse_add_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { + ) -> FloatTensor { todo!() } @@ -233,7 +234,7 @@ where fn sparse_sub_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { + ) -> FloatTensor { todo!() } @@ -254,7 +255,7 @@ where fn sparse_mul_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { + ) -> FloatTensor { todo!() } @@ -275,7 +276,7 @@ where fn sparse_div_dense( lhs: SparseTensor, rhs: burn_tensor::ops::FloatTensor, - ) -> SparseTensor { + ) -> FloatTensor { todo!() } From 4782a3b392205317757fad4be2c29199404cf38e Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sat, 20 Jul 2024 10:16:05 +0000 Subject: [PATCH 17/38] slice_assign --- .../burn-sparse/src/decorator/sparse_coo.rs | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 25f5b536a4..0a29a0f774 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -306,7 +306,18 @@ where mask = mask.mul(mask_lower).mul(mask_upper); } - let nonzero = mask.not_equal_elem(B::IntElem::from_elem(0)).nonzero(); + let nonzero = mask.not_equal_elem(B::IntElem::from_elem(0)); + if !nonzero.clone().any().into_scalar() { + // no existing values were in the slice, so return an empty tensor + return SparseCOOTensor { + coordinates: None, + values: None, + shape: out_shape, + device, + }; + } + + let nonzero = nonzero.nonzero(); let indices_dim1 = nonzero .get(0) @@ -475,15 +486,31 @@ where fn sparse_slice_assign( tensor: SparseTensor, ranges: [core::ops::Range; D2], - value: SparseTensor, + mut value: SparseTensor, ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = tensor; - todo!() + let value_nnz = value + .coordinates + .as_ref() + .map(|coords| coords.shape().dims[1]) + .unwrap_or(0); + + let mut ranges = Vec::from(ranges); + ranges.extend(tensor.shape.dims[ranges.len()..D1].iter().map(|&l| 0..l)); + let ranges: [core::ops::Range; D1] = ranges.try_into().expect("D2 must be <= D1"); + + let shape = tensor.shape.clone(); + let sliced = Self::sparse_reshape( + Self::sparse_slice(tensor.clone(), ranges.clone()), + shape.clone(), + ); + let tensor = Self::sparse_sub(tensor, sliced); + let offset = Tensor::::from_ints(ranges.map(|r| r.start), &tensor.device); + let offset = offset.unsqueeze_dim::<2>(1).repeat(1, value_nnz); + + value.shape = shape; + value.coordinates = value.coordinates.map(|coords| coords + offset); + + Self::sparse_add(tensor, value) } fn sparse_repeat( @@ -494,7 +521,7 @@ where let SparseCOOTensor { coordinates, values, - mut shape, + shape, device, } = tensor; @@ -551,14 +578,14 @@ where lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("elementwise equal is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_not_equal( lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("elementwise not_equal is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_any( From f898d79de6cab3b0df732a9f661ce57e71b85d40 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 21 Jul 2024 10:41:53 +0000 Subject: [PATCH 18/38] sddmm + more numerics (sign, abs, etc) --- crates/burn-sparse/src/backend/api.rs | 9 + crates/burn-sparse/src/backend/kind.rs | 86 ++--- .../burn-sparse/src/backend/sparse_backend.rs | 162 ++++++++- .../burn-sparse/src/decorator/sparse_coo.rs | 330 +++++++++++++++++- .../burn-sparse/src/decorator/sparse_csr.rs | 231 +++++++++++- 5 files changed, 753 insertions(+), 65 deletions(-) diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index da7ab4159f..5d064b3ec9 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -16,6 +16,7 @@ pub trait SparseTensorApi where B: SparseBackend, { + fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self; fn dense_int(self) -> Tensor; fn spmm(self, rhs: Tensor) -> Tensor; fn dense(self) -> Tensor; @@ -56,6 +57,14 @@ where ))) } + fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self { + Tensor::new(B::sparse_sddmm( + lhs.into_primitive().tensor(), + rhs.into_primitive().tensor(), + self.into_primitive(), + )) + } + fn coalesce(self, reduction: CoalesceReduction) -> Tensor { match reduction { CoalesceReduction::Sum => Tensor::new(B::sparse_coalesce_sum(self.into_primitive())), diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index f0a78d4279..78e31363ae 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -192,7 +192,7 @@ impl Numeric for Sparse { lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_remainder_scalar(lhs, rhs.elem()) } fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { @@ -207,11 +207,11 @@ impl Numeric for Sparse { } fn neg(tensor: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_neg(tensor) } fn sign(tensor: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_sign(tensor) } fn zeros( @@ -237,97 +237,97 @@ impl Numeric for Sparse { } fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - todo!() + B::sparse_sum(tensor) } fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - todo!() + B::sparse_sum_dim(tensor, dim) } fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { - todo!() + B::sparse_prod(tensor) } fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - todo!() + B::sparse_prod_dim(tensor, dim) } fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - todo!() + B::sparse_mean(tensor) } fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - todo!() + B::sparse_mean_dim(tensor, dim) } fn equal_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_equal_elem(lhs, rhs)) } fn not_equal_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_not_equal_elem(lhs, rhs)) } fn greater( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor { - todo!() + Tensor::new(B::sparse_greater(lhs, rhs)) } fn greater_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_greater_elem(lhs, rhs)) } fn greater_equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor { - todo!() + Tensor::new(B::sparse_greater_equal(lhs, rhs)) } fn greater_equal_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_greater_equal_elem(lhs, rhs)) } fn lower( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor { - todo!() + Tensor::new(B::sparse_lower(lhs, rhs)) } fn lower_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_lower_elem(lhs, rhs)) } fn lower_equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor { - todo!() + Tensor::new(B::sparse_lower_equal(lhs, rhs)) } fn lower_equal_elem( lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor { - todo!() + Tensor::new(B::sparse_lower_equal_elem(lhs, rhs)) } fn mask_where( @@ -335,7 +335,7 @@ impl Numeric for Sparse { mask: Tensor, source: Self::Primitive, ) -> Self::Primitive { - todo!() + panic!("masking of sparse tensors is unsupported") } fn mask_fill( @@ -343,7 +343,7 @@ impl Numeric for Sparse { mask: Tensor, value: Self::Elem, ) -> Self::Primitive { - todo!() + panic!("masking of sparse tensors is unsupported") } fn gather( @@ -351,7 +351,7 @@ impl Numeric for Sparse { tensor: Self::Primitive, indices: Tensor, ) -> Self::Primitive { - todo!() + B::sparse_gather(dim, tensor, indices.into_primitive()) } fn scatter( @@ -360,7 +360,7 @@ impl Numeric for Sparse { indices: Tensor, values: Self::Primitive, ) -> Self::Primitive { - todo!() + B::sparse_scatter(dim, tensor, indices.into_primitive(), values) } fn select( @@ -368,7 +368,7 @@ impl Numeric for Sparse { dim: usize, indices: Tensor, ) -> Self::Primitive { - todo!() + B::sparse_select(tensor, dim, indices.into_primitive()) } fn select_assign( @@ -377,29 +377,29 @@ impl Numeric for Sparse { indices: Tensor, values: Self::Primitive, ) -> Self::Primitive { - todo!() + B::sparse_select_assign(tensor, dim, indices.into_primitive(), values) } fn argmax( tensor: Self::Primitive, dim: usize, ) -> ::IntTensorPrimitive { - todo!() + panic!("Argmax is unsupported for sparse tensors"); } fn argmin( tensor: Self::Primitive, dim: usize, ) -> ::IntTensorPrimitive { - todo!() + panic!("Argmin is unsupported for sparse tensors"); } fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - todo!() + B::sparse_max(tensor) } fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - todo!() + B::sparse_max_dim(tensor, dim) } fn max_dim_with_indices( @@ -410,11 +410,11 @@ impl Numeric for Sparse { } fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - todo!() + B::sparse_min(tensor) } fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - todo!() + B::sparse_min_dim(tensor, dim) } fn min_dim_with_indices( @@ -429,53 +429,53 @@ impl Numeric for Sparse { min: Self::Elem, max: Self::Elem, ) -> Self::Primitive { - todo!() + B::sparse_clamp(tensor, min, max) } fn clamp_min( tensor: Self::Primitive, min: Self::Elem, ) -> Self::Primitive { - todo!() + B::sparse_clamp_min(tensor, min) } fn clamp_max( tensor: Self::Primitive, max: Self::Elem, ) -> Self::Primitive { - todo!() + B::sparse_clamp_max(tensor, max) } fn abs(tensor: Self::Primitive) -> Self::Primitive { - todo!() + B::sparse_abs(tensor) } fn powf( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Self::Primitive { - todo!() + B::sparse_powf(lhs, rhs) } fn powi( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Self::Primitive { - todo!() + B::sparse_powi(lhs, rhs) } fn powf_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_powf_scalar(lhs, rhs.elem()) } fn powi_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - todo!() + B::sparse_powi_scalar(lhs, rhs.elem()) } fn random( @@ -483,7 +483,7 @@ impl Numeric for Sparse { distribution: burn_tensor::Distribution, device: &::Device, ) -> Self::Primitive { - todo!() + panic!("Random is unsupported for sparse tensors") } fn sort( @@ -491,7 +491,7 @@ impl Numeric for Sparse { dim: usize, descending: bool, ) -> Self::Primitive { - todo!() + panic!("Sorting is unsupported for sparse tensors") } fn sort_with_indices( @@ -502,7 +502,7 @@ impl Numeric for Sparse { Self::Primitive, >::Primitive, ) { - todo!() + panic!("Sorting is unsupported for sparse tensors") } fn argsort( @@ -510,6 +510,6 @@ impl Numeric for Sparse { dim: usize, descending: bool, ) -> >::Primitive { - todo!() + panic!("Sorting is unsupported for sparse tensors") } } diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index 3de056a038..c74fa2a4fd 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -1,7 +1,7 @@ use crate::backend::SparseTensor; use burn_tensor::{ backend::Backend, - ops::{BoolTensor, FloatElem, FloatTensor}, + ops::{BoolTensor, FloatElem, FloatTensor, IntTensor}, Device, Shape, TensorData, }; use core::{future::Future, ops::Range}; @@ -28,8 +28,9 @@ pub trait SparseBackend: Backend { ) -> Self::FloatTensorPrimitive; fn sparse_sddmm( - lhs: Self::SparseTensorPrimitive, + lhs: Self::FloatTensorPrimitive, rhs: Self::FloatTensorPrimitive, + sparse: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive; fn sparse_coalesce_sum( @@ -368,4 +369,161 @@ pub trait SparseBackend: Backend { lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor; + + fn sparse_max(tensor: SparseTensor) -> SparseTensor; + + fn sparse_max_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor; + + fn sparse_min(tensor: SparseTensor) -> SparseTensor; + + fn sparse_min_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor; + + fn sparse_greater( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_greater_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_greater_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_greater_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_lower( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_lower_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_lower_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> BoolTensor; + + fn sparse_lower_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_abs(tensor: SparseTensor) -> SparseTensor; + fn sparse_sign(tensor: SparseTensor) -> SparseTensor; + + fn sparse_powf( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + fn sparse_powi( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor; + + fn sparse_powf_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + fn sparse_powi_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + fn sparse_clamp( + tensor: SparseTensor, + min: FloatElem, + max: FloatElem, + ) -> SparseTensor; + + fn sparse_clamp_min( + tensor: SparseTensor, + min: FloatElem, + ) -> SparseTensor; + + fn sparse_clamp_max( + tensor: SparseTensor, + max: FloatElem, + ) -> SparseTensor; + + fn sparse_select( + tensor: SparseTensor, + dim: usize, + indices: IntTensor, + ) -> SparseTensor; + + fn sparse_select_assign( + tensor: SparseTensor, + dim: usize, + indices: IntTensor, + values: SparseTensor, + ) -> SparseTensor; + + fn sparse_gather( + dim: usize, + tensor: SparseTensor, + indices: IntTensor, + ) -> SparseTensor; + + fn sparse_scatter( + dim: usize, + tensor: SparseTensor, + indices: IntTensor, + values: SparseTensor, + ) -> SparseTensor; + + fn sparse_sum(tensor: SparseTensor) -> SparseTensor; + + fn sparse_sum_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor; + + fn sparse_prod(tensor: SparseTensor) -> SparseTensor; + + fn sparse_prod_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor; + + fn sparse_mean(tensor: SparseTensor) -> SparseTensor; + + fn sparse_mean_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor; + + fn sparse_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_not_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> BoolTensor; + + fn sparse_remainder_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor; + + fn sparse_neg(tensor: SparseTensor) -> SparseTensor; } diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 0a29a0f774..ea47f56006 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -1,5 +1,3 @@ -use std::ops::Mul; - use crate::backend::SparseBackend; use crate::backend::SparseTensor; use crate::decorator::SparseCOO; @@ -13,7 +11,6 @@ use burn_tensor::{ backend::Backend, Bool, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; -use half::vec; #[derive(Clone, Debug)] pub struct SparseCOOTensor { @@ -23,14 +20,20 @@ pub struct SparseCOOTensor { pub device: Device, } -fn flatten_coordinates( +fn flatten_coordinates( coordinates: Tensor, shape: Shape, device: &Device, ) -> Tensor { let mut strides_data = [[1]; D]; - for i in (0..D - 1).rev() { - strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + for i in (0..D).rev() { + if D - 1 - i == S { + strides_data[i] = [1]; + } else if D - 1 - i < S { + strides_data[i] = [0]; + } else { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } } let strides_data: TensorData = TensorData::from(strides_data); let strides: Tensor = Tensor::from_data(strides_data, device); @@ -131,12 +134,77 @@ where }; let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); - let flat_coordinates = flatten_coordinates(coordinates, shape.clone(), &device).squeeze(0); + let flat_coordinates = + flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); let dense = dense.select_assign(0, flat_coordinates, values); dense.reshape(shape).into_primitive().tensor() } + fn sparse_sddmm( + lhs: Self::FloatTensorPrimitive, + rhs: Self::FloatTensorPrimitive, + sparse: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + if sparse.coordinates.is_none() || sparse.values.is_none() { + return sparse; + } + + // Flatten the lhs and rhs into a tensor of rows and cols respectively + let lhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(lhs)); + let rhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(rhs)).transpose(); + let lhs_dims = lhs.shape().dims; + let rhs_dims = rhs.shape().dims; + + if lhs_dims[D - 1] != rhs_dims[D - 1] + || lhs_dims[D - 2] != sparse.shape.dims[D - 2] + || rhs_dims[D - 2] != sparse.shape.dims[D - 1] + { + panic!("invalid dimensions for sddmm. lhs and rhs must have compatible shapes for matmul, and sparse must have the correct shape for output of matmul between lhs and rhs."); + } + + let lhs = lhs.reshape([-1, lhs_dims[D - 1] as i32]); + let rhs = rhs.reshape([-1, rhs_dims[D - 1] as i32]); + + // Flatten the sparse tensor into + let device = sparse.device.clone(); + let mut shape = sparse.shape.clone(); + let lhs_coordinates = sparse + .coordinates + .clone() + .expect("Expected non-empty sparse tensor"); + + // swap the last two dims so its column-first + let swizzle = Tensor::::arange(0..D as i64, &device) + .slice_assign( + [D - 2..D], + Tensor::::from_ints([D - 1, D - 2], &device), + ) + .unsqueeze_dim(1) + .repeat(1, lhs_coordinates.shape().dims[1]); + let rhs_coordinates = lhs_coordinates.clone().gather(0, swizzle); + + let row_indices = flatten_coordinates::(lhs_coordinates, shape.clone(), &device); + + shape.dims.swap(D - 1, D - 2); + let col_indices = flatten_coordinates::(rhs_coordinates, shape.clone(), &device); + + let row_indices = row_indices.transpose().repeat(1, lhs_dims[D - 1]); + let col_indices = col_indices.transpose().repeat(1, rhs_dims[D - 1]); + + let lhs = lhs.gather(0, row_indices); + let rhs = rhs.gather(0, col_indices); + + let dotted = lhs.mul(rhs).sum_dim(1).squeeze(1); + + SparseCOOTensor { + coordinates: sparse.coordinates, + values: Some(dotted), + shape: sparse.shape, + device, + } + } + fn sparse_spmm( lhs: Self::SparseTensorPrimitive, rhs: Self::FloatTensorPrimitive, @@ -218,13 +286,6 @@ where scattered.reshape(out_shape).into_primitive().tensor() } - fn sparse_sddmm( - lhs: Self::SparseTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - todo!() - } - fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { tensor.device.clone() } @@ -375,7 +436,7 @@ where let device = tensor.device; // Flatten the coordinates - let flat_coordinates = flatten_coordinates(coordinates, shape, &device); + let flat_coordinates = flatten_coordinates::(coordinates, shape, &device); // Unflatten the coordinates to the new shape let new_coordinates = unflatten_coordinates(flat_coordinates, out_shape.clone()); @@ -670,7 +731,8 @@ where let device = tensor.device; let nnz = coordinates.shape().dims[1]; - let coordinates = flatten_coordinates(coordinates, original_shape.clone(), &device); + let coordinates = + flatten_coordinates::(coordinates, original_shape.clone(), &device); let flat_shape = Shape::new([original_shape.num_elements()]); let (coordinates, indices) = coordinates.sort_with_indices(1); @@ -824,7 +886,7 @@ where let device = lhs.device; let shape = lhs.shape; - let coordinates = flatten_coordinates(coordinates, shape.clone(), &device); + let coordinates = flatten_coordinates::(coordinates, shape.clone(), &device); let dense = B::float_reshape(rhs, Shape::new([shape.num_elements()])); let dense = B::float_scatter( @@ -883,7 +945,7 @@ where return Self::float_zeros(lhs.shape, &lhs.device); } - // TODO: this could be optimized a little if/when scatter gets other reduction strategies + // TODO: this could potentially be optimized if/when scatter gets other reduction strategies let lhs = Self::sparse_to_dense(lhs); Self::float_mul(lhs, rhs) } @@ -915,7 +977,7 @@ where return Self::float_zeros(lhs.shape, &lhs.device); } - // TODO: this could be optimized a little if/when scatter gets other reduction strategies + // TODO: this could potentially be optimized if/when scatter gets other reduction strategies let lhs = Self::sparse_to_dense(lhs); Self::float_div(lhs, rhs) } @@ -927,4 +989,234 @@ where lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); lhs } + + fn sparse_max(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_max_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_min(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_min_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_greater( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_abs(mut tensor: SparseTensor) -> SparseTensor { + tensor.values = tensor.values.map(|values| values.abs()); + tensor + } + + fn sparse_powf( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_powi( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_powf_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_powi_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp( + tensor: SparseTensor, + min: FloatElem, + max: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp_min( + tensor: SparseTensor, + min: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp_max( + tensor: SparseTensor, + max: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_select( + tensor: SparseTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_select_assign( + tensor: SparseTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + values: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_gather( + dim: usize, + tensor: SparseTensor, + indices: burn_tensor::ops::IntTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_scatter( + dim: usize, + tensor: SparseTensor, + indices: burn_tensor::ops::IntTensor, + values: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sum(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_sum_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_prod(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_prod_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_mean(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_mean_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_not_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_remainder_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_neg(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_sign(mut tensor: SparseTensor) -> SparseTensor { + tensor.values = tensor.values.map(|values| values.sign()); + tensor + } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index c5d9561e92..9671efcc1a 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -45,8 +45,9 @@ where } fn sparse_sddmm( - lhs: Self::SparseTensorPrimitive, + lhs: Self::FloatTensorPrimitive, rhs: Self::FloatTensorPrimitive, + sparse: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive { todo!() } @@ -286,4 +287,232 @@ where ) -> SparseTensor { todo!() } + + fn sparse_max(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_max_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_min(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_min_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_greater( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_greater_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_equal( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_lower_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_abs(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_powf( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_powi( + lhs: SparseTensor, + rhs: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_powf_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_powi_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp( + tensor: SparseTensor, + min: FloatElem, + max: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp_min( + tensor: SparseTensor, + min: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_clamp_max( + tensor: SparseTensor, + max: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_select( + tensor: SparseTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_select_assign( + tensor: SparseTensor, + dim: usize, + indices: burn_tensor::ops::IntTensor, + values: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_gather( + dim: usize, + tensor: SparseTensor, + indices: burn_tensor::ops::IntTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_scatter( + dim: usize, + tensor: SparseTensor, + indices: burn_tensor::ops::IntTensor, + values: SparseTensor, + ) -> SparseTensor { + todo!() + } + + fn sparse_sum(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_sum_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_prod(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_prod_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_mean(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_mean_dim( + tensor: SparseTensor, + dim: usize, + ) -> SparseTensor { + todo!() + } + + fn sparse_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_not_equal_elem( + lhs: SparseTensor, + rhs: FloatElem, + ) -> burn_tensor::ops::BoolTensor { + todo!() + } + + fn sparse_remainder_scalar( + lhs: SparseTensor, + rhs: FloatElem, + ) -> SparseTensor { + todo!() + } + + fn sparse_neg(tensor: SparseTensor) -> SparseTensor { + todo!() + } + + fn sparse_sign(tensor: SparseTensor) -> SparseTensor { + todo!() + } } From 80ab5d83bd6164b98ae9fa25367676986b5bd66d Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 28 Jul 2024 05:46:41 +0000 Subject: [PATCH 19/38] made unimplemented functions panic --- .../burn-sparse/src/backend/sparse_backend.rs | 4 + .../burn-sparse/src/decorator/sparse_coo.rs | 171 ++++++++++++++---- .../burn-sparse/src/decorator/sparse_csr.rs | 6 + 3 files changed, 143 insertions(+), 38 deletions(-) diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/src/backend/sparse_backend.rs index c74fa2a4fd..44db7b8e2f 100644 --- a/crates/burn-sparse/src/backend/sparse_backend.rs +++ b/crates/burn-sparse/src/backend/sparse_backend.rs @@ -37,6 +37,10 @@ pub trait SparseBackend: Backend { tensor: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive; + fn sparse_remove_zeros( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive; + fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize; fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32; diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index ea47f56006..9e132b3683 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -632,6 +632,7 @@ where tensors: Vec>, dim: usize, ) -> SparseTensor { + let mut offset = 0; todo!() } @@ -991,81 +992,89 @@ where } fn sparse_max(tensor: SparseTensor) -> SparseTensor { - todo!() + panic!("max is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_max_dim( tensor: SparseTensor, dim: usize, ) -> SparseTensor { - todo!() + panic!( + "max_dim is unsupported for SparseCOO until scatter supports other reduction methods" + ); } fn sparse_min(tensor: SparseTensor) -> SparseTensor { - todo!() + panic!("min is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_min_dim( tensor: SparseTensor, dim: usize, ) -> SparseTensor { - todo!() + panic!( + "min_dim is unsupported for SparseCOO until scatter supports other reduction methods" + ); } fn sparse_greater( lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_greater is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_greater_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_equal( lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_greater_equal is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_equal_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!( + "sparse_greater_equal_elem is not supported for SparseCOO as it outputs a dense tensor" + ); } fn sparse_lower( lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_lower is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_lower_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_equal( lhs: SparseTensor, rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_lower_equal is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_equal_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!( + "sparse_lower_equal_elem is not supported for SparseCOO as it outputs a dense tensor" + ); } fn sparse_abs(mut tensor: SparseTensor) -> SparseTensor { @@ -1077,58 +1086,96 @@ where lhs: SparseTensor, rhs: SparseTensor, ) -> SparseTensor { - todo!() + panic!("sparse_powf is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_powi( lhs: SparseTensor, rhs: SparseTensor, ) -> SparseTensor { - todo!() + panic!("sparse_powi is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_powf_scalar( - lhs: SparseTensor, + mut lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + lhs.values = lhs.values.map(|values| values.powf_scalar(rhs)); + lhs } fn sparse_powi_scalar( - lhs: SparseTensor, + mut lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + lhs.values = lhs.values.map(|values| values.powi_scalar(rhs)); + lhs } fn sparse_clamp( - tensor: SparseTensor, + mut tensor: SparseTensor, min: FloatElem, max: FloatElem, ) -> SparseTensor { - todo!() + tensor.values = tensor.values.map(|values| values.clamp(min, max)); + tensor } fn sparse_clamp_min( - tensor: SparseTensor, + mut tensor: SparseTensor, min: FloatElem, ) -> SparseTensor { - todo!() + tensor.values = tensor.values.map(|values| values.clamp_min(min)); + tensor } fn sparse_clamp_max( - tensor: SparseTensor, + mut tensor: SparseTensor, max: FloatElem, ) -> SparseTensor { - todo!() + tensor.values = tensor.values.map(|values| values.clamp_max(max)); + tensor } fn sparse_select( - tensor: SparseTensor, + mut tensor: SparseTensor, dim: usize, indices: burn_tensor::ops::IntTensor, ) -> SparseTensor { - todo!() + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let mut shape = tensor.shape; + let indices = Tensor::::new(indices); + + let nnz = coordinates.shape().dims[1]; + let dim_coords = coordinates + .clone() + .slice([dim..dim + 1, 0..nnz]) + .squeeze::<1>(0); + let indices = indices.select(0, dim_coords); + let indices_len = indices.shape().num_elements(); + let coordinates = coordinates.slice_assign( + [dim..dim + 1, 0..nnz], + indices.unsqueeze::<2>().repeat(1, D), + ); + + shape.dims[dim] = indices_len; + + SparseCOOTensor { + coordinates: Some(coordinates), + values: Some(values), + shape, + device, + } } fn sparse_select_assign( @@ -1158,65 +1205,113 @@ where } fn sparse_sum(tensor: SparseTensor) -> SparseTensor { - todo!() + tensor + .values + .map(|values| Self::sparse_to_sparse(values.sum().into_primitive().tensor())) + .unwrap_or(Self::sparse_empty(Shape::new([1]), &tensor.device)) } fn sparse_sum_dim( tensor: SparseTensor, dim: usize, ) -> SparseTensor { - todo!() + panic!("sparse_sum_dim unsupported for SparseCOO"); } fn sparse_prod(tensor: SparseTensor) -> SparseTensor { - todo!() + if tensor.coordinates.is_none() && tensor.values.is_none() { + return Self::sparse_empty(Shape::new([1]), &tensor.device); + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let shape = tensor.shape; + + if shape.num_elements() != coordinates.dims()[1] { + Self::sparse_empty(Shape::new([1]), &device) + } else { + Self::sparse_to_sparse(values.sum().into_primitive().tensor()) + } } fn sparse_prod_dim( tensor: SparseTensor, dim: usize, ) -> SparseTensor { - todo!() + panic!("sparse_prod_dim is not supported for SparseCOO until scatter supports product reduction") } fn sparse_mean(tensor: SparseTensor) -> SparseTensor { - todo!() + tensor + .values + .map(|values| { + let elems = values.shape().num_elements(); + Self::sparse_to_sparse((values.sum() / elems as f32).into_primitive().tensor()) + }) + .unwrap_or(Self::sparse_empty(Shape::new([1]), &tensor.device)) } fn sparse_mean_dim( tensor: SparseTensor, dim: usize, ) -> SparseTensor { - todo!() + panic!("mean_dim is not supported for SparseCOO until scatter supports mean reduction"); } fn sparse_equal_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_not_equal_elem( lhs: SparseTensor, rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { - todo!() + panic!("sparse_not_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_remainder_scalar( - lhs: SparseTensor, + mut lhs: SparseTensor, rhs: FloatElem, ) -> SparseTensor { - todo!() + lhs.values = lhs.values.map(|v| v.remainder_scalar(rhs)); + lhs } - fn sparse_neg(tensor: SparseTensor) -> SparseTensor { - todo!() + fn sparse_neg(mut tensor: SparseTensor) -> SparseTensor { + tensor.values = tensor.values.map(|v| v.neg()); + tensor } fn sparse_sign(mut tensor: SparseTensor) -> SparseTensor { tensor.values = tensor.values.map(|values| values.sign()); tensor } + + fn sparse_remove_zeros( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let mut shape = tensor.shape; + + // let zeros = tensor.values.map(|values| values.equal_elem(0).nonzero()); + todo!() + } } diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index 9671efcc1a..4823b81774 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -515,4 +515,10 @@ where fn sparse_sign(tensor: SparseTensor) -> SparseTensor { todo!() } + + fn sparse_remove_zeros( + tensor: Self::SparseTensorPrimitive, + ) -> Self::SparseTensorPrimitive { + todo!() + } } From f6c0ff82046fb65393f33c172f8f6dbee8587d60 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 11 Aug 2024 10:03:26 +0000 Subject: [PATCH 20/38] style fixes --- crates/burn-sparse/src/backend/kind.rs | 54 ++-- .../burn-sparse/src/decorator/sparse_coo.rs | 170 +++++----- .../burn-sparse/src/decorator/sparse_csr.rs | 296 +++++++++--------- 3 files changed, 258 insertions(+), 262 deletions(-) diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index 78e31363ae..7efd5770f5 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -83,7 +83,7 @@ impl BasicOps for Sparse { } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - B::sparse_flip(tensor, &axes) + B::sparse_flip(tensor, axes) } fn slice_assign( @@ -331,17 +331,17 @@ impl Numeric for Sparse { } fn mask_where( - tensor: Self::Primitive, - mask: Tensor, - source: Self::Primitive, + _tensor: Self::Primitive, + _mask: Tensor, + _source: Self::Primitive, ) -> Self::Primitive { panic!("masking of sparse tensors is unsupported") } fn mask_fill( - tensor: Self::Primitive, - mask: Tensor, - value: Self::Elem, + _tensor: Self::Primitive, + _mask: Tensor, + _value: Self::Elem, ) -> Self::Primitive { panic!("masking of sparse tensors is unsupported") } @@ -381,15 +381,15 @@ impl Numeric for Sparse { } fn argmax( - tensor: Self::Primitive, - dim: usize, + _tensor: Self::Primitive, + _dim: usize, ) -> ::IntTensorPrimitive { panic!("Argmax is unsupported for sparse tensors"); } fn argmin( - tensor: Self::Primitive, - dim: usize, + _tensor: Self::Primitive, + _dim: usize, ) -> ::IntTensorPrimitive { panic!("Argmin is unsupported for sparse tensors"); } @@ -403,8 +403,8 @@ impl Numeric for Sparse { } fn max_dim_with_indices( - tensor: Self::Primitive, - dim: usize, + _tensor: Self::Primitive, + _dim: usize, ) -> (Self::Primitive, ::IntTensorPrimitive) { todo!() } @@ -418,8 +418,8 @@ impl Numeric for Sparse { } fn min_dim_with_indices( - tensor: Self::Primitive, - dim: usize, + _tensor: Self::Primitive, + _dim: usize, ) -> (Self::Primitive, ::IntTensorPrimitive) { todo!() } @@ -479,25 +479,25 @@ impl Numeric for Sparse { } fn random( - shape: Shape, - distribution: burn_tensor::Distribution, - device: &::Device, + _shape: Shape, + _distribution: burn_tensor::Distribution, + _device: &::Device, ) -> Self::Primitive { panic!("Random is unsupported for sparse tensors") } fn sort( - tensor: Self::Primitive, - dim: usize, - descending: bool, + _tensor: Self::Primitive, + _dim: usize, + _descending: bool, ) -> Self::Primitive { panic!("Sorting is unsupported for sparse tensors") } fn sort_with_indices( - tensor: Self::Primitive, - dim: usize, - descending: bool, + _tensor: Self::Primitive, + _dim: usize, + _descending: bool, ) -> ( Self::Primitive, >::Primitive, @@ -506,9 +506,9 @@ impl Numeric for Sparse { } fn argsort( - tensor: Self::Primitive, - dim: usize, - descending: bool, + _tensor: Self::Primitive, + _dim: usize, + _descending: bool, ) -> >::Primitive { panic!("Sorting is unsupported for sparse tensors") } diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/src/decorator/sparse_coo.rs index 9e132b3683..bb5056fd30 100644 --- a/crates/burn-sparse/src/decorator/sparse_coo.rs +++ b/crates/burn-sparse/src/decorator/sparse_coo.rs @@ -5,7 +5,7 @@ use crate::decorator::SparseDecorator; use burn_tensor::ops::FloatElem; use burn_tensor::ops::FloatTensor; use burn_tensor::ops::FloatTensorOps; -use burn_tensor::ops::IntTensorOps; + use burn_tensor::Device; use burn_tensor::{ backend::Backend, Bool, ElementConversion, Float, Int, Shape, Tensor, TensorData, @@ -58,9 +58,8 @@ fn unflatten_coordinates( } new_coordinates.reverse(); - let reshaped_coordinates = Tensor::stack(new_coordinates, 0); - reshaped_coordinates + Tensor::stack(new_coordinates, 0) } impl SparseBackend for SparseDecorator @@ -381,7 +380,7 @@ where let nonzero = nonzero.nonzero(); let indices_dim1 = nonzero - .get(0) + .first() .cloned() .expect("Expected dimension to exist"); @@ -403,7 +402,7 @@ where data: TensorData, device: &burn_tensor::Device, ) -> SparseTensor { - let dense = B::float_from_data(data, &device); + let dense = B::float_from_data(data, device); Self::sparse_to_sparse(dense) } @@ -629,23 +628,23 @@ where } fn sparse_cat( - tensors: Vec>, - dim: usize, + _tensors: Vec>, + _dim: usize, ) -> SparseTensor { - let mut offset = 0; + let _offset = 0; todo!() } fn sparse_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("elementwise equal is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_not_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("elementwise not_equal is unsupported for SparseCOO until scatter supports other reduction methods"); } @@ -655,17 +654,17 @@ where ) -> burn_tensor::ops::BoolTensor { let SparseCOOTensor { coordinates, - values, - shape, - device, + values: _, + shape: _, + device: _, } = tensor; - let any = !matches!(coordinates, None); + let any = coordinates.is_some(); Tensor::::from([any]).into_primitive() } fn sparse_any_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> burn_tensor::ops::BoolTensor { panic!("any_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); } @@ -675,9 +674,9 @@ where ) -> burn_tensor::ops::BoolTensor { let SparseCOOTensor { coordinates, - values, + values: _, shape, - device, + device: _, } = tensor; let all = match coordinates { Some(coordinates) => shape.num_elements() == coordinates.shape().dims[1], @@ -687,21 +686,21 @@ where } fn sparse_all_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> burn_tensor::ops::BoolTensor { panic!("all_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); } fn sparse_expand( tensor: SparseTensor, - shape: Shape, + _shape: Shape, ) -> SparseTensor { let SparseCOOTensor { - coordinates, - values, - shape, - device, + coordinates: _, + values: _, + shape: _, + device: _, } = tensor; todo!() } @@ -734,7 +733,7 @@ where let coordinates = flatten_coordinates::(coordinates, original_shape.clone(), &device); - let flat_shape = Shape::new([original_shape.num_elements()]); + let _flat_shape = Shape::new([original_shape.num_elements()]); let (coordinates, indices) = coordinates.sort_with_indices(1); let values = values.select(0, indices.squeeze(0)); @@ -748,8 +747,7 @@ where let diff = Tensor::cat(vec![ones, diff], 1); // TODO this all would be way cleaner with cumsum/max, but that is waiting on a pull request as of writing - // this is technically O(nnz) but only in super rare and likely constructed cases - // lots of inspiration could be taken from pytorch_scatter for better implementations + // inspiration could be taken from pytorch_scatter for better implementations let unique_mask = diff.not_equal_elem(0); let unique_indices = unique_mask.clone().nonzero().remove(1); let steps = Tensor::cat( @@ -921,15 +919,15 @@ where } fn sparse_sub_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { panic!("Cannot add scalar to sparse, only zero preserving operations are permitted"); } fn sparse_mul( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { panic!("sparse_mul is unsupported until scatter supports multiplication based reduction"); } @@ -991,26 +989,26 @@ where lhs } - fn sparse_max(tensor: SparseTensor) -> SparseTensor { + fn sparse_max(_tensor: SparseTensor) -> SparseTensor { panic!("max is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_max_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { panic!( "max_dim is unsupported for SparseCOO until scatter supports other reduction methods" ); } - fn sparse_min(tensor: SparseTensor) -> SparseTensor { + fn sparse_min(_tensor: SparseTensor) -> SparseTensor { panic!("min is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_min_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { panic!( "min_dim is unsupported for SparseCOO until scatter supports other reduction methods" @@ -1018,29 +1016,29 @@ where } fn sparse_greater( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_greater is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_greater_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_greater_equal is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_greater_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!( "sparse_greater_equal_elem is not supported for SparseCOO as it outputs a dense tensor" @@ -1048,29 +1046,29 @@ where } fn sparse_lower( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_lower is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_lower_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_lower_equal is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_lower_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!( "sparse_lower_equal_elem is not supported for SparseCOO as it outputs a dense tensor" @@ -1083,15 +1081,15 @@ where } fn sparse_powf( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { panic!("sparse_powf is unsupported for SparseCOO until scatter supports other reduction methods"); } fn sparse_powi( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { panic!("sparse_powi is unsupported for SparseCOO until scatter supports other reduction methods"); } @@ -1138,7 +1136,7 @@ where } fn sparse_select( - mut tensor: SparseTensor, + tensor: SparseTensor, dim: usize, indices: burn_tensor::ops::IntTensor, ) -> SparseTensor { @@ -1179,27 +1177,27 @@ where } fn sparse_select_assign( - tensor: SparseTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, - values: SparseTensor, + _tensor: SparseTensor, + _dim: usize, + _indices: burn_tensor::ops::IntTensor, + _values: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_gather( - dim: usize, - tensor: SparseTensor, - indices: burn_tensor::ops::IntTensor, + _dim: usize, + _tensor: SparseTensor, + _indices: burn_tensor::ops::IntTensor, ) -> SparseTensor { todo!() } fn sparse_scatter( - dim: usize, - tensor: SparseTensor, - indices: burn_tensor::ops::IntTensor, - values: SparseTensor, + _dim: usize, + _tensor: SparseTensor, + _indices: burn_tensor::ops::IntTensor, + _values: SparseTensor, ) -> SparseTensor { todo!() } @@ -1212,8 +1210,8 @@ where } fn sparse_sum_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { panic!("sparse_sum_dim unsupported for SparseCOO"); } @@ -1240,8 +1238,8 @@ where } fn sparse_prod_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { panic!("sparse_prod_dim is not supported for SparseCOO until scatter supports product reduction") } @@ -1257,22 +1255,22 @@ where } fn sparse_mean_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { panic!("mean_dim is not supported for SparseCOO until scatter supports mean reduction"); } fn sparse_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); } fn sparse_not_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { panic!("sparse_not_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); } @@ -1302,14 +1300,14 @@ where return tensor; } - let coordinates = tensor + let _coordinates = tensor .coordinates .expect("Mismatch between coordinates and values"); - let values = tensor + let _values = tensor .values .expect("Mismatch between coordinates and values"); - let device = tensor.device; - let mut shape = tensor.shape; + let _device = tensor.device; + let _shape = tensor.shape; // let zeros = tensor.values.map(|values| values.equal_elem(0).nonzero()); todo!() diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/src/decorator/sparse_csr.rs index 4823b81774..04826fc10f 100644 --- a/crates/burn-sparse/src/decorator/sparse_csr.rs +++ b/crates/burn-sparse/src/decorator/sparse_csr.rs @@ -19,505 +19,503 @@ where type SparseTensorPrimitive = SparseCSRTensor; fn sparse_empty( - shape: burn_tensor::Shape, - device: &burn_tensor::Device, + _shape: burn_tensor::Shape, + _device: &burn_tensor::Device, ) -> SparseTensor { todo!() } fn sparse_to_sparse( - dense: Self::FloatTensorPrimitive, + _dense: Self::FloatTensorPrimitive, ) -> Self::SparseTensorPrimitive { todo!() } fn sparse_to_dense( - sparse: Self::SparseTensorPrimitive, + _sparse: Self::SparseTensorPrimitive, ) -> Self::FloatTensorPrimitive { todo!() } fn sparse_spmm( - lhs: Self::SparseTensorPrimitive, - rhs: Self::FloatTensorPrimitive, + _lhs: Self::SparseTensorPrimitive, + _rhs: Self::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive { todo!() } fn sparse_sddmm( - lhs: Self::FloatTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - sparse: Self::SparseTensorPrimitive, + _lhs: Self::FloatTensorPrimitive, + _rhs: Self::FloatTensorPrimitive, + _sparse: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive { todo!() } fn sparse_slice( - tensor: SparseTensor, - indices: [std::ops::Range; D2], + _tensor: SparseTensor, + _indices: [std::ops::Range; D2], ) -> SparseTensor { todo!() } - fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { + fn sparse_device(_tensor: &SparseTensor) -> burn_tensor::Device { todo!() } fn sparse_to_device( - tensor: SparseTensor, - device: &burn_tensor::Device, + _tensor: SparseTensor, + _device: &burn_tensor::Device, ) -> SparseTensor { todo!() } - fn sparse_shape(tensor: &SparseTensor) -> burn_tensor::Shape { + fn sparse_shape(_tensor: &SparseTensor) -> burn_tensor::Shape { todo!() } - fn sparse_into_data( - tensor: SparseTensor, - ) -> impl std::future::Future + Send { - async { todo!() } - } + async fn sparse_into_data( + _tensor: SparseTensor, + ) -> burn_tensor::TensorData { todo!() } fn sparse_from_data( - data: burn_tensor::TensorData, - device: &burn_tensor::Device, + _data: burn_tensor::TensorData, + _device: &burn_tensor::Device, ) -> SparseTensor { todo!() } fn sparse_reshape( - tensor: SparseTensor, - shape: burn_tensor::Shape, + _tensor: SparseTensor, + _shape: burn_tensor::Shape, ) -> SparseTensor { todo!() } - fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { + fn sparse_transpose(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_swap_dims( - tensor: SparseTensor, - dim1: usize, - dim2: usize, + _tensor: SparseTensor, + _dim1: usize, + _dim2: usize, ) -> SparseTensor { todo!() } fn sparse_permute( - tensor: SparseTensor, - axes: &[usize], + _tensor: SparseTensor, + _axes: &[usize], ) -> SparseTensor { todo!() } fn sparse_flip( - tensor: SparseTensor, - axes: &[usize], + _tensor: SparseTensor, + _axes: &[usize], ) -> SparseTensor { todo!() } fn sparse_slice_assign( - tensor: SparseTensor, - ranges: [std::ops::Range; D2], - value: SparseTensor, + _tensor: SparseTensor, + _ranges: [std::ops::Range; D2], + _value: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_repeat( - tensor: SparseTensor, - dim: usize, - times: usize, + _tensor: SparseTensor, + _dim: usize, + _times: usize, ) -> SparseTensor { todo!() } fn sparse_cat( - tensors: Vec>, - dim: usize, + _tensors: Vec>, + _dim: usize, ) -> SparseTensor { todo!() } fn sparse_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_not_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_any( - tensor: SparseTensor, + _tensor: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_any_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_all( - tensor: SparseTensor, + _tensor: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_all_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_expand( - tensor: SparseTensor, - shape: burn_tensor::Shape, + _tensor: SparseTensor, + _shape: burn_tensor::Shape, ) -> SparseTensor { todo!() } fn sparse_coalesce_sum( - tensor: Self::SparseTensorPrimitive, + _tensor: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive { todo!() } - fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize { + fn sparse_nonzero(_tensor: Self::SparseTensorPrimitive) -> usize { todo!() } - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32 { + fn sparse_density(_sparse: Self::SparseTensorPrimitive) -> f32 { todo!() } fn sparse_add( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_add_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } fn sparse_add_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, + _lhs: SparseTensor, + _rhs: burn_tensor::ops::FloatTensor, ) -> FloatTensor { todo!() } fn sparse_sub( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_sub_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, + _lhs: SparseTensor, + _rhs: burn_tensor::ops::FloatTensor, ) -> FloatTensor { todo!() } fn sparse_sub_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } fn sparse_mul( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_mul_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, + _lhs: SparseTensor, + _rhs: burn_tensor::ops::FloatTensor, ) -> FloatTensor { todo!() } fn sparse_mul_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } fn sparse_div( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_div_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, + _lhs: SparseTensor, + _rhs: burn_tensor::ops::FloatTensor, ) -> FloatTensor { todo!() } fn sparse_div_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } - fn sparse_max(tensor: SparseTensor) -> SparseTensor { + fn sparse_max(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_max_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { todo!() } - fn sparse_min(tensor: SparseTensor) -> SparseTensor { + fn sparse_min(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_min_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { todo!() } fn sparse_greater( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_greater_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_greater_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_greater_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_lower( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_lower_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_lower_equal( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_lower_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } - fn sparse_abs(tensor: SparseTensor) -> SparseTensor { + fn sparse_abs(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_powf( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_powi( - lhs: SparseTensor, - rhs: SparseTensor, + _lhs: SparseTensor, + _rhs: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_powf_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } fn sparse_powi_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } fn sparse_clamp( - tensor: SparseTensor, - min: FloatElem, - max: FloatElem, + _tensor: SparseTensor, + _min: FloatElem, + _max: FloatElem, ) -> SparseTensor { todo!() } fn sparse_clamp_min( - tensor: SparseTensor, - min: FloatElem, + _tensor: SparseTensor, + _min: FloatElem, ) -> SparseTensor { todo!() } fn sparse_clamp_max( - tensor: SparseTensor, - max: FloatElem, + _tensor: SparseTensor, + _max: FloatElem, ) -> SparseTensor { todo!() } fn sparse_select( - tensor: SparseTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, + _tensor: SparseTensor, + _dim: usize, + _indices: burn_tensor::ops::IntTensor, ) -> SparseTensor { todo!() } fn sparse_select_assign( - tensor: SparseTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, - values: SparseTensor, + _tensor: SparseTensor, + _dim: usize, + _indices: burn_tensor::ops::IntTensor, + _values: SparseTensor, ) -> SparseTensor { todo!() } fn sparse_gather( - dim: usize, - tensor: SparseTensor, - indices: burn_tensor::ops::IntTensor, + _dim: usize, + _tensor: SparseTensor, + _indices: burn_tensor::ops::IntTensor, ) -> SparseTensor { todo!() } fn sparse_scatter( - dim: usize, - tensor: SparseTensor, - indices: burn_tensor::ops::IntTensor, - values: SparseTensor, + _dim: usize, + _tensor: SparseTensor, + _indices: burn_tensor::ops::IntTensor, + _values: SparseTensor, ) -> SparseTensor { todo!() } - fn sparse_sum(tensor: SparseTensor) -> SparseTensor { + fn sparse_sum(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_sum_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { todo!() } - fn sparse_prod(tensor: SparseTensor) -> SparseTensor { + fn sparse_prod(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_prod_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { todo!() } - fn sparse_mean(tensor: SparseTensor) -> SparseTensor { + fn sparse_mean(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_mean_dim( - tensor: SparseTensor, - dim: usize, + _tensor: SparseTensor, + _dim: usize, ) -> SparseTensor { todo!() } fn sparse_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_not_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> burn_tensor::ops::BoolTensor { todo!() } fn sparse_remainder_scalar( - lhs: SparseTensor, - rhs: FloatElem, + _lhs: SparseTensor, + _rhs: FloatElem, ) -> SparseTensor { todo!() } - fn sparse_neg(tensor: SparseTensor) -> SparseTensor { + fn sparse_neg(_tensor: SparseTensor) -> SparseTensor { todo!() } - fn sparse_sign(tensor: SparseTensor) -> SparseTensor { + fn sparse_sign(_tensor: SparseTensor) -> SparseTensor { todo!() } fn sparse_remove_zeros( - tensor: Self::SparseTensorPrimitive, + _tensor: Self::SparseTensorPrimitive, ) -> Self::SparseTensorPrimitive { todo!() } From b63ea7ac6c974077dad64c3aacde296960b00e1b Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 11 Aug 2024 14:34:56 +0000 Subject: [PATCH 21/38] Refactor of tensor API in progress --- Cargo.lock | 2 +- crates/burn-sparse/src/backend/kind.rs | 10 +- crates/burn-sparse/src/decorator/ops.rs | 56 +- crates/burn-sparse/src/lib.rs | 4 +- crates/burn-tensor/src/tensor/api/base.rs | 9 +- crates/burn-tensor/src/tensor/api/kind.rs | 28 +- crates/burn-tensor/src/tensor/api/mod.rs | 4 + crates/burn-tensor/src/tensor/api/repr.rs | 20 + crates/burn-tensor/src/tensor/api/sparse.rs | 152 +++++ crates/burn-tensor/src/tensor/ops/mod.rs | 2 + .../src/tensor/ops/sparse_tensor.rs | 568 ++++++++++++++++++ 11 files changed, 821 insertions(+), 34 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/api/repr.rs create mode 100644 crates/burn-tensor/src/tensor/api/sparse.rs create mode 100644 crates/burn-tensor/src/tensor/ops/sparse_tensor.rs diff --git a/Cargo.lock b/Cargo.lock index af85780c14..c4227fa262 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -703,7 +703,7 @@ dependencies = [ "rand", "rand_distr", "serde", - "syn 2.0.69", + "syn 2.0.72", ] [[package]] diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/src/backend/kind.rs index 7efd5770f5..805b7d0dbd 100644 --- a/crates/burn-sparse/src/backend/kind.rs +++ b/crates/burn-sparse/src/backend/kind.rs @@ -3,18 +3,20 @@ use std::{future::Future, ops::Range}; use crate::backend::SparseBackend; use burn_tensor::{backend::Backend, BasicOps, Numeric, Shape, Tensor, TensorData, TensorKind}; +pub trait SparseRepr {} + /// A type-level representation of the kind of a sparse (float) tensor. #[derive(Clone, Debug)] -pub struct Sparse; +pub struct Sparse; -impl TensorKind for Sparse { +impl TensorKind for Sparse { type Primitive = B::SparseTensorPrimitive; fn name() -> &'static str { "Sparse" } } -impl BasicOps for Sparse { +impl BasicOps for Sparse { type Elem = B::FloatElem; fn into_data_async( @@ -94,7 +96,7 @@ impl BasicOps for Sparse { B::sparse_slice_assign(tensor, ranges, value) } - fn repeat( + fn repeat_dim( tensor: Self::Primitive, dim: usize, times: usize, diff --git a/crates/burn-sparse/src/decorator/ops.rs b/crates/burn-sparse/src/decorator/ops.rs index 1ec03ce3cd..7994ef94c0 100644 --- a/crates/burn-sparse/src/decorator/ops.rs +++ b/crates/burn-sparse/src/decorator/ops.rs @@ -647,14 +647,6 @@ where B::int_select_assign(tensor, dim, indices, value) } - fn int_repeat( - tensor: IntTensor, D>, - dim: usize, - times: usize, - ) -> IntTensor, D> { - B::int_repeat(tensor, dim, times) - } - fn int_cat( tensors: Vec, D>>, dim: usize, @@ -975,20 +967,6 @@ where B: Backend, R: SparseRepresentation, { - fn quantize( - tensor: FloatTensor, - strategy: &burn_tensor::QuantizationStrategy, - ) -> burn_tensor::ops::QuantizedTensor { - B::quantize(tensor, strategy) - } - - fn dequantize( - tensor: burn_tensor::ops::QuantizedTensor, - strategy: &burn_tensor::QuantizationStrategy, - ) -> FloatTensor { - B::dequantize(tensor, strategy) - } - fn q_shape(tensor: &burn_tensor::ops::QuantizedTensor) -> Shape { B::q_shape(tensor) } @@ -996,6 +974,40 @@ where fn q_device(tensor: &burn_tensor::ops::QuantizedTensor) -> Device { B::q_device(tensor) } + + fn q_from_data( + data: TensorData, + device: &Device>, + ) -> burn_tensor::ops::QuantizedTensor, D> { + B::q_from_data(data, device) + } + + fn q_reshape( + tensor: burn_tensor::ops::QuantizedTensor, D1>, + shape: Shape, + ) -> burn_tensor::ops::QuantizedTensor, D2> { + B::q_reshape(tensor, shape) + } + + fn q_into_data( + tensor: burn_tensor::ops::QuantizedTensor, D>, + ) -> impl std::future::Future + Send { + B::q_into_data(tensor) + } + + fn quantize( + tensor: FloatTensor, D>, + scheme: &burn_tensor::quantization::QuantizationScheme, + qparams: burn_tensor::quantization::QuantizationParametersPrimitive>, + ) -> burn_tensor::ops::QuantizedTensor, D> { + B::quantize(tensor, scheme, qparams) + } + + fn dequantize( + tensor: burn_tensor::ops::QuantizedTensor, D>, + ) -> FloatTensor, D> { + B::dequantize(tensor) + } } impl ModuleOps> for SparseDecorator diff --git a/crates/burn-sparse/src/lib.rs b/crates/burn-sparse/src/lib.rs index 9dfe4c9dce..b5ed4799d3 100644 --- a/crates/burn-sparse/src/lib.rs +++ b/crates/burn-sparse/src/lib.rs @@ -1,2 +1,2 @@ -pub mod backend; -pub mod decorator; +// pub mod backend; +// pub mod decorator; diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 1201fcda27..e6ab89ce9f 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,14 +18,15 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{DType, Element, TensorPrimitive}; +use crate::{DType, Dense, Element, TensorPrimitive, TensorRepr}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] -pub struct Tensor +pub struct Tensor where B: Backend, - K: TensorKind, + K: TensorKind, + R: TensorRepr, { pub(crate) primitive: K::Primitive, } @@ -1278,7 +1279,7 @@ impl core::ops::BitXor for Tensor { /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait BasicOps: TensorKind { +pub trait BasicOps = Dense>: TensorKind { /// The type of the tensor elements. type Elem: Element; diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index 7afe1d2c36..4e50b3f1e5 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -1,4 +1,6 @@ use crate::backend::Backend; +use crate::{Dense, Sparse, SparseRepr, TensorRepr}; +use core::marker::PhantomData; /// A type-level representation of the kind of a float tensor #[derive(Clone, Debug)] @@ -32,7 +34,7 @@ impl TensorPrimitive { } /// A type-level representation of the kind of a tensor. -pub trait TensorKind: Clone + core::fmt::Debug { +pub trait TensorKind = Dense>: Clone + core::fmt::Debug { /// The primitive type of the tensor. type Primitive: Clone + core::fmt::Debug + Send; @@ -47,6 +49,14 @@ impl TensorKind for Float { } } +impl> TensorKind> for Float { + type Primitive = R::FloatPrimitive; + + fn name() -> &'static str { + "SparseFloat" + } +} + impl TensorKind for Int { type Primitive = B::IntTensorPrimitive; fn name() -> &'static str { @@ -54,9 +64,25 @@ impl TensorKind for Int { } } +impl> TensorKind> for Int { + type Primitive = R::IntPrimitive; + + fn name() -> &'static str { + "SparseInt" + } +} + impl TensorKind for Bool { type Primitive = B::BoolTensorPrimitive; fn name() -> &'static str { "Bool" } } + +impl> TensorKind> for Bool { + type Primitive = R::BoolPrimitive; + + fn name() -> &'static str { + "SparseBool" + } +} diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index 60272d80bd..aaef17e0aa 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -11,7 +11,9 @@ mod int; mod kind; mod narrow; mod numeric; +mod repr; mod sort; +mod sparse; pub use argwhere::argwhere_data; pub use autodiff::*; @@ -21,4 +23,6 @@ pub use chunk::chunk; pub use kind::*; pub use narrow::narrow; pub use numeric::*; +pub use repr::*; pub use sort::{argsort, sort, sort_with_indices}; +pub use sparse::*; diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs new file mode 100644 index 0000000000..88d2282da7 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -0,0 +1,20 @@ +use crate::{backend::Backend, ops::SparseTensorOps}; +use core::marker::PhantomData; + +pub trait TensorRepr: Clone + core::fmt::Debug {} + +pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps { + type FloatPrimitive: Clone + core::fmt::Debug + Send; + type IntPrimitive: Clone + core::fmt::Debug + Send; + type BoolPrimitive: Clone + core::fmt::Debug + Send; +} + +#[derive(Clone, Debug)] +pub struct Dense; + +#[derive(Clone, Debug)] +pub struct Sparse, B: Backend>(PhantomData<(R, B)>); + +impl TensorRepr for Dense {} + +impl> TensorRepr for Sparse {} diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs new file mode 100644 index 0000000000..c8a988f941 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -0,0 +1,152 @@ +use crate::{ + backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Device, Float, Int, Shape, Sparse, + SparseRepr, Tensor, TensorData, TensorKind, TensorPrimitive, +}; +use core::{future::Future, ops::Range}; + +use crate::check; + +type Primitive = >>::Primitive; + +impl> BasicOps> for Float +where + Float: TensorKind>, +{ + type Elem = B::FloatElem; + + fn empty(shape: Shape, device: &B::Device) -> Primitive { + R::float_empty(shape, device) + } + + fn shape(tensor: &Primitive) -> Shape { + R::float_shape(tensor) + } + + fn reshape( + tensor: Primitive, + shape: Shape, + ) -> Primitive { + R::float_reshape(tensor, shape) + } + + fn transpose(tensor: Primitive) -> Primitive { + R::float_transpose(tensor) + } + + fn swap_dims( + tensor: Primitive, + dim1: usize, + dim2: usize, + ) -> Primitive { + check!(TensorCheck::swap_dims::(dim1, dim2)); + R::float_swap_dims(tensor, dim1, dim2) + } + + fn slice( + tensor: Primitive, + ranges: [Range; D2], + ) -> Primitive { + R::float_slice(tensor, ranges) + } + + fn slice_assign( + tensor: Primitive, + ranges: [Range; D2], + value: Primitive, + ) -> Primitive { + R::float_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Primitive) -> ::Device { + R::float_device(tensor) + } + + fn to_device( + tensor: Primitive, + device: &::Device, + ) -> Primitive { + R::float_to_device(tensor, device) + } + + fn from_data(data: TensorData, device: &B::Device) -> Primitive { + R::float_from_data(data, device) + } + + fn repeat_dim( + tensor: Primitive, + dim: usize, + times: usize, + ) -> Primitive { + R::float_repeat_dim(tensor, dim, times) + } + + fn cat( + vectors: Vec>, + dim: usize, + ) -> Primitive { + R::float_cat(vectors.into_iter().map(|tensor| tensor).collect(), dim) + } + + fn equal( + lhs: Primitive, + rhs: Primitive, + ) -> Tensor { + Tensor::new(R::float_equal(lhs, rhs)) + } + + fn not_equal( + lhs: Primitive, + rhs: Primitive, + ) -> Tensor { + Tensor::new(R::float_not_equal(lhs, rhs)) + } + + fn any(tensor: Primitive) -> Tensor { + Tensor::new(R::float_any(tensor)) + } + + fn any_dim(tensor: Primitive, dim: usize) -> Tensor { + Tensor::new(R::float_any_dim(tensor, dim)) + } + + fn all(tensor: Primitive) -> Tensor { + Tensor::new(R::float_all(tensor)) + } + + fn all_dim(tensor: Primitive, dim: usize) -> Tensor { + Tensor::new(R::float_all_dim(tensor, dim)) + } + + fn permute( + tensor: Primitive, + axes: [usize; D], + ) -> Primitive { + R::float_permute(tensor.tensor(), &axes) + } + + fn expand( + tensor: Primitive, + shape: Shape, + ) -> Primitive { + R::float_expand(tensor, shape) + } + + fn flip( + tensor: Primitive, + axes: &[usize], + ) -> Primitive { + R::float_flip(tensor, axes) + } + + // fn into_data_async( + // tensor: Self::Primitive, + // ) -> impl Future + Send { + // todo!() + // } + + fn into_data_async( + tensor: Primitive, + ) -> impl Future + Send { + R::float_into_data(tensor) + } +} diff --git a/crates/burn-tensor/src/tensor/ops/mod.rs b/crates/burn-tensor/src/tensor/ops/mod.rs index 1cce562586..af59004a81 100644 --- a/crates/burn-tensor/src/tensor/ops/mod.rs +++ b/crates/burn-tensor/src/tensor/ops/mod.rs @@ -4,6 +4,7 @@ mod bool_tensor; mod int_tensor; mod modules; mod qtensor; +mod sparse_tensor; mod tensor; pub use activation::*; @@ -12,4 +13,5 @@ pub use bool_tensor::*; pub use int_tensor::*; pub use modules::*; pub use qtensor::*; +pub use sparse_tensor::*; pub use tensor::*; diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs new file mode 100644 index 0000000000..2251cf5104 --- /dev/null +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -0,0 +1,568 @@ +use crate::{ + backend::Backend, BasicOps, Device, Shape, Sparse, SparseRepr, TensorData, TensorPrimitive, +}; +use core::{future::Future, ops::Range}; + +use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; + +type SparseFloatPrimitive +where + R: SparseRepr, += R::FloatPrimitive; + +type DenseFloatPrimitive +where + B: Backend, + R: SparseRepr, += B::FloatTensorPrimitive; + +pub trait SparseFloatOps, B: Backend> { + fn float_to_sparse( + dense: B::FloatTensorPrimitive, + ) -> SparseFloatPrimitive; + + fn float_empty( + shape: Shape, + device: &Device, + ) -> SparseFloatPrimitive; + + fn float_to_dense( + sparse: SparseFloatPrimitive, + ) -> B::FloatTensorPrimitive; + + fn float_spmm( + lhs: SparseFloatPrimitive, + rhs: B::FloatTensorPrimitive, + ) -> B::FloatTensorPrimitive; + + fn float_sddmm( + lhs: B::FloatTensorPrimitive, + rhs: B::FloatTensorPrimitive, + sparse: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_coalesce_sum( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_remove_zeros( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_nonzero(tensor: SparseFloatPrimitive) -> usize; + + fn float_density(sparse: SparseFloatPrimitive) -> f32; + + /// Gets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The elements at the given indices. + fn float_slice( + tensor: SparseFloatPrimitive, + indices: [Range; D2], + ) -> SparseFloatPrimitive; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn float_device(tensor: &SparseFloatPrimitive) -> Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn float_to_device( + tensor: SparseFloatPrimitive, + device: &Device, + ) -> SparseFloatPrimitive; + + /// Gets the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn float_shape(tensor: &SparseFloatPrimitive) -> Shape; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn float_into_data( + tensor: SparseFloatPrimitive, + ) -> impl Future + Send; + + /// Creates a tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the data. + fn float_from_data( + data: TensorData, + device: &Device, + ) -> SparseFloatPrimitive; + + fn float_reshape( + tensor: SparseFloatPrimitive, + shape: Shape, + ) -> SparseFloatPrimitive; + + fn float_transpose( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_swap_dims( + tensor: SparseFloatPrimitive, + dim1: usize, + dim2: usize, + ) -> SparseFloatPrimitive; + + fn float_permute( + tensor: SparseFloatPrimitive, + axes: &[usize], + ) -> SparseFloatPrimitive; + + fn float_flip( + tensor: SparseFloatPrimitive, + axes: &[usize], + ) -> SparseFloatPrimitive; + + fn float_slice_assign( + tensor: SparseFloatPrimitive, + ranges: [Range; D2], + value: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_repeat( + tensor: SparseFloatPrimitive, + dim: usize, + times: usize, + ) -> SparseFloatPrimitive; + + fn float_cat( + tensors: Vec>, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_equal( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_not_equal( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_any(tensor: SparseFloatPrimitive) -> BoolTensor; + + fn float_any_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> BoolTensor; + + fn float_all(tensor: SparseFloatPrimitive) -> BoolTensor; + + fn float_all_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> BoolTensor; + + fn float_expand( + tensor: SparseFloatPrimitive, + shape: Shape, + ) -> SparseFloatPrimitive; + + /// Adds two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn float_add( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + /// Adds a sparse and dense tensor together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn float_add_dense( + lhs: SparseFloatPrimitive, + rhs: FloatTensor, + ) -> FloatTensor; + + /// Adds a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of adding the scalar to the tensor. + fn float_add_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn float_sub( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + /// Subtracts a dense from a sparse tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor (sparse). + /// * `rhs` - The right hand side tensor (dense). + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn float_sub_dense( + lhs: SparseFloatPrimitive, + rhs: FloatTensor, + ) -> FloatTensor; + + /// Subtracts a scalar from a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of subtracting the scalar from the tensor. + fn float_sub_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + /// Multiplies two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together. + fn float_mul( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + /// Multiplies a sparse and dense tensor together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together. + fn float_mul_dense( + lhs: SparseFloatPrimitive, + rhs: FloatTensor, + ) -> FloatTensor; + + /// Multiplies a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the scalar with the tensor. + fn float_mul_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + /// Divides two sparse tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn float_div( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + /// Divides a sparse and dense tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn float_div_dense( + lhs: SparseFloatPrimitive, + rhs: FloatTensor, + ) -> FloatTensor; + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn float_div_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_max( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_max_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_min( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_min_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_greater( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_greater_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_greater_equal( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_greater_equal_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_lower( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_lower_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_lower_equal( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> BoolTensor; + + fn float_lower_equal_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_abs( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + fn float_sign( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_powf( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_powi( + lhs: SparseFloatPrimitive, + rhs: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_powf_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_powi_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_clamp( + tensor: SparseFloatPrimitive, + min: FloatElem, + max: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_clamp_min( + tensor: SparseFloatPrimitive, + min: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_clamp_max( + tensor: SparseFloatPrimitive, + max: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_select( + tensor: SparseFloatPrimitive, + dim: usize, + indices: IntTensor, + ) -> SparseFloatPrimitive; + + fn float_select_assign( + tensor: SparseFloatPrimitive, + dim: usize, + indices: IntTensor, + values: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_gather( + dim: usize, + tensor: SparseFloatPrimitive, + indices: IntTensor, + ) -> SparseFloatPrimitive; + + fn float_scatter( + dim: usize, + tensor: SparseFloatPrimitive, + indices: IntTensor, + values: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_sum( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_sum_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_prod( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_prod_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_mean( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; + + fn float_mean_dim( + tensor: SparseFloatPrimitive, + dim: usize, + ) -> SparseFloatPrimitive; + + fn float_equal_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_not_equal_elem( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> BoolTensor; + + fn float_remainder_scalar( + lhs: SparseFloatPrimitive, + rhs: FloatElem, + ) -> SparseFloatPrimitive; + + fn float_neg( + tensor: SparseFloatPrimitive, + ) -> SparseFloatPrimitive; +} + +pub trait SparseIntOps, B: Backend> {} + +pub trait SparseBoolOps, B: Backend> {} + +pub trait SparseTensorOps, B: Backend>: + SparseFloatOps + SparseIntOps + SparseBoolOps +{ +} From 1d0d36634652679786990f4f19062b8ec327f166 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 14 Aug 2024 08:14:14 +0000 Subject: [PATCH 22/38] New sparse tensor API, seems really good --- crates/burn-sparse/src/backend/mod.rs | 16 +- crates/burn-sparse/src/decorator/mod.rs | 18 +- crates/burn-sparse/src/lib.rs | 4 +- crates/burn-tensor/src/tensor/api/base.rs | 19 +- crates/burn-tensor/src/tensor/api/kind.rs | 17 +- crates/burn-tensor/src/tensor/api/repr.rs | 25 +- crates/burn-tensor/src/tensor/api/sparse.rs | 145 +++---- .../src/tensor/ops/sparse_tensor.rs | 386 ++++++++---------- 8 files changed, 305 insertions(+), 325 deletions(-) diff --git a/crates/burn-sparse/src/backend/mod.rs b/crates/burn-sparse/src/backend/mod.rs index 741143a850..3b88b7f0ea 100644 --- a/crates/burn-sparse/src/backend/mod.rs +++ b/crates/burn-sparse/src/backend/mod.rs @@ -1,9 +1,9 @@ -mod alias; -mod api; -mod kind; -mod sparse_backend; +// mod alias; +// mod api; +// mod kind; +// mod sparse_backend; -pub use alias::*; -pub use api::*; -pub use kind::*; -pub use sparse_backend::*; +// pub use alias::*; +// pub use api::*; +// pub use kind::*; +// pub use sparse_backend::*; diff --git a/crates/burn-sparse/src/decorator/mod.rs b/crates/burn-sparse/src/decorator/mod.rs index ba2b258c05..dc49680a84 100644 --- a/crates/burn-sparse/src/decorator/mod.rs +++ b/crates/burn-sparse/src/decorator/mod.rs @@ -1,10 +1,10 @@ -mod backend; -mod ops; -mod precision_bridge; -mod representation; -mod sparse_coo; -mod sparse_csr; +// mod backend; +// mod ops; +// mod precision_bridge; +// mod representation; +// mod sparse_coo; +// mod sparse_csr; -pub use backend::*; -pub use precision_bridge::*; -pub use representation::*; +// pub use backend::*; +// pub use precision_bridge::*; +// pub use representation::*; diff --git a/crates/burn-sparse/src/lib.rs b/crates/burn-sparse/src/lib.rs index b5ed4799d3..9dfe4c9dce 100644 --- a/crates/burn-sparse/src/lib.rs +++ b/crates/burn-sparse/src/lib.rs @@ -1,2 +1,2 @@ -// pub mod backend; -// pub mod decorator; +pub mod backend; +pub mod decorator; diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index e6ab89ce9f..67a4be82c8 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,7 +18,7 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{DType, Dense, Element, TensorPrimitive, TensorRepr}; +use crate::{DType, Dense, Element, SparseRepr, TensorPrimitive, TensorRepr}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] @@ -1279,7 +1279,10 @@ impl core::ops::BitXor for Tensor { /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait BasicOps = Dense>: TensorKind { +pub trait BasicOps = Dense>: TensorKind +where + Bool: TensorKind, +{ /// The type of the tensor elements. type Elem: Element; @@ -1606,7 +1609,7 @@ pub trait BasicOps = Dense>: TensorKind { fn equal( lhs: Self::Primitive, rhs: Self::Primitive, - ) -> Tensor; + ) -> Tensor; /// Applies element-wise non-equality comparison between the given tensors. /// @@ -1630,7 +1633,7 @@ pub trait BasicOps = Dense>: TensorKind { fn not_equal( lhs: Self::Primitive, rhs: Self::Primitive, - ) -> Tensor; + ) -> Tensor; /// Returns the name of the element type. fn elem_type_name() -> &'static str { @@ -1653,7 +1656,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function /// which is more high-level and designed for public use. - fn any(tensor: Self::Primitive) -> Tensor; + fn any(tensor: Self::Primitive) -> Tensor; /// Tests if any element in the tensor evaluates to True along a given dimension dim. /// @@ -1673,7 +1676,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function, /// which is more high-level and designed for public use. - fn any_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn any_dim(tensor: Self::Primitive, dim: usize) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True. /// @@ -1691,7 +1694,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function, /// which is more high-level and designed for public use. - fn all(tensor: Self::Primitive) -> Tensor; + fn all(tensor: Self::Primitive) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// @@ -1710,7 +1713,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function, /// which is more high-level and designed for public use. - fn all_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn all_dim(tensor: Self::Primitive, dim: usize) -> Tensor; /// Broadcasts the given tensor to the specified shape. /// diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index 4e50b3f1e5..c7a8a3c042 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -40,6 +40,11 @@ pub trait TensorKind = Dense>: Clone + core::fmt::D /// The name of the tensor kind. fn name() -> &'static str; + + /// The representation of the tensor kind. + fn representation() -> &'static str { + R::name() + } } impl TensorKind for Float { @@ -50,10 +55,10 @@ impl TensorKind for Float { } impl> TensorKind> for Float { - type Primitive = R::FloatPrimitive; + type Primitive = R::FloatTensorPrimitive; fn name() -> &'static str { - "SparseFloat" + >::name() } } @@ -65,10 +70,10 @@ impl TensorKind for Int { } impl> TensorKind> for Int { - type Primitive = R::IntPrimitive; + type Primitive = R::IntTensorPrimitive; fn name() -> &'static str { - "SparseInt" + >::name() } } @@ -80,9 +85,9 @@ impl TensorKind for Bool { } impl> TensorKind> for Bool { - type Primitive = R::BoolPrimitive; + type Primitive = R::BoolTensorPrimitive; fn name() -> &'static str { - "SparseBool" + >::name() } } diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 88d2282da7..8d7b4831ce 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -1,12 +1,15 @@ -use crate::{backend::Backend, ops::SparseTensorOps}; +use crate::{backend::Backend, ops::SparseTensorOps, TensorKind}; use core::marker::PhantomData; -pub trait TensorRepr: Clone + core::fmt::Debug {} +pub trait TensorRepr: Clone + core::fmt::Debug { + fn name() -> &'static str; +} pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps { - type FloatPrimitive: Clone + core::fmt::Debug + Send; - type IntPrimitive: Clone + core::fmt::Debug + Send; - type BoolPrimitive: Clone + core::fmt::Debug + Send; + type FloatTensorPrimitive: Clone + core::fmt::Debug + Send; + type IntTensorPrimitive: Clone + core::fmt::Debug + Send; + type BoolTensorPrimitive: Clone + core::fmt::Debug + Send; + fn name() -> &'static str; } #[derive(Clone, Debug)] @@ -15,6 +18,14 @@ pub struct Dense; #[derive(Clone, Debug)] pub struct Sparse, B: Backend>(PhantomData<(R, B)>); -impl TensorRepr for Dense {} +impl TensorRepr for Dense { + fn name() -> &'static str { + "Dense" + } +} -impl> TensorRepr for Sparse {} +impl, B: Backend> TensorRepr for Sparse { + fn name() -> &'static str { + R::name() + } +} diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index c8a988f941..5856c56af8 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -1,152 +1,143 @@ use crate::{ - backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Device, Float, Int, Shape, Sparse, - SparseRepr, Tensor, TensorData, TensorKind, TensorPrimitive, + backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Device, Element, Float, Int, + Shape, Sparse, SparseRepr, Tensor, TensorData, TensorKind, TensorPrimitive, TensorRepr, }; use core::{future::Future, ops::Range}; use crate::check; -type Primitive = >>::Primitive; - -impl> BasicOps> for Float -where - Float: TensorKind>, -{ +impl> BasicOps> for Float { type Elem = B::FloatElem; - fn empty(shape: Shape, device: &B::Device) -> Primitive { + fn empty( + shape: Shape, + device: &::Device, + ) -> R::FloatTensorPrimitive { R::float_empty(shape, device) } - fn shape(tensor: &Primitive) -> Shape { + fn shape(tensor: &Self::Primitive) -> Shape { R::float_shape(tensor) } fn reshape( - tensor: Primitive, + tensor: Self::Primitive, shape: Shape, - ) -> Primitive { + ) -> Self::Primitive { R::float_reshape(tensor, shape) } - fn transpose(tensor: Primitive) -> Primitive { + fn transpose(tensor: Self::Primitive) -> Self::Primitive { R::float_transpose(tensor) } fn swap_dims( - tensor: Primitive, + tensor: Self::Primitive, dim1: usize, dim2: usize, - ) -> Primitive { - check!(TensorCheck::swap_dims::(dim1, dim2)); + ) -> Self::Primitive { R::float_swap_dims(tensor, dim1, dim2) } + fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { + R::float_permute(tensor, &axes) + } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + R::float_flip(tensor, axes) + } + fn slice( - tensor: Primitive, - ranges: [Range; D2], - ) -> Primitive { - R::float_slice(tensor, ranges) + tensor: Self::Primitive, + range: [Range; D2], + ) -> Self::Primitive { + R::float_slice(tensor, range) } fn slice_assign( - tensor: Primitive, + tensor: Self::Primitive, ranges: [Range; D2], - value: Primitive, - ) -> Primitive { + value: Self::Primitive, + ) -> Self::Primitive { R::float_slice_assign(tensor, ranges, value) } - fn device(tensor: &Primitive) -> ::Device { + fn device(tensor: &Self::Primitive) -> ::Device { R::float_device(tensor) } fn to_device( - tensor: Primitive, + tensor: Self::Primitive, device: &::Device, - ) -> Primitive { + ) -> Self::Primitive { R::float_to_device(tensor, device) } - fn from_data(data: TensorData, device: &B::Device) -> Primitive { + fn into_data_async( + tensor: Self::Primitive, + ) -> impl Future + Send { + R::float_into_data(tensor) + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> Self::Primitive { R::float_from_data(data, device) } fn repeat_dim( - tensor: Primitive, + tensor: Self::Primitive, dim: usize, times: usize, - ) -> Primitive { + ) -> Self::Primitive { R::float_repeat_dim(tensor, dim, times) } - fn cat( - vectors: Vec>, - dim: usize, - ) -> Primitive { - R::float_cat(vectors.into_iter().map(|tensor| tensor).collect(), dim) + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + R::float_cat(vectors, dim) + } + + fn expand( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + R::float_expand(tensor, shape) } fn equal( - lhs: Primitive, - rhs: Primitive, - ) -> Tensor { + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { Tensor::new(R::float_equal(lhs, rhs)) } fn not_equal( - lhs: Primitive, - rhs: Primitive, - ) -> Tensor { + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { Tensor::new(R::float_not_equal(lhs, rhs)) } - fn any(tensor: Primitive) -> Tensor { + fn any(tensor: Self::Primitive) -> Tensor> { Tensor::new(R::float_any(tensor)) } - fn any_dim(tensor: Primitive, dim: usize) -> Tensor { + fn any_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { Tensor::new(R::float_any_dim(tensor, dim)) } - fn all(tensor: Primitive) -> Tensor { + fn all(tensor: Self::Primitive) -> Tensor> { Tensor::new(R::float_all(tensor)) } - fn all_dim(tensor: Primitive, dim: usize) -> Tensor { + fn all_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { Tensor::new(R::float_all_dim(tensor, dim)) } - - fn permute( - tensor: Primitive, - axes: [usize; D], - ) -> Primitive { - R::float_permute(tensor.tensor(), &axes) - } - - fn expand( - tensor: Primitive, - shape: Shape, - ) -> Primitive { - R::float_expand(tensor, shape) - } - - fn flip( - tensor: Primitive, - axes: &[usize], - ) -> Primitive { - R::float_flip(tensor, axes) - } - - // fn into_data_async( - // tensor: Self::Primitive, - // ) -> impl Future + Send { - // todo!() - // } - - fn into_data_async( - tensor: Primitive, - ) -> impl Future + Send { - R::float_into_data(tensor) - } } diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 2251cf5104..9f084ca261 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,57 +1,45 @@ -use crate::{ - backend::Backend, BasicOps, Device, Shape, Sparse, SparseRepr, TensorData, TensorPrimitive, -}; -use core::{future::Future, ops::Range}; - use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; +use crate::{backend::Backend, Device, Float, Shape, SparseRepr, TensorData, TensorKind}; +use core::{future::Future, ops::Range}; -type SparseFloatPrimitive -where - R: SparseRepr, -= R::FloatPrimitive; - -type DenseFloatPrimitive -where - B: Backend, - R: SparseRepr, -= B::FloatTensorPrimitive; +pub trait SparseTensorOps, B: Backend>: SparseFloatOps {} pub trait SparseFloatOps, B: Backend> { fn float_to_sparse( dense: B::FloatTensorPrimitive, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_empty( shape: Shape, - device: &Device, - ) -> SparseFloatPrimitive; + device: &Device, + ) -> R::FloatTensorPrimitive; fn float_to_dense( - sparse: SparseFloatPrimitive, - ) -> B::FloatTensorPrimitive; + sparse: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_spmm( - lhs: SparseFloatPrimitive, + lhs: R::FloatTensorPrimitive, rhs: B::FloatTensorPrimitive, ) -> B::FloatTensorPrimitive; fn float_sddmm( lhs: B::FloatTensorPrimitive, rhs: B::FloatTensorPrimitive, - sparse: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + sparse: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_coalesce_sum( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_remove_zeros( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; - fn float_nonzero(tensor: SparseFloatPrimitive) -> usize; + fn float_nonzero(tensor: R::FloatTensorPrimitive) -> usize; - fn float_density(sparse: SparseFloatPrimitive) -> f32; + fn float_density(sparse: R::FloatTensorPrimitive) -> f32; /// Gets the element at the given indices. /// @@ -64,9 +52,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The elements at the given indices. fn float_slice( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, indices: [Range; D2], - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; /// Gets the device of the tensor. /// @@ -77,7 +65,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The device of the tensor. - fn float_device(tensor: &SparseFloatPrimitive) -> Device; + fn float_device(tensor: &R::FloatTensorPrimitive) -> Device; /// Moves the tensor to the given device. /// @@ -90,9 +78,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The tensor on the given device. fn float_to_device( - tensor: SparseFloatPrimitive, - device: &Device, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + device: &Device, + ) -> R::FloatTensorPrimitive; /// Gets the shape of the tensor. /// @@ -103,7 +91,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The shape of the tensor. - fn float_shape(tensor: &SparseFloatPrimitive) -> Shape; + fn float_shape(tensor: &R::FloatTensorPrimitive) -> Shape; /// Converts the tensor to a data structure. /// @@ -115,7 +103,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The data structure with the tensor's data. fn float_into_data( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, ) -> impl Future + Send; /// Creates a tensor from the data structure. @@ -130,79 +118,79 @@ pub trait SparseFloatOps, B: Backend> { /// The tensor with the data. fn float_from_data( data: TensorData, - device: &Device, - ) -> SparseFloatPrimitive; + device: &Device, + ) -> R::FloatTensorPrimitive; fn float_reshape( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, shape: Shape, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_transpose( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_swap_dims( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim1: usize, dim2: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_permute( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, axes: &[usize], - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_flip( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, axes: &[usize], - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_slice_assign( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, ranges: [Range; D2], - value: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + value: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; - fn float_repeat( - tensor: SparseFloatPrimitive, + fn float_repeat_dim( + tensor: R::FloatTensorPrimitive, dim: usize, times: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_equal( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; fn float_not_equal( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; - fn float_any(tensor: SparseFloatPrimitive) -> BoolTensor; + fn float_any(tensor: R::FloatTensorPrimitive) -> R::BoolTensorPrimitive<1>; fn float_any_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> BoolTensor; + ) -> R::BoolTensorPrimitive; - fn float_all(tensor: SparseFloatPrimitive) -> BoolTensor; + fn float_all(tensor: R::FloatTensorPrimitive) -> R::BoolTensorPrimitive<1>; fn float_all_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> BoolTensor; + ) -> R::BoolTensorPrimitive; fn float_expand( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, shape: Shape, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; /// Adds two sparse tensors together. /// @@ -215,9 +203,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the two tensors together. fn float_add( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; /// Adds a sparse and dense tensor together. /// @@ -230,7 +218,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the two tensors together. fn float_add_dense( - lhs: SparseFloatPrimitive, + lhs: R::FloatTensorPrimitive, rhs: FloatTensor, ) -> FloatTensor; @@ -245,9 +233,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the scalar to the tensor. fn float_add_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; /// Subtracts two tensors. /// @@ -260,9 +248,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the two tensors. fn float_sub( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; /// Subtracts a dense from a sparse tensor. /// @@ -275,7 +263,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the two tensors. fn float_sub_dense( - lhs: SparseFloatPrimitive, + lhs: R::FloatTensorPrimitive, rhs: FloatTensor, ) -> FloatTensor; @@ -290,9 +278,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the scalar from the tensor. fn float_sub_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; /// Multiplies two sparse tensors together. /// @@ -305,9 +293,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the two tensors together. fn float_mul( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; /// Multiplies a sparse and dense tensor together. /// @@ -320,7 +308,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the two tensors together. fn float_mul_dense( - lhs: SparseFloatPrimitive, + lhs: R::FloatTensorPrimitive, rhs: FloatTensor, ) -> FloatTensor; @@ -335,9 +323,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the scalar with the tensor. fn float_mul_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; /// Divides two sparse tensors. /// @@ -350,9 +338,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the two tensors. fn float_div( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; /// Divides a sparse and dense tensor. /// @@ -365,7 +353,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the two tensors. fn float_div_dense( - lhs: SparseFloatPrimitive, + lhs: R::FloatTensorPrimitive, rhs: FloatTensor, ) -> FloatTensor; @@ -380,189 +368,171 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the tensor by the scalar. fn float_div_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; - fn float_max( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_max(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; fn float_max_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; - fn float_min( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_min(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; fn float_min_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_greater( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; fn float_greater_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; fn float_greater_equal( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; fn float_greater_equal_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; fn float_lower( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; fn float_lower_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; fn float_lower_equal( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::BoolTensorPrimitive; fn float_lower_equal_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; - fn float_abs( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; - fn float_sign( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_abs(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; + fn float_sign(tensor: R::FloatTensorPrimitive) + -> R::FloatTensorPrimitive; fn float_powf( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_powi( - lhs: SparseFloatPrimitive, - rhs: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_powf_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; fn float_powi_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; fn float_clamp( - tensor: SparseFloatPrimitive, - min: FloatElem, - max: FloatElem, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + min: FloatElem, + max: FloatElem, + ) -> R::FloatTensorPrimitive; fn float_clamp_min( - tensor: SparseFloatPrimitive, - min: FloatElem, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + min: FloatElem, + ) -> R::FloatTensorPrimitive; fn float_clamp_max( - tensor: SparseFloatPrimitive, - max: FloatElem, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + max: FloatElem, + ) -> R::FloatTensorPrimitive; fn float_select( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - indices: IntTensor, - ) -> SparseFloatPrimitive; + indices: IntTensor, + ) -> R::FloatTensorPrimitive; fn float_select_assign( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - indices: IntTensor, - values: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + indices: IntTensor, + values: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; fn float_gather( dim: usize, - tensor: SparseFloatPrimitive, - indices: IntTensor, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + indices: IntTensor, + ) -> R::FloatTensorPrimitive; fn float_scatter( dim: usize, - tensor: SparseFloatPrimitive, - indices: IntTensor, - values: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + tensor: R::FloatTensorPrimitive, + indices: IntTensor, + values: R::FloatTensorPrimitive, + ) -> R::FloatTensorPrimitive; - fn float_sum( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_sum(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; fn float_sum_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; - fn float_prod( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_prod(tensor: R::FloatTensorPrimitive) + -> R::FloatTensorPrimitive; fn float_prod_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; - fn float_mean( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_mean(tensor: R::FloatTensorPrimitive) + -> R::FloatTensorPrimitive; fn float_mean_dim( - tensor: SparseFloatPrimitive, + tensor: R::FloatTensorPrimitive, dim: usize, - ) -> SparseFloatPrimitive; + ) -> R::FloatTensorPrimitive; fn float_equal_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; fn float_not_equal_elem( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> BoolTensor; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::BoolTensorPrimitive; fn float_remainder_scalar( - lhs: SparseFloatPrimitive, - rhs: FloatElem, - ) -> SparseFloatPrimitive; + lhs: R::FloatTensorPrimitive, + rhs: FloatElem, + ) -> R::FloatTensorPrimitive; - fn float_neg( - tensor: SparseFloatPrimitive, - ) -> SparseFloatPrimitive; + fn float_neg(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; } pub trait SparseIntOps, B: Backend> {} -pub trait SparseBoolOps, B: Backend> {} - -pub trait SparseTensorOps, B: Backend>: - SparseFloatOps + SparseIntOps + SparseBoolOps -{ -} +// pub trait SparseBoolOps, B: Backend> {} From b98ecfca2423ff90256366e14c8459969c5bef11 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 14 Aug 2024 11:51:00 +0000 Subject: [PATCH 23/38] Changing up primitives for blanket impl --- crates/burn-tensor/src/tensor/api/base.rs | 120 +++++--- crates/burn-tensor/src/tensor/api/chunk.rs | 9 +- crates/burn-tensor/src/tensor/api/kind.rs | 84 ++++-- crates/burn-tensor/src/tensor/api/repr.rs | 24 +- crates/burn-tensor/src/tensor/api/sparse.rs | 272 ++++++++++++++++++ .../src/tensor/ops/sparse_tensor.rs | 200 ++++++++++++- 6 files changed, 624 insertions(+), 85 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 67a4be82c8..e7851f8517 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,7 +18,7 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{DType, Dense, Element, SparseRepr, TensorPrimitive, TensorRepr}; +use crate::{ChangeRepr, DType, Dense, Element, SparseRepr, TensorPrimitive, TensorRepr}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] @@ -42,10 +42,27 @@ where } } -impl Tensor +impl Tensor where B: Backend, - K: BasicOps, + R: TensorRepr, + K: TensorKind, +{ + fn change_repr>(self) -> Tensor + where + K: TensorKind, + R: ChangeRepr, + { + R::change_repr(self) + } +} + +impl Tensor +where + B: Backend, + K: BasicOps, + R: TensorRepr, + Bool: TensorKind, { /// Converts the tensor into a primitive tensor. pub fn into_primitive(self) -> K::Primitive { @@ -107,7 +124,7 @@ where /// println!("{:?}", reshaped_tensor.shape()); /// } /// ``` - pub fn reshape>(self, shape: S) -> Tensor { + pub fn reshape>(self, shape: S) -> Tensor { // Convert reshape args to shape let shape = shape.into_shape(&self); Tensor::new(K::reshape::(self.primitive, shape)) @@ -122,7 +139,7 @@ where /// # Returns /// /// The transposed tensor. - pub fn transpose(self) -> Tensor { + pub fn transpose(self) -> Tensor { Tensor::new(K::transpose(self.primitive)) } @@ -137,7 +154,7 @@ where /// # Returns /// /// The tensor with the dimensions swapped. - pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { + pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) } @@ -153,7 +170,7 @@ where /// # Returns /// /// The tensor with the dimensions permuted. - pub fn permute(self, axes: [isize; D]) -> Tensor { + pub fn permute(self, axes: [isize; D]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; D] = [0; D]; for (i, &x) in axes.iter().enumerate() { @@ -193,7 +210,7 @@ where /// The tensor with the dimensions moved. // This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op // for it - pub fn movedim(self, src: S1, dst: S2) -> Tensor { + pub fn movedim(self, src: S1, dst: S2) -> Tensor { let source_dims = src.into_dim_vec::(); let destination_dims = dst.into_dim_vec::(); @@ -234,7 +251,7 @@ where /// # Returns /// /// The tensor with the axes flipped. - pub fn flip(self, axes: [isize; N]) -> Tensor { + pub fn flip(self, axes: [isize; N]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; N] = [0; N]; for (i, &x) in axes.iter().enumerate() { @@ -288,7 +305,7 @@ where /// } /// /// ``` - pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { + pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { check!(TensorCheck::flatten::(start_dim, end_dim)); let current_dims = self.shape().dims; @@ -339,7 +356,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze(self, dim: usize) -> Tensor { + pub fn squeeze(self, dim: usize) -> Tensor { check!(TensorCheck::squeeze::(dim, &self.shape().dims)); let current_dims = self.shape().dims; @@ -388,7 +405,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { + pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { let current_dims = self.shape().dims; let mut dim_indices: Vec; @@ -458,7 +475,7 @@ where /// // Shape { dims: [1, 1, 3, 3] } /// } /// ``` - pub fn unsqueeze(self) -> Tensor { + pub fn unsqueeze(self) -> Tensor { check!(TensorCheck::unsqueeze::()); let mut dims = [1; D2]; @@ -487,7 +504,7 @@ where /// // Shape { dims: [3, 1, 3] } /// } /// ``` - pub fn unsqueeze_dim(self, dim: usize) -> Tensor { + pub fn unsqueeze_dim(self, dim: usize) -> Tensor { check!(TensorCheck::unsqueeze_dim::<{ D }>(dim)); let mut dims = [1; D2]; @@ -524,7 +541,7 @@ where /// // Shape { dims: [1, 3, 4, 5, 1, 1] } /// } /// ``` - pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { + pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { let mut new_dims = [1; D2]; let old_dims = self.shape().dims; //for checking if the dimension is in the acceptable range @@ -636,7 +653,7 @@ where /// This function uses the `RangesArg` trait for flexible range specification. The trait /// handles the conversion of various range formats and applies clamping and negative /// index handling internally. - pub fn slice>(self, ranges: R) -> Self { + pub fn slice>(self, ranges: RA) -> Self { let ranges = ranges.into_ranges(self.shape()); check!(TensorCheck::slice(&self.shape(), &ranges)); @@ -745,7 +762,7 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn equal(self, other: Self) -> Tensor { + pub fn equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); K::equal(self.primitive, other.primitive) } @@ -755,7 +772,7 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn not_equal(self, other: Self) -> Tensor { + pub fn not_equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); K::not_equal(self.primitive, other.primitive) } @@ -780,7 +797,10 @@ where /// /// If all tensors don't have the same shape. /// Given dimension is not with range of 0..D2 - pub fn stack(tensors: Vec>, dim: usize) -> Tensor { + pub fn stack( + tensors: Vec>, + dim: usize, + ) -> Tensor { check!(TensorCheck::stack(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); Tensor::::cat(tensors, dim) @@ -795,7 +815,7 @@ where /// # Returns /// /// A tensor iterator. - pub fn iter_dim(self, dim: usize) -> DimIter { + pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); DimIter::new(self, dim) } @@ -846,7 +866,7 @@ where /// /// A boolean tensor `Tensor` containing a single element, True if any element in the input tensor /// evaluates to True, False otherwise. - pub fn any(self) -> Tensor { + pub fn any(self) -> Tensor { K::any(self.primitive) } @@ -862,7 +882,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. - pub fn any_dim(self, dim: usize) -> Tensor { + pub fn any_dim(self, dim: usize) -> Tensor { K::any_dim(self.primitive, dim) } @@ -876,7 +896,7 @@ where /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. - pub fn all(self) -> Tensor { + pub fn all(self) -> Tensor { K::all(self.primitive) } @@ -892,7 +912,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. - pub fn all_dim(self, dim: usize) -> Tensor { + pub fn all_dim(self, dim: usize) -> Tensor { K::all_dim(self.primitive, dim) } @@ -936,25 +956,27 @@ where /// # Returns /// /// A new tensor with the given shape. - pub fn expand>(self, shape: S) -> Tensor { + pub fn expand>(self, shape: S) -> Tensor { let shape = shape.into_shape(&self.shape()); check!(TensorCheck::expand("expand", &self.shape(), &shape,)); - Tensor::::new(K::expand(self.primitive, shape)) + Tensor::::new(K::expand(self.primitive, shape)) } } /// Iterator given by (Tensor::iter_dim). -pub struct DimIter +pub struct DimIter where B: Backend, - K: BasicOps, + K: BasicOps, + R: TensorRepr, + Bool: TensorKind, { start: usize, end: usize, dim: usize, ranges: [Range; D], - tensor: Tensor, + tensor: Tensor, } impl> Iterator for DimIter { @@ -1279,10 +1301,7 @@ impl core::ops::BitXor for Tensor { /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait BasicOps = Dense>: TensorKind -where - Bool: TensorKind, -{ +pub trait BasicOps = Dense>: TensorKind { /// The type of the tensor elements. type Elem: Element; @@ -2264,27 +2283,35 @@ impl RangesArg for [(i64, i64); D2] { /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. - fn into_shape>( + fn into_shape, R: TensorRepr>( self, - tensor: &Tensor, - ) -> Shape; + tensor: &Tensor, + ) -> Shape + where + Bool: TensorKind; } impl ReshapeArgs for Shape { - fn into_shape>( + fn into_shape, R: TensorRepr>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + Bool: TensorKind, + { check!(TensorCheck::reshape_args_usize(&tensor.shape(), &self)); self } } impl ReshapeArgs for [usize; D2] { - fn into_shape>( + fn into_shape, R: TensorRepr>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + Bool: TensorKind, + { let shape = Shape::from(self); check!(TensorCheck::reshape_args_usize(&tensor.shape(), &shape)); @@ -2294,10 +2321,13 @@ impl ReshapeArgs for [usize; D2] { } impl ReshapeArgs for [i32; D2] { - fn into_shape>( + fn into_shape, R: TensorRepr>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + Bool: TensorKind, + { // Validate the reshape arguments check!(TensorCheck::reshape_args_i32(&self)); diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index e247485c4a..d86f9b5214 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,5 +1,5 @@ use super::narrow::narrow; -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{backend::Backend, BasicOps, Bool, TensorKind, TensorRepr}; use alloc::vec::Vec; /// Split the tensor along the given dimension into chunks. @@ -20,11 +20,14 @@ use alloc::vec::Vec; /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -pub fn chunk + BasicOps>( +pub fn chunk + BasicOps, R: TensorRepr>( tensor: K::Primitive, chunks: usize, dim: usize, -) -> Vec> { +) -> Vec> +where + Bool: TensorKind, +{ let size = K::shape(&tensor).dims[dim]; if size < chunks { return (0..size) diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index c7a8a3c042..e47f7ee667 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -47,47 +47,71 @@ pub trait TensorKind = Dense>: Clone + core::fmt::D } } -impl TensorKind for Float { - type Primitive = TensorPrimitive; - fn name() -> &'static str { - "Float" - } -} - -impl> TensorKind> for Float { - type Primitive = R::FloatTensorPrimitive; - - fn name() -> &'static str { - >::name() - } -} +// impl TensorKind for Float { +// type Primitive = TensorPrimitive; +// fn name() -> &'static str { +// "Float" +// } +// } + +// impl> TensorKind> for Float { +// type Primitive = R::FloatTensorPrimitive; + +// fn name() -> &'static str { +// >::name() +// } +// } + +// impl TensorKind for Int { +// type Primitive = B::IntTensorPrimitive; +// fn name() -> &'static str { +// "Int" +// } +// } + +// impl> TensorKind> for Int { +// type Primitive = R::IntTensorPrimitive; + +// fn name() -> &'static str { +// >::name() +// } +// } + +// impl TensorKind for Bool { +// type Primitive = B::BoolTensorPrimitive; +// fn name() -> &'static str { +// "Bool" +// } +// } + +// impl> TensorKind> for Bool { +// type Primitive = R::BoolTensorPrimitive; + +// fn name() -> &'static str { +// >::name() +// } +// } + +impl> TensorKind for Bool { + type Primitive = R::Primitive; -impl TensorKind for Int { - type Primitive = B::IntTensorPrimitive; fn name() -> &'static str { - "Int" + "Bool" } } -impl> TensorKind> for Int { - type Primitive = R::IntTensorPrimitive; +impl> TensorKind for Float { + type Primitive = R::Primitive; fn name() -> &'static str { - >::name() - } -} - -impl TensorKind for Bool { - type Primitive = B::BoolTensorPrimitive; - fn name() -> &'static str { - "Bool" + "Float" } } -impl> TensorKind> for Bool { - type Primitive = R::BoolTensorPrimitive; +impl> TensorKind for Int { + type Primitive = R::Primitive; fn name() -> &'static str { - >::name() + "Int" } } diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 8d7b4831ce..804c5a76a3 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -1,14 +1,26 @@ -use crate::{backend::Backend, ops::SparseTensorOps, TensorKind}; +use crate::{backend::Backend, ops::SparseTensorOps, Bool, Float, Int, Tensor, TensorKind}; use core::marker::PhantomData; pub trait TensorRepr: Clone + core::fmt::Debug { + type Primitive, const D: usize>; + fn name() -> &'static str; } +pub trait ChangeRepr>: TensorRepr { + fn change_repr, K2: TensorKind>( + lhs: Tensor, + ) -> Tensor; +} + pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps { - type FloatTensorPrimitive: Clone + core::fmt::Debug + Send; - type IntTensorPrimitive: Clone + core::fmt::Debug + Send; - type BoolTensorPrimitive: Clone + core::fmt::Debug + Send; + type Primitive, const D: usize>: Clone + core::fmt::Debug + Send; + // type FloatTensorPrimitive: Clone + core::fmt::Debug + Send = + // Self::Primitive; + // type IntTensorPrimitive: Clone + core::fmt::Debug + Send = + // Self::Primitive; + // type BoolTensorPrimitive: Clone + core::fmt::Debug + Send = + Self::Primitive; fn name() -> &'static str; } @@ -19,12 +31,16 @@ pub struct Dense; pub struct Sparse, B: Backend>(PhantomData<(R, B)>); impl TensorRepr for Dense { + type Primitive, const D: usize> = K::Primitive; + fn name() -> &'static str { "Dense" } } impl, B: Backend> TensorRepr for Sparse { + type Primitive, const D: usize> = R::Primitive; + fn name() -> &'static str { R::name() } diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 5856c56af8..96dd69d8c6 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -141,3 +141,275 @@ impl> BasicOps> for Float { Tensor::new(R::float_all_dim(tensor, dim)) } } + +impl> BasicOps> for Bool { + type Elem = bool; + + fn empty( + shape: Shape, + device: &::Device, + ) -> Self::Primitive { + R::bool_empty(shape, device) + } + + fn shape(tensor: &Self::Primitive) -> Shape { + R::bool_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + R::bool_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + R::bool_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + R::bool_swap_dims(tensor, dim1, dim2) + } + + fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { + R::bool_permute(tensor, &axes) + } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + R::bool_flip(tensor, axes) + } + + fn slice( + tensor: Self::Primitive, + range: [Range; D2], + ) -> Self::Primitive { + R::bool_slice(tensor, range) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + R::bool_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + R::bool_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + R::bool_to_device(tensor, device) + } + + fn into_data_async( + tensor: Self::Primitive, + ) -> impl Future + Send { + R::bool_into_data(tensor) + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> Self::Primitive { + R::bool_from_data(data, device) + } + + fn repeat_dim( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + R::bool_repeat_dim(tensor, dim, times) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + R::bool_cat(vectors, dim) + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { + Tensor::new(R::bool_equal(lhs, rhs)) + } + + fn not_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { + Tensor::new(R::bool_not_equal(lhs, rhs)) + } + + fn any(tensor: Self::Primitive) -> Tensor> { + Tensor::new(R::bool_any(tensor)) + } + + fn any_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { + Tensor::new(R::bool_any_dim(tensor, dim)) + } + + fn all(tensor: Self::Primitive) -> Tensor> { + Tensor::new(R::bool_all(tensor)) + } + + fn all_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { + Tensor::new(R::bool_all_dim(tensor, dim)) + } + + fn expand( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + R::bool_expand(tensor, shape) + } +} + +impl> BasicOps> for Int { + type Elem = i32; + + fn empty( + shape: Shape, + device: &::Device, + ) -> Self::Primitive { + R::int_empty(shape, device) + } + + fn shape(tensor: &Self::Primitive) -> Shape { + R::int_shape(tensor) + } + + fn reshape( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + R::int_reshape(tensor, shape) + } + + fn transpose(tensor: Self::Primitive) -> Self::Primitive { + R::int_transpose(tensor) + } + + fn swap_dims( + tensor: Self::Primitive, + dim1: usize, + dim2: usize, + ) -> Self::Primitive { + R::int_swap_dims(tensor, dim1, dim2) + } + + fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { + R::int_permute(tensor, &axes) + } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + R::int_flip(tensor, axes) + } + + fn slice( + tensor: Self::Primitive, + range: [Range; D2], + ) -> Self::Primitive { + R::int_slice(tensor, range) + } + + fn slice_assign( + tensor: Self::Primitive, + ranges: [Range; D2], + value: Self::Primitive, + ) -> Self::Primitive { + R::int_slice_assign(tensor, ranges, value) + } + + fn device(tensor: &Self::Primitive) -> ::Device { + R::int_device(tensor) + } + + fn to_device( + tensor: Self::Primitive, + device: &::Device, + ) -> Self::Primitive { + R::int_to_device(tensor, device) + } + + fn into_data_async( + tensor: Self::Primitive, + ) -> impl Future + Send { + R::int_into_data(tensor) + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> Self::Primitive { + R::int_from_data(data, device) + } + + fn repeat_dim( + tensor: Self::Primitive, + dim: usize, + times: usize, + ) -> Self::Primitive { + R::int_repeat_dim(tensor, dim, times) + } + + fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { + R::int_cat(vectors, dim) + } + + fn equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { + Tensor::new(R::int_equal(lhs, rhs)) + } + + fn not_equal( + lhs: Self::Primitive, + rhs: Self::Primitive, + ) -> Tensor> { + Tensor::new(R::int_not_equal(lhs, rhs)) + } + + fn any(tensor: Self::Primitive) -> Tensor> { + Tensor::new(R::int_any(tensor)) + } + + fn any_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { + Tensor::new(R::int_any_dim(tensor, dim)) + } + + fn all(tensor: Self::Primitive) -> Tensor> { + Tensor::new(R::int_all(tensor)) + } + + fn all_dim( + tensor: Self::Primitive, + dim: usize, + ) -> Tensor> { + Tensor::new(R::int_all_dim(tensor, dim)) + } + + fn expand( + tensor: Self::Primitive, + shape: Shape, + ) -> Self::Primitive { + R::int_expand(tensor, shape) + } +} diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 9f084ca261..e566b5a141 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -2,7 +2,10 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; use crate::{backend::Backend, Device, Float, Shape, SparseRepr, TensorData, TensorKind}; use core::{future::Future, ops::Range}; -pub trait SparseTensorOps, B: Backend>: SparseFloatOps {} +pub trait SparseTensorOps, B: Backend>: + SparseFloatOps + SparseBoolOps + SparseIntOps +{ +} pub trait SparseFloatOps, B: Backend> { fn float_to_sparse( @@ -533,6 +536,197 @@ pub trait SparseFloatOps, B: Backend> { fn float_neg(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; } -pub trait SparseIntOps, B: Backend> {} +pub trait SparseBoolOps, B: Backend> { + fn bool_empty(shape: Shape, device: &Device) + -> R::BoolTensorPrimitive; + + fn bool_shape(tensor: &R::BoolTensorPrimitive) -> Shape; + + fn bool_reshape( + tensor: R::BoolTensorPrimitive, + shape: Shape, + ) -> R::BoolTensorPrimitive; + + fn bool_transpose( + tensor: R::BoolTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn bool_swap_dims( + tensor: R::BoolTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> R::BoolTensorPrimitive; + + fn bool_permute( + tensor: R::BoolTensorPrimitive, + axes: &[usize], + ) -> R::BoolTensorPrimitive; + + fn bool_flip( + tensor: R::BoolTensorPrimitive, + axes: &[usize], + ) -> R::BoolTensorPrimitive; + + fn bool_slice( + tensor: R::BoolTensorPrimitive, + indices: [Range; D2], + ) -> R::BoolTensorPrimitive; + + fn bool_slice_assign( + tensor: R::BoolTensorPrimitive, + ranges: [Range; D2], + value: R::BoolTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn bool_device(tensor: &R::BoolTensorPrimitive) -> Device; + + fn bool_to_device( + tensor: R::BoolTensorPrimitive, + device: &Device, + ) -> R::BoolTensorPrimitive; + + fn bool_into_data( + tensor: R::BoolTensorPrimitive, + ) -> impl Future + Send; + + fn bool_from_data( + data: TensorData, + device: &Device, + ) -> R::BoolTensorPrimitive; + + fn bool_repeat_dim( + tensor: R::BoolTensorPrimitive, + dim: usize, + times: usize, + ) -> R::BoolTensorPrimitive; + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> R::BoolTensorPrimitive; + + fn bool_equal( + lhs: R::BoolTensorPrimitive, + rhs: R::BoolTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn bool_not_equal( + lhs: R::BoolTensorPrimitive, + rhs: R::BoolTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn bool_any(tensor: R::BoolTensorPrimitive) -> R::BoolTensorPrimitive<1>; -// pub trait SparseBoolOps, B: Backend> {} + fn bool_any_dim( + tensor: R::BoolTensorPrimitive, + dim: usize, + ) -> R::BoolTensorPrimitive; + + fn bool_all(tensor: R::BoolTensorPrimitive) -> R::BoolTensorPrimitive<1>; + + fn bool_all_dim( + tensor: R::BoolTensorPrimitive, + dim: usize, + ) -> R::BoolTensorPrimitive; + + fn bool_expand( + tensor: R::BoolTensorPrimitive, + shape: Shape, + ) -> R::BoolTensorPrimitive; +} + +pub trait SparseIntOps, B: Backend> { + fn int_empty(shape: Shape, device: &Device) -> R::IntTensorPrimitive; + + fn int_shape(tensor: &R::IntTensorPrimitive) -> Shape; + + fn int_reshape( + tensor: R::IntTensorPrimitive, + shape: Shape, + ) -> R::IntTensorPrimitive; + + fn int_transpose(tensor: R::IntTensorPrimitive) -> R::IntTensorPrimitive; + + fn int_swap_dims( + tensor: R::IntTensorPrimitive, + dim1: usize, + dim2: usize, + ) -> R::IntTensorPrimitive; + + fn int_permute( + tensor: R::IntTensorPrimitive, + axes: &[usize], + ) -> R::IntTensorPrimitive; + + fn int_flip( + tensor: R::IntTensorPrimitive, + axes: &[usize], + ) -> R::IntTensorPrimitive; + + fn int_slice( + tensor: R::IntTensorPrimitive, + indices: [Range; D2], + ) -> R::IntTensorPrimitive; + + fn int_slice_assign( + tensor: R::IntTensorPrimitive, + ranges: [Range; D2], + value: R::IntTensorPrimitive, + ) -> R::IntTensorPrimitive; + + fn int_device(tensor: &R::IntTensorPrimitive) -> Device; + + fn int_to_device( + tensor: R::IntTensorPrimitive, + device: &Device, + ) -> R::IntTensorPrimitive; + + fn int_into_data( + tensor: R::IntTensorPrimitive, + ) -> impl Future + Send; + + fn int_from_data( + data: TensorData, + device: &Device, + ) -> R::IntTensorPrimitive; + + fn int_repeat_dim( + tensor: R::IntTensorPrimitive, + dim: usize, + times: usize, + ) -> R::IntTensorPrimitive; + + fn int_cat( + tensors: Vec>, + dim: usize, + ) -> R::IntTensorPrimitive; + + fn int_equal( + lhs: R::IntTensorPrimitive, + rhs: R::IntTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn int_not_equal( + lhs: R::IntTensorPrimitive, + rhs: R::IntTensorPrimitive, + ) -> R::BoolTensorPrimitive; + + fn int_any(tensor: R::IntTensorPrimitive) -> R::BoolTensorPrimitive<1>; + + fn int_any_dim( + tensor: R::IntTensorPrimitive, + dim: usize, + ) -> R::BoolTensorPrimitive; + + fn int_all(tensor: R::IntTensorPrimitive) -> R::BoolTensorPrimitive<1>; + + fn int_all_dim( + tensor: R::IntTensorPrimitive, + dim: usize, + ) -> R::BoolTensorPrimitive; + + fn int_expand( + tensor: R::IntTensorPrimitive, + shape: Shape, + ) -> R::IntTensorPrimitive; +} From 6ef4b4dae84d6ed05384f8af8e3327e1caeb3850 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 14 Aug 2024 12:29:45 +0000 Subject: [PATCH 24/38] Seemingly everything but tensorchecks working --- crates/burn-tensor/src/tensor/api/base.rs | 18 +- crates/burn-tensor/src/tensor/api/chunk.rs | 15 +- crates/burn-tensor/src/tensor/api/kind.rs | 7 +- crates/burn-tensor/src/tensor/api/narrow.rs | 9 +- crates/burn-tensor/src/tensor/api/repr.rs | 6 +- crates/burn-tensor/src/tensor/api/sparse.rs | 2 +- .../burn-tensor/src/tensor/ops/bool_tensor.rs | 8 +- .../burn-tensor/src/tensor/ops/int_tensor.rs | 6 +- .../src/tensor/ops/sparse_tensor.rs | 469 +++++++++--------- crates/burn-tensor/src/tensor/ops/tensor.rs | 6 +- 10 files changed, 269 insertions(+), 277 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index e7851f8517..2c90cc7f69 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -803,7 +803,7 @@ where ) -> Tensor { check!(TensorCheck::stack(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); - Tensor::::cat(tensors, dim) + Tensor::::cat(tensors, dim) } /// Iterate over slices of tensors alongside a given dimension. @@ -817,7 +817,7 @@ where /// A tensor iterator. pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); - DimIter::new(self, dim) + DimIter::::new(self, dim) } /// Returns a new tensor with the given dimension narrowed to the given range. @@ -833,7 +833,7 @@ where pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); check!(TensorCheck::narrow(&self, dim, start, length)); - Self::new(narrow::(self.primitive, dim, start, length)) + Self::new(narrow::(self.primitive, dim, start, length)) } /// Attempts to split the tensor along the given dimension into chunks. @@ -850,7 +850,7 @@ where /// A vector of tensors. pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); - chunk::(self.primitive, chunks, dim) + chunk::(self.primitive, chunks, dim) .into_iter() .map(|v| Self::new(v)) .collect() @@ -979,8 +979,10 @@ where tensor: Tensor, } -impl> Iterator for DimIter { - type Item = Tensor; +impl, R: TensorRepr> Iterator + for DimIter +{ + type Item = Tensor; fn next(&mut self) -> Option { if self.start >= self.end { @@ -1013,8 +1015,8 @@ impl> DoubleEndedIterator for DimIter } } -impl> DimIter { - fn new(tensor: Tensor, dim: usize) -> Self { +impl, R: TensorRepr> DimIter { + fn new(tensor: Tensor, dim: usize) -> Self { let dims = tensor.dims(); let ranges = dims .iter() diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index d86f9b5214..bcd8df4d1b 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,5 +1,5 @@ use super::narrow::narrow; -use crate::{backend::Backend, BasicOps, Bool, TensorKind, TensorRepr}; +use crate::{backend::Backend, BasicOps, Bool, Dense, TensorKind, TensorRepr}; use alloc::vec::Vec; /// Split the tensor along the given dimension into chunks. @@ -24,14 +24,11 @@ pub fn chunk + BasicOps, R tensor: K::Primitive, chunks: usize, dim: usize, -) -> Vec> -where - Bool: TensorKind, -{ +) -> Vec> { let size = K::shape(&tensor).dims[dim]; if size < chunks { return (0..size) - .map(|i| narrow::(tensor.clone(), dim, i, 1)) + .map(|i| narrow::(tensor.clone(), dim, i, 1)) .collect(); } @@ -40,7 +37,7 @@ where if size % chunks == 0 { let chunk_size = size / chunks; for _ in 0..chunks { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -51,7 +48,7 @@ where } else { let chunk_size = (size / chunks) + 1; // assumes not divisible for _ in 0..chunks - 1 { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -60,7 +57,7 @@ where sum_chunk_size += chunk_size; } let remainder = size % chunk_size; - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index e47f7ee667..3ab4c3f695 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -38,6 +38,9 @@ pub trait TensorKind = Dense>: Clone + core::fmt::D /// The primitive type of the tensor. type Primitive: Clone + core::fmt::Debug + Send; + /// The primitive type of the tensor when dense. + type DensePrimitive: Clone + core::fmt::Debug + Send; + /// The name of the tensor kind. fn name() -> &'static str; @@ -56,7 +59,6 @@ pub trait TensorKind = Dense>: Clone + core::fmt::D // impl> TensorKind> for Float { // type Primitive = R::FloatTensorPrimitive; - // fn name() -> &'static str { // >::name() // } @@ -93,6 +95,7 @@ pub trait TensorKind = Dense>: Clone + core::fmt::D // } impl> TensorKind for Bool { + type DensePrimitive = B::BoolTensorPrimitive; type Primitive = R::Primitive; fn name() -> &'static str { @@ -101,6 +104,7 @@ impl> TensorKind for Bool { } impl> TensorKind for Float { + type DensePrimitive = TensorPrimitive; type Primitive = R::Primitive; fn name() -> &'static str { @@ -109,6 +113,7 @@ impl> TensorKind for Float { } impl> TensorKind for Int { + type DensePrimitive = B::IntTensorPrimitive; type Primitive = R::Primitive; fn name() -> &'static str { diff --git a/crates/burn-tensor/src/tensor/api/narrow.rs b/crates/burn-tensor/src/tensor/api/narrow.rs index 88290bd388..e01d2c953c 100644 --- a/crates/burn-tensor/src/tensor/api/narrow.rs +++ b/crates/burn-tensor/src/tensor/api/narrow.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{backend::Backend, BasicOps, Dense, TensorKind, TensorRepr}; use alloc::vec::Vec; /// Returns a new tensor with the given dimension narrowed to the given range. @@ -17,7 +17,12 @@ use alloc::vec::Vec; /// # Returns /// /// A new tensor with the given dimension narrowed to the given range. -pub fn narrow + BasicOps>( +pub fn narrow< + B: Backend, + const D: usize, + K: TensorKind + BasicOps, + R: TensorRepr, +>( tensor: K::Primitive, dim: usize, start: usize, diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 804c5a76a3..9371b19502 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -2,7 +2,7 @@ use crate::{backend::Backend, ops::SparseTensorOps, Bool, Float, Int, Tensor, Te use core::marker::PhantomData; pub trait TensorRepr: Clone + core::fmt::Debug { - type Primitive, const D: usize>; + type Primitive, const D: usize>: Clone + core::fmt::Debug + Send; fn name() -> &'static str; } @@ -20,7 +20,7 @@ pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps: Clone + core::fmt::Debug + Send = // Self::Primitive; // type BoolTensorPrimitive: Clone + core::fmt::Debug + Send = - Self::Primitive; + // Self::Primitive; fn name() -> &'static str; } @@ -31,7 +31,7 @@ pub struct Dense; pub struct Sparse, B: Backend>(PhantomData<(R, B)>); impl TensorRepr for Dense { - type Primitive, const D: usize> = K::Primitive; + type Primitive, const D: usize> = K::DensePrimitive; fn name() -> &'static str { "Dense" diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 96dd69d8c6..20754848b9 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -12,7 +12,7 @@ impl> BasicOps> for Float { fn empty( shape: Shape, device: &::Device, - ) -> R::FloatTensorPrimitive { + ) -> R::Primitive { R::float_empty(shape, device) } diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index b1718ad5c0..81560696f0 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -3,8 +3,8 @@ use super::{ FloatTensor, IntTensor, }; use crate::{ - argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, - TensorData, + argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, Dense, ElementConversion, + Tensor, TensorData, }; use alloc::vec::Vec; use core::{future::Future, ops::Range}; @@ -306,7 +306,7 @@ pub trait BoolTensorOps { start: usize, length: usize, ) -> BoolTensor { - narrow::(tensor, dim, start, length) + narrow::(tensor, dim, start, length) } /// Split the tensor along the given dimension into chunks. @@ -325,7 +325,7 @@ pub trait BoolTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(tensor, chunks, dim) + chunk::(tensor, chunks, dim) } /// Tests if any element in the boolean `tensor` evaluates to True. diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 28a2eb803b..2aa4d21dad 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -3,7 +3,7 @@ use super::repeat_dim::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use crate::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData}; -use crate::{cartesian_grid, Tensor}; +use crate::{cartesian_grid, Dense, Tensor}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use core::future::Future; @@ -1011,7 +1011,7 @@ pub trait IntTensorOps { start: usize, length: usize, ) -> IntTensor { - narrow::(tensor, dim, start, length) + narrow::(tensor, dim, start, length) } /// Generates a cartesian grid for the given tensor shape on the specified device. @@ -1060,7 +1060,7 @@ pub trait IntTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(tensor, chunks, dim) + chunk::(tensor, chunks, dim) } /// Creates a new int tensor with random values. diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index e566b5a141..3b69c4d97b 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,5 +1,7 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; -use crate::{backend::Backend, Device, Float, Shape, SparseRepr, TensorData, TensorKind}; +use crate::{ + backend::Backend, Bool, Device, Float, Int, Shape, SparseRepr, TensorData, TensorKind, +}; use core::{future::Future, ops::Range}; pub trait SparseTensorOps, B: Backend>: @@ -8,41 +10,33 @@ pub trait SparseTensorOps, B: Backend>: } pub trait SparseFloatOps, B: Backend> { - fn float_to_sparse( - dense: B::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + fn float_to_sparse(dense: B::FloatTensorPrimitive) + -> R::Primitive; - fn float_empty( - shape: Shape, - device: &Device, - ) -> R::FloatTensorPrimitive; + fn float_empty(shape: Shape, device: &Device) -> R::Primitive; - fn float_to_dense( - sparse: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + fn float_to_dense(sparse: R::Primitive) -> R::Primitive; fn float_spmm( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: B::FloatTensorPrimitive, ) -> B::FloatTensorPrimitive; fn float_sddmm( lhs: B::FloatTensorPrimitive, rhs: B::FloatTensorPrimitive, - sparse: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + sparse: R::Primitive, + ) -> R::Primitive; - fn float_coalesce_sum( - tensor: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + fn float_coalesce_sum(tensor: R::Primitive) + -> R::Primitive; - fn float_remove_zeros( - tensor: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + fn float_remove_zeros(tensor: R::Primitive) + -> R::Primitive; - fn float_nonzero(tensor: R::FloatTensorPrimitive) -> usize; + fn float_nonzero(tensor: R::Primitive) -> usize; - fn float_density(sparse: R::FloatTensorPrimitive) -> f32; + fn float_density(sparse: R::Primitive) -> f32; /// Gets the element at the given indices. /// @@ -55,9 +49,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The elements at the given indices. fn float_slice( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, indices: [Range; D2], - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Gets the device of the tensor. /// @@ -68,7 +62,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The device of the tensor. - fn float_device(tensor: &R::FloatTensorPrimitive) -> Device; + fn float_device(tensor: &R::Primitive) -> Device; /// Moves the tensor to the given device. /// @@ -81,9 +75,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The tensor on the given device. fn float_to_device( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, device: &Device, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Gets the shape of the tensor. /// @@ -94,7 +88,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The shape of the tensor. - fn float_shape(tensor: &R::FloatTensorPrimitive) -> Shape; + fn float_shape(tensor: &R::Primitive) -> Shape; /// Converts the tensor to a data structure. /// @@ -106,7 +100,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The data structure with the tensor's data. fn float_into_data( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, ) -> impl Future + Send; /// Creates a tensor from the data structure. @@ -122,78 +116,76 @@ pub trait SparseFloatOps, B: Backend> { fn float_from_data( data: TensorData, device: &Device, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_reshape( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_transpose( - tensor: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + fn float_transpose(tensor: R::Primitive) -> R::Primitive; fn float_swap_dims( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim1: usize, dim2: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_permute( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_flip( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_slice_assign( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, ranges: [Range; D2], - value: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + value: R::Primitive, + ) -> R::Primitive; fn float_repeat_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, times: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_equal( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_not_equal( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; - fn float_any(tensor: R::FloatTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn float_any(tensor: R::Primitive) -> R::Primitive; fn float_any_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; - fn float_all(tensor: R::FloatTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn float_all(tensor: R::Primitive) -> R::Primitive; fn float_all_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_expand( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Adds two sparse tensors together. /// @@ -206,9 +198,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the two tensors together. fn float_add( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; /// Adds a sparse and dense tensor together. /// @@ -221,7 +213,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the two tensors together. fn float_add_dense( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatTensor, ) -> FloatTensor; @@ -236,9 +228,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the scalar to the tensor. fn float_add_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Subtracts two tensors. /// @@ -251,9 +243,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the two tensors. fn float_sub( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; /// Subtracts a dense from a sparse tensor. /// @@ -266,7 +258,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the two tensors. fn float_sub_dense( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatTensor, ) -> FloatTensor; @@ -281,9 +273,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the scalar from the tensor. fn float_sub_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Multiplies two sparse tensors together. /// @@ -296,9 +288,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the two tensors together. fn float_mul( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; /// Multiplies a sparse and dense tensor together. /// @@ -311,7 +303,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the two tensors together. fn float_mul_dense( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatTensor, ) -> FloatTensor; @@ -326,9 +318,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the scalar with the tensor. fn float_mul_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; /// Divides two sparse tensors. /// @@ -341,9 +333,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the two tensors. fn float_div( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; /// Divides a sparse and dense tensor. /// @@ -356,7 +348,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the two tensors. fn float_div_dense( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatTensor, ) -> FloatTensor; @@ -371,362 +363,353 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the tensor by the scalar. fn float_div_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_max(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; + fn float_max(tensor: R::Primitive) -> R::Primitive; fn float_max_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_min(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; + fn float_min(tensor: R::Primitive) -> R::Primitive; fn float_min_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_greater( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_greater_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_greater_equal( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_greater_equal_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_lower( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_lower_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_lower_equal( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_lower_equal_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; - fn float_abs(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; - fn float_sign(tensor: R::FloatTensorPrimitive) - -> R::FloatTensorPrimitive; + fn float_abs(tensor: R::Primitive) -> R::Primitive; + fn float_sign(tensor: R::Primitive) -> R::Primitive; fn float_powf( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_powi( - lhs: R::FloatTensorPrimitive, - rhs: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn float_powf_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_powi_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_clamp( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, min: FloatElem, max: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_clamp_min( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, min: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_clamp_max( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, max: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_select( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, indices: IntTensor, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_select_assign( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, indices: IntTensor, - values: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + values: R::Primitive, + ) -> R::Primitive; fn float_gather( dim: usize, - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, indices: IntTensor, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_scatter( dim: usize, - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, indices: IntTensor, - values: R::FloatTensorPrimitive, - ) -> R::FloatTensorPrimitive; + values: R::Primitive, + ) -> R::Primitive; - fn float_sum(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; + fn float_sum(tensor: R::Primitive) -> R::Primitive; fn float_sum_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_prod(tensor: R::FloatTensorPrimitive) - -> R::FloatTensorPrimitive; + fn float_prod(tensor: R::Primitive) -> R::Primitive; fn float_prod_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_mean(tensor: R::FloatTensorPrimitive) - -> R::FloatTensorPrimitive; + fn float_mean(tensor: R::Primitive) -> R::Primitive; fn float_mean_dim( - tensor: R::FloatTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; fn float_equal_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_not_equal_elem( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn float_remainder_scalar( - lhs: R::FloatTensorPrimitive, + lhs: R::Primitive, rhs: FloatElem, - ) -> R::FloatTensorPrimitive; + ) -> R::Primitive; - fn float_neg(tensor: R::FloatTensorPrimitive) -> R::FloatTensorPrimitive; + fn float_neg(tensor: R::Primitive) -> R::Primitive; } pub trait SparseBoolOps, B: Backend> { - fn bool_empty(shape: Shape, device: &Device) - -> R::BoolTensorPrimitive; + fn bool_empty(shape: Shape, device: &Device) -> R::Primitive; - fn bool_shape(tensor: &R::BoolTensorPrimitive) -> Shape; + fn bool_shape(tensor: &R::Primitive) -> Shape; fn bool_reshape( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; - fn bool_transpose( - tensor: R::BoolTensorPrimitive, - ) -> R::BoolTensorPrimitive; + fn bool_transpose(tensor: R::Primitive) -> R::Primitive; fn bool_swap_dims( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, dim1: usize, dim2: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_permute( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_flip( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_slice( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, indices: [Range; D2], - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_slice_assign( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, ranges: [Range; D2], - value: R::BoolTensorPrimitive, - ) -> R::BoolTensorPrimitive; + value: R::Primitive, + ) -> R::Primitive; - fn bool_device(tensor: &R::BoolTensorPrimitive) -> Device; + fn bool_device(tensor: &R::Primitive) -> Device; fn bool_to_device( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, device: &Device, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_into_data( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, ) -> impl Future + Send; fn bool_from_data( data: TensorData, device: &Device, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_repeat_dim( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, dim: usize, times: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_equal( - lhs: R::BoolTensorPrimitive, - rhs: R::BoolTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn bool_not_equal( - lhs: R::BoolTensorPrimitive, - rhs: R::BoolTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; - fn bool_any(tensor: R::BoolTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn bool_any(tensor: R::Primitive) -> R::Primitive; fn bool_any_dim( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; - fn bool_all(tensor: R::BoolTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn bool_all(tensor: R::Primitive) -> R::Primitive; fn bool_all_dim( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn bool_expand( - tensor: R::BoolTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; } pub trait SparseIntOps, B: Backend> { - fn int_empty(shape: Shape, device: &Device) -> R::IntTensorPrimitive; + fn int_empty(shape: Shape, device: &Device) -> R::Primitive; - fn int_shape(tensor: &R::IntTensorPrimitive) -> Shape; + fn int_shape(tensor: &R::Primitive) -> Shape; fn int_reshape( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; - fn int_transpose(tensor: R::IntTensorPrimitive) -> R::IntTensorPrimitive; + fn int_transpose(tensor: R::Primitive) -> R::Primitive; fn int_swap_dims( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, dim1: usize, dim2: usize, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_permute( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_flip( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, axes: &[usize], - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_slice( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, indices: [Range; D2], - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_slice_assign( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, ranges: [Range; D2], - value: R::IntTensorPrimitive, - ) -> R::IntTensorPrimitive; + value: R::Primitive, + ) -> R::Primitive; - fn int_device(tensor: &R::IntTensorPrimitive) -> Device; + fn int_device(tensor: &R::Primitive) -> Device; fn int_to_device( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, device: &Device, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_into_data( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, ) -> impl Future + Send; - fn int_from_data( - data: TensorData, - device: &Device, - ) -> R::IntTensorPrimitive; + fn int_from_data(data: TensorData, device: &Device) -> R::Primitive; fn int_repeat_dim( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, dim: usize, times: usize, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; fn int_equal( - lhs: R::IntTensorPrimitive, - rhs: R::IntTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; fn int_not_equal( - lhs: R::IntTensorPrimitive, - rhs: R::IntTensorPrimitive, - ) -> R::BoolTensorPrimitive; + lhs: R::Primitive, + rhs: R::Primitive, + ) -> R::Primitive; - fn int_any(tensor: R::IntTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn int_any(tensor: R::Primitive) -> R::Primitive; fn int_any_dim( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; - fn int_all(tensor: R::IntTensorPrimitive) -> R::BoolTensorPrimitive<1>; + fn int_all(tensor: R::Primitive) -> R::Primitive; fn int_all_dim( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, dim: usize, - ) -> R::BoolTensorPrimitive; + ) -> R::Primitive; fn int_expand( - tensor: R::IntTensorPrimitive, + tensor: R::Primitive, shape: Shape, - ) -> R::IntTensorPrimitive; + ) -> R::Primitive; } diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 0edd5c8ee4..e906760743 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -5,7 +5,7 @@ use crate::backend::BackendBridge; use crate::tensor::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData}; use crate::{tensor::api::chunk, tensor::api::narrow}; -use crate::{Tensor, TensorPrimitive}; +use crate::{Dense, Tensor, TensorPrimitive}; use alloc::vec::Vec; use core::future::Future; use core::ops::Range; @@ -1251,7 +1251,7 @@ pub trait FloatTensorOps { start: usize, length: usize, ) -> FloatTensor { - narrow::(TensorPrimitive::Float(tensor), dim, start, length).tensor() + narrow::(TensorPrimitive::Float(tensor), dim, start, length).tensor() } /// Split the tensor along the given dimension into chunks. @@ -1270,7 +1270,7 @@ pub trait FloatTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(TensorPrimitive::Float(tensor), chunks, dim) + chunk::(TensorPrimitive::Float(tensor), chunks, dim) .into_iter() .map(|t| t.tensor()) .collect() From 5d53f7a5e5271750b9504757b857913b456fea56 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Mon, 19 Aug 2024 10:32:27 +0000 Subject: [PATCH 25/38] Reintroduced the COO decorator to burn-sparse --- crates/burn-sparse/src/decorator/coo.rs | 33 ++ crates/burn-sparse/src/decorator/coo_bool.rs | 161 ++++++ crates/burn-sparse/src/decorator/coo_float.rs | 536 ++++++++++++++++++ crates/burn-sparse/src/decorator/coo_int.rs | 157 +++++ crates/burn-sparse/src/decorator/mod.rs | 4 + 5 files changed, 891 insertions(+) create mode 100644 crates/burn-sparse/src/decorator/coo.rs create mode 100644 crates/burn-sparse/src/decorator/coo_bool.rs create mode 100644 crates/burn-sparse/src/decorator/coo_float.rs create mode 100644 crates/burn-sparse/src/decorator/coo_int.rs diff --git a/crates/burn-sparse/src/decorator/coo.rs b/crates/burn-sparse/src/decorator/coo.rs new file mode 100644 index 0000000000..04efda71ab --- /dev/null +++ b/crates/burn-sparse/src/decorator/coo.rs @@ -0,0 +1,33 @@ +use burn_tensor::backend::Backend; +use burn_tensor::ops::SparseBoolOps; +use burn_tensor::ops::SparseTensorOps; +use burn_tensor::Dense; +use burn_tensor::Device; +use burn_tensor::Float; +use burn_tensor::Int; +use burn_tensor::Shape; +use burn_tensor::Sparse; +use burn_tensor::SparseRepr; +use burn_tensor::Tensor; +use burn_tensor::TensorKind; + +#[derive(Clone, Debug)] +pub struct COO; + +#[derive(Clone, Debug)] +pub struct SparseCOOTensor, const D: usize> { + pub coordinates: Option>, + pub values: Option>, + pub shape: Shape, + pub device: Device, +} + +impl SparseRepr for COO { + type Primitive, const D: usize> = SparseCOOTensor; + + fn name() -> &'static str { + "SparseCOO" + } +} + +impl SparseTensorOps for COO {} diff --git a/crates/burn-sparse/src/decorator/coo_bool.rs b/crates/burn-sparse/src/decorator/coo_bool.rs new file mode 100644 index 0000000000..f8b8bd3b76 --- /dev/null +++ b/crates/burn-sparse/src/decorator/coo_bool.rs @@ -0,0 +1,161 @@ +use burn_tensor::{ + backend::Backend, + ops::{SparseBoolOps, SparseTensorOps}, + SparseRepr, +}; + +use super::coo::COO; +type R = COO; + +impl SparseBoolOps for R { + fn bool_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn bool_shape( + tensor: &>::Primitive, + ) -> burn_tensor::Shape { + todo!() + } + + fn bool_reshape( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } + + fn bool_transpose( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_swap_dims( + tensor: >::Primitive, + dim1: usize, + dim2: usize, + ) -> >::Primitive { + todo!() + } + + fn bool_permute( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn bool_flip( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn bool_slice( + tensor: >::Primitive, + indices: [std::ops::Range; D2], + ) -> >::Primitive { + todo!() + } + + fn bool_slice_assign( + tensor: >::Primitive, + ranges: [std::ops::Range; D2], + value: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_device( + tensor: &>::Primitive, + ) -> burn_tensor::Device { + todo!() + } + + fn bool_to_device( + tensor: >::Primitive, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn bool_into_data( + tensor: >::Primitive, + ) -> impl std::future::Future + Send { + async { todo!() } + } + + fn bool_from_data( + data: burn_tensor::TensorData, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn bool_repeat_dim( + tensor: >::Primitive, + dim: usize, + times: usize, + ) -> >::Primitive { + todo!() + } + + fn bool_cat( + tensors: Vec<>::Primitive>, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn bool_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_not_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_any( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_any_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn bool_all( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn bool_all_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn bool_expand( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } +} diff --git a/crates/burn-sparse/src/decorator/coo_float.rs b/crates/burn-sparse/src/decorator/coo_float.rs new file mode 100644 index 0000000000..6d75eb4ed1 --- /dev/null +++ b/crates/burn-sparse/src/decorator/coo_float.rs @@ -0,0 +1,536 @@ +use burn_tensor::{backend::Backend, ops::SparseFloatOps, SparseRepr}; + +use super::coo::COO; +type R = COO; + +impl SparseFloatOps for R { + fn float_to_sparse( + dense: ::FloatTensorPrimitive, + ) -> >::Primitive { + todo!() + } + + fn float_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn float_to_dense( + sparse: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_spmm( + lhs: >::Primitive, + rhs: ::FloatTensorPrimitive, + ) -> ::FloatTensorPrimitive { + todo!() + } + + fn float_sddmm( + lhs: ::FloatTensorPrimitive, + rhs: ::FloatTensorPrimitive, + sparse: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_coalesce_sum( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_remove_zeros( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_nonzero( + tensor: >::Primitive, + ) -> usize { + todo!() + } + + fn float_density( + sparse: >::Primitive, + ) -> f32 { + todo!() + } + + fn float_slice( + tensor: >::Primitive, + indices: [std::ops::Range; D2], + ) -> >::Primitive { + todo!() + } + + fn float_device( + tensor: &>::Primitive, + ) -> burn_tensor::Device { + todo!() + } + + fn float_to_device( + tensor: >::Primitive, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn float_shape( + tensor: &>::Primitive, + ) -> burn_tensor::Shape { + todo!() + } + + fn float_into_data( + tensor: >::Primitive, + ) -> impl std::future::Future + Send { + async { todo!() } + } + + fn float_from_data( + data: burn_tensor::TensorData, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn float_reshape( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } + + fn float_transpose( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_swap_dims( + tensor: >::Primitive, + dim1: usize, + dim2: usize, + ) -> >::Primitive { + todo!() + } + + fn float_permute( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn float_flip( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn float_slice_assign( + tensor: >::Primitive, + ranges: [std::ops::Range; D2], + value: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_repeat_dim( + tensor: >::Primitive, + dim: usize, + times: usize, + ) -> >::Primitive { + todo!() + } + + fn float_cat( + tensors: Vec<>::Primitive>, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_not_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_any( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_any_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_all( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_all_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_expand( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } + + fn float_add( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_add_dense( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + todo!() + } + + fn float_add_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_sub( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_sub_dense( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + todo!() + } + + fn float_sub_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_mul( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_mul_dense( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + todo!() + } + + fn float_mul_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_div( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_div_dense( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + todo!() + } + + fn float_div_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_max( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_max_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_min( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_min_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_greater( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_greater_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_greater_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_greater_equal_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_lower( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_lower_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_lower_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_lower_equal_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_abs( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_sign( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_powf( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_powi( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_powf_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_powi_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_clamp( + tensor: >::Primitive, + min: burn_tensor::ops::FloatElem, + max: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_clamp_min( + tensor: >::Primitive, + min: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_clamp_max( + tensor: >::Primitive, + max: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_select( + tensor: >::Primitive, + dim: usize, + indices: burn_tensor::ops::IntTensor, + ) -> >::Primitive { + todo!() + } + + fn float_select_assign( + tensor: >::Primitive, + dim: usize, + indices: burn_tensor::ops::IntTensor, + values: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_gather( + dim: usize, + tensor: >::Primitive, + indices: burn_tensor::ops::IntTensor, + ) -> >::Primitive { + todo!() + } + + fn float_scatter( + dim: usize, + tensor: >::Primitive, + indices: burn_tensor::ops::IntTensor, + values: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_sum( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_sum_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_prod( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_prod_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_mean( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn float_mean_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn float_equal_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_not_equal_elem( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_remainder_scalar( + lhs: >::Primitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::Primitive { + todo!() + } + + fn float_neg( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } +} diff --git a/crates/burn-sparse/src/decorator/coo_int.rs b/crates/burn-sparse/src/decorator/coo_int.rs new file mode 100644 index 0000000000..746ca7c53e --- /dev/null +++ b/crates/burn-sparse/src/decorator/coo_int.rs @@ -0,0 +1,157 @@ +use burn_tensor::{backend::Backend, ops::SparseIntOps, SparseRepr}; + +use super::coo::COO; +type R = COO; + +impl SparseIntOps for R { + fn int_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn int_shape( + tensor: &>::Primitive, + ) -> burn_tensor::Shape { + todo!() + } + + fn int_reshape( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } + + fn int_transpose( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_swap_dims( + tensor: >::Primitive, + dim1: usize, + dim2: usize, + ) -> >::Primitive { + todo!() + } + + fn int_permute( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn int_flip( + tensor: >::Primitive, + axes: &[usize], + ) -> >::Primitive { + todo!() + } + + fn int_slice( + tensor: >::Primitive, + indices: [std::ops::Range; D2], + ) -> >::Primitive { + todo!() + } + + fn int_slice_assign( + tensor: >::Primitive, + ranges: [std::ops::Range; D2], + value: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_device( + tensor: &>::Primitive, + ) -> burn_tensor::Device { + todo!() + } + + fn int_to_device( + tensor: >::Primitive, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn int_into_data( + tensor: >::Primitive, + ) -> impl std::future::Future + Send { + async { todo!() } + } + + fn int_from_data( + data: burn_tensor::TensorData, + device: &burn_tensor::Device, + ) -> >::Primitive { + todo!() + } + + fn int_repeat_dim( + tensor: >::Primitive, + dim: usize, + times: usize, + ) -> >::Primitive { + todo!() + } + + fn int_cat( + tensors: Vec<>::Primitive>, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn int_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_not_equal( + lhs: >::Primitive, + rhs: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_any( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_any_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn int_all( + tensor: >::Primitive, + ) -> >::Primitive { + todo!() + } + + fn int_all_dim( + tensor: >::Primitive, + dim: usize, + ) -> >::Primitive { + todo!() + } + + fn int_expand( + tensor: >::Primitive, + shape: burn_tensor::Shape, + ) -> >::Primitive { + todo!() + } +} diff --git a/crates/burn-sparse/src/decorator/mod.rs b/crates/burn-sparse/src/decorator/mod.rs index dc49680a84..2343c84944 100644 --- a/crates/burn-sparse/src/decorator/mod.rs +++ b/crates/burn-sparse/src/decorator/mod.rs @@ -4,6 +4,10 @@ // mod representation; // mod sparse_coo; // mod sparse_csr; +mod coo; +mod coo_bool; +mod coo_float; +mod coo_int; // pub use backend::*; // pub use precision_bridge::*; From 6ebe15b7fdd990989b9a9fb1c22027b670064b66 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Mon, 19 Aug 2024 13:13:34 +0000 Subject: [PATCH 26/38] transferred accross most basic ops for float tensor --- crates/burn-sparse/src/decorator/coo.rs | 43 ++ crates/burn-sparse/src/decorator/coo_float.rs | 641 +++++++++++++----- crates/burn-tensor/src/tensor/api/base.rs | 10 +- crates/burn-tensor/src/tensor/api/repr.rs | 7 +- crates/burn-tensor/src/tensor/api/sparse.rs | 6 +- .../src/tensor/ops/sparse_tensor.rs | 151 +---- 6 files changed, 529 insertions(+), 329 deletions(-) diff --git a/crates/burn-sparse/src/decorator/coo.rs b/crates/burn-sparse/src/decorator/coo.rs index 04efda71ab..ecbd482e71 100644 --- a/crates/burn-sparse/src/decorator/coo.rs +++ b/crates/burn-sparse/src/decorator/coo.rs @@ -9,6 +9,7 @@ use burn_tensor::Shape; use burn_tensor::Sparse; use burn_tensor::SparseRepr; use burn_tensor::Tensor; +use burn_tensor::TensorData; use burn_tensor::TensorKind; #[derive(Clone, Debug)] @@ -31,3 +32,45 @@ impl SparseRepr for COO { } impl SparseTensorOps for COO {} + +pub(crate) fn flatten_coordinates( + coordinates: Tensor, + shape: Shape, + device: &Device, +) -> Tensor { + let mut strides_data = [[1]; D]; + for i in (0..D).rev() { + if D - 1 - i == S { + strides_data[i] = [1]; + } else if D - 1 - i < S { + strides_data[i] = [0]; + } else { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } + } + let strides_data: TensorData = TensorData::from(strides_data); + let strides: Tensor = Tensor::from_data(strides_data, device); + let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); + + flat_coordinates.unsqueeze_dim(0) +} + +pub(crate) fn unflatten_coordinates( + flat_coordinates: Tensor, + new_shape: Shape, +) -> Tensor { + let flat_coordinates = flat_coordinates.squeeze::<1>(0); + let mut remaining_flat_coordinates = flat_coordinates.clone(); + let mut new_coordinates = Vec::with_capacity(D); + + for &dim_size in new_shape.dims.iter().rev() { + let size = dim_size as i64; + let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); + new_coordinates.push(new_coord.clone()); + remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); + } + + new_coordinates.reverse(); + + Tensor::stack(new_coordinates, 0) +} diff --git a/crates/burn-sparse/src/decorator/coo_float.rs b/crates/burn-sparse/src/decorator/coo_float.rs index 6d75eb4ed1..d364381ee1 100644 --- a/crates/burn-sparse/src/decorator/coo_float.rs +++ b/crates/burn-sparse/src/decorator/coo_float.rs @@ -1,13 +1,59 @@ -use burn_tensor::{backend::Backend, ops::SparseFloatOps, SparseRepr}; +use burn_tensor::cast::ToElement; +use burn_tensor::ops::FloatElem; +use burn_tensor::{backend::Backend, ops::SparseFloatOps, SparseRepr, Tensor}; +use burn_tensor::{Bool, ElementConversion, Float, Shape, TensorData, TensorPrimitive}; +use burn_tensor::{Device, Int}; -use super::coo::COO; +use super::coo::{flatten_coordinates, unflatten_coordinates, SparseCOOTensor, COO}; type R = COO; impl SparseFloatOps for R { fn float_to_sparse( dense: ::FloatTensorPrimitive, ) -> >::Primitive { - todo!() + let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); + + let shape = dense.shape(); + let device = dense.device(); + + let significant = dense.clone().not_equal_elem(0.0); + if !significant.clone().any().into_scalar() { + return Self::float_empty(dense.shape(), &device); + }; + + let coordinates = significant + .clone() + .nonzero() + .into_iter() + .map(|tensor| { + let length = tensor.shape().dims[0]; + let shape = Shape::new([1, length]); + tensor.reshape(shape) + }) + .collect(); + + let coordinates = Tensor::cat(coordinates, 0); + + let dense = dense.flatten(0, D - 1); + + let dims = significant.dims(); + let values = dense.gather( + 0, + significant + .flatten::<1>(0, dims.len() - 1) + .nonzero() + .remove(0), + ); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } } fn float_empty( @@ -19,15 +65,111 @@ impl SparseFloatOps for R { fn float_to_dense( sparse: >::Primitive, - ) -> >::Primitive { - todo!() + ) -> B::FloatTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = sparse; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + return Tensor::::zeros(shape, &device) + .into_primitive() + .tensor(); + }; + + let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); + let flat_coordinates = + flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); + let dense = dense.select_assign(0, flat_coordinates, values); + + dense.reshape(shape).into_primitive().tensor() } fn float_spmm( lhs: >::Primitive, rhs: ::FloatTensorPrimitive, ) -> ::FloatTensorPrimitive { - todo!() + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = lhs; + + let rhs: Tensor = Tensor::from_primitive(TensorPrimitive::Float(rhs)); + let rhs_shape = rhs.shape(); + let mut out_shape = shape.clone(); + out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return Tensor::::zeros(out_shape, &device) + .into_primitive() + .tensor(); + }; + + let nnz = coordinates.shape().dims[1]; + + // Ensure they are of the correct shape to multiply + if shape.dims[D - 1] != rhs_shape.dims[D - 2] { + panic!("Invalid shape for matrix multiplication"); + } + + // Ensure batches are the same + if D > 2 && rhs_shape.dims[0..D - 2] != shape.dims[0..D - 2] { + panic!("Batches must be of the same shape"); + } + + // Compute strides for the dense tensor to match the flattened shape + let mut strides_data = [1; D]; + for i in (0..D - 1).rev() { + strides_data[i] = strides_data[i + 1] * shape.dims[i + 1] as i32; + } + let strides: Tensor = + Tensor::::from_ints(strides_data, &device).unsqueeze_dim(1); + + let column_index = coordinates.clone().slice([D - 1..D, 0..nnz]); + + // the indices into the flat row vector at which the containing matrix starts + let matrix_starts: Tensor = if D > 2 { + coordinates + .clone() + .slice([0..D - 2, 0..nnz]) + .mul(strides.clone().slice([0..D - 2])) + .div_scalar((shape.dims[D - 1]) as i32) + .sum_dim(0) + } else { + Tensor::::zeros(column_index.shape(), &device) + }; + + let row_index = coordinates.slice([D - 2..D - 1, 0..nnz]); + + let gather_index = matrix_starts.clone() + column_index; + let scatter_index = matrix_starts + row_index; + + let gather_index = gather_index + .transpose() + .repeat_dim(1, rhs_shape.dims[D - 1]); + let scatter_index = scatter_index + .transpose() + .repeat_dim(1, rhs_shape.dims[D - 1]); + let values = values.unsqueeze_dim(1).repeat_dim(1, rhs_shape.dims[D - 1]); + + // Flatten the rhs similarly into 2 dimensions + let rhs: Tensor = rhs.reshape([-1, rhs_shape.dims[D - 1] as i32]); + + // Do the matmul using gather/scatter + let output: Tensor = + Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); + let gathered = rhs.gather(0, gather_index); + + let multiplied = gathered.mul(values); + + let scattered = output.scatter(0, scatter_index, multiplied); + + scattered.reshape(out_shape).into_primitive().tensor() } fn float_sddmm( @@ -35,31 +177,194 @@ impl SparseFloatOps for R { rhs: ::FloatTensorPrimitive, sparse: >::Primitive, ) -> >::Primitive { - todo!() + if sparse.coordinates.is_none() || sparse.values.is_none() { + return sparse; + } + + // Flatten the lhs and rhs into a tensor of rows and cols respectively + let lhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(lhs)); + let rhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(rhs)).transpose(); + let lhs_dims = lhs.shape().dims; + let rhs_dims = rhs.shape().dims; + + if lhs_dims[D - 1] != rhs_dims[D - 1] + || lhs_dims[D - 2] != sparse.shape.dims[D - 2] + || rhs_dims[D - 2] != sparse.shape.dims[D - 1] + { + panic!("invalid dimensions for sddmm. lhs and rhs must have compatible shapes for matmul, and sparse must have the correct shape for output of matmul between lhs and rhs."); + } + + let lhs = lhs.reshape([-1, lhs_dims[D - 1] as i32]); + let rhs = rhs.reshape([-1, rhs_dims[D - 1] as i32]); + + // Flatten the sparse tensor into + let device = sparse.device.clone(); + let mut shape = sparse.shape.clone(); + let lhs_coordinates = sparse + .coordinates + .clone() + .expect("Expected non-empty sparse tensor"); + + // swap the last two dims so its column-first + let swizzle = Tensor::::arange(0..D as i64, &device) + .slice_assign( + [D - 2..D], + Tensor::::from_ints([D - 1, D - 2], &device), + ) + .unsqueeze_dim(1) + .repeat_dim(1, lhs_coordinates.shape().dims[1]); + let rhs_coordinates = lhs_coordinates.clone().gather(0, swizzle); + + let row_indices = flatten_coordinates::(lhs_coordinates, shape.clone(), &device); + + shape.dims.swap(D - 1, D - 2); + let col_indices = flatten_coordinates::(rhs_coordinates, shape.clone(), &device); + + let row_indices = row_indices.transpose().repeat_dim(1, lhs_dims[D - 1]); + let col_indices = col_indices.transpose().repeat_dim(1, rhs_dims[D - 1]); + + let lhs = lhs.gather(0, row_indices); + let rhs = rhs.gather(0, col_indices); + + let dotted = lhs.mul(rhs).sum_dim(1).squeeze(1); + + SparseCOOTensor { + coordinates: sparse.coordinates, + values: Some(dotted), + shape: sparse.shape, + device, + } } fn float_coalesce_sum( tensor: >::Primitive, ) -> >::Primitive { - todo!() + if tensor.coordinates.as_ref().map(|c| c.shape().dims[1] <= 1) == Some(true) { + return tensor; + } + let original_shape = tensor.shape.clone(); + + if tensor.coordinates.is_none() && tensor.values.is_none() { + return SparseCOOTensor { + coordinates: None, + values: None, + shape: original_shape, + device: tensor.device, + }; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let nnz = coordinates.shape().dims[1]; + + let coordinates = + flatten_coordinates::(coordinates, original_shape.clone(), &device); + let _flat_shape = Shape::new([original_shape.num_elements()]); + + let (coordinates, indices) = coordinates.sort_with_indices(1); + let values = values.select(0, indices.squeeze(0)); + let range = Tensor::::arange(0..nnz as i64, &device).unsqueeze::<2>(); + + // Get the diff of coordinates, diff[i] = coordinates[i]-coordinates[i-1] + let left_slice = coordinates.clone().slice([0..1, 0..nnz - 1]); + let right_slice = coordinates.clone().slice([0..1, 1..nnz]); + let diff = right_slice - left_slice; + let ones = Tensor::::ones(Shape::new([1, 1]), &device); + let diff = Tensor::cat(vec![ones, diff], 1); + + // TODO this all would be way cleaner with cumsum/max, but that is waiting on a pull request as of writing + // inspiration could be taken from pytorch_scatter for better implementations + let unique_mask = diff.not_equal_elem(0); + let unique_indices = unique_mask.clone().nonzero().remove(1); + let steps = Tensor::cat( + vec![unique_indices.clone(), Tensor::from_data([nnz], &device)], + 0, + ); + let unique = steps.shape().dims[0]; + let steps = steps + .clone() + .slice([1..unique]) + .sub(steps.slice([0..unique - 1])) + .max() + // .sub_scalar(1) + .into_scalar() + .elem::(); + + let mut scatter_indices = range.mul(unique_mask.int()); + + for _ in 0..steps { + scatter_indices = scatter_indices + .clone() + .slice([0..1, 1..nnz]) + .max_pair(scatter_indices.slice([0..1, 0..nnz - 1])); + scatter_indices = Tensor::cat( + vec![Tensor::zeros(Shape::new([1, 1]), &device), scatter_indices], + 1, + ); + } + + // Scatter/Gather everything into place + let zeroed = Tensor::::zeros(Shape::new([nnz]), &device); + let values = zeroed.scatter(0, scatter_indices.squeeze(0), values); + let values = values.gather(0, unique_indices.clone()); + let coordinates = coordinates.gather(1, unique_indices.unsqueeze::<2>()); + let coordinates = unflatten_coordinates(coordinates, original_shape.clone()); + + let coordinates = Some(coordinates); + let values = Some(values); + + // reshape back into the original shape and send it! + SparseCOOTensor { + coordinates, + values, + shape: original_shape, + device, + } } fn float_remove_zeros( tensor: >::Primitive, ) -> >::Primitive { + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let _coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let _values = tensor + .values + .expect("Mismatch between coordinates and values"); + let _device = tensor.device; + let _shape = tensor.shape; + + // let zeros = tensor.values.map(|values| values.equal_elem(0).nonzero()); todo!() } - fn float_nonzero( + fn float_number_nonzero( tensor: >::Primitive, ) -> usize { - todo!() + match tensor.coordinates { + Some(coordinates) => coordinates.shape().dims[1], + None => 0, + } } fn float_density( sparse: >::Primitive, ) -> f32 { - todo!() + match sparse.coordinates { + Some(coordinates) => { + coordinates.shape().dims[1] as f32 / sparse.shape.num_elements() as f32 + } + None => 0.0, + } } fn float_slice( @@ -72,20 +377,25 @@ impl SparseFloatOps for R { fn float_device( tensor: &>::Primitive, ) -> burn_tensor::Device { - todo!() + tensor.device.clone() } fn float_to_device( tensor: >::Primitive, device: &burn_tensor::Device, ) -> >::Primitive { - todo!() + SparseCOOTensor { + coordinates: tensor.coordinates.map(|t| t.to_device(device)), + values: tensor.values.map(|t| t.to_device(device)), + shape: tensor.shape, + device: device.clone(), + } } fn float_shape( tensor: &>::Primitive, ) -> burn_tensor::Shape { - todo!() + tensor.shape.clone() } fn float_into_data( @@ -210,228 +520,197 @@ impl SparseFloatOps for R { lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() - } - - fn float_add_dense( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - todo!() - } - - fn float_add_scalar( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates: lhs_coordinates, + values: lhs_values, + shape: lhs_shape, + device: lhs_device, + } = lhs; + let (Some(lhs_coordinates), Some(lhs_values)) = (lhs_coordinates, lhs_values) else { + return rhs; + }; + + let SparseCOOTensor { + coordinates: rhs_coordinates, + values: rhs_values, + shape: rhs_shape, + device: rhs_device, + } = rhs; + let (Some(rhs_coordinates), Some(rhs_values)) = (rhs_coordinates, rhs_values) else { + return SparseCOOTensor { + coordinates: Some(lhs_coordinates), + values: Some(lhs_values), + shape: lhs_shape, + device: lhs_device, + }; + }; + + assert_eq!(lhs_shape, rhs_shape); + assert_eq!(lhs_device, rhs_device); + + let coordinates = Some(Tensor::cat(vec![lhs_coordinates, rhs_coordinates], 1)); + let values = Some(Tensor::cat(vec![lhs_values, rhs_values], 0)); + let shape = lhs_shape; + let device = lhs_device; + + let result = SparseCOOTensor { + coordinates, + values, + shape, + device, + }; + + Self::float_coalesce_sum(result) } fn float_sub( lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() - } - - fn float_sub_dense( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - todo!() - } - - fn float_sub_scalar( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() + Self::float_add( + lhs, + Self::float_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), + ) } fn float_mul( lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() - } - - fn float_mul_dense( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - todo!() + panic!("float_mul is unsupported until scatter supports multiplication based reduction"); } fn float_mul_scalar( - lhs: >::Primitive, + mut lhs: >::Primitive, rhs: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + lhs.values = lhs.values.map(|values| values.mul_scalar(rhs)); + lhs } fn float_div( lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() - } - - fn float_div_dense( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - todo!() + panic!("float_div is unsupported until scatter supports multiplication based reduction"); } fn float_div_scalar( - lhs: >::Primitive, + mut lhs: >::Primitive, rhs: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); + lhs } fn float_max( tensor: >::Primitive, ) -> >::Primitive { - todo!() + panic!("max is unsupported for COO until scatter supports max reduction"); } fn float_max_dim( tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() + panic!("max_dim is unsupported for COO until scatter supports max reduction"); } fn float_min( tensor: >::Primitive, ) -> >::Primitive { - todo!() + panic!("min is unsupported for COO until scatter supports min reduction"); } fn float_min_dim( tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() - } - - fn float_greater( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { - todo!() - } - - fn float_greater_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() - } - - fn float_greater_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { - todo!() - } - - fn float_greater_equal_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() - } - - fn float_lower( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { - todo!() - } - - fn float_lower_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() - } - - fn float_lower_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { - todo!() - } - - fn float_lower_equal_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() + panic!("min_dim is unsupported for COO until scatter supports min reduction"); } fn float_abs( - tensor: >::Primitive, + mut tensor: >::Primitive, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.abs()); + tensor } fn float_sign( - tensor: >::Primitive, + mut tensor: >::Primitive, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.sign()); + tensor } fn float_powf( lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() + panic!("float_powf is unsupported for COO until scatter supports other reduction methods"); } fn float_powi( lhs: >::Primitive, rhs: >::Primitive, ) -> >::Primitive { - todo!() + panic!("float_powi is unsupported for COO until scatter supports other reduction methods"); } fn float_powf_scalar( - lhs: >::Primitive, + mut lhs: >::Primitive, rhs: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + lhs.values = lhs.values.map(|values| values.powf_scalar(rhs)); + lhs } fn float_powi_scalar( - lhs: >::Primitive, + mut lhs: >::Primitive, rhs: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + lhs.values = lhs.values.map(|values| values.powi_scalar(rhs)); + lhs } fn float_clamp( - tensor: >::Primitive, + mut tensor: >::Primitive, min: burn_tensor::ops::FloatElem, max: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.clamp(min, max)); + if min.to_f64() == 0f64 || max.to_f64() == 0f64 { + // Clamp can zero elements if a boundary is zero + Self::float_remove_zeros(tensor) + } else { + tensor + } } fn float_clamp_min( - tensor: >::Primitive, + mut tensor: >::Primitive, min: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.clamp_min(min)); + if min.to_f64() == 0f64 { + // min can zero elements if boundary is 0 + Self::float_remove_zeros(tensor) + } else { + tensor + } } fn float_clamp_max( - tensor: >::Primitive, + mut tensor: >::Primitive, max: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.clamp_max(max)); + if max.to_f64() == 0f64 { + // max can zero elements if boundary is 0 + Self::float_remove_zeros(tensor) + } else { + tensor + } } fn float_select( @@ -439,7 +718,40 @@ impl SparseFloatOps for R { dim: usize, indices: burn_tensor::ops::IntTensor, ) -> >::Primitive { - todo!() + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let mut shape = tensor.shape; + let indices = Tensor::::new(indices); + + let nnz = coordinates.shape().dims[1]; + let dim_coords = coordinates + .clone() + .slice([dim..dim + 1, 0..nnz]) + .squeeze::<1>(0); + let indices = indices.select(0, dim_coords); + let indices_len = indices.shape().num_elements(); + let coordinates = coordinates.slice_assign( + [dim..dim + 1, 0..nnz], + indices.unsqueeze::<2>().repeat_dim(1, D), + ); + + shape.dims[dim] = indices_len; + + SparseCOOTensor { + coordinates: Some(coordinates), + values: Some(values), + shape, + device, + } } fn float_select_assign( @@ -470,21 +782,18 @@ impl SparseFloatOps for R { fn float_sum( tensor: >::Primitive, - ) -> >::Primitive { - todo!() + ) -> >::Primitive { + tensor + .values + .map(|values| Self::float_to_sparse(values.sum().into_primitive().tensor())) + .unwrap_or(Self::float_empty(Shape::new([1]), &tensor.device)) } fn float_sum_dim( tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() - } - - fn float_prod( - tensor: >::Primitive, - ) -> >::Primitive { - todo!() + panic!("float_sum_dim unsupported for COO"); } fn float_prod_dim( @@ -496,8 +805,12 @@ impl SparseFloatOps for R { fn float_mean( tensor: >::Primitive, - ) -> >::Primitive { - todo!() + ) -> >::Primitive { + let num_elems = tensor.shape.num_elements(); + Self::float_div_scalar( + Self::float_sum(tensor), + ElementConversion::elem(num_elems as f32), + ) } fn float_mean_dim( @@ -507,30 +820,18 @@ impl SparseFloatOps for R { todo!() } - fn float_equal_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() - } - - fn float_not_equal_elem( - lhs: >::Primitive, - rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { - todo!() - } - fn float_remainder_scalar( - lhs: >::Primitive, + mut lhs: >::Primitive, rhs: burn_tensor::ops::FloatElem, ) -> >::Primitive { - todo!() + lhs.values = lhs.values.map(|values| values.remainder_scalar(rhs)); + lhs } fn float_neg( - tensor: >::Primitive, + mut tensor: >::Primitive, ) -> >::Primitive { - todo!() + tensor.values = tensor.values.map(|values| values.neg()); + tensor } } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 2c90cc7f69..7cbe31e5ca 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -763,7 +763,7 @@ where /// /// If the two tensors don't have the same shape. pub fn equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); + // check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); K::equal(self.primitive, other.primitive) } @@ -773,7 +773,7 @@ where /// /// If the two tensors don't have the same shape. pub fn not_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); + // check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); K::not_equal(self.primitive, other.primitive) } @@ -783,7 +783,7 @@ where /// /// If all tensors don't have the same shape. pub fn cat(tensors: Vec, dim: usize) -> Self { - check!(TensorCheck::cat(&tensors, dim)); + // check!(TensorCheck::cat(&tensors, dim)); Self::new(K::cat( tensors.into_iter().map(|vector| vector.primitive).collect(), @@ -801,7 +801,7 @@ where tensors: Vec>, dim: usize, ) -> Tensor { - check!(TensorCheck::stack(&tensors, dim)); + // check!(TensorCheck::stack(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); Tensor::::cat(tensors, dim) } @@ -832,7 +832,7 @@ where /// A new tensor with the given dimension narrowed to the given range. pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); - check!(TensorCheck::narrow(&self, dim, start, length)); + // check!(TensorCheck::narrow(&self, dim, start, length)); Self::new(narrow::(self.primitive, dim, start, length)) } diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 9371b19502..492afd3f91 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -15,12 +15,7 @@ pub trait ChangeRepr>: TensorRepr { pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps { type Primitive, const D: usize>: Clone + core::fmt::Debug + Send; - // type FloatTensorPrimitive: Clone + core::fmt::Debug + Send = - // Self::Primitive; - // type IntTensorPrimitive: Clone + core::fmt::Debug + Send = - // Self::Primitive; - // type BoolTensorPrimitive: Clone + core::fmt::Debug + Send = - // Self::Primitive; + fn name() -> &'static str; } diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 20754848b9..53b12694fc 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -238,14 +238,14 @@ impl> BasicOps> for Bool { lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor> { - Tensor::new(R::bool_equal(lhs, rhs)) + panic!("Non-zero preserving operations are not supported for sparse tensors"); } fn not_equal( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor> { - Tensor::new(R::bool_not_equal(lhs, rhs)) + panic!("Non-zero preserving operations are not supported for sparse tensors"); } fn any(tensor: Self::Primitive) -> Tensor> { @@ -374,6 +374,7 @@ impl> BasicOps> for Int { lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor> { + panic!("Non-zero preserving operations are not supported for sparse tensors"); Tensor::new(R::int_equal(lhs, rhs)) } @@ -381,6 +382,7 @@ impl> BasicOps> for Int { lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor> { + panic!("Non-zero preserving operations are not supported for sparse tensors"); Tensor::new(R::int_not_equal(lhs, rhs)) } diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 3b69c4d97b..26c5845b28 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -15,7 +15,8 @@ pub trait SparseFloatOps, B: Backend> { fn float_empty(shape: Shape, device: &Device) -> R::Primitive; - fn float_to_dense(sparse: R::Primitive) -> R::Primitive; + fn float_to_dense(sparse: R::Primitive) + -> B::FloatTensorPrimitive; fn float_spmm( lhs: R::Primitive, @@ -34,7 +35,7 @@ pub trait SparseFloatOps, B: Backend> { fn float_remove_zeros(tensor: R::Primitive) -> R::Primitive; - fn float_nonzero(tensor: R::Primitive) -> usize; + fn float_number_nonzero(tensor: R::Primitive) -> usize; fn float_density(sparse: R::Primitive) -> f32; @@ -202,36 +203,6 @@ pub trait SparseFloatOps, B: Backend> { rhs: R::Primitive, ) -> R::Primitive; - /// Adds a sparse and dense tensor together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of adding the two tensors together. - fn float_add_dense( - lhs: R::Primitive, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Adds a scalar to a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of adding the scalar to the tensor. - fn float_add_scalar( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - /// Subtracts two tensors. /// /// # Arguments @@ -247,36 +218,6 @@ pub trait SparseFloatOps, B: Backend> { rhs: R::Primitive, ) -> R::Primitive; - /// Subtracts a dense from a sparse tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor (sparse). - /// * `rhs` - The right hand side tensor (dense). - /// - /// # Returns - /// - /// The result of subtracting the two tensors. - fn float_sub_dense( - lhs: R::Primitive, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Subtracts a scalar from a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of subtracting the scalar from the tensor. - fn float_sub_scalar( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - /// Multiplies two sparse tensors together. /// /// # Arguments @@ -292,21 +233,6 @@ pub trait SparseFloatOps, B: Backend> { rhs: R::Primitive, ) -> R::Primitive; - /// Multiplies a sparse and dense tensor together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of multiplying the two tensors together. - fn float_mul_dense( - lhs: R::Primitive, - rhs: FloatTensor, - ) -> FloatTensor; - /// Multiplies a scalar to a tensor. /// /// # Arguments @@ -337,21 +263,6 @@ pub trait SparseFloatOps, B: Backend> { rhs: R::Primitive, ) -> R::Primitive; - /// Divides a sparse and dense tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of dividing the two tensors. - fn float_div_dense( - lhs: R::Primitive, - rhs: FloatTensor, - ) -> FloatTensor; - /// Divides a tensor by a scalar. /// /// # Arguments @@ -381,46 +292,6 @@ pub trait SparseFloatOps, B: Backend> { dim: usize, ) -> R::Primitive; - fn float_greater( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; - - fn float_greater_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - - fn float_greater_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; - - fn float_greater_equal_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - - fn float_lower( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; - - fn float_lower_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - - fn float_lower_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; - - fn float_lower_equal_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - fn float_abs(tensor: R::Primitive) -> R::Primitive; fn float_sign(tensor: R::Primitive) -> R::Primitive; @@ -486,37 +357,25 @@ pub trait SparseFloatOps, B: Backend> { values: R::Primitive, ) -> R::Primitive; - fn float_sum(tensor: R::Primitive) -> R::Primitive; + fn float_sum(tensor: R::Primitive) -> R::Primitive; fn float_sum_dim( tensor: R::Primitive, dim: usize, ) -> R::Primitive; - fn float_prod(tensor: R::Primitive) -> R::Primitive; - fn float_prod_dim( tensor: R::Primitive, dim: usize, ) -> R::Primitive; - fn float_mean(tensor: R::Primitive) -> R::Primitive; + fn float_mean(tensor: R::Primitive) -> R::Primitive; fn float_mean_dim( tensor: R::Primitive, dim: usize, ) -> R::Primitive; - fn float_equal_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - - fn float_not_equal_elem( - lhs: R::Primitive, - rhs: FloatElem, - ) -> R::Primitive; - fn float_remainder_scalar( lhs: R::Primitive, rhs: FloatElem, From 937480ecfe6fcfb72fdf713111e2ab7639bad6aa Mon Sep 17 00:00:00 2001 From: mcarthur Date: Mon, 19 Aug 2024 13:36:32 +0000 Subject: [PATCH 27/38] most functions transferred --- crates/burn-sparse/src/decorator/coo_bool.rs | 6 + crates/burn-sparse/src/decorator/coo_float.rs | 235 ++++++++++++++++-- .../src/tensor/ops/sparse_tensor.rs | 2 + 3 files changed, 220 insertions(+), 23 deletions(-) diff --git a/crates/burn-sparse/src/decorator/coo_bool.rs b/crates/burn-sparse/src/decorator/coo_bool.rs index f8b8bd3b76..93c7561100 100644 --- a/crates/burn-sparse/src/decorator/coo_bool.rs +++ b/crates/burn-sparse/src/decorator/coo_bool.rs @@ -8,6 +8,12 @@ use super::coo::COO; type R = COO; impl SparseBoolOps for R { + fn bool_to_sparse( + dense: ::BoolTensorPrimitive, + ) -> >::Primitive { + todo!() + } + fn bool_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, diff --git a/crates/burn-sparse/src/decorator/coo_float.rs b/crates/burn-sparse/src/decorator/coo_float.rs index d364381ee1..f9c6958c88 100644 --- a/crates/burn-sparse/src/decorator/coo_float.rs +++ b/crates/burn-sparse/src/decorator/coo_float.rs @@ -1,7 +1,7 @@ use burn_tensor::cast::ToElement; -use burn_tensor::ops::FloatElem; +use burn_tensor::ops::{FloatElem, SparseBoolOps}; use burn_tensor::{backend::Backend, ops::SparseFloatOps, SparseRepr, Tensor}; -use burn_tensor::{Bool, ElementConversion, Float, Shape, TensorData, TensorPrimitive}; +use burn_tensor::{Bool, ElementConversion, Float, Shape, Sparse, TensorData, TensorPrimitive}; use burn_tensor::{Device, Int}; use super::coo::{flatten_coordinates, unflatten_coordinates, SparseCOOTensor, COO}; @@ -60,7 +60,12 @@ impl SparseFloatOps for R { shape: burn_tensor::Shape, device: &burn_tensor::Device, ) -> >::Primitive { - todo!() + SparseCOOTensor { + coordinates: None, + values: None, + shape, + device: device.clone(), + } } fn float_to_dense( @@ -334,16 +339,15 @@ impl SparseFloatOps for R { return tensor; } - let _coordinates = tensor + let coordinates = tensor .coordinates .expect("Mismatch between coordinates and values"); - let _values = tensor + let values = tensor .values .expect("Mismatch between coordinates and values"); - let _device = tensor.device; - let _shape = tensor.shape; + let device = tensor.device; + let shape = tensor.shape; - // let zeros = tensor.values.map(|values| values.equal_elem(0).nonzero()); todo!() } @@ -413,15 +417,47 @@ impl SparseFloatOps for R { fn float_reshape( tensor: >::Primitive, - shape: burn_tensor::Shape, + out_shape: burn_tensor::Shape, ) -> >::Primitive { - todo!() + if tensor.coordinates.is_none() && tensor.values.is_none() { + return SparseCOOTensor { + coordinates: None, + values: None, + shape: out_shape, + device: tensor.device, + }; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let shape = tensor.shape; + let device = tensor.device; + + // Flatten the coordinates + let flat_coordinates = flatten_coordinates::(coordinates, shape, &device); + + // Unflatten the coordinates to the new shape + let new_coordinates = unflatten_coordinates(flat_coordinates, out_shape.clone()); + + SparseCOOTensor { + coordinates: Some(new_coordinates), + values: Some(values), + shape: out_shape, + device, + } } fn float_transpose( tensor: >::Primitive, ) -> >::Primitive { - todo!() + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(d - 1, d - 2); + Self::float_permute(tensor, &axes) } fn float_swap_dims( @@ -429,29 +465,117 @@ impl SparseFloatOps for R { dim1: usize, dim2: usize, ) -> >::Primitive { - todo!() + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(dim1, dim2); + Self::float_permute(tensor, &axes) } fn float_permute( tensor: >::Primitive, axes: &[usize], ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates, + values, + mut shape, + device, + } = tensor; + + for (i, &j) in (0..D).zip(axes).filter(|(i, j)| i < j) { + shape.dims.swap(i, j); + } + + let axes = Tensor::from(axes); + let coordinates = coordinates.map(|coordinates| coordinates.select(0, axes)); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } } fn float_flip( tensor: >::Primitive, axes: &[usize], ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = tensor; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + + let nnz = coordinates.shape().dims[1]; + + let mut mask = [0; D]; + for &axis in axes { + mask[axis] = 1; + } + let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, &device) + .unsqueeze_dim(1) + .repeat_dim(1, nnz) + .bool(); + + let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, &device) + .unsqueeze_dim(1) + .repeat_dim(1, nnz) + .sub(coordinates.clone()) + .sub_scalar(1); + + let coordinates = coordinates.mask_where(mask, flipped); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } } fn float_slice_assign( tensor: >::Primitive, ranges: [std::ops::Range; D2], - value: >::Primitive, + mut value: >::Primitive, ) -> >::Primitive { - todo!() + let value_nnz = value + .coordinates + .as_ref() + .map(|coords| coords.shape().dims[1]) + .unwrap_or(0); + + let mut ranges = Vec::from(ranges); + ranges.extend(tensor.shape.dims[ranges.len()..D1].iter().map(|&l| 0..l)); + let ranges: [core::ops::Range; D1] = ranges.try_into().expect("D2 must be <= D1"); + + let shape = tensor.shape.clone(); + let sliced = Self::float_reshape( + Self::float_slice(tensor.clone(), ranges.clone()), + shape.clone(), + ); + let tensor = Self::float_sub(tensor, sliced); + let offset = Tensor::::from_ints(ranges.map(|r| r.start), &tensor.device); + let offset = offset.unsqueeze_dim::<2>(1).repeat_dim(1, value_nnz); + + value.shape = shape; + value.coordinates = value.coordinates.map(|coords| coords + offset); + + Self::float_add(tensor, value) } fn float_repeat_dim( @@ -459,7 +583,53 @@ impl SparseFloatOps for R { dim: usize, times: usize, ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = tensor; + + let mut out_shape = shape.clone(); + out_shape.dims[dim] *= times; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + + let device = coordinates.device(); + let nnz = coordinates.shape().dims[1]; + + let values = values.repeat_dim(0, times); + + let coordinates_mask: Tensor = Tensor::zeros(coordinates.shape(), &device); + let ones: Tensor = Tensor::ones(Shape::new([1, nnz]), &device); + let coordinates_mask = coordinates_mask.slice_assign([dim..dim + 1, 0..nnz], ones); + let coordinates = Tensor::cat( + (0..times) + .map(|n| { + coordinates.clone() + + coordinates_mask.clone() * (n as i32) * (shape.dims[dim] as i32) + }) + .collect::>(), + 1, + ); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape: out_shape, + device, + } } fn float_cat( @@ -486,27 +656,46 @@ impl SparseFloatOps for R { fn float_any( tensor: >::Primitive, ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates, + values: _, + shape: _, + device: _, + } = tensor; + let any = coordinates.is_some(); + let bool = Tensor::::from([any]).into_primitive(); + >::bool_to_sparse(bool) } fn float_any_dim( tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() + panic!("any_dim is unsupported for COO until scatter supports any-based reduction"); } fn float_all( tensor: >::Primitive, ) -> >::Primitive { - todo!() + let SparseCOOTensor { + coordinates, + values: _, + shape, + device: _, + } = tensor; + let all = match coordinates { + Some(coordinates) => shape.num_elements() == coordinates.shape().dims[1], + None => false, + }; + let bool = Tensor::::from([all]).into_primitive(); + >::bool_to_sparse(bool) } fn float_all_dim( tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() + panic!("all_dim is unsupported for COO until scatter supports all-based reduction"); } fn float_expand( @@ -800,7 +989,7 @@ impl SparseFloatOps for R { tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() + panic!("float_prod_dim is not supported for COO until scatter supports product reduction") } fn float_mean( @@ -817,7 +1006,7 @@ impl SparseFloatOps for R { tensor: >::Primitive, dim: usize, ) -> >::Primitive { - todo!() + panic!("float_mean_dim is not supported for COO until scatter supports mean reduction"); } fn float_remainder_scalar( diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 26c5845b28..97436cb6ce 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -385,6 +385,8 @@ pub trait SparseFloatOps, B: Backend> { } pub trait SparseBoolOps, B: Backend> { + fn bool_to_sparse(dense: B::BoolTensorPrimitive) -> R::Primitive; + fn bool_empty(shape: Shape, device: &Device) -> R::Primitive; fn bool_shape(tensor: &R::Primitive) -> Shape; From bac85ac1d691260dff60928c53333552555cc13b Mon Sep 17 00:00:00 2001 From: mcarthur Date: Tue, 20 Aug 2024 05:36:20 +0000 Subject: [PATCH 28/38] Added use to mod --- crates/burn-sparse/src/decorator/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/burn-sparse/src/decorator/mod.rs b/crates/burn-sparse/src/decorator/mod.rs index 2343c84944..8f2af97c05 100644 --- a/crates/burn-sparse/src/decorator/mod.rs +++ b/crates/burn-sparse/src/decorator/mod.rs @@ -12,3 +12,7 @@ mod coo_int; // pub use backend::*; // pub use precision_bridge::*; // pub use representation::*; +pub use coo::*; +pub use coo_bool::*; +pub use coo_float::*; +pub use coo_int::*; From 1dac1e8013f4204f9251006bf49f3f594eedf0da Mon Sep 17 00:00:00 2001 From: McArthur-Alford Date: Wed, 21 Aug 2024 13:23:10 +1000 Subject: [PATCH 29/38] Some more functions, a little broken --- crates/burn-tensor/src/tensor/api/base.rs | 5 ++- crates/burn-tensor/src/tensor/api/check.rs | 36 +++++++++++++++++ crates/burn-tensor/src/tensor/api/mod.rs | 3 ++ .../src/tensor/api/sparse_float.rs | 40 +++++++++++++++++++ .../src/tensor/api/sparse_numeric.rs | 8 ++++ .../src/tensor/ops/sparse_tensor.rs | 4 +- 6 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/api/sparse_float.rs create mode 100644 crates/burn-tensor/src/tensor/api/sparse_numeric.rs diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 7cbe31e5ca..ba19195651 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -31,10 +31,11 @@ where pub(crate) primitive: K::Primitive, } -impl From for Tensor +impl From for Tensor where B: Backend, - K: BasicOps, + R: TensorRepr, + K: BasicOps, T: Into, { fn from(value: T) -> Self { diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 754fab4c6c..0f68981f7a 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,5 @@ use crate::{backend::Backend, BasicOps, Shape, Tensor}; +use crate::{Dense, Float, Sparse, SparseRepr, TensorRepr}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -548,6 +549,41 @@ impl TensorCheck { check } + pub(crate) fn spmm, const D: usize>( + lhs: &Tensor>, + rhs: &Tensor, + ) -> Self { + let mut check = Self::Ok; + + check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); + + if D < 2 { + return check; + } + + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); + + let dim_lhs = shape_lhs.dims[D - 1]; + let dim_rhs = shape_rhs.dims[D - 2]; + + if dim_lhs != dim_rhs { + check = check.register( + "Matmul", + TensorError::new(format!( + "The inner dimension of matmul should be the same, but got {dim_lhs} and \ + {dim_rhs}." + )) + .details(format!( + "Lhs shape {:?}, rhs shape {:?}.", + shape_lhs.dims, shape_rhs.dims + )), + ); + } + + check + } + pub(crate) fn stack>( tensors: &[Tensor], dim: usize, diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index aaef17e0aa..d38622c053 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -14,6 +14,8 @@ mod numeric; mod repr; mod sort; mod sparse; +mod sparse_float; +mod sparse_numeric; pub use argwhere::argwhere_data; pub use autodiff::*; @@ -26,3 +28,4 @@ pub use numeric::*; pub use repr::*; pub use sort::{argsort, sort, sort_with_indices}; pub use sparse::*; +pub use sparse_numeric::*; diff --git a/crates/burn-tensor/src/tensor/api/sparse_float.rs b/crates/burn-tensor/src/tensor/api/sparse_float.rs new file mode 100644 index 0000000000..ad6143b673 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_float.rs @@ -0,0 +1,40 @@ +use crate::check; +use crate::{ + backend::Backend, check::TensorCheck, Dense, Float, Sparse, SparseRepr, Tensor, TensorKind, +}; + +impl Tensor> +where + B: Backend, + R: SparseRepr, +{ + /// Executes an operation on the tensor and modifies its value. + /// + /// # Notes + /// + /// This won't necessary reuse the same tensor data/buffer, but it should if there is + /// no other reference pointing to the same tensor. + /// + /// Wrapping operations with inplace is not an optimization, it's mainly there if you + /// want to mutate a tensor by using owned operations. A plausible usage would be to + /// update the weights of a mutable model reference. + pub fn inplace Self>(&mut self, func: F) { + let mut tensor_owned = Tensor::empty([0; D], &self.device()); + core::mem::swap(&mut tensor_owned, self); + + let mut tensor_new = func(tensor_owned); + core::mem::swap(&mut tensor_new, self); + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn spmm(self, rhs: Tensor) -> Self { + check!(TensorCheck::spmm(&self, &rhs)); + Self::new(R::float_spmm(self.primitive, rhs.primitive)) + } +} diff --git a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs new file mode 100644 index 0000000000..7fd4db61e2 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs @@ -0,0 +1,8 @@ +use crate::{backend::Backend, BasicOps, SparseRepr}; + +/// Trait that list all operations that can be applied on all numerical sparse tensors. +/// +/// # Warnings +/// +/// This is an internal trait, use the public API provided by [tensor struct](Tensor). +pub trait SparseNumeric, B: Backend>: BasicOps {} diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 97436cb6ce..6af44b3147 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,6 +1,6 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; use crate::{ - backend::Backend, Bool, Device, Float, Int, Shape, SparseRepr, TensorData, TensorKind, + backend::Backend, Bool, Device, Float, Int, Shape, Sparse, SparseRepr, TensorData, TensorKind, }; use core::{future::Future, ops::Range}; @@ -20,7 +20,7 @@ pub trait SparseFloatOps, B: Backend> { fn float_spmm( lhs: R::Primitive, - rhs: B::FloatTensorPrimitive, + rhs: >>::DensePrimitive, ) -> B::FloatTensorPrimitive; fn float_sddmm( From 6150e8ac8e7c3a56aabfe3ef84b7719ba92a9962 Mon Sep 17 00:00:00 2001 From: McArthur-Alford Date: Thu, 22 Aug 2024 17:17:02 +1000 Subject: [PATCH 30/38] A huge overhaul, much nicer types and much less confusing, achieves the same result --- crates/burn-sparse/src/decorator/coo.rs | 8 +- crates/burn-sparse/src/decorator/coo_bool.rs | 94 ++-- crates/burn-sparse/src/decorator/coo_float.rs | 265 +++++----- crates/burn-sparse/src/decorator/coo_int.rs | 90 ++-- crates/burn-tensor/src/tensor/api/base.rs | 274 ++++++----- crates/burn-tensor/src/tensor/api/check.rs | 70 +-- crates/burn-tensor/src/tensor/api/chunk.rs | 22 +- crates/burn-tensor/src/tensor/api/kind.rs | 80 +--- crates/burn-tensor/src/tensor/api/mod.rs | 2 + crates/burn-tensor/src/tensor/api/narrow.rs | 16 +- crates/burn-tensor/src/tensor/api/repr.rs | 46 +- crates/burn-tensor/src/tensor/api/sparse.rs | 404 +++++++++------- .../src/tensor/api/sparse_float.rs | 20 +- .../src/tensor/api/sparse_numeric.rs | 4 +- crates/burn-tensor/src/tensor/api/storage.rs | 30 ++ .../src/tensor/ops/sparse_tensor.rs | 451 ++++++++++-------- 16 files changed, 997 insertions(+), 879 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/api/storage.rs diff --git a/crates/burn-sparse/src/decorator/coo.rs b/crates/burn-sparse/src/decorator/coo.rs index ecbd482e71..b195a323b7 100644 --- a/crates/burn-sparse/src/decorator/coo.rs +++ b/crates/burn-sparse/src/decorator/coo.rs @@ -7,7 +7,7 @@ use burn_tensor::Float; use burn_tensor::Int; use burn_tensor::Shape; use burn_tensor::Sparse; -use burn_tensor::SparseRepr; +use burn_tensor::SparseStorage; use burn_tensor::Tensor; use burn_tensor::TensorData; use burn_tensor::TensorKind; @@ -16,15 +16,15 @@ use burn_tensor::TensorKind; pub struct COO; #[derive(Clone, Debug)] -pub struct SparseCOOTensor, const D: usize> { +pub struct SparseCOOTensor, const D: usize> { pub coordinates: Option>, pub values: Option>, pub shape: Shape, pub device: Device, } -impl SparseRepr for COO { - type Primitive, const D: usize> = SparseCOOTensor; +impl SparseStorage for COO { + type SparsePrimitive, const D: usize> = SparseCOOTensor; fn name() -> &'static str { "SparseCOO" diff --git a/crates/burn-sparse/src/decorator/coo_bool.rs b/crates/burn-sparse/src/decorator/coo_bool.rs index 93c7561100..519d1b0124 100644 --- a/crates/burn-sparse/src/decorator/coo_bool.rs +++ b/crates/burn-sparse/src/decorator/coo_bool.rs @@ -1,97 +1,95 @@ +use super::coo::COO; use burn_tensor::{ backend::Backend, ops::{SparseBoolOps, SparseTensorOps}, - SparseRepr, + SparseStorage, }; -use super::coo::COO; -type R = COO; - -impl SparseBoolOps for R { +impl SparseBoolOps for COO { fn bool_to_sparse( dense: ::BoolTensorPrimitive, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_shape( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Shape { todo!() } fn bool_reshape( - tensor: >::Primitive, + tensor: >::SparsePrimitive, shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_transpose( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_swap_dims( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim1: usize, dim2: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_permute( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_flip( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_slice( - tensor: >::Primitive, + tensor: >::SparsePrimitive, indices: [std::ops::Range; D2], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_slice_assign( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ranges: [std::ops::Range; D2], - value: >::Primitive, - ) -> >::Primitive { + value: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_device( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Device { todo!() } fn bool_to_device( - tensor: >::Primitive, + tensor: >::SparsePrimitive, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_into_data( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ) -> impl std::future::Future + Send { async { todo!() } } @@ -99,69 +97,69 @@ impl SparseBoolOps for R { fn bool_from_data( data: burn_tensor::TensorData, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_repeat_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, times: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_cat( - tensors: Vec<>::Primitive>, + tensors: Vec<>::SparsePrimitive>, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_not_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_any( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_any_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_all( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn bool_all_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn bool_expand( - tensor: >::Primitive, + tensor: >::SparsePrimitive, shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } } diff --git a/crates/burn-sparse/src/decorator/coo_float.rs b/crates/burn-sparse/src/decorator/coo_float.rs index f9c6958c88..ba6e8f9447 100644 --- a/crates/burn-sparse/src/decorator/coo_float.rs +++ b/crates/burn-sparse/src/decorator/coo_float.rs @@ -1,16 +1,17 @@ +use super::coo::{flatten_coordinates, unflatten_coordinates, SparseCOOTensor, COO}; use burn_tensor::cast::ToElement; use burn_tensor::ops::{FloatElem, SparseBoolOps}; -use burn_tensor::{backend::Backend, ops::SparseFloatOps, SparseRepr, Tensor}; -use burn_tensor::{Bool, ElementConversion, Float, Shape, Sparse, TensorData, TensorPrimitive}; +use burn_tensor::{backend::Backend, ops::SparseFloatOps, Tensor}; +use burn_tensor::{ + Bool, ElementConversion, Float, Shape, Sparse, SparseStorage, TensorData, TensorKind, + TensorPrimitive, +}; use burn_tensor::{Device, Int}; -use super::coo::{flatten_coordinates, unflatten_coordinates, SparseCOOTensor, COO}; -type R = COO; - -impl SparseFloatOps for R { +impl SparseFloatOps for COO { fn float_to_sparse( dense: ::FloatTensorPrimitive, - ) -> >::Primitive { + ) -> >::SparsePrimitive { let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); let shape = dense.shape(); @@ -59,7 +60,7 @@ impl SparseFloatOps for R { fn float_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { SparseCOOTensor { coordinates: None, values: None, @@ -69,7 +70,7 @@ impl SparseFloatOps for R { } fn float_to_dense( - sparse: >::Primitive, + sparse: >::SparsePrimitive, ) -> B::FloatTensorPrimitive { let SparseCOOTensor { coordinates, @@ -93,8 +94,8 @@ impl SparseFloatOps for R { } fn float_spmm( - lhs: >::Primitive, - rhs: ::FloatTensorPrimitive, + lhs: >::SparsePrimitive, + rhs: >::Primitive, ) -> ::FloatTensorPrimitive { let SparseCOOTensor { coordinates, @@ -103,7 +104,7 @@ impl SparseFloatOps for R { device, } = lhs; - let rhs: Tensor = Tensor::from_primitive(TensorPrimitive::Float(rhs)); + let rhs: Tensor = Tensor::from_primitive(rhs); let rhs_shape = rhs.shape(); let mut out_shape = shape.clone(); out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; @@ -170,6 +171,8 @@ impl SparseFloatOps for R { Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); let gathered = rhs.gather(0, gather_index); + println!("{}", gathered); + println!("{}", values); let multiplied = gathered.mul(values); let scattered = output.scatter(0, scatter_index, multiplied); @@ -180,8 +183,8 @@ impl SparseFloatOps for R { fn float_sddmm( lhs: ::FloatTensorPrimitive, rhs: ::FloatTensorPrimitive, - sparse: >::Primitive, - ) -> >::Primitive { + sparse: >::SparsePrimitive, + ) -> >::SparsePrimitive { if sparse.coordinates.is_none() || sparse.values.is_none() { return sparse; } @@ -242,8 +245,8 @@ impl SparseFloatOps for R { } fn float_coalesce_sum( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { if tensor.coordinates.as_ref().map(|c| c.shape().dims[1] <= 1) == Some(true) { return tensor; } @@ -333,8 +336,8 @@ impl SparseFloatOps for R { } fn float_remove_zeros( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { if tensor.coordinates.is_none() && tensor.values.is_none() { return tensor; } @@ -352,7 +355,7 @@ impl SparseFloatOps for R { } fn float_number_nonzero( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ) -> usize { match tensor.coordinates { Some(coordinates) => coordinates.shape().dims[1], @@ -361,7 +364,7 @@ impl SparseFloatOps for R { } fn float_density( - sparse: >::Primitive, + sparse: >::SparsePrimitive, ) -> f32 { match sparse.coordinates { Some(coordinates) => { @@ -372,22 +375,22 @@ impl SparseFloatOps for R { } fn float_slice( - tensor: >::Primitive, + tensor: >::SparsePrimitive, indices: [std::ops::Range; D2], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn float_device( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Device { tensor.device.clone() } fn float_to_device( - tensor: >::Primitive, + tensor: >::SparsePrimitive, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { SparseCOOTensor { coordinates: tensor.coordinates.map(|t| t.to_device(device)), values: tensor.values.map(|t| t.to_device(device)), @@ -397,13 +400,13 @@ impl SparseFloatOps for R { } fn float_shape( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Shape { tensor.shape.clone() } fn float_into_data( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ) -> impl std::future::Future + Send { async { todo!() } } @@ -411,14 +414,14 @@ impl SparseFloatOps for R { fn float_from_data( data: burn_tensor::TensorData, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn float_reshape( - tensor: >::Primitive, + tensor: >::SparsePrimitive, out_shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { if tensor.coordinates.is_none() && tensor.values.is_none() { return SparseCOOTensor { coordinates: None, @@ -452,8 +455,8 @@ impl SparseFloatOps for R { } fn float_transpose( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { let d = tensor.shape.dims.len(); let mut axes: Vec = (0..d).collect(); axes.swap(d - 1, d - 2); @@ -461,10 +464,10 @@ impl SparseFloatOps for R { } fn float_swap_dims( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim1: usize, dim2: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { let d = tensor.shape.dims.len(); let mut axes: Vec = (0..d).collect(); axes.swap(dim1, dim2); @@ -472,9 +475,9 @@ impl SparseFloatOps for R { } fn float_permute( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates, values, @@ -498,9 +501,9 @@ impl SparseFloatOps for R { } fn float_flip( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates, values, @@ -549,10 +552,10 @@ impl SparseFloatOps for R { } fn float_slice_assign( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ranges: [std::ops::Range; D2], - mut value: >::Primitive, - ) -> >::Primitive { + mut value: >::SparsePrimitive, + ) -> >::SparsePrimitive { let value_nnz = value .coordinates .as_ref() @@ -579,10 +582,10 @@ impl SparseFloatOps for R { } fn float_repeat_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, times: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates, values, @@ -633,29 +636,29 @@ impl SparseFloatOps for R { } fn float_cat( - tensors: Vec<>::Primitive>, + tensors: Vec<>::SparsePrimitive>, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn float_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn float_not_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn float_any( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates, values: _, @@ -664,19 +667,19 @@ impl SparseFloatOps for R { } = tensor; let any = coordinates.is_some(); let bool = Tensor::::from([any]).into_primitive(); - >::bool_to_sparse(bool) + >::bool_to_sparse(bool) } fn float_any_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("any_dim is unsupported for COO until scatter supports any-based reduction"); } fn float_all( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates, values: _, @@ -688,27 +691,27 @@ impl SparseFloatOps for R { None => false, }; let bool = Tensor::::from([all]).into_primitive(); - >::bool_to_sparse(bool) + >::bool_to_sparse(bool) } fn float_all_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("all_dim is unsupported for COO until scatter supports all-based reduction"); } fn float_expand( - tensor: >::Primitive, + tensor: >::SparsePrimitive, shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn float_add( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { let SparseCOOTensor { coordinates: lhs_coordinates, values: lhs_values, @@ -753,9 +756,9 @@ impl SparseFloatOps for R { } fn float_sub( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { Self::float_add( lhs, Self::float_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), @@ -763,110 +766,110 @@ impl SparseFloatOps for R { } fn float_mul( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("float_mul is unsupported until scatter supports multiplication based reduction"); } fn float_mul_scalar( - mut lhs: >::Primitive, + mut lhs: >::SparsePrimitive, rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { lhs.values = lhs.values.map(|values| values.mul_scalar(rhs)); lhs } fn float_div( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("float_div is unsupported until scatter supports multiplication based reduction"); } fn float_div_scalar( - mut lhs: >::Primitive, + mut lhs: >::SparsePrimitive, rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); lhs } fn float_max( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("max is unsupported for COO until scatter supports max reduction"); } fn float_max_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("max_dim is unsupported for COO until scatter supports max reduction"); } fn float_min( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("min is unsupported for COO until scatter supports min reduction"); } fn float_min_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("min_dim is unsupported for COO until scatter supports min reduction"); } fn float_abs( - mut tensor: >::Primitive, - ) -> >::Primitive { + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.abs()); tensor } fn float_sign( - mut tensor: >::Primitive, - ) -> >::Primitive { + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.sign()); tensor } fn float_powf( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("float_powf is unsupported for COO until scatter supports other reduction methods"); } fn float_powi( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { panic!("float_powi is unsupported for COO until scatter supports other reduction methods"); } fn float_powf_scalar( - mut lhs: >::Primitive, + mut lhs: >::SparsePrimitive, rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { lhs.values = lhs.values.map(|values| values.powf_scalar(rhs)); lhs } fn float_powi_scalar( - mut lhs: >::Primitive, + mut lhs: >::SparsePrimitive, rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { lhs.values = lhs.values.map(|values| values.powi_scalar(rhs)); lhs } fn float_clamp( - mut tensor: >::Primitive, + mut tensor: >::SparsePrimitive, min: burn_tensor::ops::FloatElem, max: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.clamp(min, max)); if min.to_f64() == 0f64 || max.to_f64() == 0f64 { // Clamp can zero elements if a boundary is zero @@ -877,9 +880,9 @@ impl SparseFloatOps for R { } fn float_clamp_min( - mut tensor: >::Primitive, + mut tensor: >::SparsePrimitive, min: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.clamp_min(min)); if min.to_f64() == 0f64 { // min can zero elements if boundary is 0 @@ -890,9 +893,9 @@ impl SparseFloatOps for R { } fn float_clamp_max( - mut tensor: >::Primitive, + mut tensor: >::SparsePrimitive, max: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.clamp_max(max)); if max.to_f64() == 0f64 { // max can zero elements if boundary is 0 @@ -903,10 +906,10 @@ impl SparseFloatOps for R { } fn float_select( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, indices: burn_tensor::ops::IntTensor, - ) -> >::Primitive { + ) -> >::SparsePrimitive { if tensor.coordinates.is_none() && tensor.values.is_none() { return tensor; } @@ -944,34 +947,34 @@ impl SparseFloatOps for R { } fn float_select_assign( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, indices: burn_tensor::ops::IntTensor, - values: >::Primitive, - ) -> >::Primitive { + values: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn float_gather( dim: usize, - tensor: >::Primitive, + tensor: >::SparsePrimitive, indices: burn_tensor::ops::IntTensor, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn float_scatter( dim: usize, - tensor: >::Primitive, + tensor: >::SparsePrimitive, indices: burn_tensor::ops::IntTensor, - values: >::Primitive, - ) -> >::Primitive { + values: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn float_sum( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { tensor .values .map(|values| Self::float_to_sparse(values.sum().into_primitive().tensor())) @@ -979,22 +982,22 @@ impl SparseFloatOps for R { } fn float_sum_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("float_sum_dim unsupported for COO"); } fn float_prod_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("float_prod_dim is not supported for COO until scatter supports product reduction") } fn float_mean( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { let num_elems = tensor.shape.num_elements(); Self::float_div_scalar( Self::float_sum(tensor), @@ -1003,23 +1006,23 @@ impl SparseFloatOps for R { } fn float_mean_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { panic!("float_mean_dim is not supported for COO until scatter supports mean reduction"); } fn float_remainder_scalar( - mut lhs: >::Primitive, + mut lhs: >::SparsePrimitive, rhs: burn_tensor::ops::FloatElem, - ) -> >::Primitive { + ) -> >::SparsePrimitive { lhs.values = lhs.values.map(|values| values.remainder_scalar(rhs)); lhs } fn float_neg( - mut tensor: >::Primitive, - ) -> >::Primitive { + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { tensor.values = tensor.values.map(|values| values.neg()); tensor } diff --git a/crates/burn-sparse/src/decorator/coo_int.rs b/crates/burn-sparse/src/decorator/coo_int.rs index 746ca7c53e..126e439fee 100644 --- a/crates/burn-sparse/src/decorator/coo_int.rs +++ b/crates/burn-sparse/src/decorator/coo_int.rs @@ -1,87 +1,85 @@ -use burn_tensor::{backend::Backend, ops::SparseIntOps, SparseRepr}; - use super::coo::COO; -type R = COO; +use burn_tensor::{backend::Backend, ops::SparseIntOps, SparseStorage}; -impl SparseIntOps for R { +impl SparseIntOps for COO { fn int_empty( shape: burn_tensor::Shape, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_shape( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Shape { todo!() } fn int_reshape( - tensor: >::Primitive, + tensor: >::SparsePrimitive, shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_transpose( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_swap_dims( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim1: usize, dim2: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_permute( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_flip( - tensor: >::Primitive, + tensor: >::SparsePrimitive, axes: &[usize], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_slice( - tensor: >::Primitive, + tensor: >::SparsePrimitive, indices: [std::ops::Range; D2], - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_slice_assign( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ranges: [std::ops::Range; D2], - value: >::Primitive, - ) -> >::Primitive { + value: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_device( - tensor: &>::Primitive, + tensor: &>::SparsePrimitive, ) -> burn_tensor::Device { todo!() } fn int_to_device( - tensor: >::Primitive, + tensor: >::SparsePrimitive, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_into_data( - tensor: >::Primitive, + tensor: >::SparsePrimitive, ) -> impl std::future::Future + Send { async { todo!() } } @@ -89,69 +87,69 @@ impl SparseIntOps for R { fn int_from_data( data: burn_tensor::TensorData, device: &burn_tensor::Device, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_repeat_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, times: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_cat( - tensors: Vec<>::Primitive>, + tensors: Vec<>::SparsePrimitive>, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_not_equal( - lhs: >::Primitive, - rhs: >::Primitive, - ) -> >::Primitive { + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_any( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_any_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_all( - tensor: >::Primitive, - ) -> >::Primitive { + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { todo!() } fn int_all_dim( - tensor: >::Primitive, + tensor: >::SparsePrimitive, dim: usize, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } fn int_expand( - tensor: >::Primitive, + tensor: >::SparsePrimitive, shape: burn_tensor::Shape, - ) -> >::Primitive { + ) -> >::SparsePrimitive { todo!() } } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index ba19195651..2f544ebdff 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,60 +18,64 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{ChangeRepr, DType, Dense, Element, SparseRepr, TensorPrimitive, TensorRepr}; +use crate::{ + DType, Dense, Element, ReprPrimitive, TensorPrimitive, TensorRepr, TensorReprT, TensorStorage, +}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] -pub struct Tensor +pub struct Tensor where B: Backend, - K: TensorKind, - R: TensorRepr, + K: TensorKind, + SR: TensorStorage, + TensorRepr: TensorReprT + TensorReprT, { - pub(crate) primitive: K::Primitive, + pub(crate) primitive: >::Primitive, } -impl From for Tensor +impl From for Tensor where B: Backend, - R: TensorRepr, - K: BasicOps, + K: BasicOps, + SR: TensorStorage, T: Into, + TensorRepr: TensorReprT + TensorReprT, { fn from(value: T) -> Self { Tensor::from_data(value.into(), &Default::default()) } } -impl Tensor +// impl Tensor +// where +// B: Backend, +// R: TensorRepr, +// K: TensorKind, +// { +// fn change_repr>(self) -> Tensor +// where +// K: TensorKind, +// R: ChangeRepr, +// { +// R::change_repr(self) +// } +// } + +impl Tensor where B: Backend, - R: TensorRepr, - K: TensorKind, -{ - fn change_repr>(self) -> Tensor - where - K: TensorKind, - R: ChangeRepr, - { - R::change_repr(self) - } -} - -impl Tensor -where - B: Backend, - K: BasicOps, - R: TensorRepr, - Bool: TensorKind, + K: BasicOps, + SR: TensorStorage, + TensorRepr: TensorReprT + TensorReprT, { /// Converts the tensor into a primitive tensor. - pub fn into_primitive(self) -> K::Primitive { + pub fn into_primitive(self) -> ReprPrimitive { self.primitive } /// Converts from a primitive tensor into a tensor. - pub fn from_primitive(tensor: K::Primitive) -> Self { + pub fn from_primitive(tensor: ReprPrimitive) -> Self { Self::new(tensor) } @@ -125,7 +129,7 @@ where /// println!("{:?}", reshaped_tensor.shape()); /// } /// ``` - pub fn reshape>(self, shape: S) -> Tensor { + pub fn reshape>(self, shape: S) -> Tensor { // Convert reshape args to shape let shape = shape.into_shape(&self); Tensor::new(K::reshape::(self.primitive, shape)) @@ -140,7 +144,7 @@ where /// # Returns /// /// The transposed tensor. - pub fn transpose(self) -> Tensor { + pub fn transpose(self) -> Tensor { Tensor::new(K::transpose(self.primitive)) } @@ -155,7 +159,7 @@ where /// # Returns /// /// The tensor with the dimensions swapped. - pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { + pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) } @@ -171,7 +175,7 @@ where /// # Returns /// /// The tensor with the dimensions permuted. - pub fn permute(self, axes: [isize; D]) -> Tensor { + pub fn permute(self, axes: [isize; D]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; D] = [0; D]; for (i, &x) in axes.iter().enumerate() { @@ -211,7 +215,11 @@ where /// The tensor with the dimensions moved. // This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op // for it - pub fn movedim(self, src: S1, dst: S2) -> Tensor { + pub fn movedim( + self, + src: S1, + dst: S2, + ) -> Tensor { let source_dims = src.into_dim_vec::(); let destination_dims = dst.into_dim_vec::(); @@ -252,7 +260,7 @@ where /// # Returns /// /// The tensor with the axes flipped. - pub fn flip(self, axes: [isize; N]) -> Tensor { + pub fn flip(self, axes: [isize; N]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; N] = [0; N]; for (i, &x) in axes.iter().enumerate() { @@ -306,7 +314,11 @@ where /// } /// /// ``` - pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { + pub fn flatten( + self, + start_dim: usize, + end_dim: usize, + ) -> Tensor { check!(TensorCheck::flatten::(start_dim, end_dim)); let current_dims = self.shape().dims; @@ -357,7 +369,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze(self, dim: usize) -> Tensor { + pub fn squeeze(self, dim: usize) -> Tensor { check!(TensorCheck::squeeze::(dim, &self.shape().dims)); let current_dims = self.shape().dims; @@ -406,7 +418,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { + pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { let current_dims = self.shape().dims; let mut dim_indices: Vec; @@ -476,7 +488,7 @@ where /// // Shape { dims: [1, 1, 3, 3] } /// } /// ``` - pub fn unsqueeze(self) -> Tensor { + pub fn unsqueeze(self) -> Tensor { check!(TensorCheck::unsqueeze::()); let mut dims = [1; D2]; @@ -505,7 +517,7 @@ where /// // Shape { dims: [3, 1, 3] } /// } /// ``` - pub fn unsqueeze_dim(self, dim: usize) -> Tensor { + pub fn unsqueeze_dim(self, dim: usize) -> Tensor { check!(TensorCheck::unsqueeze_dim::<{ D }>(dim)); let mut dims = [1; D2]; @@ -542,7 +554,7 @@ where /// // Shape { dims: [1, 3, 4, 5, 1, 1] } /// } /// ``` - pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { + pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { let mut new_dims = [1; D2]; let old_dims = self.shape().dims; //for checking if the dimension is in the acceptable range @@ -763,7 +775,7 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn equal(self, other: Self) -> Tensor { + pub fn equal(self, other: Self) -> Tensor { // check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); K::equal(self.primitive, other.primitive) } @@ -773,7 +785,7 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn not_equal(self, other: Self) -> Tensor { + pub fn not_equal(self, other: Self) -> Tensor { // check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); K::not_equal(self.primitive, other.primitive) } @@ -799,12 +811,12 @@ where /// If all tensors don't have the same shape. /// Given dimension is not with range of 0..D2 pub fn stack( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> Tensor { + ) -> Tensor { // check!(TensorCheck::stack(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); - Tensor::::cat(tensors, dim) + Tensor::::cat(tensors, dim) } /// Iterate over slices of tensors alongside a given dimension. @@ -816,9 +828,9 @@ where /// # Returns /// /// A tensor iterator. - pub fn iter_dim(self, dim: usize) -> DimIter { + pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); - DimIter::::new(self, dim) + DimIter::::new(self, dim) } /// Returns a new tensor with the given dimension narrowed to the given range. @@ -834,7 +846,7 @@ where pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); // check!(TensorCheck::narrow(&self, dim, start, length)); - Self::new(narrow::(self.primitive, dim, start, length)) + Self::new(narrow::(self.primitive, dim, start, length)) } /// Attempts to split the tensor along the given dimension into chunks. @@ -851,7 +863,7 @@ where /// A vector of tensors. pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); - chunk::(self.primitive, chunks, dim) + chunk::(self.primitive, chunks, dim) .into_iter() .map(|v| Self::new(v)) .collect() @@ -867,7 +879,7 @@ where /// /// A boolean tensor `Tensor` containing a single element, True if any element in the input tensor /// evaluates to True, False otherwise. - pub fn any(self) -> Tensor { + pub fn any(self) -> Tensor { K::any(self.primitive) } @@ -883,7 +895,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. - pub fn any_dim(self, dim: usize) -> Tensor { + pub fn any_dim(self, dim: usize) -> Tensor { K::any_dim(self.primitive, dim) } @@ -897,7 +909,7 @@ where /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. - pub fn all(self) -> Tensor { + pub fn all(self) -> Tensor { K::all(self.primitive) } @@ -913,7 +925,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. - pub fn all_dim(self, dim: usize) -> Tensor { + pub fn all_dim(self, dim: usize) -> Tensor { K::all_dim(self.primitive, dim) } @@ -957,33 +969,39 @@ where /// # Returns /// /// A new tensor with the given shape. - pub fn expand>(self, shape: S) -> Tensor { + pub fn expand>( + self, + shape: S, + ) -> Tensor { let shape = shape.into_shape(&self.shape()); check!(TensorCheck::expand("expand", &self.shape(), &shape,)); - Tensor::::new(K::expand(self.primitive, shape)) + Tensor::::new(K::expand(self.primitive, shape)) } } /// Iterator given by (Tensor::iter_dim). -pub struct DimIter +pub struct DimIter where B: Backend, - K: BasicOps, - R: TensorRepr, - Bool: TensorKind, + K: BasicOps, + SR: TensorStorage, + Bool: TensorKind, + TensorRepr: TensorReprT + TensorReprT, { start: usize, end: usize, dim: usize, ranges: [Range; D], - tensor: Tensor, + tensor: Tensor, } -impl, R: TensorRepr> Iterator - for DimIter +impl, SR: TensorStorage> Iterator + for DimIter +where + TensorRepr: TensorReprT + TensorReprT, { - type Item = Tensor; + type Item = Tensor; fn next(&mut self) -> Option { if self.start >= self.end { @@ -1016,8 +1034,11 @@ impl> DoubleEndedIterator for DimIter } } -impl, R: TensorRepr> DimIter { - fn new(tensor: Tensor, dim: usize) -> Self { +impl, SR: TensorStorage> DimIter +where + TensorRepr: TensorReprT + TensorReprT, +{ + fn new(tensor: Tensor, dim: usize) -> Self { let dims = tensor.dims(); let ranges = dims .iter() @@ -1304,7 +1325,11 @@ impl core::ops::BitXor for Tensor { /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait BasicOps = Dense>: TensorKind { +pub trait BasicOps = Dense>: TensorKind +where + TensorRepr: TensorReprT, + TensorRepr: TensorReprT, +{ /// The type of the tensor elements. type Elem: Element; @@ -1327,7 +1352,7 @@ pub trait BasicOps = Dense>: TensorKind { /// /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, /// which is more high-level and designed for public use. - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; + fn empty(shape: Shape, device: &B::Device) -> ReprPrimitive; /// Returns the shape of the tensor. /// @@ -1347,7 +1372,7 @@ pub trait BasicOps = Dense>: TensorKind { /// /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, /// which is more high-level and designed for public use. - fn shape(tensor: &Self::Primitive) -> Shape; + fn shape(tensor: &ReprPrimitive) -> Shape; /// Reshapes the tensor. /// @@ -1369,9 +1394,9 @@ pub trait BasicOps = Dense>: TensorKind { /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, /// which is more high-level and designed for public use. fn reshape( - tensor: Self::Primitive, + tensor: ReprPrimitive, shape: Shape, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Transposes a tensor. /// @@ -1382,7 +1407,9 @@ pub trait BasicOps = Dense>: TensorKind { /// # Returns /// /// The transposed tensor. - fn transpose(tensor: Self::Primitive) -> Self::Primitive; + fn transpose( + tensor: ReprPrimitive, + ) -> ReprPrimitive; /// Swaps two dimensions of a tensor. /// @@ -1396,10 +1423,10 @@ pub trait BasicOps = Dense>: TensorKind { /// /// The tensor with the dimensions swapped. fn swap_dims( - tensor: Self::Primitive, + tensor: ReprPrimitive, dim1: usize, dim2: usize, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Permutes the dimensions of a tensor. /// @@ -1411,7 +1438,10 @@ pub trait BasicOps = Dense>: TensorKind { /// # Returns /// /// The tensor with the dimensions permuted. - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive; + fn permute( + tensor: ReprPrimitive, + axes: [usize; D], + ) -> ReprPrimitive; /// Flips the tensor along the given axes. /// @@ -1423,7 +1453,10 @@ pub trait BasicOps = Dense>: TensorKind { /// # Returns /// /// The tensor with the axes flipped. - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive; + fn flip( + tensor: ReprPrimitive, + axes: &[usize], + ) -> ReprPrimitive; /// Select tensor elements corresponding for the given ranges. /// @@ -1445,9 +1478,9 @@ pub trait BasicOps = Dense>: TensorKind { /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, /// which is more high-level and designed for public use. fn slice( - tensor: Self::Primitive, + tensor: ReprPrimitive, range: [Range; D2], - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Assigns the given value to the tensor elements corresponding for the given ranges. /// @@ -1470,10 +1503,10 @@ pub trait BasicOps = Dense>: TensorKind { /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, /// which is more high-level and designed for public use. fn slice_assign( - tensor: Self::Primitive, + tensor: ReprPrimitive, ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive; + value: ReprPrimitive, + ) -> ReprPrimitive; /// Returns the device on which the tensor is allocated. /// @@ -1493,7 +1526,7 @@ pub trait BasicOps = Dense>: TensorKind { /// /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, /// which is more high-level and designed for public use. - fn device(tensor: &Self::Primitive) -> B::Device; + fn device(tensor: &ReprPrimitive) -> B::Device; /// Moves the tensor to the given device. /// @@ -1515,9 +1548,9 @@ pub trait BasicOps = Dense>: TensorKind { /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, /// which is more high-level and designed for public use. fn to_device( - tensor: Self::Primitive, + tensor: ReprPrimitive, device: &B::Device, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Extracts the data from the tensor asynchronously. /// @@ -1538,7 +1571,7 @@ pub trait BasicOps = Dense>: TensorKind { /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, /// which is more high-level and designed for public use. fn into_data_async( - tensor: Self::Primitive, + tensor: ReprPrimitive, ) -> impl Future + Send; /// Creates a tensor from the given data. @@ -1560,7 +1593,10 @@ pub trait BasicOps = Dense>: TensorKind { /// /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, /// which is more high-level and designed for public use. - fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive; + fn from_data( + data: TensorData, + device: &B::Device, + ) -> ReprPrimitive; /// Repeat the tensor along the given dimension. /// @@ -1583,10 +1619,10 @@ pub trait BasicOps = Dense>: TensorKind { /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function, /// which is more high-level and designed for public use. fn repeat_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, dim: usize, times: usize, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Concatenates the given tensors along the given dimension. /// @@ -1607,7 +1643,10 @@ pub trait BasicOps = Dense>: TensorKind { /// /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, /// which is more high-level and designed for public use. - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + fn cat( + vectors: Vec>, + dim: usize, + ) -> ReprPrimitive; /// Equates the given tensors. /// @@ -1629,9 +1668,9 @@ pub trait BasicOps = Dense>: TensorKind { /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, /// which is more high-level and designed for public use. fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; + lhs: ReprPrimitive, + rhs: ReprPrimitive, + ) -> Tensor; /// Applies element-wise non-equality comparison between the given tensors. /// @@ -1653,9 +1692,9 @@ pub trait BasicOps = Dense>: TensorKind { /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal) /// function, which is more high-level and designed for public use. fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; + lhs: ReprPrimitive, + rhs: ReprPrimitive, + ) -> Tensor; /// Returns the name of the element type. fn elem_type_name() -> &'static str { @@ -1678,7 +1717,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function /// which is more high-level and designed for public use. - fn any(tensor: Self::Primitive) -> Tensor; + fn any(tensor: ReprPrimitive) -> Tensor; /// Tests if any element in the tensor evaluates to True along a given dimension dim. /// @@ -1698,7 +1737,10 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function, /// which is more high-level and designed for public use. - fn any_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn any_dim( + tensor: ReprPrimitive, + dim: usize, + ) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True. /// @@ -1716,7 +1758,7 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function, /// which is more high-level and designed for public use. - fn all(tensor: Self::Primitive) -> Tensor; + fn all(tensor: ReprPrimitive) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// @@ -1735,7 +1777,10 @@ pub trait BasicOps = Dense>: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function, /// which is more high-level and designed for public use. - fn all_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn all_dim( + tensor: ReprPrimitive, + dim: usize, + ) -> Tensor; /// Broadcasts the given tensor to the specified shape. /// @@ -1748,9 +1793,9 @@ pub trait BasicOps = Dense>: TensorKind { /// /// The broadcasted tensor. fn expand( - tensor: Self::Primitive, + tensor: ReprPrimitive, shape: Shape, - ) -> Self::Primitive; + ) -> ReprPrimitive; } impl BasicOps for Float { @@ -2286,21 +2331,22 @@ impl RangesArg for [(i64, i64); D2] { /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. - fn into_shape, R: TensorRepr>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, + tensor: &Tensor, ) -> Shape where - Bool: TensorKind; + TensorRepr: TensorReprT + TensorReprT; } impl ReshapeArgs for Shape { - fn into_shape, R: TensorRepr>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, + tensor: &Tensor, ) -> Shape where - Bool: TensorKind, + Bool: TensorKind, + TensorRepr: TensorReprT + TensorReprT, { check!(TensorCheck::reshape_args_usize(&tensor.shape(), &self)); @@ -2308,12 +2354,13 @@ impl ReshapeArgs for Shape { } } impl ReshapeArgs for [usize; D2] { - fn into_shape, R: TensorRepr>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, + tensor: &Tensor, ) -> Shape where - Bool: TensorKind, + Bool: TensorKind, + TensorRepr: TensorReprT + TensorReprT, { let shape = Shape::from(self); @@ -2324,12 +2371,13 @@ impl ReshapeArgs for [usize; D2] { } impl ReshapeArgs for [i32; D2] { - fn into_shape, R: TensorRepr>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, + tensor: &Tensor, ) -> Shape where - Bool: TensorKind, + TensorRepr: TensorReprT + TensorReprT, + Bool: TensorKind, { // Validate the reshape arguments check!(TensorCheck::reshape_args_i32(&self)); diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 0f68981f7a..a23b2c908d 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,5 +1,5 @@ use crate::{backend::Backend, BasicOps, Shape, Tensor}; -use crate::{Dense, Float, Sparse, SparseRepr, TensorRepr}; +use crate::{Dense, Float, Sparse, TensorRepr}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -549,40 +549,40 @@ impl TensorCheck { check } - pub(crate) fn spmm, const D: usize>( - lhs: &Tensor>, - rhs: &Tensor, - ) -> Self { - let mut check = Self::Ok; - - check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); - - if D < 2 { - return check; - } - - let shape_lhs = lhs.shape(); - let shape_rhs = rhs.shape(); - - let dim_lhs = shape_lhs.dims[D - 1]; - let dim_rhs = shape_rhs.dims[D - 2]; - - if dim_lhs != dim_rhs { - check = check.register( - "Matmul", - TensorError::new(format!( - "The inner dimension of matmul should be the same, but got {dim_lhs} and \ - {dim_rhs}." - )) - .details(format!( - "Lhs shape {:?}, rhs shape {:?}.", - shape_lhs.dims, shape_rhs.dims - )), - ); - } - - check - } + // pub(crate) fn spmm, const D: usize>( + // lhs: &Tensor>, + // rhs: &Tensor, + // ) -> Self { + // let mut check = Self::Ok; + + // check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); + + // if D < 2 { + // return check; + // } + + // let shape_lhs = lhs.shape(); + // let shape_rhs = rhs.shape(); + + // let dim_lhs = shape_lhs.dims[D - 1]; + // let dim_rhs = shape_rhs.dims[D - 2]; + + // if dim_lhs != dim_rhs { + // check = check.register( + // "Matmul", + // TensorError::new(format!( + // "The inner dimension of matmul should be the same, but got {dim_lhs} and \ + // {dim_rhs}." + // )) + // .details(format!( + // "Lhs shape {:?}, rhs shape {:?}.", + // shape_lhs.dims, shape_rhs.dims + // )), + // ); + // } + + // check + // } pub(crate) fn stack>( tensors: &[Tensor], diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index bcd8df4d1b..bded0a5cb1 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,5 +1,8 @@ use super::narrow::narrow; -use crate::{backend::Backend, BasicOps, Bool, Dense, TensorKind, TensorRepr}; +use crate::{ + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorReprT, + TensorStorage, +}; use alloc::vec::Vec; /// Split the tensor along the given dimension into chunks. @@ -20,15 +23,18 @@ use alloc::vec::Vec; /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -pub fn chunk + BasicOps, R: TensorRepr>( - tensor: K::Primitive, +pub fn chunk + BasicOps, SR: TensorStorage>( + tensor: ReprPrimitive, chunks: usize, dim: usize, -) -> Vec> { +) -> Vec> +where + TensorRepr: TensorReprT + TensorReprT, +{ let size = K::shape(&tensor).dims[dim]; if size < chunks { return (0..size) - .map(|i| narrow::(tensor.clone(), dim, i, 1)) + .map(|i| narrow::(tensor.clone(), dim, i, 1)) .collect(); } @@ -37,7 +43,7 @@ pub fn chunk + BasicOps, R if size % chunks == 0 { let chunk_size = size / chunks; for _ in 0..chunks { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -48,7 +54,7 @@ pub fn chunk + BasicOps, R } else { let chunk_size = (size / chunks) + 1; // assumes not divisible for _ in 0..chunks - 1 { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -57,7 +63,7 @@ pub fn chunk + BasicOps, R sum_chunk_size += chunk_size; } let remainder = size % chunk_size; - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index 3ab4c3f695..697a1a7ae0 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -1,5 +1,5 @@ use crate::backend::Backend; -use crate::{Dense, Sparse, SparseRepr, TensorRepr}; +use crate::{Dense, Sparse, TensorRepr}; use core::marker::PhantomData; /// A type-level representation of the kind of a float tensor @@ -34,89 +34,31 @@ impl TensorPrimitive { } /// A type-level representation of the kind of a tensor. -pub trait TensorKind = Dense>: Clone + core::fmt::Debug { +pub trait TensorKind: Clone + core::fmt::Debug { /// The primitive type of the tensor. type Primitive: Clone + core::fmt::Debug + Send; - /// The primitive type of the tensor when dense. - type DensePrimitive: Clone + core::fmt::Debug + Send; - /// The name of the tensor kind. fn name() -> &'static str; - - /// The representation of the tensor kind. - fn representation() -> &'static str { - R::name() - } } -// impl TensorKind for Float { -// type Primitive = TensorPrimitive; -// fn name() -> &'static str { -// "Float" -// } -// } - -// impl> TensorKind> for Float { -// type Primitive = R::FloatTensorPrimitive; -// fn name() -> &'static str { -// >::name() -// } -// } - -// impl TensorKind for Int { -// type Primitive = B::IntTensorPrimitive; -// fn name() -> &'static str { -// "Int" -// } -// } - -// impl> TensorKind> for Int { -// type Primitive = R::IntTensorPrimitive; - -// fn name() -> &'static str { -// >::name() -// } -// } - -// impl TensorKind for Bool { -// type Primitive = B::BoolTensorPrimitive; -// fn name() -> &'static str { -// "Bool" -// } -// } - -// impl> TensorKind> for Bool { -// type Primitive = R::BoolTensorPrimitive; - -// fn name() -> &'static str { -// >::name() -// } -// } - -impl> TensorKind for Bool { - type DensePrimitive = B::BoolTensorPrimitive; - type Primitive = R::Primitive; - +impl TensorKind for Float { + type Primitive = TensorPrimitive; fn name() -> &'static str { - "Bool" + "Float" } } -impl> TensorKind for Float { - type DensePrimitive = TensorPrimitive; - type Primitive = R::Primitive; - +impl TensorKind for Int { + type Primitive = B::IntTensorPrimitive; fn name() -> &'static str { - "Float" + "Int" } } -impl> TensorKind for Int { - type DensePrimitive = B::IntTensorPrimitive; - type Primitive = R::Primitive; - +impl TensorKind for Bool { + type Primitive = B::BoolTensorPrimitive; fn name() -> &'static str { - "Int" + "Bool" } } diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index d38622c053..e4f9f8b872 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -16,6 +16,7 @@ mod sort; mod sparse; mod sparse_float; mod sparse_numeric; +mod storage; pub use argwhere::argwhere_data; pub use autodiff::*; @@ -29,3 +30,4 @@ pub use repr::*; pub use sort::{argsort, sort, sort_with_indices}; pub use sparse::*; pub use sparse_numeric::*; +pub use storage::*; diff --git a/crates/burn-tensor/src/tensor/api/narrow.rs b/crates/burn-tensor/src/tensor/api/narrow.rs index e01d2c953c..a68987e583 100644 --- a/crates/burn-tensor/src/tensor/api/narrow.rs +++ b/crates/burn-tensor/src/tensor/api/narrow.rs @@ -1,4 +1,7 @@ -use crate::{backend::Backend, BasicOps, Dense, TensorKind, TensorRepr}; +use crate::{ + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorReprT, + TensorStorage, +}; use alloc::vec::Vec; /// Returns a new tensor with the given dimension narrowed to the given range. @@ -20,14 +23,17 @@ use alloc::vec::Vec; pub fn narrow< B: Backend, const D: usize, - K: TensorKind + BasicOps, - R: TensorRepr, + K: TensorKind + BasicOps, + SR: TensorStorage, >( - tensor: K::Primitive, + tensor: ReprPrimitive, dim: usize, start: usize, length: usize, -) -> K::Primitive { +) -> ReprPrimitive +where + TensorRepr: TensorReprT + TensorReprT, +{ let shape = K::shape(&tensor); let ranges: Vec<_> = (0..D) diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 492afd3f91..7b6be778bb 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -1,42 +1,20 @@ -use crate::{backend::Backend, ops::SparseTensorOps, Bool, Float, Int, Tensor, TensorKind}; -use core::marker::PhantomData; +use crate::{backend::Backend, Dense, Float, Sparse, SparseStorage, TensorKind, TensorStorage}; -pub trait TensorRepr: Clone + core::fmt::Debug { - type Primitive, const D: usize>: Clone + core::fmt::Debug + Send; +pub type ReprPrimitive = + >::Primitive; - fn name() -> &'static str; +pub trait TensorReprT, S: TensorStorage> { + type Primitive: Clone + core::fmt::Debug + Send; } -pub trait ChangeRepr>: TensorRepr { - fn change_repr, K2: TensorKind>( - lhs: Tensor, - ) -> Tensor; -} - -pub trait SparseRepr: Clone + core::fmt::Debug + SparseTensorOps { - type Primitive, const D: usize>: Clone + core::fmt::Debug + Send; - - fn name() -> &'static str; -} +pub struct TensorRepr; -#[derive(Clone, Debug)] -pub struct Dense; - -#[derive(Clone, Debug)] -pub struct Sparse, B: Backend>(PhantomData<(R, B)>); - -impl TensorRepr for Dense { - type Primitive, const D: usize> = K::DensePrimitive; - - fn name() -> &'static str { - "Dense" - } +impl> TensorReprT for TensorRepr { + type Primitive = K::Primitive; } -impl, B: Backend> TensorRepr for Sparse { - type Primitive, const D: usize> = R::Primitive; - - fn name() -> &'static str { - R::name() - } +impl, SR: SparseStorage> TensorReprT> + for TensorRepr +{ + type Primitive = SR::SparsePrimitive; } diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 53b12694fc..2386cfedd2 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -1,417 +1,469 @@ use crate::{ backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Device, Element, Float, Int, - Shape, Sparse, SparseRepr, Tensor, TensorData, TensorKind, TensorPrimitive, TensorRepr, + ReprPrimitive, Shape, Sparse, SparseStorage, Tensor, TensorData, TensorKind, TensorPrimitive, + TensorStorage, }; use core::{future::Future, ops::Range}; use crate::check; -impl> BasicOps> for Float { +impl> BasicOps> for Float { type Elem = B::FloatElem; fn empty( shape: Shape, device: &::Device, - ) -> R::Primitive { - R::float_empty(shape, device) + ) -> SR::SparsePrimitive { + SR::float_empty(shape, device) } - fn shape(tensor: &Self::Primitive) -> Shape { - R::float_shape(tensor) + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::float_shape(tensor) } fn reshape( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::float_reshape(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::float_reshape(tensor, shape) } - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - R::float_transpose(tensor) + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::float_transpose(tensor) } fn swap_dims( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim1: usize, dim2: usize, - ) -> Self::Primitive { - R::float_swap_dims(tensor, dim1, dim2) + ) -> ReprPrimitive, D> { + SR::float_swap_dims(tensor, dim1, dim2) } - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - R::float_permute(tensor, &axes) + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::float_permute(tensor, &axes) } - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - R::float_flip(tensor, axes) + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::float_flip(tensor, axes) } fn slice( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, range: [Range; D2], - ) -> Self::Primitive { - R::float_slice(tensor, range) + ) -> ReprPrimitive, D1> { + SR::float_slice(tensor, range) } fn slice_assign( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - R::float_slice_assign(tensor, ranges, value) + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::float_slice_assign(tensor, ranges, value) } - fn device(tensor: &Self::Primitive) -> ::Device { - R::float_device(tensor) + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::float_device(tensor) } fn to_device( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, device: &::Device, - ) -> Self::Primitive { - R::float_to_device(tensor, device) + ) -> ReprPrimitive, D> { + SR::float_to_device(tensor, device) } fn into_data_async( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, ) -> impl Future + Send { - R::float_into_data(tensor) + SR::float_into_data(tensor) } fn from_data( data: TensorData, device: &::Device, - ) -> Self::Primitive { - R::float_from_data(data, device) + ) -> ReprPrimitive, D> { + SR::float_from_data(data, device) } fn repeat_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, times: usize, - ) -> Self::Primitive { - R::float_repeat_dim(tensor, dim, times) + ) -> ReprPrimitive, D> { + SR::float_repeat_dim(tensor, dim, times) } - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - R::float_cat(vectors, dim) + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::float_cat(vectors, dim) } fn expand( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::float_expand(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::float_expand(tensor, shape) } fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { - Tensor::new(R::float_equal(lhs, rhs)) + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_equal(lhs, rhs)) } fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { - Tensor::new(R::float_not_equal(lhs, rhs)) + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_not_equal(lhs, rhs)) } - fn any(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::float_any(tensor)) + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_any(tensor)) } fn any_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::float_any_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::float_any_dim(tensor, dim)) } - fn all(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::float_all(tensor)) + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_all(tensor)) } fn all_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::float_all_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::float_all_dim(tensor, dim)) } } -impl> BasicOps> for Bool { +impl> BasicOps> for Bool { type Elem = bool; fn empty( shape: Shape, device: &::Device, - ) -> Self::Primitive { - R::bool_empty(shape, device) + ) -> ReprPrimitive, D> { + SR::bool_empty(shape, device) } - fn shape(tensor: &Self::Primitive) -> Shape { - R::bool_shape(tensor) + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::bool_shape(tensor) } fn reshape( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::bool_reshape(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::bool_reshape(tensor, shape) } - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - R::bool_transpose(tensor) + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::bool_transpose(tensor) } fn swap_dims( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim1: usize, dim2: usize, - ) -> Self::Primitive { - R::bool_swap_dims(tensor, dim1, dim2) + ) -> ReprPrimitive, D> { + SR::bool_swap_dims(tensor, dim1, dim2) } - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - R::bool_permute(tensor, &axes) + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::bool_permute(tensor, &axes) } - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - R::bool_flip(tensor, axes) + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::bool_flip(tensor, axes) } fn slice( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, range: [Range; D2], - ) -> Self::Primitive { - R::bool_slice(tensor, range) + ) -> ReprPrimitive, D1> { + SR::bool_slice(tensor, range) } fn slice_assign( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - R::bool_slice_assign(tensor, ranges, value) + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::bool_slice_assign(tensor, ranges, value) } - fn device(tensor: &Self::Primitive) -> ::Device { - R::bool_device(tensor) + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::bool_device(tensor) } fn to_device( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, device: &::Device, - ) -> Self::Primitive { - R::bool_to_device(tensor, device) + ) -> ReprPrimitive, D> { + SR::bool_to_device(tensor, device) } fn into_data_async( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, ) -> impl Future + Send { - R::bool_into_data(tensor) + SR::bool_into_data(tensor) } fn from_data( data: TensorData, device: &::Device, - ) -> Self::Primitive { - R::bool_from_data(data, device) + ) -> ReprPrimitive, D> { + SR::bool_from_data(data, device) } fn repeat_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, times: usize, - ) -> Self::Primitive { - R::bool_repeat_dim(tensor, dim, times) + ) -> ReprPrimitive, D> { + SR::bool_repeat_dim(tensor, dim, times) } - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - R::bool_cat(vectors, dim) + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::bool_cat(vectors, dim) } fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { panic!("Non-zero preserving operations are not supported for sparse tensors"); } fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { panic!("Non-zero preserving operations are not supported for sparse tensors"); } - fn any(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::bool_any(tensor)) + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::bool_any(tensor)) } fn any_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::bool_any_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::bool_any_dim(tensor, dim)) } - fn all(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::bool_all(tensor)) + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::bool_all(tensor)) } fn all_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::bool_all_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::bool_all_dim(tensor, dim)) } fn expand( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::bool_expand(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::bool_expand(tensor, shape) } } -impl> BasicOps> for Int { +impl> BasicOps> for Int { type Elem = i32; fn empty( shape: Shape, device: &::Device, - ) -> Self::Primitive { - R::int_empty(shape, device) + ) -> ReprPrimitive, D> { + SR::int_empty(shape, device) } - fn shape(tensor: &Self::Primitive) -> Shape { - R::int_shape(tensor) + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::int_shape(tensor) } fn reshape( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::int_reshape(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::int_reshape(tensor, shape) } - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - R::int_transpose(tensor) + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::int_transpose(tensor) } fn swap_dims( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim1: usize, dim2: usize, - ) -> Self::Primitive { - R::int_swap_dims(tensor, dim1, dim2) + ) -> ReprPrimitive, D> { + SR::int_swap_dims(tensor, dim1, dim2) } - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - R::int_permute(tensor, &axes) + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::int_permute(tensor, &axes) } - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - R::int_flip(tensor, axes) + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::int_flip(tensor, axes) } fn slice( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, range: [Range; D2], - ) -> Self::Primitive { - R::int_slice(tensor, range) + ) -> ReprPrimitive, D1> { + SR::int_slice(tensor, range) } fn slice_assign( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - R::int_slice_assign(tensor, ranges, value) + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::int_slice_assign(tensor, ranges, value) } - fn device(tensor: &Self::Primitive) -> ::Device { - R::int_device(tensor) + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::int_device(tensor) } fn to_device( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, device: &::Device, - ) -> Self::Primitive { - R::int_to_device(tensor, device) + ) -> ReprPrimitive, D> { + SR::int_to_device(tensor, device) } fn into_data_async( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, ) -> impl Future + Send { - R::int_into_data(tensor) + SR::int_into_data(tensor) } fn from_data( data: TensorData, device: &::Device, - ) -> Self::Primitive { - R::int_from_data(data, device) + ) -> ReprPrimitive, D> { + SR::int_from_data(data, device) } fn repeat_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, times: usize, - ) -> Self::Primitive { - R::int_repeat_dim(tensor, dim, times) + ) -> ReprPrimitive, D> { + SR::int_repeat_dim(tensor, dim, times) } - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - R::int_cat(vectors, dim) + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::int_cat(vectors, dim) } fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { panic!("Non-zero preserving operations are not supported for sparse tensors"); - Tensor::new(R::int_equal(lhs, rhs)) + Tensor::new(SR::int_equal(lhs, rhs)) } fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor> { + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { panic!("Non-zero preserving operations are not supported for sparse tensors"); - Tensor::new(R::int_not_equal(lhs, rhs)) + Tensor::new(SR::int_not_equal(lhs, rhs)) } - fn any(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::int_any(tensor)) + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::int_any(tensor)) } fn any_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::int_any_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::int_any_dim(tensor, dim)) } - fn all(tensor: Self::Primitive) -> Tensor> { - Tensor::new(R::int_all(tensor)) + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::int_all(tensor)) } fn all_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> Tensor> { - Tensor::new(R::int_all_dim(tensor, dim)) + ) -> Tensor> { + Tensor::new(SR::int_all_dim(tensor, dim)) } fn expand( - tensor: Self::Primitive, + tensor: ReprPrimitive, D1>, shape: Shape, - ) -> Self::Primitive { - R::int_expand(tensor, shape) + ) -> ReprPrimitive, D2> { + SR::int_expand(tensor, shape) } } diff --git a/crates/burn-tensor/src/tensor/api/sparse_float.rs b/crates/burn-tensor/src/tensor/api/sparse_float.rs index ad6143b673..8e51d274e9 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_float.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_float.rs @@ -1,12 +1,11 @@ -use crate::check; -use crate::{ - backend::Backend, check::TensorCheck, Dense, Float, Sparse, SparseRepr, Tensor, TensorKind, -}; +use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; +use crate::{check, Bool, SparseStorage, TensorPrimitive, TensorRepr, TensorReprT}; -impl Tensor> +impl Tensor> where B: Backend, - R: SparseRepr, + SR: SparseStorage, + TensorRepr: TensorReprT> + TensorReprT>, { /// Executes an operation on the tensor and modifies its value. /// @@ -33,8 +32,11 @@ where /// # Panics /// /// If the two tensors dont' have a compatible shape. - pub fn spmm(self, rhs: Tensor) -> Self { - check!(TensorCheck::spmm(&self, &rhs)); - Self::new(R::float_spmm(self.primitive, rhs.primitive)) + pub fn spmm(self, rhs: Tensor) -> Tensor { + // check!(TensorCheck::spmm(&self, &rhs)); + Tensor::::new(TensorPrimitive::Float(SR::float_spmm( + self.primitive, + rhs.primitive, + ))) } } diff --git a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs index 7fd4db61e2..9d084f148f 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs @@ -1,8 +1,8 @@ -use crate::{backend::Backend, BasicOps, SparseRepr}; +use crate::{backend::Backend, BasicOps}; /// Trait that list all operations that can be applied on all numerical sparse tensors. /// /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait SparseNumeric, B: Backend>: BasicOps {} +pub trait SparseNumeric: BasicOps {} diff --git a/crates/burn-tensor/src/tensor/api/storage.rs b/crates/burn-tensor/src/tensor/api/storage.rs new file mode 100644 index 0000000000..ec54cfc434 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/storage.rs @@ -0,0 +1,30 @@ +use crate::{backend::Backend, ops::SparseTensorOps, Bool, Float, Int, Tensor, TensorKind}; +use core::marker::PhantomData; + +pub trait TensorStorage: Clone + core::fmt::Debug { + fn name() -> &'static str; +} + +pub trait SparseStorage: Clone + core::fmt::Debug + SparseTensorOps { + type SparsePrimitive, const D: usize>: Clone + core::fmt::Debug + Send; + + fn name() -> &'static str; +} + +#[derive(Clone, Debug)] +pub struct Dense; + +#[derive(Clone, Debug)] +pub struct Sparse>(PhantomData<(B, SR)>); + +impl TensorStorage for Dense { + fn name() -> &'static str { + "Dense" + } +} + +impl> TensorStorage for Sparse { + fn name() -> &'static str { + SR::name() + } +} diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 6af44b3147..d07d8415d6 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,43 +1,55 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; use crate::{ - backend::Backend, Bool, Device, Float, Int, Shape, Sparse, SparseRepr, TensorData, TensorKind, + backend::Backend, Bool, Device, Float, Int, ReprPrimitive, Shape, Sparse, SparseStorage, + TensorData, TensorKind, TensorRepr, TensorReprT, }; use core::{future::Future, ops::Range}; -pub trait SparseTensorOps, B: Backend>: - SparseFloatOps + SparseBoolOps + SparseIntOps +pub trait SparseTensorOps, B: Backend>: + SparseFloatOps + SparseBoolOps + SparseIntOps { } -pub trait SparseFloatOps, B: Backend> { - fn float_to_sparse(dense: B::FloatTensorPrimitive) - -> R::Primitive; +pub trait SparseFloatOps, B: Backend> +where + TensorRepr: TensorReprT>, + TensorRepr: TensorReprT>, +{ + fn float_to_sparse( + dense: B::FloatTensorPrimitive, + ) -> SR::SparsePrimitive; - fn float_empty(shape: Shape, device: &Device) -> R::Primitive; + fn float_empty( + shape: Shape, + device: &Device, + ) -> SR::SparsePrimitive; - fn float_to_dense(sparse: R::Primitive) - -> B::FloatTensorPrimitive; + fn float_to_dense( + sparse: SR::SparsePrimitive, + ) -> B::FloatTensorPrimitive; fn float_spmm( - lhs: R::Primitive, - rhs: >>::DensePrimitive, + lhs: ReprPrimitive, D>, + rhs: >::Primitive, ) -> B::FloatTensorPrimitive; fn float_sddmm( lhs: B::FloatTensorPrimitive, rhs: B::FloatTensorPrimitive, - sparse: R::Primitive, - ) -> R::Primitive; + sparse: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn float_coalesce_sum(tensor: R::Primitive) - -> R::Primitive; + fn float_coalesce_sum( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn float_remove_zeros(tensor: R::Primitive) - -> R::Primitive; + fn float_remove_zeros( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn float_number_nonzero(tensor: R::Primitive) -> usize; + fn float_number_nonzero(tensor: SR::SparsePrimitive) -> usize; - fn float_density(sparse: R::Primitive) -> f32; + fn float_density(sparse: SR::SparsePrimitive) -> f32; /// Gets the element at the given indices. /// @@ -50,9 +62,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The elements at the given indices. fn float_slice( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, indices: [Range; D2], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; /// Gets the device of the tensor. /// @@ -63,7 +75,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The device of the tensor. - fn float_device(tensor: &R::Primitive) -> Device; + fn float_device(tensor: &SR::SparsePrimitive) -> Device; /// Moves the tensor to the given device. /// @@ -76,9 +88,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The tensor on the given device. fn float_to_device( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, device: &Device, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; /// Gets the shape of the tensor. /// @@ -89,7 +101,7 @@ pub trait SparseFloatOps, B: Backend> { /// # Returns /// /// The shape of the tensor. - fn float_shape(tensor: &R::Primitive) -> Shape; + fn float_shape(tensor: &SR::SparsePrimitive) -> Shape; /// Converts the tensor to a data structure. /// @@ -101,7 +113,7 @@ pub trait SparseFloatOps, B: Backend> { /// /// The data structure with the tensor's data. fn float_into_data( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ) -> impl Future + Send; /// Creates a tensor from the data structure. @@ -117,76 +129,82 @@ pub trait SparseFloatOps, B: Backend> { fn float_from_data( data: TensorData, device: &Device, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_reshape( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_transpose(tensor: R::Primitive) -> R::Primitive; + fn float_transpose( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_swap_dims( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim1: usize, dim2: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_permute( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_flip( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_slice_assign( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ranges: [Range; D2], - value: R::Primitive, - ) -> R::Primitive; + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_repeat_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, times: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_not_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn float_any(tensor: R::Primitive) -> R::Primitive; + fn float_any( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_any_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_all(tensor: R::Primitive) -> R::Primitive; + fn float_all( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_all_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_expand( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; /// Adds two sparse tensors together. /// @@ -199,9 +217,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of adding the two tensors together. fn float_add( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; /// Subtracts two tensors. /// @@ -214,9 +232,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of subtracting the two tensors. fn float_sub( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; /// Multiplies two sparse tensors together. /// @@ -229,9 +247,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the two tensors together. fn float_mul( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; /// Multiplies a scalar to a tensor. /// @@ -244,9 +262,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of multiplying the scalar with the tensor. fn float_mul_scalar( - lhs: R::Primitive, + lhs: SR::SparsePrimitive, rhs: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; /// Divides two sparse tensors. /// @@ -259,9 +277,9 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the two tensors. fn float_div( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; /// Divides a tensor by a scalar. /// @@ -274,303 +292,338 @@ pub trait SparseFloatOps, B: Backend> { /// /// The result of dividing the tensor by the scalar. fn float_div_scalar( - lhs: R::Primitive, + lhs: SR::SparsePrimitive, rhs: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_max(tensor: R::Primitive) -> R::Primitive; + fn float_max( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_max_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_min(tensor: R::Primitive) -> R::Primitive; + fn float_min( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_min_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_abs(tensor: R::Primitive) -> R::Primitive; - fn float_sign(tensor: R::Primitive) -> R::Primitive; + fn float_abs( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + fn float_sign( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_powf( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_powi( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_powf_scalar( - lhs: R::Primitive, + lhs: SR::SparsePrimitive, rhs: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_powi_scalar( - lhs: R::Primitive, + lhs: SR::SparsePrimitive, rhs: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_clamp( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, min: FloatElem, max: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_clamp_min( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, min: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_clamp_max( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, max: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_select( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, indices: IntTensor, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_select_assign( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, indices: IntTensor, - values: R::Primitive, - ) -> R::Primitive; + values: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_gather( dim: usize, - tensor: R::Primitive, + tensor: SR::SparsePrimitive, indices: IntTensor, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_scatter( dim: usize, - tensor: R::Primitive, + tensor: SR::SparsePrimitive, indices: IntTensor, - values: R::Primitive, - ) -> R::Primitive; + values: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn float_sum(tensor: R::Primitive) -> R::Primitive; + fn float_sum( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_sum_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_prod_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_mean(tensor: R::Primitive) -> R::Primitive; + fn float_mean( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn float_mean_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn float_remainder_scalar( - lhs: R::Primitive, + lhs: SR::SparsePrimitive, rhs: FloatElem, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn float_neg(tensor: R::Primitive) -> R::Primitive; + fn float_neg( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; } -pub trait SparseBoolOps, B: Backend> { - fn bool_to_sparse(dense: B::BoolTensorPrimitive) -> R::Primitive; +pub trait SparseBoolOps, B: Backend> { + fn bool_to_sparse( + dense: B::BoolTensorPrimitive, + ) -> SR::SparsePrimitive; - fn bool_empty(shape: Shape, device: &Device) -> R::Primitive; + fn bool_empty( + shape: Shape, + device: &Device, + ) -> SR::SparsePrimitive; - fn bool_shape(tensor: &R::Primitive) -> Shape; + fn bool_shape(tensor: &SR::SparsePrimitive) -> Shape; fn bool_reshape( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn bool_transpose(tensor: R::Primitive) -> R::Primitive; + fn bool_transpose( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn bool_swap_dims( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim1: usize, dim2: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_permute( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_flip( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_slice( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, indices: [Range; D2], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_slice_assign( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ranges: [Range; D2], - value: R::Primitive, - ) -> R::Primitive; + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn bool_device(tensor: &R::Primitive) -> Device; + fn bool_device(tensor: &SR::SparsePrimitive) -> Device; fn bool_to_device( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, device: &Device, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_into_data( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ) -> impl Future + Send; fn bool_from_data( data: TensorData, device: &Device, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_repeat_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, times: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn bool_not_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn bool_any(tensor: R::Primitive) -> R::Primitive; + fn bool_any( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn bool_any_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn bool_all(tensor: R::Primitive) -> R::Primitive; + fn bool_all( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn bool_all_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn bool_expand( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; } -pub trait SparseIntOps, B: Backend> { - fn int_empty(shape: Shape, device: &Device) -> R::Primitive; +pub trait SparseIntOps, B: Backend> { + fn int_empty( + shape: Shape, + device: &Device, + ) -> SR::SparsePrimitive; - fn int_shape(tensor: &R::Primitive) -> Shape; + fn int_shape(tensor: &SR::SparsePrimitive) -> Shape; fn int_reshape( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn int_transpose(tensor: R::Primitive) -> R::Primitive; + fn int_transpose( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn int_swap_dims( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim1: usize, dim2: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_permute( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_flip( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, axes: &[usize], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_slice( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, indices: [Range; D2], - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_slice_assign( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ranges: [Range; D2], - value: R::Primitive, - ) -> R::Primitive; + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn int_device(tensor: &R::Primitive) -> Device; + fn int_device(tensor: &SR::SparsePrimitive) -> Device; fn int_to_device( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, device: &Device, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_into_data( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, ) -> impl Future + Send; - fn int_from_data(data: TensorData, device: &Device) -> R::Primitive; + fn int_from_data( + data: TensorData, + device: &Device, + ) -> SR::SparsePrimitive; fn int_repeat_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, times: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_cat( - tensors: Vec>, + tensors: Vec>, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; fn int_not_equal( - lhs: R::Primitive, - rhs: R::Primitive, - ) -> R::Primitive; + lhs: SR::SparsePrimitive, + rhs: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; - fn int_any(tensor: R::Primitive) -> R::Primitive; + fn int_any(tensor: SR::SparsePrimitive) + -> SR::SparsePrimitive; fn int_any_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; - fn int_all(tensor: R::Primitive) -> R::Primitive; + fn int_all(tensor: SR::SparsePrimitive) + -> SR::SparsePrimitive; fn int_all_dim( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, dim: usize, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; fn int_expand( - tensor: R::Primitive, + tensor: SR::SparsePrimitive, shape: Shape, - ) -> R::Primitive; + ) -> SR::SparsePrimitive; } From 40d2afdb59eab8526bc32d854d01febbffe8249e Mon Sep 17 00:00:00 2001 From: McArthur-Alford Date: Thu, 22 Aug 2024 17:46:40 +1000 Subject: [PATCH 31/38] Cleanup of types --- crates/burn-tensor/src/tensor/api/base.rs | 43 +++++++++++-------- crates/burn-tensor/src/tensor/api/chunk.rs | 6 +-- crates/burn-tensor/src/tensor/api/narrow.rs | 6 +-- crates/burn-tensor/src/tensor/api/repr.rs | 13 ++---- .../src/tensor/api/sparse_float.rs | 4 +- .../src/tensor/ops/sparse_tensor.rs | 7 +-- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 2f544ebdff..c57257c59a 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,9 +18,7 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{ - DType, Dense, Element, ReprPrimitive, TensorPrimitive, TensorRepr, TensorReprT, TensorStorage, -}; +use crate::{DType, Dense, Element, ReprPrimitive, TensorPrimitive, TensorRepr, TensorStorage}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] @@ -29,9 +27,9 @@ where B: Backend, K: TensorKind, SR: TensorStorage, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, { - pub(crate) primitive: >::Primitive, + pub(crate) primitive: <(B, K, SR) as TensorRepr>::Primitive, } impl From for Tensor @@ -40,7 +38,8 @@ where K: BasicOps, SR: TensorStorage, T: Into, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { fn from(value: T) -> Self { Tensor::from_data(value.into(), &Default::default()) @@ -53,7 +52,7 @@ where // R: TensorRepr, // K: TensorKind, // { -// fn change_repr>(self) -> Tensor +// fn change_repr, // R: ChangeRepr, @@ -67,7 +66,8 @@ where B: Backend, K: BasicOps, SR: TensorStorage, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { /// Converts the tensor into a primitive tensor. pub fn into_primitive(self) -> ReprPrimitive { @@ -987,7 +987,8 @@ where K: BasicOps, SR: TensorStorage, Bool: TensorKind, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { start: usize, end: usize, @@ -999,7 +1000,8 @@ where impl, SR: TensorStorage> Iterator for DimIter where - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { type Item = Tensor; @@ -1036,7 +1038,8 @@ impl> DoubleEndedIterator for DimIter impl, SR: TensorStorage> DimIter where - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { fn new(tensor: Tensor, dim: usize) -> Self { let dims = tensor.dims(); @@ -1327,8 +1330,8 @@ impl core::ops::BitXor for Tensor { /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait BasicOps = Dense>: TensorKind where - TensorRepr: TensorReprT, - TensorRepr: TensorReprT, + (B, Self, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { /// The type of the tensor elements. type Elem: Element; @@ -2336,7 +2339,8 @@ pub trait ReshapeArgs { tensor: &Tensor, ) -> Shape where - TensorRepr: TensorReprT + TensorReprT; + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr; } impl ReshapeArgs for Shape { @@ -2345,8 +2349,8 @@ impl ReshapeArgs for Shape { tensor: &Tensor, ) -> Shape where - Bool: TensorKind, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { check!(TensorCheck::reshape_args_usize(&tensor.shape(), &self)); @@ -2360,7 +2364,8 @@ impl ReshapeArgs for [usize; D2] { ) -> Shape where Bool: TensorKind, - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { let shape = Shape::from(self); @@ -2376,8 +2381,8 @@ impl ReshapeArgs for [i32; D2] { tensor: &Tensor, ) -> Shape where - TensorRepr: TensorReprT + TensorReprT, - Bool: TensorKind, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { // Validate the reshape arguments check!(TensorCheck::reshape_args_i32(&self)); diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index bded0a5cb1..07a266dd8d 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,7 +1,6 @@ use super::narrow::narrow; use crate::{ - backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorReprT, - TensorStorage, + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorStorage, }; use alloc::vec::Vec; @@ -29,7 +28,8 @@ pub fn chunk + BasicOps, SR: dim: usize, ) -> Vec> where - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { let size = K::shape(&tensor).dims[dim]; if size < chunks { diff --git a/crates/burn-tensor/src/tensor/api/narrow.rs b/crates/burn-tensor/src/tensor/api/narrow.rs index a68987e583..be7299e385 100644 --- a/crates/burn-tensor/src/tensor/api/narrow.rs +++ b/crates/burn-tensor/src/tensor/api/narrow.rs @@ -1,6 +1,5 @@ use crate::{ - backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorReprT, - TensorStorage, + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorStorage, }; use alloc::vec::Vec; @@ -32,7 +31,8 @@ pub fn narrow< length: usize, ) -> ReprPrimitive where - TensorRepr: TensorReprT + TensorReprT, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { let shape = K::shape(&tensor); diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index 7b6be778bb..dfc91d3eb1 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -1,20 +1,15 @@ use crate::{backend::Backend, Dense, Float, Sparse, SparseStorage, TensorKind, TensorStorage}; -pub type ReprPrimitive = - >::Primitive; +pub type ReprPrimitive = <(B, K, S) as TensorRepr>::Primitive; -pub trait TensorReprT, S: TensorStorage> { +pub trait TensorRepr { type Primitive: Clone + core::fmt::Debug + Send; } -pub struct TensorRepr; - -impl> TensorReprT for TensorRepr { +impl> TensorRepr for (B, K, Dense) { type Primitive = K::Primitive; } -impl, SR: SparseStorage> TensorReprT> - for TensorRepr -{ +impl, SR: SparseStorage> TensorRepr for (B, K, Sparse) { type Primitive = SR::SparsePrimitive; } diff --git a/crates/burn-tensor/src/tensor/api/sparse_float.rs b/crates/burn-tensor/src/tensor/api/sparse_float.rs index 8e51d274e9..b675c5d86a 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_float.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_float.rs @@ -1,11 +1,11 @@ use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; -use crate::{check, Bool, SparseStorage, TensorPrimitive, TensorRepr, TensorReprT}; +use crate::{check, Bool, SparseStorage, TensorPrimitive, TensorRepr}; impl Tensor> where B: Backend, SR: SparseStorage, - TensorRepr: TensorReprT> + TensorReprT>, + (B, Float, Sparse): TensorRepr, { /// Executes an operation on the tensor and modifies its value. /// diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index d07d8415d6..a2c8d41101 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,7 +1,8 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; +use crate::TensorRepr; use crate::{ backend::Backend, Bool, Device, Float, Int, ReprPrimitive, Shape, Sparse, SparseStorage, - TensorData, TensorKind, TensorRepr, TensorReprT, + TensorData, TensorKind, }; use core::{future::Future, ops::Range}; @@ -12,8 +13,8 @@ pub trait SparseTensorOps, B: Backend>: pub trait SparseFloatOps, B: Backend> where - TensorRepr: TensorReprT>, - TensorRepr: TensorReprT>, + (B, Float, Sparse): TensorRepr, + (B, Bool, Sparse): TensorRepr, { fn float_to_sparse( dense: B::FloatTensorPrimitive, From 1c06aab5a2dfea80ad07214961579fee7bfcecac Mon Sep 17 00:00:00 2001 From: McArthur-Alford Date: Sat, 24 Aug 2024 16:45:49 +1000 Subject: [PATCH 32/38] BasicSparseOps & into/from sparse --- crates/burn-tensor/src/tensor/api/mod.rs | 2 + crates/burn-tensor/src/tensor/api/repr.rs | 4 +- crates/burn-tensor/src/tensor/api/sparse.rs | 36 ++- .../src/tensor/api/sparse_tensor.rs | 30 +++ .../src/tensor/ops/sparse_tensor.rs | 210 +++++++++--------- 5 files changed, 175 insertions(+), 107 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/api/sparse_tensor.rs diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index e4f9f8b872..4e4f94bb1b 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -16,6 +16,7 @@ mod sort; mod sparse; mod sparse_float; mod sparse_numeric; +mod sparse_tensor; mod storage; pub use argwhere::argwhere_data; @@ -30,4 +31,5 @@ pub use repr::*; pub use sort::{argsort, sort, sort_with_indices}; pub use sparse::*; pub use sparse_numeric::*; +pub use sparse_tensor::*; pub use storage::*; diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs index dfc91d3eb1..68dfef6b34 100644 --- a/crates/burn-tensor/src/tensor/api/repr.rs +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -1,4 +1,6 @@ -use crate::{backend::Backend, Dense, Float, Sparse, SparseStorage, TensorKind, TensorStorage}; +use crate::{ + backend::Backend, Dense, Float, Sparse, SparseStorage, Tensor, TensorKind, TensorStorage, +}; pub type ReprPrimitive = <(B, K, S) as TensorRepr>::Primitive; diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 2386cfedd2..75df0a51f2 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -1,12 +1,42 @@ use crate::{ - backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Device, Element, Float, Int, - ReprPrimitive, Shape, Sparse, SparseStorage, Tensor, TensorData, TensorKind, TensorPrimitive, - TensorStorage, + backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Dense, Device, Element, Float, + Int, ReprPrimitive, Shape, Sparse, SparseStorage, Tensor, TensorData, TensorKind, + TensorPrimitive, TensorRepr, TensorStorage, }; use core::{future::Future, ops::Range}; use crate::check; +pub trait BasicSparseOps, SR: SparseStorage> +where + (B, K, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive; + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D>; +} + +impl> BasicSparseOps for SR +where + (B, Float, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + TensorPrimitive::Float(SR::float_to_dense(tensor)) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::float_to_sparse(tensor.tensor()) + } +} + impl> BasicOps> for Float { type Elem = B::FloatElem; diff --git a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs new file mode 100644 index 0000000000..65fd160e68 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs @@ -0,0 +1,30 @@ +use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; +use crate::{check, BasicOps, BasicSparseOps, Bool, SparseStorage, TensorPrimitive, TensorRepr}; + +impl Tensor +where + B: Backend, + K: TensorKind, +{ + pub fn into_sparse + BasicSparseOps>( + self, + ) -> Tensor> + where + K: BasicOps>, + (B, K, Sparse): TensorRepr, + { + Tensor::>::from_primitive(SR::into_sparse(self.primitive)) + } +} + +impl Tensor> +where + B: Backend, + K: TensorKind + BasicOps> + BasicOps, + SR: SparseStorage + BasicSparseOps, + (B, K, Sparse): TensorRepr, +{ + pub fn into_dense(self) -> Tensor { + Tensor::::from_primitive(SR::into_dense(self.primitive)) + } +} diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index a2c8d41101..c3f70e2034 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -18,15 +18,15 @@ where { fn float_to_sparse( dense: B::FloatTensorPrimitive, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_empty( shape: Shape, device: &Device, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_to_dense( - sparse: SR::SparsePrimitive, + sparse: ReprPrimitive, D>, ) -> B::FloatTensorPrimitive; fn float_spmm( @@ -37,20 +37,22 @@ where fn float_sddmm( lhs: B::FloatTensorPrimitive, rhs: B::FloatTensorPrimitive, - sparse: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + sparse: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_coalesce_sum( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_remove_zeros( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; - fn float_number_nonzero(tensor: SR::SparsePrimitive) -> usize; + fn float_number_nonzero( + tensor: ReprPrimitive, D>, + ) -> usize; - fn float_density(sparse: SR::SparsePrimitive) -> f32; + fn float_density(sparse: ReprPrimitive, D>) -> f32; /// Gets the element at the given indices. /// @@ -76,7 +78,9 @@ where /// # Returns /// /// The device of the tensor. - fn float_device(tensor: &SR::SparsePrimitive) -> Device; + fn float_device( + tensor: &ReprPrimitive, D>, + ) -> Device; /// Moves the tensor to the given device. /// @@ -89,9 +93,9 @@ where /// /// The tensor on the given device. fn float_to_device( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, device: &Device, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; /// Gets the shape of the tensor. /// @@ -102,7 +106,7 @@ where /// # Returns /// /// The shape of the tensor. - fn float_shape(tensor: &SR::SparsePrimitive) -> Shape; + fn float_shape(tensor: &ReprPrimitive, D>) -> Shape; /// Converts the tensor to a data structure. /// @@ -114,7 +118,7 @@ where /// /// The data structure with the tensor's data. fn float_into_data( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, ) -> impl Future + Send; /// Creates a tensor from the data structure. @@ -130,7 +134,7 @@ where fn float_from_data( data: TensorData, device: &Device, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_reshape( tensor: SR::SparsePrimitive, @@ -138,24 +142,24 @@ where ) -> SR::SparsePrimitive; fn float_transpose( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_swap_dims( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim1: usize, dim2: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_permute( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, axes: &[usize], - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_flip( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, axes: &[usize], - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_slice_assign( tensor: SR::SparsePrimitive, @@ -164,41 +168,41 @@ where ) -> SR::SparsePrimitive; fn float_repeat_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, times: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_cat( - tensors: Vec>, + tensors: Vec, D>>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_not_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_any( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_any_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, ) -> SR::SparsePrimitive; fn float_all( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_all_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, ) -> SR::SparsePrimitive; @@ -218,9 +222,9 @@ where /// /// The result of adding the two tensors together. fn float_add( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; /// Subtracts two tensors. /// @@ -233,9 +237,9 @@ where /// /// The result of subtracting the two tensors. fn float_sub( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; /// Multiplies two sparse tensors together. /// @@ -248,9 +252,9 @@ where /// /// The result of multiplying the two tensors together. fn float_mul( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; /// Multiplies a scalar to a tensor. /// @@ -263,9 +267,9 @@ where /// /// The result of multiplying the scalar with the tensor. fn float_mul_scalar( - lhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, rhs: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; /// Divides two sparse tensors. /// @@ -278,9 +282,9 @@ where /// /// The result of dividing the two tensors. fn float_div( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; /// Divides a tensor by a scalar. /// @@ -293,128 +297,128 @@ where /// /// The result of dividing the tensor by the scalar. fn float_div_scalar( - lhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, rhs: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_max( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_max_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_min( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_min_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_abs( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_sign( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_powf( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_powi( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_powf_scalar( - lhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, rhs: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_powi_scalar( - lhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, rhs: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_clamp( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, min: FloatElem, max: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_clamp_min( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, min: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_clamp_max( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, max: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_select( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, indices: IntTensor, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_select_assign( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, indices: IntTensor, - values: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + values: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_gather( dim: usize, - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, indices: IntTensor, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_scatter( dim: usize, - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, indices: IntTensor, - values: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + values: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; fn float_sum( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_sum_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_prod_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_mean( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, ) -> SR::SparsePrimitive; fn float_mean_dim( - tensor: SR::SparsePrimitive, + tensor: ReprPrimitive, D>, dim: usize, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_remainder_scalar( - lhs: SR::SparsePrimitive, + lhs: ReprPrimitive, D>, rhs: FloatElem, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn float_neg( - tensor: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; } pub trait SparseBoolOps, B: Backend> { From 4535e37c82bc902c7c6f953011abb6191cb35ccb Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 25 Aug 2024 09:36:13 +0000 Subject: [PATCH 33/38] Big cleanup of burn-sparse --- crates/burn-sparse/{src => old}/backend/alias.rs | 0 crates/burn-sparse/{src => old}/backend/api.rs | 0 crates/burn-sparse/{src => old}/backend/kind.rs | 0 crates/burn-sparse/{src => old}/backend/mod.rs | 0 .../{src => old}/backend/sparse_backend.rs | 0 crates/burn-sparse/{src => old}/decorator/backend.rs | 0 crates/burn-sparse/{src => old}/decorator/mod.rs | 0 crates/burn-sparse/{src => old}/decorator/ops.rs | 0 .../{src => old}/decorator/precision_bridge.rs | 0 .../{src => old}/decorator/representation.rs | 0 .../burn-sparse/{src => old}/decorator/sparse_coo.rs | 0 .../burn-sparse/{src => old}/decorator/sparse_csr.rs | 0 crates/burn-sparse/src/{decorator => }/coo.rs | 0 crates/burn-sparse/src/{decorator => }/coo_bool.rs | 0 crates/burn-sparse/src/{decorator => }/coo_float.rs | 2 -- crates/burn-sparse/src/{decorator => }/coo_int.rs | 0 crates/burn-sparse/src/lib.rs | 11 +++++++++-- 17 files changed, 9 insertions(+), 4 deletions(-) rename crates/burn-sparse/{src => old}/backend/alias.rs (100%) rename crates/burn-sparse/{src => old}/backend/api.rs (100%) rename crates/burn-sparse/{src => old}/backend/kind.rs (100%) rename crates/burn-sparse/{src => old}/backend/mod.rs (100%) rename crates/burn-sparse/{src => old}/backend/sparse_backend.rs (100%) rename crates/burn-sparse/{src => old}/decorator/backend.rs (100%) rename crates/burn-sparse/{src => old}/decorator/mod.rs (100%) rename crates/burn-sparse/{src => old}/decorator/ops.rs (100%) rename crates/burn-sparse/{src => old}/decorator/precision_bridge.rs (100%) rename crates/burn-sparse/{src => old}/decorator/representation.rs (100%) rename crates/burn-sparse/{src => old}/decorator/sparse_coo.rs (100%) rename crates/burn-sparse/{src => old}/decorator/sparse_csr.rs (100%) rename crates/burn-sparse/src/{decorator => }/coo.rs (100%) rename crates/burn-sparse/src/{decorator => }/coo_bool.rs (100%) rename crates/burn-sparse/src/{decorator => }/coo_float.rs (99%) rename crates/burn-sparse/src/{decorator => }/coo_int.rs (100%) diff --git a/crates/burn-sparse/src/backend/alias.rs b/crates/burn-sparse/old/backend/alias.rs similarity index 100% rename from crates/burn-sparse/src/backend/alias.rs rename to crates/burn-sparse/old/backend/alias.rs diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/old/backend/api.rs similarity index 100% rename from crates/burn-sparse/src/backend/api.rs rename to crates/burn-sparse/old/backend/api.rs diff --git a/crates/burn-sparse/src/backend/kind.rs b/crates/burn-sparse/old/backend/kind.rs similarity index 100% rename from crates/burn-sparse/src/backend/kind.rs rename to crates/burn-sparse/old/backend/kind.rs diff --git a/crates/burn-sparse/src/backend/mod.rs b/crates/burn-sparse/old/backend/mod.rs similarity index 100% rename from crates/burn-sparse/src/backend/mod.rs rename to crates/burn-sparse/old/backend/mod.rs diff --git a/crates/burn-sparse/src/backend/sparse_backend.rs b/crates/burn-sparse/old/backend/sparse_backend.rs similarity index 100% rename from crates/burn-sparse/src/backend/sparse_backend.rs rename to crates/burn-sparse/old/backend/sparse_backend.rs diff --git a/crates/burn-sparse/src/decorator/backend.rs b/crates/burn-sparse/old/decorator/backend.rs similarity index 100% rename from crates/burn-sparse/src/decorator/backend.rs rename to crates/burn-sparse/old/decorator/backend.rs diff --git a/crates/burn-sparse/src/decorator/mod.rs b/crates/burn-sparse/old/decorator/mod.rs similarity index 100% rename from crates/burn-sparse/src/decorator/mod.rs rename to crates/burn-sparse/old/decorator/mod.rs diff --git a/crates/burn-sparse/src/decorator/ops.rs b/crates/burn-sparse/old/decorator/ops.rs similarity index 100% rename from crates/burn-sparse/src/decorator/ops.rs rename to crates/burn-sparse/old/decorator/ops.rs diff --git a/crates/burn-sparse/src/decorator/precision_bridge.rs b/crates/burn-sparse/old/decorator/precision_bridge.rs similarity index 100% rename from crates/burn-sparse/src/decorator/precision_bridge.rs rename to crates/burn-sparse/old/decorator/precision_bridge.rs diff --git a/crates/burn-sparse/src/decorator/representation.rs b/crates/burn-sparse/old/decorator/representation.rs similarity index 100% rename from crates/burn-sparse/src/decorator/representation.rs rename to crates/burn-sparse/old/decorator/representation.rs diff --git a/crates/burn-sparse/src/decorator/sparse_coo.rs b/crates/burn-sparse/old/decorator/sparse_coo.rs similarity index 100% rename from crates/burn-sparse/src/decorator/sparse_coo.rs rename to crates/burn-sparse/old/decorator/sparse_coo.rs diff --git a/crates/burn-sparse/src/decorator/sparse_csr.rs b/crates/burn-sparse/old/decorator/sparse_csr.rs similarity index 100% rename from crates/burn-sparse/src/decorator/sparse_csr.rs rename to crates/burn-sparse/old/decorator/sparse_csr.rs diff --git a/crates/burn-sparse/src/decorator/coo.rs b/crates/burn-sparse/src/coo.rs similarity index 100% rename from crates/burn-sparse/src/decorator/coo.rs rename to crates/burn-sparse/src/coo.rs diff --git a/crates/burn-sparse/src/decorator/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs similarity index 100% rename from crates/burn-sparse/src/decorator/coo_bool.rs rename to crates/burn-sparse/src/coo_bool.rs diff --git a/crates/burn-sparse/src/decorator/coo_float.rs b/crates/burn-sparse/src/coo_float.rs similarity index 99% rename from crates/burn-sparse/src/decorator/coo_float.rs rename to crates/burn-sparse/src/coo_float.rs index ba6e8f9447..8149cd8f04 100644 --- a/crates/burn-sparse/src/decorator/coo_float.rs +++ b/crates/burn-sparse/src/coo_float.rs @@ -171,8 +171,6 @@ impl SparseFloatOps for COO { Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); let gathered = rhs.gather(0, gather_index); - println!("{}", gathered); - println!("{}", values); let multiplied = gathered.mul(values); let scattered = output.scatter(0, scatter_index, multiplied); diff --git a/crates/burn-sparse/src/decorator/coo_int.rs b/crates/burn-sparse/src/coo_int.rs similarity index 100% rename from crates/burn-sparse/src/decorator/coo_int.rs rename to crates/burn-sparse/src/coo_int.rs diff --git a/crates/burn-sparse/src/lib.rs b/crates/burn-sparse/src/lib.rs index 9dfe4c9dce..63a079f460 100644 --- a/crates/burn-sparse/src/lib.rs +++ b/crates/burn-sparse/src/lib.rs @@ -1,2 +1,9 @@ -pub mod backend; -pub mod decorator; +mod coo; +mod coo_bool; +mod coo_float; +mod coo_int; + +pub use coo::*; +pub use coo_bool::*; +pub use coo_float::*; +pub use coo_int::*; From ae8ab6886ea111ab2c11ea929d4f0fd64d090ac4 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 25 Aug 2024 09:58:30 +0000 Subject: [PATCH 34/38] Removed old --- crates/burn-sparse/old/backend/alias.rs | 4 - crates/burn-sparse/old/backend/api.rs | 95 -- crates/burn-sparse/old/backend/kind.rs | 517 ------- crates/burn-sparse/old/backend/mod.rs | 9 - .../burn-sparse/old/backend/sparse_backend.rs | 533 ------- crates/burn-sparse/old/decorator/backend.rs | 42 - crates/burn-sparse/old/decorator/mod.rs | 18 - crates/burn-sparse/old/decorator/ops.rs | 1155 --------------- .../old/decorator/precision_bridge.rs | 37 - .../old/decorator/representation.rs | 21 - .../burn-sparse/old/decorator/sparse_coo.rs | 1315 ----------------- .../burn-sparse/old/decorator/sparse_csr.rs | 522 ------- crates/burn-sparse/src/coo_bool.rs | 13 - crates/burn-sparse/src/coo_float.rs | 13 - crates/burn-sparse/src/coo_int.rs | 13 - 15 files changed, 4307 deletions(-) delete mode 100644 crates/burn-sparse/old/backend/alias.rs delete mode 100644 crates/burn-sparse/old/backend/api.rs delete mode 100644 crates/burn-sparse/old/backend/kind.rs delete mode 100644 crates/burn-sparse/old/backend/mod.rs delete mode 100644 crates/burn-sparse/old/backend/sparse_backend.rs delete mode 100644 crates/burn-sparse/old/decorator/backend.rs delete mode 100644 crates/burn-sparse/old/decorator/mod.rs delete mode 100644 crates/burn-sparse/old/decorator/ops.rs delete mode 100644 crates/burn-sparse/old/decorator/precision_bridge.rs delete mode 100644 crates/burn-sparse/old/decorator/representation.rs delete mode 100644 crates/burn-sparse/old/decorator/sparse_coo.rs delete mode 100644 crates/burn-sparse/old/decorator/sparse_csr.rs diff --git a/crates/burn-sparse/old/backend/alias.rs b/crates/burn-sparse/old/backend/alias.rs deleted file mode 100644 index dca8d059ce..0000000000 --- a/crates/burn-sparse/old/backend/alias.rs +++ /dev/null @@ -1,4 +0,0 @@ -use crate::backend::SparseBackend; - -/// Sparse tensor primitive type used by the backend. -pub type SparseTensor = ::SparseTensorPrimitive; diff --git a/crates/burn-sparse/old/backend/api.rs b/crates/burn-sparse/old/backend/api.rs deleted file mode 100644 index 5d064b3ec9..0000000000 --- a/crates/burn-sparse/old/backend/api.rs +++ /dev/null @@ -1,95 +0,0 @@ -use crate::backend::{Sparse, SparseBackend}; -use burn_tensor::{Int, Tensor, TensorPrimitive}; - -pub enum CoalesceReduction { - Sum, -} - -pub trait ToSparse -where - B: SparseBackend, -{ - fn into_sparse(self) -> Tensor; -} - -pub trait SparseTensorApi -where - B: SparseBackend, -{ - fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self; - fn dense_int(self) -> Tensor; - fn spmm(self, rhs: Tensor) -> Tensor; - fn dense(self) -> Tensor; - fn coalesce(self, reduce: CoalesceReduction) -> Tensor; - fn number_nonzero(self) -> usize; - fn density(self) -> f32; - fn add_dense(self, rhs: Tensor) -> Tensor; - fn mul_dense(self, rhs: Tensor) -> Tensor; -} - -impl ToSparse for Tensor -where - B: SparseBackend, -{ - fn into_sparse(self) -> Tensor { - Tensor::new(B::sparse_to_sparse(self.into_primitive().tensor())) - } -} - -impl SparseTensorApi for Tensor -where - B: SparseBackend, -{ - fn dense(self) -> Tensor { - Tensor::new(TensorPrimitive::Float(B::sparse_to_dense( - self.into_primitive(), - ))) - } - - fn dense_int(self) -> Tensor { - self.dense().int() - } - - fn spmm(self, rhs: Tensor) -> Tensor { - Tensor::new(TensorPrimitive::Float(B::sparse_spmm( - self.into_primitive(), - rhs.into_primitive().tensor(), - ))) - } - - fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self { - Tensor::new(B::sparse_sddmm( - lhs.into_primitive().tensor(), - rhs.into_primitive().tensor(), - self.into_primitive(), - )) - } - - fn coalesce(self, reduction: CoalesceReduction) -> Tensor { - match reduction { - CoalesceReduction::Sum => Tensor::new(B::sparse_coalesce_sum(self.into_primitive())), - } - } - - fn number_nonzero(self) -> usize { - B::sparse_nonzero(self.into_primitive()) - } - - fn density(self) -> f32 { - B::sparse_density(self.into_primitive()) - } - - fn add_dense(self, rhs: Tensor) -> Tensor { - Tensor::new(TensorPrimitive::Float(B::sparse_add_dense( - self.into_primitive(), - rhs.into_primitive().tensor(), - ))) - } - - fn mul_dense(self, rhs: Tensor) -> Tensor { - Tensor::new(TensorPrimitive::Float(B::sparse_mul_dense( - self.into_primitive(), - rhs.into_primitive().tensor(), - ))) - } -} diff --git a/crates/burn-sparse/old/backend/kind.rs b/crates/burn-sparse/old/backend/kind.rs deleted file mode 100644 index 805b7d0dbd..0000000000 --- a/crates/burn-sparse/old/backend/kind.rs +++ /dev/null @@ -1,517 +0,0 @@ -use std::{future::Future, ops::Range}; - -use crate::backend::SparseBackend; -use burn_tensor::{backend::Backend, BasicOps, Numeric, Shape, Tensor, TensorData, TensorKind}; - -pub trait SparseRepr {} - -/// A type-level representation of the kind of a sparse (float) tensor. -#[derive(Clone, Debug)] -pub struct Sparse; - -impl TensorKind for Sparse { - type Primitive = B::SparseTensorPrimitive; - fn name() -> &'static str { - "Sparse" - } -} - -impl BasicOps for Sparse { - type Elem = B::FloatElem; - - fn into_data_async( - tensor: Self::Primitive, - ) -> impl Future + Send { - B::sparse_into_data(tensor) - } - - fn device(tensor: &Self::Primitive) -> ::Device { - B::sparse_device(tensor) - } - - fn to_device( - tensor: Self::Primitive, - device: &::Device, - ) -> Self::Primitive { - B::sparse_to_device(tensor, device) - } - - fn from_data( - data: TensorData, - device: &::Device, - ) -> Self::Primitive { - B::sparse_from_data(data, device) - } - - fn shape(tensor: &Self::Primitive) -> Shape { - B::sparse_shape(tensor) - } - - fn empty( - shape: Shape, - device: &::Device, - ) -> Self::Primitive { - B::sparse_empty(shape, device) - } - - fn slice( - tensor: Self::Primitive, - ranges: [Range; D2], - ) -> Self::Primitive { - B::sparse_slice(tensor, ranges) - } - - fn reshape( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::sparse_reshape(tensor, shape) - } - - fn transpose(tensor: Self::Primitive) -> Self::Primitive { - B::sparse_transpose(tensor) - } - - fn swap_dims( - tensor: Self::Primitive, - dim1: usize, - dim2: usize, - ) -> Self::Primitive { - B::sparse_swap_dims(tensor, dim1, dim2) - } - - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - B::sparse_permute(tensor, &axes) - } - - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - B::sparse_flip(tensor, axes) - } - - fn slice_assign( - tensor: Self::Primitive, - ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive { - B::sparse_slice_assign(tensor, ranges, value) - } - - fn repeat_dim( - tensor: Self::Primitive, - dim: usize, - times: usize, - ) -> Self::Primitive { - B::sparse_repeat(tensor, dim, times) - } - - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - B::sparse_cat(vectors, dim) - } - - fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_equal(lhs, rhs)) - } - - fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_not_equal(lhs, rhs)) - } - - fn any( - tensor: Self::Primitive, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_any(tensor)) - } - - fn any_dim( - tensor: Self::Primitive, - dim: usize, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_any_dim(tensor, dim)) - } - - fn all( - tensor: Self::Primitive, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_all(tensor)) - } - - fn all_dim( - tensor: Self::Primitive, - dim: usize, - ) -> burn_tensor::Tensor { - Tensor::new(B::sparse_all_dim(tensor, dim)) - } - - fn expand( - tensor: Self::Primitive, - shape: Shape, - ) -> Self::Primitive { - B::sparse_expand(tensor, shape) - } -} - -impl Numeric for Sparse { - fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::sparse_add(lhs, rhs) - } - - fn add_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_add_scalar(lhs, rhs.elem()) - } - - fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::sparse_sub(lhs, rhs) - } - - fn sub_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_sub_scalar(lhs, rhs.elem()) - } - - fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::sparse_div(lhs, rhs) - } - - fn div_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_div_scalar(lhs, rhs.elem()) - } - - fn remainder_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_remainder_scalar(lhs, rhs.elem()) - } - - fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::sparse_mul(lhs, rhs) - } - - fn mul_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_mul_scalar(lhs, rhs.elem()) - } - - fn neg(tensor: Self::Primitive) -> Self::Primitive { - B::sparse_neg(tensor) - } - - fn sign(tensor: Self::Primitive) -> Self::Primitive { - B::sparse_sign(tensor) - } - - fn zeros( - shape: Shape, - device: &::Device, - ) -> Self::Primitive { - B::sparse_empty(shape, device) - } - - fn ones( - shape: Shape, - device: &::Device, - ) -> Self::Primitive { - B::sparse_to_sparse(B::float_ones(shape, device)) - } - - fn full( - shape: Shape, - fill_value: E, - device: &::Device, - ) -> Self::Primitive { - B::sparse_to_sparse(B::float_full(shape, fill_value.elem(), device)) - } - - fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sparse_sum(tensor) - } - - fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sparse_sum_dim(tensor, dim) - } - - fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sparse_prod(tensor) - } - - fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sparse_prod_dim(tensor, dim) - } - - fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sparse_mean(tensor) - } - - fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sparse_mean_dim(tensor, dim) - } - - fn equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_equal_elem(lhs, rhs)) - } - - fn not_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_not_equal_elem(lhs, rhs)) - } - - fn greater( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::sparse_greater(lhs, rhs)) - } - - fn greater_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_greater_elem(lhs, rhs)) - } - - fn greater_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::sparse_greater_equal(lhs, rhs)) - } - - fn greater_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_greater_equal_elem(lhs, rhs)) - } - - fn lower( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::sparse_lower(lhs, rhs)) - } - - fn lower_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_lower_elem(lhs, rhs)) - } - - fn lower_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor { - Tensor::new(B::sparse_lower_equal(lhs, rhs)) - } - - fn lower_equal_elem( - lhs: Self::Primitive, - rhs: Self::Elem, - ) -> Tensor { - Tensor::new(B::sparse_lower_equal_elem(lhs, rhs)) - } - - fn mask_where( - _tensor: Self::Primitive, - _mask: Tensor, - _source: Self::Primitive, - ) -> Self::Primitive { - panic!("masking of sparse tensors is unsupported") - } - - fn mask_fill( - _tensor: Self::Primitive, - _mask: Tensor, - _value: Self::Elem, - ) -> Self::Primitive { - panic!("masking of sparse tensors is unsupported") - } - - fn gather( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - ) -> Self::Primitive { - B::sparse_gather(dim, tensor, indices.into_primitive()) - } - - fn scatter( - dim: usize, - tensor: Self::Primitive, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::sparse_scatter(dim, tensor, indices.into_primitive(), values) - } - - fn select( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - ) -> Self::Primitive { - B::sparse_select(tensor, dim, indices.into_primitive()) - } - - fn select_assign( - tensor: Self::Primitive, - dim: usize, - indices: Tensor, - values: Self::Primitive, - ) -> Self::Primitive { - B::sparse_select_assign(tensor, dim, indices.into_primitive(), values) - } - - fn argmax( - _tensor: Self::Primitive, - _dim: usize, - ) -> ::IntTensorPrimitive { - panic!("Argmax is unsupported for sparse tensors"); - } - - fn argmin( - _tensor: Self::Primitive, - _dim: usize, - ) -> ::IntTensorPrimitive { - panic!("Argmin is unsupported for sparse tensors"); - } - - fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sparse_max(tensor) - } - - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sparse_max_dim(tensor, dim) - } - - fn max_dim_with_indices( - _tensor: Self::Primitive, - _dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - todo!() - } - - fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - B::sparse_min(tensor) - } - - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - B::sparse_min_dim(tensor, dim) - } - - fn min_dim_with_indices( - _tensor: Self::Primitive, - _dim: usize, - ) -> (Self::Primitive, ::IntTensorPrimitive) { - todo!() - } - - fn clamp( - tensor: Self::Primitive, - min: Self::Elem, - max: Self::Elem, - ) -> Self::Primitive { - B::sparse_clamp(tensor, min, max) - } - - fn clamp_min( - tensor: Self::Primitive, - min: Self::Elem, - ) -> Self::Primitive { - B::sparse_clamp_min(tensor, min) - } - - fn clamp_max( - tensor: Self::Primitive, - max: Self::Elem, - ) -> Self::Primitive { - B::sparse_clamp_max(tensor, max) - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::sparse_abs(tensor) - } - - fn powf( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Self::Primitive { - B::sparse_powf(lhs, rhs) - } - - fn powi( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Self::Primitive { - B::sparse_powi(lhs, rhs) - } - - fn powf_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_powf_scalar(lhs, rhs.elem()) - } - - fn powi_scalar( - lhs: Self::Primitive, - rhs: E, - ) -> Self::Primitive { - B::sparse_powi_scalar(lhs, rhs.elem()) - } - - fn random( - _shape: Shape, - _distribution: burn_tensor::Distribution, - _device: &::Device, - ) -> Self::Primitive { - panic!("Random is unsupported for sparse tensors") - } - - fn sort( - _tensor: Self::Primitive, - _dim: usize, - _descending: bool, - ) -> Self::Primitive { - panic!("Sorting is unsupported for sparse tensors") - } - - fn sort_with_indices( - _tensor: Self::Primitive, - _dim: usize, - _descending: bool, - ) -> ( - Self::Primitive, - >::Primitive, - ) { - panic!("Sorting is unsupported for sparse tensors") - } - - fn argsort( - _tensor: Self::Primitive, - _dim: usize, - _descending: bool, - ) -> >::Primitive { - panic!("Sorting is unsupported for sparse tensors") - } -} diff --git a/crates/burn-sparse/old/backend/mod.rs b/crates/burn-sparse/old/backend/mod.rs deleted file mode 100644 index 3b88b7f0ea..0000000000 --- a/crates/burn-sparse/old/backend/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -// mod alias; -// mod api; -// mod kind; -// mod sparse_backend; - -// pub use alias::*; -// pub use api::*; -// pub use kind::*; -// pub use sparse_backend::*; diff --git a/crates/burn-sparse/old/backend/sparse_backend.rs b/crates/burn-sparse/old/backend/sparse_backend.rs deleted file mode 100644 index 44db7b8e2f..0000000000 --- a/crates/burn-sparse/old/backend/sparse_backend.rs +++ /dev/null @@ -1,533 +0,0 @@ -use crate::backend::SparseTensor; -use burn_tensor::{ - backend::Backend, - ops::{BoolTensor, FloatElem, FloatTensor, IntTensor}, - Device, Shape, TensorData, -}; -use core::{future::Future, ops::Range}; - -pub trait SparseBackend: Backend { - type SparseTensorPrimitive: Clone + Send + 'static + core::fmt::Debug; - - fn sparse_empty( - shape: Shape, - device: &Device, - ) -> SparseTensor; - - fn sparse_to_sparse( - dense: Self::FloatTensorPrimitive, - ) -> Self::SparseTensorPrimitive; - - fn sparse_to_dense( - sparse: Self::SparseTensorPrimitive, - ) -> Self::FloatTensorPrimitive; - - fn sparse_spmm( - lhs: Self::SparseTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - ) -> Self::FloatTensorPrimitive; - - fn sparse_sddmm( - lhs: Self::FloatTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - sparse: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive; - - fn sparse_coalesce_sum( - tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive; - - fn sparse_remove_zeros( - tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive; - - fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize; - - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32; - - /// Gets the element at the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `indices` - The indices. - /// - /// # Returns - /// - /// The elements at the given indices. - fn sparse_slice( - tensor: SparseTensor, - indices: [Range; D2], - ) -> SparseTensor; - - /// Gets the device of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The device of the tensor. - fn sparse_device(tensor: &SparseTensor) -> Device; - - /// Moves the tensor to the given device. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// * `device` - The device to move the tensor to. - /// - /// # Returns - /// - /// The tensor on the given device. - fn sparse_to_device( - tensor: SparseTensor, - device: &Device, - ) -> SparseTensor; - - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn sparse_shape(tensor: &SparseTensor) -> Shape; - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn sparse_into_data( - tensor: SparseTensor, - ) -> impl Future + Send; - - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn sparse_from_data( - data: TensorData, - device: &Device, - ) -> SparseTensor; - - fn sparse_reshape( - tensor: SparseTensor, - shape: Shape, - ) -> SparseTensor; - - fn sparse_transpose(tensor: SparseTensor) -> SparseTensor; - - fn sparse_swap_dims( - tensor: SparseTensor, - dim1: usize, - dim2: usize, - ) -> SparseTensor; - - fn sparse_permute( - tensor: SparseTensor, - axes: &[usize], - ) -> SparseTensor; - - fn sparse_flip( - tensor: SparseTensor, - axes: &[usize], - ) -> SparseTensor; - - fn sparse_slice_assign( - tensor: SparseTensor, - ranges: [Range; D2], - value: SparseTensor, - ) -> SparseTensor; - - fn sparse_repeat( - tensor: SparseTensor, - dim: usize, - times: usize, - ) -> SparseTensor; - - fn sparse_cat( - tensors: Vec>, - dim: usize, - ) -> SparseTensor; - - fn sparse_equal( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_not_equal( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_any(tensor: SparseTensor) -> BoolTensor; - - fn sparse_any_dim( - tensor: SparseTensor, - dim: usize, - ) -> BoolTensor; - - fn sparse_all(tensor: SparseTensor) -> BoolTensor; - - fn sparse_all_dim( - tensor: SparseTensor, - dim: usize, - ) -> BoolTensor; - - fn sparse_expand( - tensor: SparseTensor, - shape: Shape, - ) -> SparseTensor; - - /// Adds two sparse tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of adding the two tensors together. - fn sparse_add( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - /// Adds a sparse and dense tensor together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of adding the two tensors together. - fn sparse_add_dense( - lhs: SparseTensor, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Adds a scalar to a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of adding the scalar to the tensor. - fn sparse_add_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - /// Subtracts two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of subtracting the two tensors. - fn sparse_sub( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - /// Subtracts a dense from a sparse tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor (sparse). - /// * `rhs` - The right hand side tensor (dense). - /// - /// # Returns - /// - /// The result of subtracting the two tensors. - fn sparse_sub_dense( - lhs: SparseTensor, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Subtracts a scalar from a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of subtracting the scalar from the tensor. - fn sparse_sub_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - /// Multiplies two sparse tensors together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of multiplying the two tensors together. - fn sparse_mul( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - /// Multiplies a sparse and dense tensor together. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of multiplying the two tensors together. - fn sparse_mul_dense( - lhs: SparseTensor, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Multiplies a scalar to a tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of multiplying the scalar with the tensor. - fn sparse_mul_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - /// Divides two sparse tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of dividing the two tensors. - fn sparse_div( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - /// Divides a sparse and dense tensor. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The result of dividing the two tensors. - fn sparse_div_dense( - lhs: SparseTensor, - rhs: FloatTensor, - ) -> FloatTensor; - - /// Divides a tensor by a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of dividing the tensor by the scalar. - fn sparse_div_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - fn sparse_max(tensor: SparseTensor) -> SparseTensor; - - fn sparse_max_dim( - tensor: SparseTensor, - dim: usize, - ) -> SparseTensor; - - fn sparse_min(tensor: SparseTensor) -> SparseTensor; - - fn sparse_min_dim( - tensor: SparseTensor, - dim: usize, - ) -> SparseTensor; - - fn sparse_greater( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_greater_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_greater_equal( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_greater_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_lower( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_lower_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_lower_equal( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> BoolTensor; - - fn sparse_lower_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_abs(tensor: SparseTensor) -> SparseTensor; - fn sparse_sign(tensor: SparseTensor) -> SparseTensor; - - fn sparse_powf( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - fn sparse_powi( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor; - - fn sparse_powf_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - fn sparse_powi_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - fn sparse_clamp( - tensor: SparseTensor, - min: FloatElem, - max: FloatElem, - ) -> SparseTensor; - - fn sparse_clamp_min( - tensor: SparseTensor, - min: FloatElem, - ) -> SparseTensor; - - fn sparse_clamp_max( - tensor: SparseTensor, - max: FloatElem, - ) -> SparseTensor; - - fn sparse_select( - tensor: SparseTensor, - dim: usize, - indices: IntTensor, - ) -> SparseTensor; - - fn sparse_select_assign( - tensor: SparseTensor, - dim: usize, - indices: IntTensor, - values: SparseTensor, - ) -> SparseTensor; - - fn sparse_gather( - dim: usize, - tensor: SparseTensor, - indices: IntTensor, - ) -> SparseTensor; - - fn sparse_scatter( - dim: usize, - tensor: SparseTensor, - indices: IntTensor, - values: SparseTensor, - ) -> SparseTensor; - - fn sparse_sum(tensor: SparseTensor) -> SparseTensor; - - fn sparse_sum_dim( - tensor: SparseTensor, - dim: usize, - ) -> SparseTensor; - - fn sparse_prod(tensor: SparseTensor) -> SparseTensor; - - fn sparse_prod_dim( - tensor: SparseTensor, - dim: usize, - ) -> SparseTensor; - - fn sparse_mean(tensor: SparseTensor) -> SparseTensor; - - fn sparse_mean_dim( - tensor: SparseTensor, - dim: usize, - ) -> SparseTensor; - - fn sparse_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_not_equal_elem( - lhs: SparseTensor, - rhs: FloatElem, - ) -> BoolTensor; - - fn sparse_remainder_scalar( - lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor; - - fn sparse_neg(tensor: SparseTensor) -> SparseTensor; -} diff --git a/crates/burn-sparse/old/decorator/backend.rs b/crates/burn-sparse/old/decorator/backend.rs deleted file mode 100644 index 7a0fbf0002..0000000000 --- a/crates/burn-sparse/old/decorator/backend.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::decorator::FullPrecisionBridge; -use crate::decorator::SparseRepresentation; -use burn_tensor::backend::Backend; -use core::marker::PhantomData; -use derive_new::new; - -/// Tensor backend that extends existing backends with sparse tensor support. -/// This backend abstracts over all backends, and so lacks the performance of a direct implementation. -/// Backends implementing SparseDecorator should be used directly where possible. -#[derive(new, Clone, Copy, Default, Debug)] -pub struct SparseDecorator { - _p: PhantomData, - _r: PhantomData, -} - -impl Backend for SparseDecorator { - type Device = B::Device; - - type FullPrecisionBridge = FullPrecisionBridge; - - type FloatTensorPrimitive = B::FloatTensorPrimitive; - - type FloatElem = B::FloatElem; - - type IntTensorPrimitive = B::IntTensorPrimitive; - - type IntElem = B::IntElem; - - type BoolTensorPrimitive = B::BoolTensorPrimitive; - - type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; - - fn name() -> String { - format!("SparseDecorator<{}>", B::name()) - } - - fn seed(seed: u64) { - B::seed(seed) - } -} - -impl SparseDecorator {} diff --git a/crates/burn-sparse/old/decorator/mod.rs b/crates/burn-sparse/old/decorator/mod.rs deleted file mode 100644 index 8f2af97c05..0000000000 --- a/crates/burn-sparse/old/decorator/mod.rs +++ /dev/null @@ -1,18 +0,0 @@ -// mod backend; -// mod ops; -// mod precision_bridge; -// mod representation; -// mod sparse_coo; -// mod sparse_csr; -mod coo; -mod coo_bool; -mod coo_float; -mod coo_int; - -// pub use backend::*; -// pub use precision_bridge::*; -// pub use representation::*; -pub use coo::*; -pub use coo_bool::*; -pub use coo_float::*; -pub use coo_int::*; diff --git a/crates/burn-sparse/old/decorator/ops.rs b/crates/burn-sparse/old/decorator/ops.rs deleted file mode 100644 index 7994ef94c0..0000000000 --- a/crates/burn-sparse/old/decorator/ops.rs +++ /dev/null @@ -1,1155 +0,0 @@ -use crate::decorator::SparseDecorator; -use crate::decorator::SparseRepresentation; -use burn_tensor::{ - backend::Backend, - ops::{ - ActivationOps, BoolTensor, BoolTensorOps, ConvOptions, ConvTransposeOptions, FloatTensor, - FloatTensorOps, IntElem, IntTensor, IntTensorOps, InterpolateOptions, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, QTensorOps, - }, - Device, Distribution, Shape, TensorData, -}; -use core::ops::Range; - -impl FloatTensorOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ - fn float_random( - shape: burn_tensor::Shape, - distribution: burn_tensor::Distribution, - device: &burn_tensor::Device, - ) -> burn_tensor::ops::FloatTensor { - B::float_random(shape, distribution, device) - } - - fn float_shape( - tensor: &burn_tensor::ops::FloatTensor, - ) -> burn_tensor::Shape { - B::float_shape(tensor) - } - - fn float_device( - tensor: &burn_tensor::ops::FloatTensor, - ) -> burn_tensor::Device { - B::float_device(tensor) - } - - fn float_to_device( - tensor: burn_tensor::ops::FloatTensor, - device: &burn_tensor::Device, - ) -> burn_tensor::ops::FloatTensor { - B::float_to_device(tensor, device) - } - - fn float_into_int( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::IntTensor { - B::float_into_int(tensor) - } - - fn float_empty( - shape: burn_tensor::Shape, - device: &burn_tensor::Device, - ) -> burn_tensor::ops::FloatTensor { - B::float_empty(shape, device) - } - - fn float_add( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_add(lhs, rhs) - } - - fn float_add_scalar( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_add_scalar(lhs, rhs) - } - - fn float_sub( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_sub(lhs, rhs) - } - - fn float_sub_scalar( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_sub_scalar(lhs, rhs) - } - - fn float_mul( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_mul(lhs, rhs) - } - - fn float_mul_scalar( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_mul_scalar(lhs, rhs) - } - - fn float_div( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_div(lhs, rhs) - } - - fn float_div_scalar( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_div_scalar(lhs, rhs) - } - - fn float_remainder_scalar( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_remainder_scalar(lhs, rhs) - } - - fn float_matmul( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_matmul(lhs, rhs) - } - - fn float_recip( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_recip(tensor) - } - - fn float_swap_dims( - tensor: burn_tensor::ops::FloatTensor, - dim1: usize, - dim2: usize, - ) -> burn_tensor::ops::FloatTensor { - B::float_swap_dims(tensor, dim1, dim2) - } - - fn float_permute( - tensor: burn_tensor::ops::FloatTensor, - axes: [usize; D], - ) -> burn_tensor::ops::FloatTensor { - B::float_permute(tensor, axes) - } - - fn float_flip( - tensor: burn_tensor::ops::FloatTensor, - axes: &[usize], - ) -> burn_tensor::ops::FloatTensor { - B::float_flip(tensor, axes) - } - - fn float_reshape( - tensor: burn_tensor::ops::FloatTensor, - shape: burn_tensor::Shape, - ) -> burn_tensor::ops::FloatTensor { - B::float_reshape(tensor, shape) - } - - fn float_gather( - dim: usize, - tensor: burn_tensor::ops::FloatTensor, - indices: burn_tensor::ops::IntTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_gather(dim, tensor, indices) - } - - fn float_scatter( - dim: usize, - tensor: burn_tensor::ops::FloatTensor, - indices: burn_tensor::ops::IntTensor, - value: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_scatter(dim, tensor, indices, value) - } - - fn float_select( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_select(tensor, dim, indices) - } - - fn float_select_assign( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, - value: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_select_assign(tensor, dim, indices, value) - } - - fn float_slice( - tensor: burn_tensor::ops::FloatTensor, - ranges: [core::ops::Range; D2], - ) -> burn_tensor::ops::FloatTensor { - B::float_slice(tensor, ranges) - } - - fn float_slice_assign( - tensor: burn_tensor::ops::FloatTensor, - ranges: [core::ops::Range; D2], - value: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_slice_assign(tensor, ranges, value) - } - - fn float_mask_where( - tensor: burn_tensor::ops::FloatTensor, - mask: burn_tensor::ops::BoolTensor, - value: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_mask_where(tensor, mask, value) - } - - fn float_mask_fill( - tensor: burn_tensor::ops::FloatTensor, - mask: burn_tensor::ops::BoolTensor, - value: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::FloatTensor { - B::float_mask_fill(tensor, mask, value) - } - - fn float_equal( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::BoolTensor { - B::float_equal(lhs, rhs) - } - - fn float_equal_elem( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::BoolTensor { - B::float_equal_elem(lhs, rhs) - } - - fn float_greater( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::BoolTensor { - B::float_greater(lhs, rhs) - } - - fn float_greater_elem( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::BoolTensor { - B::float_greater_elem(lhs, rhs) - } - - fn float_greater_equal( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::BoolTensor { - B::float_greater_equal(lhs, rhs) - } - - fn float_greater_equal_elem( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::BoolTensor { - B::float_greater_equal_elem(lhs, rhs) - } - - fn float_lower( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::BoolTensor { - B::float_lower(lhs, rhs) - } - - fn float_lower_elem( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::BoolTensor { - B::float_lower_elem(lhs, rhs) - } - - fn float_lower_equal( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::BoolTensor { - B::float_lower_equal(lhs, rhs) - } - - fn float_lower_equal_elem( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatElem, - ) -> burn_tensor::ops::BoolTensor { - B::float_lower_equal_elem(lhs, rhs) - } - - fn float_sum( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_sum(tensor) - } - - fn float_sum_dim( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - ) -> burn_tensor::ops::FloatTensor { - B::float_sum_dim(tensor, dim) - } - - fn float_mean_dim( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - ) -> burn_tensor::ops::FloatTensor { - B::float_mean_dim(tensor, dim) - } - - fn float_exp( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_exp(tensor) - } - - fn float_log( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_log(tensor) - } - - fn float_log1p( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_log1p(tensor) - } - - fn float_powf( - lhs: burn_tensor::ops::FloatTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_powf(lhs, rhs) - } - - fn float_powf_scalar( - tensor: burn_tensor::ops::FloatTensor, - value: f32, - ) -> burn_tensor::ops::FloatTensor { - B::float_powf_scalar(tensor, value) - } - - fn float_sqrt( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_sqrt(tensor) - } - - fn float_abs( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_abs(tensor) - } - - fn float_cos( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_cos(tensor) - } - - fn float_sin( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_sin(tensor) - } - - fn float_tanh( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_tanh(tensor) - } - - fn float_erf( - tensor: burn_tensor::ops::FloatTensor, - ) -> burn_tensor::ops::FloatTensor { - B::float_erf(tensor) - } - - fn float_argmax( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - ) -> burn_tensor::ops::IntTensor { - B::float_argmax(tensor, dim) - } - - fn float_argmin( - tensor: burn_tensor::ops::FloatTensor, - dim: usize, - ) -> burn_tensor::ops::IntTensor { - B::float_argmin(tensor, dim) - } - - fn float_expand( - tensor: burn_tensor::ops::FloatTensor, - shape: burn_tensor::Shape, - ) -> burn_tensor::ops::FloatTensor { - B::float_expand(tensor, shape) - } - - fn float_into_data( - tensor: FloatTensor, D>, - ) -> impl std::future::Future + Send { - B::float_into_data(tensor) - } - - fn float_from_data( - data: TensorData, - device: &Device>, - ) -> FloatTensor, D> { - B::float_from_data(data, device) - } -} - -impl BoolTensorOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ - fn bool_empty( - shape: burn_tensor::Shape, - device: &burn_tensor::Device>, - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_empty(shape, device) - } - - fn bool_shape( - tensor: &burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::Shape { - B::bool_shape(tensor) - } - - fn bool_into_int( - tensor: burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::ops::IntTensor, D> { - B::bool_into_int(tensor) - } - - fn bool_into_float( - tensor: burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::ops::FloatTensor, D> { - B::bool_into_float(tensor) - } - - fn bool_device( - tensor: &burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::Device> { - B::bool_device(tensor) - } - - fn bool_to_device( - tensor: burn_tensor::ops::BoolTensor, D>, - device: &burn_tensor::Device>, - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_to_device(tensor, device) - } - - fn bool_reshape( - tensor: burn_tensor::ops::BoolTensor, D1>, - shape: burn_tensor::Shape, - ) -> burn_tensor::ops::BoolTensor, D2> { - B::bool_reshape(tensor, shape) - } - - fn bool_slice( - tensor: burn_tensor::ops::BoolTensor, D1>, - ranges: [core::ops::Range; D2], - ) -> burn_tensor::ops::BoolTensor, D1> { - B::bool_slice(tensor, ranges) - } - - fn bool_slice_assign( - tensor: burn_tensor::ops::BoolTensor, D1>, - ranges: [core::ops::Range; D2], - value: burn_tensor::ops::BoolTensor, D1>, - ) -> burn_tensor::ops::BoolTensor, D1> { - B::bool_slice_assign(tensor, ranges, value) - } - - fn bool_equal( - lhs: burn_tensor::ops::BoolTensor, D>, - rhs: burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_equal(lhs, rhs) - } - - fn bool_not( - tensor: burn_tensor::ops::BoolTensor, D>, - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_not(tensor) - } - - fn bool_swap_dims( - tensor: burn_tensor::ops::BoolTensor, D>, - dim1: usize, - dim2: usize, - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_swap_dims(tensor, dim1, dim2) - } - - fn bool_permute( - tensor: burn_tensor::ops::BoolTensor, D>, - axes: [usize; D], - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_permute(tensor, axes) - } - - fn bool_flip( - tensor: burn_tensor::ops::BoolTensor, D>, - axes: &[usize], - ) -> burn_tensor::ops::BoolTensor, D> { - B::bool_flip(tensor, axes) - } - - fn bool_expand( - tensor: burn_tensor::ops::BoolTensor, D1>, - shape: burn_tensor::Shape, - ) -> burn_tensor::ops::BoolTensor, D2> { - B::bool_expand(tensor, shape) - } - - fn bool_into_data( - tensor: BoolTensor, D>, - ) -> impl std::future::Future + Send { - B::bool_into_data(tensor) - } - - fn bool_from_data( - data: TensorData, - device: &Device>, - ) -> BoolTensor, D> { - B::bool_from_data(data, device) - } -} - -impl IntTensorOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ - fn int_empty( - shape: Shape, - device: &Device>, - ) -> IntTensor, D> { - B::int_empty(shape, device) - } - - fn int_shape(tensor: &IntTensor, D>) -> Shape { - B::int_shape(tensor) - } - - fn int_device( - tensor: &IntTensor, D>, - ) -> Device> { - B::int_device(tensor) - } - - fn int_to_device( - tensor: IntTensor, D>, - device: &Device>, - ) -> IntTensor, D> { - B::int_to_device(tensor, device) - } - - fn int_reshape( - tensor: IntTensor, D1>, - shape: Shape, - ) -> IntTensor, D2> { - B::int_reshape(tensor, shape) - } - - fn int_slice( - tensor: IntTensor, D1>, - indices: [Range; D2], - ) -> IntTensor, D1> { - B::int_slice(tensor, indices) - } - - fn int_slice_assign( - tensor: IntTensor, D1>, - indices: [Range; D2], - value: IntTensor, D1>, - ) -> IntTensor, D1> { - B::int_slice_assign(tensor, indices, value) - } - - fn int_into_float( - tensor: IntTensor, D>, - ) -> FloatTensor, D> { - B::int_into_float(tensor) - } - - fn int_mask_where( - tensor: IntTensor, D>, - mask: BoolTensor, D>, - source: IntTensor, D>, - ) -> IntTensor, D> { - B::int_mask_where(tensor, mask, source) - } - - fn int_mask_fill( - tensor: IntTensor, D>, - mask: BoolTensor, D>, - value: IntElem>, - ) -> IntTensor, D> { - B::int_mask_fill(tensor, mask, value) - } - - fn int_gather( - dim: usize, - tensor: IntTensor, D>, - indices: IntTensor, D>, - ) -> IntTensor, D> { - B::int_gather(dim, tensor, indices) - } - - fn int_scatter( - dim: usize, - tensor: IntTensor, D>, - indices: IntTensor, D>, - value: IntTensor, D>, - ) -> IntTensor, D> { - B::int_scatter(dim, tensor, indices, value) - } - - fn int_select( - tensor: IntTensor, D>, - dim: usize, - indices: IntTensor, 1>, - ) -> IntTensor, D> { - B::int_select(tensor, dim, indices) - } - - fn int_select_assign( - tensor: IntTensor, D>, - dim: usize, - indices: IntTensor, 1>, - value: IntTensor, D>, - ) -> IntTensor, D> { - B::int_select_assign(tensor, dim, indices, value) - } - - fn int_cat( - tensors: Vec, D>>, - dim: usize, - ) -> IntTensor, D> { - B::int_cat(tensors, dim) - } - - fn int_equal( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> BoolTensor, D> { - B::int_equal(lhs, rhs) - } - - fn int_equal_elem( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> BoolTensor, D> { - B::int_equal_elem(lhs, rhs) - } - - fn int_greater( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> BoolTensor, D> { - B::int_greater(lhs, rhs) - } - - fn int_greater_elem( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> BoolTensor, D> { - B::int_greater_elem(lhs, rhs) - } - - fn int_greater_equal( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> BoolTensor, D> { - B::int_greater_equal(lhs, rhs) - } - - fn int_greater_equal_elem( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> BoolTensor, D> { - B::int_greater_equal_elem(lhs, rhs) - } - - fn int_lower( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> BoolTensor, D> { - B::int_lower(lhs, rhs) - } - - fn int_lower_elem( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> BoolTensor, D> { - B::int_lower_elem(lhs, rhs) - } - - fn int_lower_equal( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> BoolTensor, D> { - B::int_lower_equal(lhs, rhs) - } - - fn int_lower_equal_elem( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> BoolTensor, D> { - B::int_lower_equal_elem(lhs, rhs) - } - - fn int_sub( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> IntTensor, D> { - B::int_sub(lhs, rhs) - } - - fn int_sub_scalar( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> IntTensor, D> { - B::int_sub_scalar(lhs, rhs) - } - - fn int_mul( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> IntTensor, D> { - B::int_mul(lhs, rhs) - } - - fn int_mul_scalar( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> IntTensor, D> { - B::int_mul_scalar(lhs, rhs) - } - - fn int_div( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> IntTensor, D> { - B::int_div(lhs, rhs) - } - - fn int_div_scalar( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> IntTensor, D> { - B::int_div_scalar(lhs, rhs) - } - - fn int_remainder_scalar( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> IntTensor, D> { - B::int_remainder_scalar(lhs, rhs) - } - - fn int_zeros( - shape: Shape, - device: &Device>, - ) -> IntTensor, D> { - B::int_zeros(shape, device) - } - - fn int_ones( - shape: Shape, - device: &Device>, - ) -> IntTensor, D> { - B::int_ones(shape, device) - } - - fn int_sum( - tensor: IntTensor, D>, - ) -> IntTensor, 1> { - B::int_sum(tensor) - } - - fn int_sum_dim( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_sum_dim(tensor, dim) - } - - fn int_prod( - tensor: IntTensor, D>, - ) -> IntTensor, 1> { - B::int_prod(tensor) - } - - fn int_prod_dim( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_prod_dim(tensor, dim) - } - - fn int_mean_dim( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_mean_dim(tensor, dim) - } - - fn int_argmax( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_argmax(tensor, dim) - } - - fn int_argmin( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_argmin(tensor, dim) - } - - fn int_max_dim( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_max_dim(tensor, dim) - } - - fn int_max_dim_with_indices( - tensor: IntTensor, D>, - dim: usize, - ) -> ( - IntTensor, D>, - IntTensor, D>, - ) { - B::int_max_dim_with_indices(tensor, dim) - } - - fn int_min_dim( - tensor: IntTensor, D>, - dim: usize, - ) -> IntTensor, D> { - B::int_min_dim(tensor, dim) - } - - fn int_min_dim_with_indices( - tensor: IntTensor, D>, - dim: usize, - ) -> ( - IntTensor, D>, - IntTensor, D>, - ) { - B::int_min_dim_with_indices(tensor, dim) - } - - fn int_abs( - tensor: IntTensor, D>, - ) -> IntTensor, D> { - B::int_abs(tensor) - } - - fn int_transpose( - tensor: IntTensor, D>, - ) -> IntTensor, D> { - B::int_transpose(tensor) - } - - fn int_swap_dims( - tensor: IntTensor, D>, - dim1: usize, - dim2: usize, - ) -> IntTensor, D> { - B::int_swap_dims(tensor, dim1, dim2) - } - - fn int_permute( - tensor: IntTensor, D>, - axes: [usize; D], - ) -> IntTensor, D> { - B::int_permute(tensor, axes) - } - - fn int_flip( - tensor: IntTensor, D>, - axes: &[usize], - ) -> IntTensor, D> { - B::int_flip(tensor, axes) - } - - fn int_narrow( - tensor: IntTensor, D>, - dim: usize, - start: usize, - length: usize, - ) -> IntTensor, D> { - B::int_narrow(tensor, dim, start, length) - } - - fn int_chunk( - tensor: IntTensor, D>, - chunks: usize, - dim: usize, - ) -> Vec, D>> { - B::int_chunk(tensor, chunks, dim) - } - - fn int_random( - shape: Shape, - distribution: Distribution, - device: &Device>, - ) -> IntTensor, D> { - B::int_random(shape, distribution, device) - } - - fn int_add( - lhs: IntTensor, D>, - rhs: IntTensor, D>, - ) -> IntTensor, D> { - B::int_add(lhs, rhs) - } - - fn int_add_scalar( - lhs: IntTensor, D>, - rhs: IntElem>, - ) -> IntTensor, D> { - B::int_add_scalar(lhs, rhs) - } - - fn int_expand( - tensor: IntTensor, D1>, - shape: Shape, - ) -> IntTensor, D2> { - B::int_expand(tensor, shape) - } - - fn int_into_data( - tensor: IntTensor, D>, - ) -> impl std::future::Future + Send { - B::int_into_data(tensor) - } - - fn int_from_data( - data: TensorData, - device: &Device>, - ) -> IntTensor, D> { - B::int_from_data(data, device) - } -} - -impl QTensorOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ - fn q_shape(tensor: &burn_tensor::ops::QuantizedTensor) -> Shape { - B::q_shape(tensor) - } - - fn q_device(tensor: &burn_tensor::ops::QuantizedTensor) -> Device { - B::q_device(tensor) - } - - fn q_from_data( - data: TensorData, - device: &Device>, - ) -> burn_tensor::ops::QuantizedTensor, D> { - B::q_from_data(data, device) - } - - fn q_reshape( - tensor: burn_tensor::ops::QuantizedTensor, D1>, - shape: Shape, - ) -> burn_tensor::ops::QuantizedTensor, D2> { - B::q_reshape(tensor, shape) - } - - fn q_into_data( - tensor: burn_tensor::ops::QuantizedTensor, D>, - ) -> impl std::future::Future + Send { - B::q_into_data(tensor) - } - - fn quantize( - tensor: FloatTensor, D>, - scheme: &burn_tensor::quantization::QuantizationScheme, - qparams: burn_tensor::quantization::QuantizationParametersPrimitive>, - ) -> burn_tensor::ops::QuantizedTensor, D> { - B::quantize(tensor, scheme, qparams) - } - - fn dequantize( - tensor: burn_tensor::ops::QuantizedTensor, D>, - ) -> FloatTensor, D> { - B::dequantize(tensor) - } -} - -impl ModuleOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ - fn conv2d( - x: FloatTensor, 4>, - weight: FloatTensor, 4>, - bias: Option, 1>>, - options: ConvOptions<2>, - ) -> FloatTensor, 4> { - B::conv2d(x, weight, bias, options) - } - - fn conv_transpose2d( - x: FloatTensor, 4>, - weight: FloatTensor, 4>, - bias: Option, 1>>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor, 4> { - B::conv_transpose2d(x, weight, bias, options) - } - - fn avg_pool2d( - x: FloatTensor, 4>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor, 4> { - B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad) - } - - fn avg_pool2d_backward( - x: FloatTensor, 4>, - grad: FloatTensor, 4>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ) -> FloatTensor, 4> { - B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad) - } - - fn max_pool2d( - x: FloatTensor, 4>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> FloatTensor, 4> { - B::max_pool2d(x, kernel_size, stride, padding, dilation) - } - - fn max_pool2d_with_indices( - x: FloatTensor, 4>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { - let MaxPool2dWithIndices { output, indices } = - B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation); - MaxPool2dWithIndices { output, indices } - } - - fn max_pool2d_with_indices_backward( - x: FloatTensor, 4>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - output_grad: FloatTensor, 4>, - indices: IntTensor, 4>, - ) -> MaxPool2dBackward> { - let MaxPool2dBackward { x_grad } = B::max_pool2d_with_indices_backward( - x, - kernel_size, - stride, - padding, - dilation, - output_grad, - indices, - ); - MaxPool2dBackward { x_grad } - } - - fn adaptive_avg_pool2d( - x: FloatTensor, 4>, - output_size: [usize; 2], - ) -> FloatTensor, 4> { - B::adaptive_avg_pool2d(x, output_size) - } - - fn adaptive_avg_pool2d_backward( - x: FloatTensor, 4>, - grad: FloatTensor, 4>, - ) -> FloatTensor, 4> { - B::adaptive_avg_pool2d_backward(x, grad) - } - - fn interpolate( - x: FloatTensor, 4>, - output_size: [usize; 2], - options: InterpolateOptions, - ) -> FloatTensor, 4> { - B::interpolate(x, output_size, options) - } - - fn interpolate_backward( - x: FloatTensor, 4>, - grad: FloatTensor, 4>, - output_size: [usize; 2], - options: InterpolateOptions, - ) -> FloatTensor, 4> { - B::interpolate_backward(x, grad, output_size, options) - } - - fn conv3d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<3>, - ) -> FloatTensor { - B::conv3d(x, weight, bias, options) - } - - fn conv_transpose3d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<3>, - ) -> FloatTensor { - B::conv_transpose3d(x, weight, bias, options) - } -} - -impl ActivationOps> for SparseDecorator -where - B: Backend, - R: SparseRepresentation, -{ -} diff --git a/crates/burn-sparse/old/decorator/precision_bridge.rs b/crates/burn-sparse/old/decorator/precision_bridge.rs deleted file mode 100644 index 2f5a78cdc9..0000000000 --- a/crates/burn-sparse/old/decorator/precision_bridge.rs +++ /dev/null @@ -1,37 +0,0 @@ -use core::marker::PhantomData; - -use burn_tensor::{ - backend::{Backend, BackendBridge}, - ops::FloatTensor, -}; - -use crate::decorator::SparseDecorator; -use crate::decorator::SparseRepresentation; - -#[derive(Debug)] -pub struct FullPrecisionBridge { - _p: PhantomData, -} - -impl BackendBridge> for FullPrecisionBridge -where - B: Backend, - R: SparseRepresentation, - Bridge: BackendBridge + 'static, -{ - type Target = SparseDecorator; - - fn into_target( - tensor: FloatTensor, D>, - device: Option>, - ) -> burn_tensor::ops::FloatTensor { - Bridge::into_target(tensor, device) - } - - fn from_target( - tensor: burn_tensor::ops::FloatTensor, - device: Option>>, - ) -> burn_tensor::ops::FloatTensor, D> { - Bridge::from_target(tensor, device) - } -} diff --git a/crates/burn-sparse/old/decorator/representation.rs b/crates/burn-sparse/old/decorator/representation.rs deleted file mode 100644 index 81d3ce96c6..0000000000 --- a/crates/burn-sparse/old/decorator/representation.rs +++ /dev/null @@ -1,21 +0,0 @@ -#[derive(Debug, Default, Clone)] -pub struct SparseCSR; - -#[derive(Debug, Default, Clone)] -pub struct SparseCOO; - -pub trait SparseRepresentation: Clone + Default + Send + Sync + 'static + core::fmt::Debug { - fn name() -> String; -} - -impl SparseRepresentation for SparseCOO { - fn name() -> String { - "SparseCOO".to_owned() - } -} - -impl SparseRepresentation for SparseCSR { - fn name() -> String { - "SparseCSR".to_owned() - } -} diff --git a/crates/burn-sparse/old/decorator/sparse_coo.rs b/crates/burn-sparse/old/decorator/sparse_coo.rs deleted file mode 100644 index bb5056fd30..0000000000 --- a/crates/burn-sparse/old/decorator/sparse_coo.rs +++ /dev/null @@ -1,1315 +0,0 @@ -use crate::backend::SparseBackend; -use crate::backend::SparseTensor; -use crate::decorator::SparseCOO; -use crate::decorator::SparseDecorator; -use burn_tensor::ops::FloatElem; -use burn_tensor::ops::FloatTensor; -use burn_tensor::ops::FloatTensorOps; - -use burn_tensor::Device; -use burn_tensor::{ - backend::Backend, Bool, ElementConversion, Float, Int, Shape, Tensor, TensorData, - TensorPrimitive, -}; - -#[derive(Clone, Debug)] -pub struct SparseCOOTensor { - pub coordinates: Option>, - pub values: Option>, - pub shape: Shape, - pub device: Device, -} - -fn flatten_coordinates( - coordinates: Tensor, - shape: Shape, - device: &Device, -) -> Tensor { - let mut strides_data = [[1]; D]; - for i in (0..D).rev() { - if D - 1 - i == S { - strides_data[i] = [1]; - } else if D - 1 - i < S { - strides_data[i] = [0]; - } else { - strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; - } - } - let strides_data: TensorData = TensorData::from(strides_data); - let strides: Tensor = Tensor::from_data(strides_data, device); - let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); - - flat_coordinates.unsqueeze_dim(0) -} - -fn unflatten_coordinates( - flat_coordinates: Tensor, - new_shape: Shape, -) -> Tensor { - let flat_coordinates = flat_coordinates.squeeze::<1>(0); - let mut remaining_flat_coordinates = flat_coordinates.clone(); - let mut new_coordinates = Vec::with_capacity(D); - - for &dim_size in new_shape.dims.iter().rev() { - let size = dim_size as i64; - let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); - new_coordinates.push(new_coord.clone()); - remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); - } - - new_coordinates.reverse(); - - Tensor::stack(new_coordinates, 0) -} - -impl SparseBackend for SparseDecorator -where - B: Backend, -{ - type SparseTensorPrimitive = SparseCOOTensor; - - fn sparse_to_sparse( - dense: Self::FloatTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); - - let shape = dense.shape(); - let device = dense.device(); - - let significant = dense.clone().not_equal_elem(0.0); - if !significant.clone().any().into_scalar() { - return Self::sparse_empty(dense.shape(), &device); - }; - - let coordinates = significant - .clone() - .nonzero() - .into_iter() - .map(|tensor| { - let length = tensor.shape().dims[0]; - let shape = Shape::new([1, length]); - tensor.reshape(shape) - }) - .collect(); - - let coordinates = Tensor::cat(coordinates, 0); - - let dense = dense.flatten(0, D - 1); - - let dims = significant.dims(); - let values = dense.gather( - 0, - significant - .flatten::<1>(0, dims.len() - 1) - .nonzero() - .remove(0), - ); - - let coordinates = Some(coordinates); - let values = Some(values); - - Self::SparseTensorPrimitive { - coordinates, - values, - shape, - device, - } - } - - fn sparse_to_dense( - sparse: Self::SparseTensorPrimitive, - ) -> Self::FloatTensorPrimitive { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = sparse; - - let (Some(coordinates), Some(values)) = (coordinates, values) else { - return Tensor::::zeros(shape, &device) - .into_primitive() - .tensor(); - }; - - let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); - let flat_coordinates = - flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); - let dense = dense.select_assign(0, flat_coordinates, values); - - dense.reshape(shape).into_primitive().tensor() - } - - fn sparse_sddmm( - lhs: Self::FloatTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - sparse: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - if sparse.coordinates.is_none() || sparse.values.is_none() { - return sparse; - } - - // Flatten the lhs and rhs into a tensor of rows and cols respectively - let lhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(lhs)); - let rhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(rhs)).transpose(); - let lhs_dims = lhs.shape().dims; - let rhs_dims = rhs.shape().dims; - - if lhs_dims[D - 1] != rhs_dims[D - 1] - || lhs_dims[D - 2] != sparse.shape.dims[D - 2] - || rhs_dims[D - 2] != sparse.shape.dims[D - 1] - { - panic!("invalid dimensions for sddmm. lhs and rhs must have compatible shapes for matmul, and sparse must have the correct shape for output of matmul between lhs and rhs."); - } - - let lhs = lhs.reshape([-1, lhs_dims[D - 1] as i32]); - let rhs = rhs.reshape([-1, rhs_dims[D - 1] as i32]); - - // Flatten the sparse tensor into - let device = sparse.device.clone(); - let mut shape = sparse.shape.clone(); - let lhs_coordinates = sparse - .coordinates - .clone() - .expect("Expected non-empty sparse tensor"); - - // swap the last two dims so its column-first - let swizzle = Tensor::::arange(0..D as i64, &device) - .slice_assign( - [D - 2..D], - Tensor::::from_ints([D - 1, D - 2], &device), - ) - .unsqueeze_dim(1) - .repeat(1, lhs_coordinates.shape().dims[1]); - let rhs_coordinates = lhs_coordinates.clone().gather(0, swizzle); - - let row_indices = flatten_coordinates::(lhs_coordinates, shape.clone(), &device); - - shape.dims.swap(D - 1, D - 2); - let col_indices = flatten_coordinates::(rhs_coordinates, shape.clone(), &device); - - let row_indices = row_indices.transpose().repeat(1, lhs_dims[D - 1]); - let col_indices = col_indices.transpose().repeat(1, rhs_dims[D - 1]); - - let lhs = lhs.gather(0, row_indices); - let rhs = rhs.gather(0, col_indices); - - let dotted = lhs.mul(rhs).sum_dim(1).squeeze(1); - - SparseCOOTensor { - coordinates: sparse.coordinates, - values: Some(dotted), - shape: sparse.shape, - device, - } - } - - fn sparse_spmm( - lhs: Self::SparseTensorPrimitive, - rhs: Self::FloatTensorPrimitive, - ) -> Self::FloatTensorPrimitive { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = lhs; - - let rhs: Tensor = Tensor::from_primitive(TensorPrimitive::Float(rhs)); - let rhs_shape = rhs.shape(); - let mut out_shape = shape.clone(); - out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; - - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early - return Tensor::::zeros(out_shape, &device) - .into_primitive() - .tensor(); - }; - - let nnz = coordinates.shape().dims[1]; - - // Ensure they are of the correct shape to multiply - if shape.dims[D - 1] != rhs_shape.dims[D - 2] { - panic!("Invalid shape for matrix multiplication"); - } - - // Ensure batches are the same - if D > 2 && rhs_shape.dims[0..D - 2] != shape.dims[0..D - 2] { - panic!("Batches must be of the same shape"); - } - - // Compute strides for the dense tensor to match the flattened shape - let mut strides_data = [1; D]; - for i in (0..D - 1).rev() { - strides_data[i] = strides_data[i + 1] * shape.dims[i + 1] as i32; - } - let strides: Tensor = - Tensor::::from_ints(strides_data, &device).unsqueeze_dim(1); - - let column_index = coordinates.clone().slice([D - 1..D, 0..nnz]); - - // the indices into the flat row vector at which the containing matrix starts - let matrix_starts: Tensor = if D > 2 { - coordinates - .clone() - .slice([0..D - 2, 0..nnz]) - .mul(strides.clone().slice([0..D - 2])) - .div_scalar((shape.dims[D - 1]) as i32) - .sum_dim(0) - } else { - Tensor::::zeros(column_index.shape(), &device) - }; - - let row_index = coordinates.slice([D - 2..D - 1, 0..nnz]); - - let gather_index = matrix_starts.clone() + column_index; - let scatter_index = matrix_starts + row_index; - - let gather_index = gather_index.transpose().repeat(1, rhs_shape.dims[D - 1]); - let scatter_index = scatter_index.transpose().repeat(1, rhs_shape.dims[D - 1]); - let values = values.unsqueeze_dim(1).repeat(1, rhs_shape.dims[D - 1]); - - // Flatten the rhs similarly into 2 dimensions - let rhs: Tensor = rhs.reshape([-1, rhs_shape.dims[D - 1] as i32]); - - // Do the matmul using gather/scatter - let output: Tensor = - Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); - let gathered = rhs.gather(0, gather_index); - - let multiplied = gathered.mul(values); - - let scattered = output.scatter(0, scatter_index, multiplied); - - scattered.reshape(out_shape).into_primitive().tensor() - } - - fn sparse_device(tensor: &SparseTensor) -> burn_tensor::Device { - tensor.device.clone() - } - - fn sparse_to_device( - tensor: SparseTensor, - device: &burn_tensor::Device, - ) -> SparseTensor { - SparseCOOTensor { - coordinates: tensor.coordinates.map(|t| t.to_device(device)), - values: tensor.values.map(|t| t.to_device(device)), - shape: tensor.shape, - device: device.clone(), - } - } - - fn sparse_shape( - tensor: &Self::SparseTensorPrimitive, - ) -> burn_tensor::Shape { - tensor.shape.clone() - } - - fn sparse_empty( - shape: burn_tensor::Shape, - device: &burn_tensor::Device, - ) -> SparseTensor { - SparseCOOTensor { - coordinates: None, - values: None, - shape, - device: device.clone(), - } - } - - fn sparse_slice( - tensor: Self::SparseTensorPrimitive, - indices: [core::ops::Range; D2], - ) -> SparseTensor { - let mut indices = Vec::from(indices); - indices.extend(tensor.shape.dims[indices.len()..D1].iter().map(|&l| 0..l)); - let indices: [core::ops::Range; D1] = indices.try_into().expect("D2 must be <= D1"); - let out_shape = Shape::new(indices.clone().map(|r| r.end)); - - if tensor.coordinates.is_none() && tensor.values.is_none() { - return SparseCOOTensor { - coordinates: None, - values: None, - shape: out_shape, - device: tensor.device, - }; - } - - let coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let values = tensor - .values - .expect("Mismatch between coordinates and values"); - let device = tensor.device; - - let number_nonzero = coordinates.shape().dims[1]; - - let mut mask: Tensor = Tensor::ones(Shape::new([number_nonzero]), &device); - - for (dim, bound) in indices.iter().enumerate() { - let coords = coordinates.clone().slice([dim..dim + 1, 0..number_nonzero]); - let coords = coords.reshape(Shape::new([number_nonzero])); - - let mask_lower = coords - .clone() - .lower_elem(B::IntElem::from_elem(bound.end)) - .int(); - - let mask_upper = coords - .clone() - .greater_equal_elem(B::IntElem::from_elem(bound.start)) - .int(); - - mask = mask.mul(mask_lower).mul(mask_upper); - } - - let nonzero = mask.not_equal_elem(B::IntElem::from_elem(0)); - if !nonzero.clone().any().into_scalar() { - // no existing values were in the slice, so return an empty tensor - return SparseCOOTensor { - coordinates: None, - values: None, - shape: out_shape, - device, - }; - } - - let nonzero = nonzero.nonzero(); - - let indices_dim1 = nonzero - .first() - .cloned() - .expect("Expected dimension to exist"); - - let coordinates = coordinates.select(1, indices_dim1.clone()); - let values = values.select(0, indices_dim1); - - let coordinates = Some(coordinates); - let values = Some(values); - - SparseCOOTensor { - coordinates, - values, - shape: out_shape, - device, - } - } - - fn sparse_from_data( - data: TensorData, - device: &burn_tensor::Device, - ) -> SparseTensor { - let dense = B::float_from_data(data, device); - Self::sparse_to_sparse(dense) - } - - fn sparse_into_data( - tensor: SparseTensor, - ) -> impl core::future::Future + Send { - B::float_into_data(Self::sparse_to_dense(tensor)) - } - - fn sparse_reshape( - tensor: SparseCOOTensor, - out_shape: Shape, - ) -> SparseCOOTensor { - if tensor.coordinates.is_none() && tensor.values.is_none() { - return SparseCOOTensor { - coordinates: None, - values: None, - shape: out_shape, - device: tensor.device, - }; - } - - let coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let values = tensor - .values - .expect("Mismatch between coordinates and values"); - let shape = tensor.shape; - let device = tensor.device; - - // Flatten the coordinates - let flat_coordinates = flatten_coordinates::(coordinates, shape, &device); - - // Unflatten the coordinates to the new shape - let new_coordinates = unflatten_coordinates(flat_coordinates, out_shape.clone()); - - SparseCOOTensor { - coordinates: Some(new_coordinates), - values: Some(values), - shape: out_shape, - device, - } - } - - fn sparse_transpose(tensor: SparseTensor) -> SparseTensor { - let d = tensor.shape.dims.len(); - let mut axes: Vec = (0..d).collect(); - axes.swap(d - 1, d - 2); - Self::sparse_permute(tensor, &axes) - } - - fn sparse_swap_dims( - tensor: SparseTensor, - dim1: usize, - dim2: usize, - ) -> SparseTensor { - let d = tensor.shape.dims.len(); - let mut axes: Vec = (0..d).collect(); - axes.swap(dim1, dim2); - Self::sparse_permute(tensor, &axes) - } - - fn sparse_permute( - tensor: SparseTensor, - axes: &[usize], - ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - mut shape, - device, - } = tensor; - - for (i, &j) in (0..D).zip(axes).filter(|(i, j)| i < j) { - shape.dims.swap(i, j); - } - - let axes = Tensor::from(axes); - let coordinates = coordinates.map(|coordinates| coordinates.select(0, axes)); - - SparseCOOTensor { - coordinates, - values, - shape, - device, - } - } - - fn sparse_flip( - tensor: SparseTensor, - axes: &[usize], - ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = tensor; - - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early - return SparseCOOTensor { - coordinates: None, - values: None, - shape, - device, - }; - }; - - let nnz = coordinates.shape().dims[1]; - - let mut mask = [0; D]; - for &axis in axes { - mask[axis] = 1; - } - let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, &device) - .unsqueeze_dim(1) - .repeat(1, nnz) - .bool(); - - let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, &device) - .unsqueeze_dim(1) - .repeat(1, nnz) - .sub(coordinates.clone()) - .sub_scalar(1); - - let coordinates = coordinates.mask_where(mask, flipped); - - let coordinates = Some(coordinates); - let values = Some(values); - - SparseCOOTensor { - coordinates, - values, - shape, - device, - } - } - - fn sparse_slice_assign( - tensor: SparseTensor, - ranges: [core::ops::Range; D2], - mut value: SparseTensor, - ) -> SparseTensor { - let value_nnz = value - .coordinates - .as_ref() - .map(|coords| coords.shape().dims[1]) - .unwrap_or(0); - - let mut ranges = Vec::from(ranges); - ranges.extend(tensor.shape.dims[ranges.len()..D1].iter().map(|&l| 0..l)); - let ranges: [core::ops::Range; D1] = ranges.try_into().expect("D2 must be <= D1"); - - let shape = tensor.shape.clone(); - let sliced = Self::sparse_reshape( - Self::sparse_slice(tensor.clone(), ranges.clone()), - shape.clone(), - ); - let tensor = Self::sparse_sub(tensor, sliced); - let offset = Tensor::::from_ints(ranges.map(|r| r.start), &tensor.device); - let offset = offset.unsqueeze_dim::<2>(1).repeat(1, value_nnz); - - value.shape = shape; - value.coordinates = value.coordinates.map(|coords| coords + offset); - - Self::sparse_add(tensor, value) - } - - fn sparse_repeat( - tensor: SparseTensor, - dim: usize, - times: usize, - ) -> SparseTensor { - let SparseCOOTensor { - coordinates, - values, - shape, - device, - } = tensor; - - let mut out_shape = shape.clone(); - out_shape.dims[dim] *= times; - - let (Some(coordinates), Some(values)) = (coordinates, values) else { - // All zeros, exit early - return SparseCOOTensor { - coordinates: None, - values: None, - shape, - device, - }; - }; - - let device = coordinates.device(); - let nnz = coordinates.shape().dims[1]; - - let values = values.repeat(0, times); - - let coordinates_mask: Tensor = Tensor::zeros(coordinates.shape(), &device); - let ones: Tensor = Tensor::ones(Shape::new([1, nnz]), &device); - let coordinates_mask = coordinates_mask.slice_assign([dim..dim + 1, 0..nnz], ones); - let coordinates = Tensor::cat( - (0..times) - .map(|n| { - coordinates.clone() - + coordinates_mask.clone() * (n as i32) * (shape.dims[dim] as i32) - }) - .collect::>(), - 1, - ); - - let coordinates = Some(coordinates); - let values = Some(values); - - SparseCOOTensor { - coordinates, - values, - shape: out_shape, - device, - } - } - - fn sparse_cat( - _tensors: Vec>, - _dim: usize, - ) -> SparseTensor { - let _offset = 0; - todo!() - } - - fn sparse_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("elementwise equal is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_not_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("elementwise not_equal is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_any( - tensor: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - let SparseCOOTensor { - coordinates, - values: _, - shape: _, - device: _, - } = tensor; - let any = coordinates.is_some(); - Tensor::::from([any]).into_primitive() - } - - fn sparse_any_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> burn_tensor::ops::BoolTensor { - panic!("any_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); - } - - fn sparse_all( - tensor: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - let SparseCOOTensor { - coordinates, - values: _, - shape, - device: _, - } = tensor; - let all = match coordinates { - Some(coordinates) => shape.num_elements() == coordinates.shape().dims[1], - None => false, - }; - Tensor::::from([all]).into_primitive() - } - - fn sparse_all_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> burn_tensor::ops::BoolTensor { - panic!("all_dim is unsupported for the SparseCOO Decorator due to performance issues, convert to dense explicitly to ensure you understand"); - } - - fn sparse_expand( - tensor: SparseTensor, - _shape: Shape, - ) -> SparseTensor { - let SparseCOOTensor { - coordinates: _, - values: _, - shape: _, - device: _, - } = tensor; - todo!() - } - - fn sparse_coalesce_sum( - tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - if tensor.coordinates.as_ref().map(|c| c.shape().dims[1] <= 1) == Some(true) { - return tensor; - } - let original_shape = tensor.shape.clone(); - - if tensor.coordinates.is_none() && tensor.values.is_none() { - return SparseCOOTensor { - coordinates: None, - values: None, - shape: original_shape, - device: tensor.device, - }; - } - - let coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let values = tensor - .values - .expect("Mismatch between coordinates and values"); - let device = tensor.device; - let nnz = coordinates.shape().dims[1]; - - let coordinates = - flatten_coordinates::(coordinates, original_shape.clone(), &device); - let _flat_shape = Shape::new([original_shape.num_elements()]); - - let (coordinates, indices) = coordinates.sort_with_indices(1); - let values = values.select(0, indices.squeeze(0)); - let range = Tensor::::arange(0..nnz as i64, &device).unsqueeze::<2>(); - - // Get the diff of coordinates, diff[i] = coordinates[i]-coordinates[i-1] - let left_slice = coordinates.clone().slice([0..1, 0..nnz - 1]); - let right_slice = coordinates.clone().slice([0..1, 1..nnz]); - let diff = right_slice - left_slice; - let ones = Tensor::::ones(Shape::new([1, 1]), &device); - let diff = Tensor::cat(vec![ones, diff], 1); - - // TODO this all would be way cleaner with cumsum/max, but that is waiting on a pull request as of writing - // inspiration could be taken from pytorch_scatter for better implementations - let unique_mask = diff.not_equal_elem(0); - let unique_indices = unique_mask.clone().nonzero().remove(1); - let steps = Tensor::cat( - vec![unique_indices.clone(), Tensor::from_data([nnz], &device)], - 0, - ); - let unique = steps.shape().dims[0]; - let steps = steps - .clone() - .slice([1..unique]) - .sub(steps.slice([0..unique - 1])) - .max() - // .sub_scalar(1) - .into_scalar() - .elem::(); - - let mut scatter_indices = range.mul(unique_mask.int()); - - for _ in 0..steps { - scatter_indices = scatter_indices - .clone() - .slice([0..1, 1..nnz]) - .max_pair(scatter_indices.slice([0..1, 0..nnz - 1])); - scatter_indices = Tensor::cat( - vec![Tensor::zeros(Shape::new([1, 1]), &device), scatter_indices], - 1, - ); - } - - // Scatter/Gather everything into place - let zeroed = Tensor::::zeros(Shape::new([nnz]), &device); - let values = zeroed.scatter(0, scatter_indices.squeeze(0), values); - let values = values.gather(0, unique_indices.clone()); - let coordinates = coordinates.gather(1, unique_indices.unsqueeze::<2>()); - let coordinates = unflatten_coordinates(coordinates, original_shape.clone()); - - let coordinates = Some(coordinates); - let values = Some(values); - - // reshape back into the original shape and send it! - SparseCOOTensor { - coordinates, - values, - shape: original_shape, - device, - } - } - - fn sparse_nonzero(tensor: Self::SparseTensorPrimitive) -> usize { - match tensor.coordinates { - Some(coordinates) => coordinates.shape().dims[1], - None => 0, - } - } - - fn sparse_density(sparse: Self::SparseTensorPrimitive) -> f32 { - match sparse.coordinates { - Some(coordinates) => { - coordinates.shape().dims[1] as f32 / sparse.shape.num_elements() as f32 - } - None => 0.0, - } - } - - fn sparse_add( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor { - let SparseCOOTensor { - coordinates: lhs_coordinates, - values: lhs_values, - shape: lhs_shape, - device: lhs_device, - } = lhs; - let (Some(lhs_coordinates), Some(lhs_values)) = (lhs_coordinates, lhs_values) else { - return rhs; - }; - - let SparseCOOTensor { - coordinates: rhs_coordinates, - values: rhs_values, - shape: rhs_shape, - device: rhs_device, - } = rhs; - let (Some(rhs_coordinates), Some(rhs_values)) = (rhs_coordinates, rhs_values) else { - return SparseCOOTensor { - coordinates: Some(lhs_coordinates), - values: Some(lhs_values), - shape: lhs_shape, - device: lhs_device, - }; - }; - - assert_eq!(lhs_shape, rhs_shape); - assert_eq!(lhs_device, rhs_device); - - let coordinates = Some(Tensor::cat(vec![lhs_coordinates, rhs_coordinates], 1)); - let values = Some(Tensor::cat(vec![lhs_values, rhs_values], 0)); - let shape = lhs_shape; - let device = lhs_device; - - let result = SparseCOOTensor { - coordinates, - values, - shape, - device, - }; - - Self::sparse_coalesce_sum(result) - } - - fn sparse_add_scalar( - _: SparseTensor, - _: FloatElem, - ) -> SparseTensor { - panic!("Cannot add scalar to sparse, only zero preserving operations are permitted"); - } - - fn sparse_add_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - if lhs.shape != B::float_shape(&rhs) { - panic!("lhs and rhs must have the same shape for sparse_add_dense"); - } - - if lhs.coordinates.is_none() && lhs.values.is_none() { - return rhs; - } - - let coordinates = lhs - .coordinates - .expect("Mismatch between coordinates and values"); - let values = lhs.values.expect("Mismatch between coordinates and values"); - let device = lhs.device; - let shape = lhs.shape; - - let coordinates = flatten_coordinates::(coordinates, shape.clone(), &device); - let dense = B::float_reshape(rhs, Shape::new([shape.num_elements()])); - - let dense = B::float_scatter( - 0, - dense, - coordinates.squeeze(0).into_primitive(), - values.into_primitive().tensor(), - ); - - B::float_reshape(dense, shape) - } - - fn sparse_sub( - lhs: SparseTensor, - rhs: SparseTensor, - ) -> SparseTensor { - Self::sparse_add( - lhs, - Self::sparse_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), - ) - } - - fn sparse_sub_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - Self::sparse_add_dense( - lhs, - B::float_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), - ) - } - - fn sparse_sub_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - panic!("Cannot add scalar to sparse, only zero preserving operations are permitted"); - } - - fn sparse_mul( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - panic!("sparse_mul is unsupported until scatter supports multiplication based reduction"); - } - - fn sparse_mul_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - if lhs.shape != B::float_shape(&rhs) { - panic!("lhs and rhs must have the same shape for sparse_add_dense"); - } - - if lhs.coordinates.is_none() && lhs.values.is_none() { - return Self::float_zeros(lhs.shape, &lhs.device); - } - - // TODO: this could potentially be optimized if/when scatter gets other reduction strategies - let lhs = Self::sparse_to_dense(lhs); - Self::float_mul(lhs, rhs) - } - - fn sparse_mul_scalar( - mut lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor { - lhs.values = lhs.values.map(|values| values.mul_scalar(rhs)); - lhs - } - - fn sparse_div( - _: SparseTensor, - _: SparseTensor, - ) -> SparseTensor { - panic!("sparse_div is unsupported until scatter supports multiplication based reduction"); - } - - fn sparse_div_dense( - lhs: SparseTensor, - rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - if lhs.shape != B::float_shape(&rhs) { - panic!("lhs and rhs must have the same shape for sparse_add_dense"); - } - - if lhs.coordinates.is_none() && lhs.values.is_none() { - return Self::float_zeros(lhs.shape, &lhs.device); - } - - // TODO: this could potentially be optimized if/when scatter gets other reduction strategies - let lhs = Self::sparse_to_dense(lhs); - Self::float_div(lhs, rhs) - } - - fn sparse_div_scalar( - mut lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor { - lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); - lhs - } - - fn sparse_max(_tensor: SparseTensor) -> SparseTensor { - panic!("max is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_max_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - panic!( - "max_dim is unsupported for SparseCOO until scatter supports other reduction methods" - ); - } - - fn sparse_min(_tensor: SparseTensor) -> SparseTensor { - panic!("min is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_min_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - panic!( - "min_dim is unsupported for SparseCOO until scatter supports other reduction methods" - ); - } - - fn sparse_greater( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_greater is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_greater_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_greater_elem is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_greater_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_greater_equal is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_greater_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!( - "sparse_greater_equal_elem is not supported for SparseCOO as it outputs a dense tensor" - ); - } - - fn sparse_lower( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_lower is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_lower_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_lower_elem is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_lower_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_lower_equal is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_lower_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!( - "sparse_lower_equal_elem is not supported for SparseCOO as it outputs a dense tensor" - ); - } - - fn sparse_abs(mut tensor: SparseTensor) -> SparseTensor { - tensor.values = tensor.values.map(|values| values.abs()); - tensor - } - - fn sparse_powf( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - panic!("sparse_powf is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_powi( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - panic!("sparse_powi is unsupported for SparseCOO until scatter supports other reduction methods"); - } - - fn sparse_powf_scalar( - mut lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor { - lhs.values = lhs.values.map(|values| values.powf_scalar(rhs)); - lhs - } - - fn sparse_powi_scalar( - mut lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor { - lhs.values = lhs.values.map(|values| values.powi_scalar(rhs)); - lhs - } - - fn sparse_clamp( - mut tensor: SparseTensor, - min: FloatElem, - max: FloatElem, - ) -> SparseTensor { - tensor.values = tensor.values.map(|values| values.clamp(min, max)); - tensor - } - - fn sparse_clamp_min( - mut tensor: SparseTensor, - min: FloatElem, - ) -> SparseTensor { - tensor.values = tensor.values.map(|values| values.clamp_min(min)); - tensor - } - - fn sparse_clamp_max( - mut tensor: SparseTensor, - max: FloatElem, - ) -> SparseTensor { - tensor.values = tensor.values.map(|values| values.clamp_max(max)); - tensor - } - - fn sparse_select( - tensor: SparseTensor, - dim: usize, - indices: burn_tensor::ops::IntTensor, - ) -> SparseTensor { - if tensor.coordinates.is_none() && tensor.values.is_none() { - return tensor; - } - - let coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let values = tensor - .values - .expect("Mismatch between coordinates and values"); - let device = tensor.device; - let mut shape = tensor.shape; - let indices = Tensor::::new(indices); - - let nnz = coordinates.shape().dims[1]; - let dim_coords = coordinates - .clone() - .slice([dim..dim + 1, 0..nnz]) - .squeeze::<1>(0); - let indices = indices.select(0, dim_coords); - let indices_len = indices.shape().num_elements(); - let coordinates = coordinates.slice_assign( - [dim..dim + 1, 0..nnz], - indices.unsqueeze::<2>().repeat(1, D), - ); - - shape.dims[dim] = indices_len; - - SparseCOOTensor { - coordinates: Some(coordinates), - values: Some(values), - shape, - device, - } - } - - fn sparse_select_assign( - _tensor: SparseTensor, - _dim: usize, - _indices: burn_tensor::ops::IntTensor, - _values: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_gather( - _dim: usize, - _tensor: SparseTensor, - _indices: burn_tensor::ops::IntTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_scatter( - _dim: usize, - _tensor: SparseTensor, - _indices: burn_tensor::ops::IntTensor, - _values: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_sum(tensor: SparseTensor) -> SparseTensor { - tensor - .values - .map(|values| Self::sparse_to_sparse(values.sum().into_primitive().tensor())) - .unwrap_or(Self::sparse_empty(Shape::new([1]), &tensor.device)) - } - - fn sparse_sum_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - panic!("sparse_sum_dim unsupported for SparseCOO"); - } - - fn sparse_prod(tensor: SparseTensor) -> SparseTensor { - if tensor.coordinates.is_none() && tensor.values.is_none() { - return Self::sparse_empty(Shape::new([1]), &tensor.device); - } - - let coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let values = tensor - .values - .expect("Mismatch between coordinates and values"); - let device = tensor.device; - let shape = tensor.shape; - - if shape.num_elements() != coordinates.dims()[1] { - Self::sparse_empty(Shape::new([1]), &device) - } else { - Self::sparse_to_sparse(values.sum().into_primitive().tensor()) - } - } - - fn sparse_prod_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - panic!("sparse_prod_dim is not supported for SparseCOO until scatter supports product reduction") - } - - fn sparse_mean(tensor: SparseTensor) -> SparseTensor { - tensor - .values - .map(|values| { - let elems = values.shape().num_elements(); - Self::sparse_to_sparse((values.sum() / elems as f32).into_primitive().tensor()) - }) - .unwrap_or(Self::sparse_empty(Shape::new([1]), &tensor.device)) - } - - fn sparse_mean_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - panic!("mean_dim is not supported for SparseCOO until scatter supports mean reduction"); - } - - fn sparse_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_not_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - panic!("sparse_not_equal_elem is not supported for SparseCOO as it outputs a dense tensor"); - } - - fn sparse_remainder_scalar( - mut lhs: SparseTensor, - rhs: FloatElem, - ) -> SparseTensor { - lhs.values = lhs.values.map(|v| v.remainder_scalar(rhs)); - lhs - } - - fn sparse_neg(mut tensor: SparseTensor) -> SparseTensor { - tensor.values = tensor.values.map(|v| v.neg()); - tensor - } - - fn sparse_sign(mut tensor: SparseTensor) -> SparseTensor { - tensor.values = tensor.values.map(|values| values.sign()); - tensor - } - - fn sparse_remove_zeros( - tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - if tensor.coordinates.is_none() && tensor.values.is_none() { - return tensor; - } - - let _coordinates = tensor - .coordinates - .expect("Mismatch between coordinates and values"); - let _values = tensor - .values - .expect("Mismatch between coordinates and values"); - let _device = tensor.device; - let _shape = tensor.shape; - - // let zeros = tensor.values.map(|values| values.equal_elem(0).nonzero()); - todo!() - } -} diff --git a/crates/burn-sparse/old/decorator/sparse_csr.rs b/crates/burn-sparse/old/decorator/sparse_csr.rs deleted file mode 100644 index 04826fc10f..0000000000 --- a/crates/burn-sparse/old/decorator/sparse_csr.rs +++ /dev/null @@ -1,522 +0,0 @@ -use crate::backend::SparseBackend; -use crate::backend::SparseTensor; -use crate::decorator::SparseCSR; -use crate::decorator::SparseDecorator; -use burn_tensor::backend::Backend; -use burn_tensor::ops::FloatElem; -use burn_tensor::ops::FloatTensor; -use core::marker::PhantomData; - -#[derive(Debug, Default, Clone)] -pub struct SparseCSRTensor { - _b: PhantomData, -} - -impl SparseBackend for SparseDecorator -where - B: Backend, -{ - type SparseTensorPrimitive = SparseCSRTensor; - - fn sparse_empty( - _shape: burn_tensor::Shape, - _device: &burn_tensor::Device, - ) -> SparseTensor { - todo!() - } - - fn sparse_to_sparse( - _dense: Self::FloatTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - todo!() - } - - fn sparse_to_dense( - _sparse: Self::SparseTensorPrimitive, - ) -> Self::FloatTensorPrimitive { - todo!() - } - - fn sparse_spmm( - _lhs: Self::SparseTensorPrimitive, - _rhs: Self::FloatTensorPrimitive, - ) -> Self::FloatTensorPrimitive { - todo!() - } - - fn sparse_sddmm( - _lhs: Self::FloatTensorPrimitive, - _rhs: Self::FloatTensorPrimitive, - _sparse: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - todo!() - } - - fn sparse_slice( - _tensor: SparseTensor, - _indices: [std::ops::Range; D2], - ) -> SparseTensor { - todo!() - } - - fn sparse_device(_tensor: &SparseTensor) -> burn_tensor::Device { - todo!() - } - - fn sparse_to_device( - _tensor: SparseTensor, - _device: &burn_tensor::Device, - ) -> SparseTensor { - todo!() - } - - fn sparse_shape(_tensor: &SparseTensor) -> burn_tensor::Shape { - todo!() - } - - async fn sparse_into_data( - _tensor: SparseTensor, - ) -> burn_tensor::TensorData { todo!() } - - fn sparse_from_data( - _data: burn_tensor::TensorData, - _device: &burn_tensor::Device, - ) -> SparseTensor { - todo!() - } - - fn sparse_reshape( - _tensor: SparseTensor, - _shape: burn_tensor::Shape, - ) -> SparseTensor { - todo!() - } - - fn sparse_transpose(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_swap_dims( - _tensor: SparseTensor, - _dim1: usize, - _dim2: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_permute( - _tensor: SparseTensor, - _axes: &[usize], - ) -> SparseTensor { - todo!() - } - - fn sparse_flip( - _tensor: SparseTensor, - _axes: &[usize], - ) -> SparseTensor { - todo!() - } - - fn sparse_slice_assign( - _tensor: SparseTensor, - _ranges: [std::ops::Range; D2], - _value: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_repeat( - _tensor: SparseTensor, - _dim: usize, - _times: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_cat( - _tensors: Vec>, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_not_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_any( - _tensor: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_any_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_all( - _tensor: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_all_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_expand( - _tensor: SparseTensor, - _shape: burn_tensor::Shape, - ) -> SparseTensor { - todo!() - } - - fn sparse_coalesce_sum( - _tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - todo!() - } - - fn sparse_nonzero(_tensor: Self::SparseTensorPrimitive) -> usize { - todo!() - } - - fn sparse_density(_sparse: Self::SparseTensorPrimitive) -> f32 { - todo!() - } - - fn sparse_add( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_add_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_add_dense( - _lhs: SparseTensor, - _rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - todo!() - } - - fn sparse_sub( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_sub_dense( - _lhs: SparseTensor, - _rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - todo!() - } - - fn sparse_sub_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_mul( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_mul_dense( - _lhs: SparseTensor, - _rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - todo!() - } - - fn sparse_mul_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_div( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_div_dense( - _lhs: SparseTensor, - _rhs: burn_tensor::ops::FloatTensor, - ) -> FloatTensor { - todo!() - } - - fn sparse_div_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_max(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_max_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_min(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_min_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_greater( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_greater_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_greater_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_greater_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_lower( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_lower_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_lower_equal( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_lower_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_abs(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_powf( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_powi( - _lhs: SparseTensor, - _rhs: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_powf_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_powi_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_clamp( - _tensor: SparseTensor, - _min: FloatElem, - _max: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_clamp_min( - _tensor: SparseTensor, - _min: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_clamp_max( - _tensor: SparseTensor, - _max: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_select( - _tensor: SparseTensor, - _dim: usize, - _indices: burn_tensor::ops::IntTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_select_assign( - _tensor: SparseTensor, - _dim: usize, - _indices: burn_tensor::ops::IntTensor, - _values: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_gather( - _dim: usize, - _tensor: SparseTensor, - _indices: burn_tensor::ops::IntTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_scatter( - _dim: usize, - _tensor: SparseTensor, - _indices: burn_tensor::ops::IntTensor, - _values: SparseTensor, - ) -> SparseTensor { - todo!() - } - - fn sparse_sum(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_sum_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_prod(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_prod_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_mean(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_mean_dim( - _tensor: SparseTensor, - _dim: usize, - ) -> SparseTensor { - todo!() - } - - fn sparse_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_not_equal_elem( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> burn_tensor::ops::BoolTensor { - todo!() - } - - fn sparse_remainder_scalar( - _lhs: SparseTensor, - _rhs: FloatElem, - ) -> SparseTensor { - todo!() - } - - fn sparse_neg(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_sign(_tensor: SparseTensor) -> SparseTensor { - todo!() - } - - fn sparse_remove_zeros( - _tensor: Self::SparseTensorPrimitive, - ) -> Self::SparseTensorPrimitive { - todo!() - } -} diff --git a/crates/burn-sparse/src/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs index 519d1b0124..4eb392da82 100644 --- a/crates/burn-sparse/src/coo_bool.rs +++ b/crates/burn-sparse/src/coo_bool.rs @@ -88,19 +88,6 @@ impl SparseBoolOps for COO { todo!() } - fn bool_into_data( - tensor: >::SparsePrimitive, - ) -> impl std::future::Future + Send { - async { todo!() } - } - - fn bool_from_data( - data: burn_tensor::TensorData, - device: &burn_tensor::Device, - ) -> >::SparsePrimitive { - todo!() - } - fn bool_repeat_dim( tensor: >::SparsePrimitive, dim: usize, diff --git a/crates/burn-sparse/src/coo_float.rs b/crates/burn-sparse/src/coo_float.rs index 8149cd8f04..6075eb62d8 100644 --- a/crates/burn-sparse/src/coo_float.rs +++ b/crates/burn-sparse/src/coo_float.rs @@ -403,19 +403,6 @@ impl SparseFloatOps for COO { tensor.shape.clone() } - fn float_into_data( - tensor: >::SparsePrimitive, - ) -> impl std::future::Future + Send { - async { todo!() } - } - - fn float_from_data( - data: burn_tensor::TensorData, - device: &burn_tensor::Device, - ) -> >::SparsePrimitive { - todo!() - } - fn float_reshape( tensor: >::SparsePrimitive, out_shape: burn_tensor::Shape, diff --git a/crates/burn-sparse/src/coo_int.rs b/crates/burn-sparse/src/coo_int.rs index 126e439fee..85120b8096 100644 --- a/crates/burn-sparse/src/coo_int.rs +++ b/crates/burn-sparse/src/coo_int.rs @@ -78,19 +78,6 @@ impl SparseIntOps for COO { todo!() } - fn int_into_data( - tensor: >::SparsePrimitive, - ) -> impl std::future::Future + Send { - async { todo!() } - } - - fn int_from_data( - data: burn_tensor::TensorData, - device: &burn_tensor::Device, - ) -> >::SparsePrimitive { - todo!() - } - fn int_repeat_dim( tensor: >::SparsePrimitive, dim: usize, From 7b902525f9ab766b557633ac7915526bd6dafebe Mon Sep 17 00:00:00 2001 From: mcarthur Date: Sun, 25 Aug 2024 09:58:40 +0000 Subject: [PATCH 35/38] Removed unsupported sparse ops --- crates/burn-core/src/backend.rs | 2 +- crates/burn-core/src/tensor.rs | 5 -- crates/burn-tensor/src/tensor/api/sparse.rs | 18 +++++--- .../src/tensor/api/sparse_float.rs | 12 ++++- .../src/tensor/ops/sparse_tensor.rs | 46 ------------------- 5 files changed, 23 insertions(+), 60 deletions(-) diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index b44b37ecf8..bed5a09ddc 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -35,4 +35,4 @@ pub use burn_tch as libtorch; pub use burn_tch::LibTorch; #[cfg(feature = "sparse")] -pub use burn_sparse::decorator as sparse; +pub use burn_sparse as sparse; diff --git a/crates/burn-core/src/tensor.rs b/crates/burn-core/src/tensor.rs index ecc858ebbe..074606bb14 100644 --- a/crates/burn-core/src/tensor.rs +++ b/crates/burn-core/src/tensor.rs @@ -1,6 +1 @@ pub use burn_tensor::*; - -#[cfg(feature = "sparse")] -pub mod sparse { - pub use burn_sparse::backend::*; -} diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 75df0a51f2..a6c20fae9a 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -117,14 +117,16 @@ impl> BasicOps> for Float { fn into_data_async( tensor: ReprPrimitive, D>, ) -> impl Future + Send { - SR::float_into_data(tensor) + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } } fn from_data( data: TensorData, device: &::Device, ) -> ReprPrimitive, D> { - SR::float_from_data(data, device) + panic!("from_data not supported for sparse tensors, convert from dense.."); } fn repeat_dim( @@ -270,14 +272,16 @@ impl> BasicOps> for Bool { fn into_data_async( tensor: ReprPrimitive, D>, ) -> impl Future + Send { - SR::bool_into_data(tensor) + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } } fn from_data( data: TensorData, device: &::Device, ) -> ReprPrimitive, D> { - SR::bool_from_data(data, device) + panic!("from_data not supported for sparse tensors, convert from dense.."); } fn repeat_dim( @@ -423,14 +427,16 @@ impl> BasicOps> for Int { fn into_data_async( tensor: ReprPrimitive, D>, ) -> impl Future + Send { - SR::int_into_data(tensor) + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } } fn from_data( data: TensorData, device: &::Device, ) -> ReprPrimitive, D> { - SR::int_from_data(data, device) + panic!("from_data not supported for sparse tensors, convert from dense.."); } fn repeat_dim( diff --git a/crates/burn-tensor/src/tensor/api/sparse_float.rs b/crates/burn-tensor/src/tensor/api/sparse_float.rs index b675c5d86a..ee35c5e398 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_float.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_float.rs @@ -35,8 +35,16 @@ where pub fn spmm(self, rhs: Tensor) -> Tensor { // check!(TensorCheck::spmm(&self, &rhs)); Tensor::::new(TensorPrimitive::Float(SR::float_spmm( - self.primitive, - rhs.primitive, + self.into_primitive(), + rhs.into_primitive(), ))) } + + pub fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self { + Tensor::new(SR::float_sddmm( + lhs.into_primitive().tensor(), + rhs.into_primitive().tensor(), + self.into_primitive(), + )) + } } diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index c3f70e2034..c6200102a7 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -108,34 +108,6 @@ where /// The shape of the tensor. fn float_shape(tensor: &ReprPrimitive, D>) -> Shape; - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn float_into_data( - tensor: ReprPrimitive, D>, - ) -> impl Future + Send; - - /// Creates a tensor from the data structure. - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// * `device` - The device to create the tensor on. - /// - /// # Returns - /// - /// The tensor with the data. - fn float_from_data( - data: TensorData, - device: &Device, - ) -> ReprPrimitive, D>; - fn float_reshape( tensor: SR::SparsePrimitive, shape: Shape, @@ -476,15 +448,6 @@ pub trait SparseBoolOps, B: Backend> { device: &Device, ) -> SR::SparsePrimitive; - fn bool_into_data( - tensor: SR::SparsePrimitive, - ) -> impl Future + Send; - - fn bool_from_data( - data: TensorData, - device: &Device, - ) -> SR::SparsePrimitive; - fn bool_repeat_dim( tensor: SR::SparsePrimitive, dim: usize, @@ -581,15 +544,6 @@ pub trait SparseIntOps, B: Backend> { device: &Device, ) -> SR::SparsePrimitive; - fn int_into_data( - tensor: SR::SparsePrimitive, - ) -> impl Future + Send; - - fn int_from_data( - data: TensorData, - device: &Device, - ) -> SR::SparsePrimitive; - fn int_repeat_dim( tensor: SR::SparsePrimitive, dim: usize, From e00de1e5c312966349b59463c6f3c81c36106d8a Mon Sep 17 00:00:00 2001 From: McArthur-Alford Date: Fri, 30 Aug 2024 13:04:51 +1000 Subject: [PATCH 36/38] Coordinates OP, plus basicsparse for float/int --- crates/burn-sparse/src/coo_bool.rs | 19 +++++++ crates/burn-sparse/src/coo_float.rs | 10 +++- crates/burn-sparse/src/coo_int.rs | 41 ++++++++++++++ crates/burn-tensor/src/tensor/api/sparse.rs | 56 +++++++++++++++++++ .../burn-tensor/src/tensor/api/sparse_int.rs | 1 + .../src/tensor/api/sparse_tensor.rs | 12 +++- .../src/tensor/ops/sparse_tensor.rs | 29 +++++++++- 7 files changed, 162 insertions(+), 6 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/api/sparse_int.rs diff --git a/crates/burn-sparse/src/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs index 4eb392da82..151c84d0a3 100644 --- a/crates/burn-sparse/src/coo_bool.rs +++ b/crates/burn-sparse/src/coo_bool.rs @@ -1,9 +1,16 @@ use super::coo::COO; +use crate::SparseCOOTensor; +use crate::{flatten_coordinates, unflatten_coordinates}; +use burn_tensor::Int; +use burn_tensor::ReprPrimitive; +use burn_tensor::Shape; +use burn_tensor::Tensor; use burn_tensor::{ backend::Backend, ops::{SparseBoolOps, SparseTensorOps}, SparseStorage, }; +use burn_tensor::{Bool, Dense}; impl SparseBoolOps for COO { fn bool_to_sparse( @@ -149,4 +156,16 @@ impl SparseBoolOps for COO { ) -> >::SparsePrimitive { todo!() } + + fn bool_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } + + fn bool_to_dense( + sparse: >::SparsePrimitive, + ) -> B::BoolTensorPrimitive { + todo!() + } } diff --git a/crates/burn-sparse/src/coo_float.rs b/crates/burn-sparse/src/coo_float.rs index 6075eb62d8..b5da5978fb 100644 --- a/crates/burn-sparse/src/coo_float.rs +++ b/crates/burn-sparse/src/coo_float.rs @@ -3,8 +3,8 @@ use burn_tensor::cast::ToElement; use burn_tensor::ops::{FloatElem, SparseBoolOps}; use burn_tensor::{backend::Backend, ops::SparseFloatOps, Tensor}; use burn_tensor::{ - Bool, ElementConversion, Float, Shape, Sparse, SparseStorage, TensorData, TensorKind, - TensorPrimitive, + Bool, Dense, ElementConversion, Float, ReprPrimitive, Shape, Sparse, SparseStorage, TensorData, + TensorKind, TensorPrimitive, }; use burn_tensor::{Device, Int}; @@ -1011,4 +1011,10 @@ impl SparseFloatOps for COO { tensor.values = tensor.values.map(|values| values.neg()); tensor } + + fn float_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } } diff --git a/crates/burn-sparse/src/coo_int.rs b/crates/burn-sparse/src/coo_int.rs index 85120b8096..0a7f87a134 100644 --- a/crates/burn-sparse/src/coo_int.rs +++ b/crates/burn-sparse/src/coo_int.rs @@ -1,4 +1,11 @@ use super::coo::COO; +use crate::SparseCOOTensor; +use crate::{flatten_coordinates, unflatten_coordinates}; +use burn_tensor::Dense; +use burn_tensor::Int; +use burn_tensor::ReprPrimitive; +use burn_tensor::Shape; +use burn_tensor::Tensor; use burn_tensor::{backend::Backend, ops::SparseIntOps, SparseStorage}; impl SparseIntOps for COO { @@ -139,4 +146,38 @@ impl SparseIntOps for COO { ) -> >::SparsePrimitive { todo!() } + + fn int_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } + + fn int_to_dense( + sparse: >::SparsePrimitive, + ) -> B::IntTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = sparse; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + return Tensor::::zeros(shape, &device).into_primitive(); + }; + + let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); + let flat_coordinates = + flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); + let dense = dense.select_assign(0, flat_coordinates, values); + + dense.reshape(shape).into_primitive() + } + + fn int_to_sparse( + dense: ::IntTensorPrimitive, + ) -> >::SparsePrimitive { + todo!() + } } diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index a6c20fae9a..1636b60096 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -18,6 +18,10 @@ where fn into_sparse( tensor: ReprPrimitive, ) -> ReprPrimitive, D>; + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option>; } impl> BasicSparseOps for SR @@ -35,6 +39,58 @@ where ) -> ReprPrimitive, D> { SR::float_to_sparse(tensor.tensor()) } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::float_coordinates(tensor) + } +} + +impl> BasicSparseOps for SR +where + (B, Int, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + SR::int_to_dense(tensor) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::int_to_sparse(tensor) + } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::int_coordinates(tensor) + } +} + +impl> BasicSparseOps for SR +where + (B, Bool, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + SR::bool_to_dense(tensor) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::bool_to_sparse(tensor) + } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::bool_coordinates(tensor) + } } impl> BasicOps> for Float { diff --git a/crates/burn-tensor/src/tensor/api/sparse_int.rs b/crates/burn-tensor/src/tensor/api/sparse_int.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_int.rs @@ -0,0 +1 @@ + diff --git a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs index 65fd160e68..067f9fcb16 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs @@ -1,5 +1,7 @@ use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; -use crate::{check, BasicOps, BasicSparseOps, Bool, SparseStorage, TensorPrimitive, TensorRepr}; +use crate::{ + check, BasicOps, BasicSparseOps, Bool, Int, SparseStorage, TensorPrimitive, TensorRepr, +}; impl Tensor where @@ -25,6 +27,12 @@ where (B, K, Sparse): TensorRepr, { pub fn into_dense(self) -> Tensor { - Tensor::::from_primitive(SR::into_dense(self.primitive)) + Tensor::::from_primitive(SR::into_dense(self.into_primitive())) + } + + pub fn coordinates(self) -> Option> { + Some(Tensor::::from_primitive(SR::coordinates( + self.into_primitive(), + )?)) } } diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index c6200102a7..b21011070c 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -1,9 +1,9 @@ use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; -use crate::TensorRepr; use crate::{ backend::Backend, Bool, Device, Float, Int, ReprPrimitive, Shape, Sparse, SparseStorage, TensorData, TensorKind, }; +use crate::{Dense, TensorRepr}; use core::{future::Future, ops::Range}; pub trait SparseTensorOps, B: Backend>: @@ -15,7 +15,12 @@ pub trait SparseFloatOps, B: Backend> where (B, Float, Sparse): TensorRepr, (B, Bool, Sparse): TensorRepr, + (B, Int, Sparse): TensorRepr, { + fn float_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + fn float_to_sparse( dense: B::FloatTensorPrimitive, ) -> ReprPrimitive, D>; @@ -394,9 +399,17 @@ where } pub trait SparseBoolOps, B: Backend> { + fn bool_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + + fn bool_to_dense( + sparse: ReprPrimitive, D>, + ) -> B::BoolTensorPrimitive; + fn bool_to_sparse( dense: B::BoolTensorPrimitive, - ) -> SR::SparsePrimitive; + ) -> ReprPrimitive, D>; fn bool_empty( shape: Shape, @@ -494,6 +507,18 @@ pub trait SparseBoolOps, B: Backend> { } pub trait SparseIntOps, B: Backend> { + fn int_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + + fn int_to_dense( + sparse: ReprPrimitive, D>, + ) -> B::IntTensorPrimitive; + + fn int_to_sparse( + dense: B::IntTensorPrimitive, + ) -> ReprPrimitive, D>; + fn int_empty( shape: Shape, device: &Device, From d8603f3fc83772481ab1620babbe8725b9be585f Mon Sep 17 00:00:00 2001 From: mcarthur Date: Fri, 30 Aug 2024 03:08:02 +0000 Subject: [PATCH 37/38] Removed unsupported ops --- crates/burn-sparse/src/coo_bool.rs | 14 --------- crates/burn-sparse/src/coo_float.rs | 14 --------- crates/burn-sparse/src/coo_int.rs | 14 --------- crates/burn-tensor/src/tensor/api/sparse.rs | 14 ++++----- .../src/tensor/api/sparse_numeric.rs | 28 +++++++++++++++-- .../src/tensor/ops/sparse_tensor.rs | 30 ------------------- 6 files changed, 31 insertions(+), 83 deletions(-) diff --git a/crates/burn-sparse/src/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs index 4eb392da82..338b9ef0f8 100644 --- a/crates/burn-sparse/src/coo_bool.rs +++ b/crates/burn-sparse/src/coo_bool.rs @@ -103,20 +103,6 @@ impl SparseBoolOps for COO { todo!() } - fn bool_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - - fn bool_not_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - fn bool_any( tensor: >::SparsePrimitive, ) -> >::SparsePrimitive { diff --git a/crates/burn-sparse/src/coo_float.rs b/crates/burn-sparse/src/coo_float.rs index 6075eb62d8..41485dce91 100644 --- a/crates/burn-sparse/src/coo_float.rs +++ b/crates/burn-sparse/src/coo_float.rs @@ -627,20 +627,6 @@ impl SparseFloatOps for COO { todo!() } - fn float_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - - fn float_not_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - fn float_any( tensor: >::SparsePrimitive, ) -> >::SparsePrimitive { diff --git a/crates/burn-sparse/src/coo_int.rs b/crates/burn-sparse/src/coo_int.rs index 85120b8096..2190a465ab 100644 --- a/crates/burn-sparse/src/coo_int.rs +++ b/crates/burn-sparse/src/coo_int.rs @@ -93,20 +93,6 @@ impl SparseIntOps for COO { todo!() } - fn int_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - - fn int_not_equal( - lhs: >::SparsePrimitive, - rhs: >::SparsePrimitive, - ) -> >::SparsePrimitive { - todo!() - } - fn int_any( tensor: >::SparsePrimitive, ) -> >::SparsePrimitive { diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index a6c20fae9a..b24257a661 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -155,14 +155,14 @@ impl> BasicOps> for Float { lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - Tensor::new(SR::float_equal(lhs, rhs)) + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); } fn not_equal( lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - Tensor::new(SR::float_not_equal(lhs, rhs)) + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); } fn any( @@ -303,14 +303,14 @@ impl> BasicOps> for Bool { lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - panic!("Non-zero preserving operations are not supported for sparse tensors"); + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); } fn not_equal( lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - panic!("Non-zero preserving operations are not supported for sparse tensors"); + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); } fn any( @@ -458,16 +458,14 @@ impl> BasicOps> for Int { lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - panic!("Non-zero preserving operations are not supported for sparse tensors"); - Tensor::new(SR::int_equal(lhs, rhs)) + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); } fn not_equal( lhs: ReprPrimitive, D>, rhs: ReprPrimitive, D>, ) -> Tensor> { - panic!("Non-zero preserving operations are not supported for sparse tensors"); - Tensor::new(SR::int_not_equal(lhs, rhs)) + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); } fn any( diff --git a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs index 9d084f148f..100cd6b38a 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs @@ -1,8 +1,30 @@ -use crate::{backend::Backend, BasicOps}; +use crate::check; -/// Trait that list all operations that can be applied on all numerical sparse tensors. +use crate::{ + backend::Backend, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Int, Shape, + Sparse, SparseStorage, Tensor, TensorKind, TensorRepr, +}; + +/// Trait that list all operations that can be applied on all sparse numerical tensors. /// /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait SparseNumeric: BasicOps {} +pub trait SparseNumeric: TensorRepr +where + B: Backend, + K: TensorKind + BasicOps, + SR: SparseStorage, + K::Elem: Element, +{ +} + +impl Tensor> +where + B: Backend, + K: TensorKind + BasicOps, + SR: SparseStorage, + (B, K, SR): SparseNumeric, + K::Elem: Element, +{ +} diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index c6200102a7..d1e6664601 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -150,16 +150,6 @@ where dim: usize, ) -> ReprPrimitive, D>; - fn float_equal( - lhs: ReprPrimitive, D>, - rhs: ReprPrimitive, D>, - ) -> SR::SparsePrimitive; - - fn float_not_equal( - lhs: ReprPrimitive, D>, - rhs: ReprPrimitive, D>, - ) -> SR::SparsePrimitive; - fn float_any( tensor: ReprPrimitive, D>, ) -> SR::SparsePrimitive; @@ -459,16 +449,6 @@ pub trait SparseBoolOps, B: Backend> { dim: usize, ) -> SR::SparsePrimitive; - fn bool_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; - - fn bool_not_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; - fn bool_any( tensor: SR::SparsePrimitive, ) -> SR::SparsePrimitive; @@ -555,16 +535,6 @@ pub trait SparseIntOps, B: Backend> { dim: usize, ) -> SR::SparsePrimitive; - fn int_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; - - fn int_not_equal( - lhs: SR::SparsePrimitive, - rhs: SR::SparsePrimitive, - ) -> SR::SparsePrimitive; - fn int_any(tensor: SR::SparsePrimitive) -> SR::SparsePrimitive; From cedd197de9cabf1f780aff6c143efc560c9477c3 Mon Sep 17 00:00:00 2001 From: mcarthur Date: Wed, 2 Oct 2024 02:41:36 +0000 Subject: [PATCH 38/38] values --- crates/burn-sparse/src/coo_bool.rs | 6 +++++ crates/burn-sparse/src/coo_float.rs | 6 +++++ crates/burn-sparse/src/coo_int.rs | 6 +++++ crates/burn-tensor/src/tensor/api/sparse.rs | 22 +++++++++++++++++++ .../src/tensor/api/sparse_tensor.rs | 6 +++++ .../src/tensor/ops/sparse_tensor.rs | 12 ++++++++++ 6 files changed, 58 insertions(+) diff --git a/crates/burn-sparse/src/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs index 5230251a77..1f63d837b8 100644 --- a/crates/burn-sparse/src/coo_bool.rs +++ b/crates/burn-sparse/src/coo_bool.rs @@ -154,4 +154,10 @@ impl SparseBoolOps for COO { ) -> B::BoolTensorPrimitive { todo!() } + + fn bool_values( + tensor: ReprPrimitive, D>, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } } diff --git a/crates/burn-sparse/src/coo_float.rs b/crates/burn-sparse/src/coo_float.rs index fcb35af460..fe34779812 100644 --- a/crates/burn-sparse/src/coo_float.rs +++ b/crates/burn-sparse/src/coo_float.rs @@ -1003,4 +1003,10 @@ impl SparseFloatOps for COO { ) -> Option> { tensor.coordinates.map(|c| c.into_primitive()) } + + fn float_values( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } } diff --git a/crates/burn-sparse/src/coo_int.rs b/crates/burn-sparse/src/coo_int.rs index 67d88eb681..67a9df94db 100644 --- a/crates/burn-sparse/src/coo_int.rs +++ b/crates/burn-sparse/src/coo_int.rs @@ -166,4 +166,10 @@ impl SparseIntOps for COO { ) -> >::SparsePrimitive { todo!() } + + fn int_values( + tensor: ReprPrimitive, D>, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } } diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs index 05b9f81be9..457c701f80 100644 --- a/crates/burn-tensor/src/tensor/api/sparse.rs +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -22,6 +22,10 @@ where fn coordinates( tensor: ReprPrimitive, D>, ) -> Option>; + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option>; } impl> BasicSparseOps for SR @@ -45,6 +49,12 @@ where ) -> Option> { SR::float_coordinates(tensor) } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::float_values(tensor) + } } impl> BasicSparseOps for SR @@ -68,6 +78,12 @@ where ) -> Option> { SR::int_coordinates(tensor) } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::int_values(tensor) + } } impl> BasicSparseOps for SR @@ -91,6 +107,12 @@ where ) -> Option> { SR::bool_coordinates(tensor) } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::bool_values(tensor) + } } impl> BasicOps> for Float { diff --git a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs index 067f9fcb16..bf690f3e94 100644 --- a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs @@ -35,4 +35,10 @@ where self.into_primitive(), )?)) } + + pub fn values(self) -> Option> { + Some(Tensor::::from_primitive(SR::values( + self.into_primitive(), + )?)) + } } diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs index 43c1f31296..ec835f5e54 100644 --- a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -17,6 +17,10 @@ where (B, Bool, Sparse): TensorRepr, (B, Int, Sparse): TensorRepr, { + fn float_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + fn float_coordinates( sparse: ReprPrimitive, D>, ) -> Option>; @@ -389,6 +393,10 @@ where } pub trait SparseBoolOps, B: Backend> { + fn bool_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + fn bool_coordinates( sparse: ReprPrimitive, D>, ) -> Option>; @@ -487,6 +495,10 @@ pub trait SparseBoolOps, B: Backend> { } pub trait SparseIntOps, B: Backend> { + fn int_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + fn int_coordinates( sparse: ReprPrimitive, D>, ) -> Option>;