diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index e1f41e77a2..78bfe301f5 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -30,6 +30,7 @@ impl Backend for Autodiff { type IntElem = B::IntElem; type BoolTensorPrimitive = B::BoolTensorPrimitive; + type BoolElem = B::BoolElem; type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; type QuantizedEncoding = B::QuantizedEncoding; diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 1980550157..438adfb98e 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -90,18 +90,18 @@ macro_rules! testgen_all { pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; ::paste::paste! { $(mod [<$float _ty>] { pub use super::*; - pub type TestBackend = TestBackend2<$float, IntType>; + pub type TestBackend = TestBackend2<$float, IntType, BoolType>; pub type TestAutodiffBackend = burn_autodiff::Autodiff; pub type TestAutodiffTensor = burn_tensor::Tensor; - pub type TestTensor = TestTensor2<$float, IntType, D>; - pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; - pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + pub type TestTensor = TestTensor2<$float, IntType, BoolType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, BoolType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, BoolType, D>; type FloatType = $float; diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index e03b26474c..1ad606b910 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -168,6 +168,7 @@ impl Backend for Candle { type IntElem = I; type BoolTensorPrimitive = CandleTensor; + type BoolElem = u32; type QuantizedTensorPrimitive = CandleQTensor; type QuantizedEncoding = u8; diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 086d00bab7..030c2d9ff1 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -7,10 +7,10 @@ pub use cubecl::cuda::CudaDevice; use cubecl::cuda::CudaRuntime; #[cfg(not(feature = "fusion"))] -pub type Cuda = JitBackend; +pub type Cuda = JitBackend; #[cfg(feature = "fusion")] -pub type Cuda = burn_fusion::Fusion>; +pub type Cuda = burn_fusion::Fusion>; #[cfg(test)] mod tests { @@ -19,5 +19,5 @@ mod tests { pub type TestRuntime = cubecl::cuda::CudaRuntime; pub use half::{bf16, f16}; - burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]); + burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64], [u8, u32]); } diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index aa72ba7dbc..aa308ad9a7 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -5,7 +5,7 @@ use burn_tensor::{ backend::{Backend, DeviceOps}, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, - Device, + Device, Element, }; use serde::{de::DeserializeOwned, Serialize}; use std::marker::PhantomData; @@ -35,6 +35,8 @@ impl Backend for Fusion { type BoolTensorPrimitive = FusionTensor; + type BoolElem = B::BoolElem; + type QuantizedTensorPrimitive = QFusionTensor; type QuantizedEncoding = B::QuantizedEncoding; @@ -142,6 +144,8 @@ pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug { type FusionDevice: DeviceOps; /// The client to interact with the runtime. type FusionClient: FusionClient; + /// The type that represents booleans on the backend. + type BoolRepr: Element; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index 89b91243c6..fc8f704e74 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -12,11 +12,11 @@ use cubecl::hip::HipRuntime; #[cfg(target_os = "linux")] #[cfg(not(feature = "fusion"))] -pub type Hip = JitBackend; +pub type Hip = JitBackend; #[cfg(target_os = "linux")] #[cfg(feature = "fusion")] -pub type Hip = burn_fusion::Fusion>; +pub type Hip = burn_fusion::Fusion>; // TODO: Hang the computer when AMD isn't available. // diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 23629c4a9e..b455d859a0 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,4 +1,5 @@ use crate::{ + element::BoolElement, tensor::{JitTensor, QJitTensor}, FloatElement, IntElement, JitRuntime, }; @@ -18,24 +19,27 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// Generic tensor backend that can be compiled just-in-time to any shader runtime #[derive(new)] -pub struct JitBackend { +pub struct JitBackend { _runtime: PhantomData, _float_elem: PhantomData, _int_elem: PhantomData, + _bool_elem: PhantomData, } -impl Backend for JitBackend +impl Backend for JitBackend where R: JitRuntime, R::Server: ComputeServer, R::Device: burn_tensor::backend::DeviceOps, F: FloatElement, I: IntElement, + BT: BoolElement, { type Device = R::Device; type FloatElem = F; type IntElem = I; + type BoolElem = BT; type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; @@ -63,19 +67,25 @@ where } } -impl core::fmt::Debug for JitBackend { +impl core::fmt::Debug + for JitBackend +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name())) } } -impl Clone for JitBackend { +impl Clone + for JitBackend +{ fn clone(&self) -> Self { Self::new() } } -impl Default for JitBackend { +impl Default + for JitBackend +{ fn default() -> Self { Self::new() } @@ -90,7 +100,9 @@ where } #[cfg(not(feature = "fusion"))] -impl ReprBackend for JitBackend { +impl ReprBackend + for JitBackend +{ type Handle = HandleKind; fn float_tensor(handle: TensorHandle) -> FloatTensor { diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index 939b2fb24e..f0e15352cf 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -13,6 +13,27 @@ pub trait FloatElement: JitElement + Float {} /// The int element type for the jit backend. pub trait IntElement: JitElement + Int {} +/// The element type for booleans for the jit backend. +pub trait BoolElement: JitElement + Int { + /// The true value for the boolean element. + fn true_val() -> Self { + Self::from_int(1) + } + + /// The false value for the boolean element. + fn false_val() -> Self { + Self::from_int(0) + } + + /// New bool element from Rust bool. + fn new_bool(val: bool) -> Self { + match val { + true => Self::true_val(), + false => Self::false_val(), + } + } +} + impl JitElement for u64 {} impl JitElement for u32 {} impl JitElement for u16 {} @@ -36,3 +57,6 @@ impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} + +impl BoolElement for u8 {} +impl BoolElement for u32 {} diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 7968626e89..4572f580b5 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,6 +1,6 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use crate::fusion::elemwise::builder::ElementWiseBuilder; use crate::tensor::{JitQuantizationParameters, QJitTensor}; +use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder}; use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; @@ -30,13 +30,14 @@ pub enum JitOptimizationState { ElementWise(ElemwiseOptimizationState), } -impl burn_fusion::Optimization> for JitOptimization +impl burn_fusion::Optimization> for JitOptimization where R: JitRuntime, + BT: BoolElement, { fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { - Self::ElementWise2(op) => op.execute(context), + Self::ElementWise2(op) => op.execute::(context), } } @@ -61,7 +62,9 @@ where } } -impl ReprBackend for JitBackend { +impl ReprBackend + for JitBackend +{ type Handle = JitFusionHandle; fn float_tensor(handle: TensorHandle) -> burn_tensor::ops::FloatTensor { @@ -122,30 +125,37 @@ impl ReprBackend for JitBackend FusionRuntime for FusionJitRuntime { +impl FusionRuntime for FusionJitRuntime { type OptimizationState = JitOptimizationState; type Optimization = JitOptimization; type FusionHandle = JitFusionHandle; type FusionDevice = R::JitDevice; type FusionClient = MutexFusionClient; + type BoolRepr = BT; fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device.clone()))] + vec![Box::new(ElementWiseBuilder::::new( + device.clone(), + BT::as_elem().into(), + ))] } } /// Fusion runtime for JIT runtimes. #[derive(Debug)] -pub struct FusionJitRuntime { +pub struct FusionJitRuntime { _b: PhantomData, + _bool: PhantomData, } -impl FusionBackend for JitBackend { - type FusionRuntime = FusionJitRuntime; +impl FusionBackend + for JitBackend +{ + type FusionRuntime = FusionJitRuntime; - type FullPrecisionBackend = JitBackend; + type FullPrecisionBackend = JitBackend; fn cast_float( tensor: burn_tensor::ops::FloatTensor, diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 6766e3000a..e37196bc2a 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -1,7 +1,10 @@ use burn_fusion::OptimizationBuilder; use crate::{ - fusion::{on_write::builder::FuseOnWriteBuilder, JitOptimization}, + fusion::{ + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + JitOptimization, + }, JitRuntime, }; @@ -14,13 +17,13 @@ pub(crate) struct ElementWiseBuilder { } impl ElementWiseBuilder { - pub fn new(device: R::Device) -> Self { + pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware_properties().max_bindings; Self { - builder: FuseOnWriteBuilder::new(max_bindings), + builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), device, } } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index f5f3000926..d3e8e35b50 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -1,4 +1,4 @@ -use crate::fusion::on_write::kernel::fuse_on_write; +use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement}; use crate::{fusion::JitFusionHandle, JitRuntime}; use burn_fusion::stream::Context; use burn_tensor::repr::TensorDescription; @@ -28,9 +28,9 @@ pub struct ElemwiseOptimizationState { impl ElemwiseOptimization { /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { self.trace - .run::(&self.client, &self.device, context) + .run::(&self.client, &self.device, context) } /// Number of element wise operations fused. diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 1bd167af90..287b656274 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -1,5 +1,5 @@ use super::{ - ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, UnaryElemwiseArgs}, + ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs}, trace::FuseOnWriteTrace, trace_builder::FuseOnWriteTraceBuilder, }; @@ -30,9 +30,9 @@ struct TryFuseBuilder { } impl TryFuseBuilder { - fn new(max_bindings: u32) -> Self { + fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { Self { - builder: FuseOnWriteTraceBuilder::new(), + builder: FuseOnWriteTraceBuilder::new(bool_precision), max_bindings, added_ops: false, } @@ -118,7 +118,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { fn reset(&mut self) { self.num_ops = 0; self.status = OptimizationStatus::Open; - self.builder = TryFuseBuilder::new(self.max_bindings); + self.builder = TryFuseBuilder::new(self.max_bindings, self.builder.builder.bool_precision); self.current_output_shape.clear(); } @@ -137,9 +137,9 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } impl FuseOnWriteBuilder { - pub fn new(max_bindings: u32) -> Self { + pub fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { Self { - builder: TryFuseBuilder::new(max_bindings), + builder: TryFuseBuilder::new(max_bindings, bool_precision), num_ops: 0, max_bindings, current_output_shape: Vec::new(), diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 591cc9c347..d9ec09aea8 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -1,6 +1,6 @@ use crate::{ fusion::{on_write::ir::LayoutInfo, strides_dyn_rank, JitFusionHandle}, - JitRuntime, + BoolElement, JitRuntime, }; use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; @@ -90,16 +90,17 @@ struct PotentialInplace<'a> { impl FuseOnWriteTrace { /// Run a trace with the given [runner](TraceRunner). - pub fn run>( + pub fn run>( &self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, JitFusionHandle>, ) { - let analysis = self.analyse::(client, device, context); + let analysis = self.analyse::(client, device, context); let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); - let outputs = self.register_outputs(&analysis.handle_outputs, analysis.vectorization); + let outputs = + self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); let mut ops = Sequence::new(); for op in analysis.reads.into_values() { @@ -126,7 +127,7 @@ impl FuseOnWriteTrace { Runner::run(client, inputs, outputs, config) } - fn analyse<'a, 'c, R: JitRuntime, Runner: TraceRunner>( + fn analyse<'a, 'c, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( &'a self, client: &ComputeClient, device: &R::Device, @@ -146,7 +147,7 @@ impl FuseOnWriteTrace { }; self.analyse_inputs(context, &mut analysis); - self.analyse_outputs(client, device, context, &mut analysis); + self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); analysis.vectorization = Runner::vectorization( analysis.handle_inputs.iter().map(|item| &item.handle), @@ -189,7 +190,7 @@ impl FuseOnWriteTrace { } } - fn analyse_outputs<'a, 'c, R: JitRuntime>( + fn analyse_outputs<'a, 'c, R: JitRuntime, BT: BoolElement>( &'a self, client: &ComputeClient, device: &R::Device, @@ -273,9 +274,9 @@ impl FuseOnWriteTrace { } } - // We encode bool tensors as u32. + // We encode bool tensors as `B`. let dtype = match tensor_global.dtype { - DType::Bool => DType::U32, + DType::Bool => BT::dtype(), _ => tensor_global.dtype, }; let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); @@ -406,7 +407,7 @@ impl FuseOnWriteTrace { inputs } - fn register_outputs<'s, R: JitRuntime>( + fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( &self, handle_outputs: &'s [HandleOutput<'_, R>], vectorization: u8, @@ -473,8 +474,11 @@ impl FuseOnWriteTrace { ElemwisePrecision::U32 => outputs.t_u32.push(arg), ElemwisePrecision::U16 => outputs.t_u16.push(arg), ElemwisePrecision::U8 => outputs.t_u8.push(arg), - // Bools are encoded as u32. - ElemwisePrecision::Bool => outputs.t_u32.push(arg), + ElemwisePrecision::Bool => match BT::dtype() { + DType::U32 => outputs.t_u32.push(arg), + DType::U8 => outputs.t_u8.push(arg), + _ => todo!(), + }, }; } } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 06e8d24e15..5cb427814d 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -16,10 +16,11 @@ pub struct FuseOnWriteTraceBuilder { scalars: BTreeMap, ops: Vec, reads: BTreeMap, + pub bool_precision: ElemwisePrecision, } impl FuseOnWriteTraceBuilder { - pub fn new() -> Self { + pub fn new(bool_precision: ElemwisePrecision) -> Self { Self { locals: Locals::default(), outputs: RegisteredTensors::default(), @@ -27,6 +28,7 @@ impl FuseOnWriteTraceBuilder { scalars: BTreeMap::default(), ops: Vec::new(), reads: BTreeMap::new(), + bool_precision, } } @@ -49,9 +51,9 @@ impl FuseOnWriteTraceBuilder { pub fn input(&mut self, tensor: &TensorDescription) -> Arg { let precision = tensor.dtype.into(); - // Bool tensors are encoded as u32. + // Bool tensors are encoded as bool_precision. let precision_input = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; @@ -82,9 +84,9 @@ impl FuseOnWriteTraceBuilder { pub fn output(&mut self, tensor: &TensorDescription) -> Arg { let precision = tensor.dtype.into(); - // Bool tensors are encoded as u32. + // Bool tensors are encoded as bool_precision. let precision_output = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; @@ -103,9 +105,9 @@ impl FuseOnWriteTraceBuilder { pub fn scalar(&mut self, _: &E, dtype: DType) -> Arg { let precision = dtype.into(); - // Bool scalars are encoded as u32. + // Bool scalars are encoded as bool_precision. let precision = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; let new_index = self.scalars.get(&precision).copied().unwrap_or(0); @@ -154,9 +156,9 @@ impl FuseOnWriteTraceBuilder { let mark = |var: &Arg, list: &mut Vec<(TensorId, ElemwisePrecision)>| { if let Arg::Local(index, precision) = var { if let Some(tensor_id) = self.locals.find_tensor_id(*precision, *index) { - // Input and outputs tensors are using u32 for booleans. + // Input and outputs tensors are using bool_precision for booleans. let precision = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => *precision, }; diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index 07a915ee1f..74e55888e1 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -1,9 +1,9 @@ -use crate::{tensor::JitTensor, JitElement, JitRuntime}; +use crate::{tensor::JitTensor, BoolElement, JitElement, JitRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim}; #[cube(launch)] -fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { - if input[ABSOLUTE_POS] >= 1 { +fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { + if input[ABSOLUTE_POS] >= B::from_int(1) { output[ABSOLUTE_POS] = T::from_int(1); } else { output[ABSOLUTE_POS] = T::from_int(0); @@ -12,11 +12,13 @@ fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { /// Cast a bool tensor to the given element type. /// -/// This alternative to cast is necessary because bool are represented as u32 +/// This alternative to cast is necessary because bool are represented as u32 or u8 /// where any non-zero value means true. Depending how it was created /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. -pub fn bool_cast(tensor: JitTensor) -> JitTensor { +pub fn bool_cast( + tensor: JitTensor, +) -> JitTensor { let num_elems = tensor.shape.num_elements(); let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new_contiguous( @@ -30,11 +32,11 @@ pub fn bool_cast(tensor: JitTensor) -> JitTens let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); - bool_cast_kernel::launch::( + bool_cast_kernel::launch::( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg::(1), + tensor.as_tensor_arg::(1), output.as_tensor_arg::(1), ); diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index 420a74d81b..007d4200d9 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -1,5 +1,7 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use burn_tensor::{DType, Shape}; +use crate::{ + element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, +}; +use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_vectorization_factor, @@ -55,10 +57,10 @@ impl ComparisonOp for LowerOp { } #[cube(launch)] -pub(crate) fn kernel_scalar_cmp>( +pub(crate) fn kernel_scalar_cmp>( input: &Tensor>, scalar: C, - output: &mut Tensor>, + output: &mut Tensor>, ) { let offset_output = ABSOLUTE_POS; @@ -70,10 +72,10 @@ pub(crate) fn kernel_scalar_cmp>( } #[cube(launch)] -pub(crate) fn kernel_cmp>( +pub(crate) fn kernel_cmp>( lhs: &Tensor>, rhs: &Tensor>, - out: &mut Tensor>, + out: &mut Tensor>, #[comptime] rank: Option, #[comptime] to_contiguous_lhs: bool, #[comptime] to_contiguous_rhs: bool, @@ -87,7 +89,7 @@ pub(crate) fn kernel_cmp>( } if to_contiguous_lhs { - offset_lhs = index_offset_with_layout::( + offset_lhs = index_offset_with_layout::( lhs, out, offset_out, @@ -98,7 +100,7 @@ pub(crate) fn kernel_cmp>( } if to_contiguous_rhs { - offset_rhs = index_offset_with_layout::( + offset_rhs = index_offset_with_layout::( rhs, out, offset_out, @@ -111,7 +113,7 @@ pub(crate) fn kernel_cmp>( out[offset_out] = Line::cast_from(O::execute(lhs[offset_lhs], rhs[offset_rhs])); } -pub(crate) fn launch_cmp>( +pub(crate) fn launch_cmp>( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { @@ -141,9 +143,9 @@ pub(crate) fn launch_cmp>( let cube_count = calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); + let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && lhs.can_mut_broadcast(&rhs) { - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, cube_dim, @@ -161,10 +163,10 @@ pub(crate) fn launch_cmp>( lhs.shape, lhs.device, lhs.strides, - DType::U32, + BT::dtype(), ) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, CubeDim::default(), @@ -182,20 +184,20 @@ pub(crate) fn launch_cmp>( rhs.shape, rhs.device, rhs.strides, - DType::U32, + BT::dtype(), ) } else { - let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, CubeDim::default(), lhs.as_tensor_arg::(vectorization_factor), rhs.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), None, to_contiguous_lhs, to_contiguous_rhs, @@ -205,7 +207,12 @@ pub(crate) fn launch_cmp>( } } -pub(crate) fn launch_scalar_cmp>( +pub(crate) fn launch_scalar_cmp< + R: JitRuntime, + E: JitElement, + BT: BoolElement, + O: ComparisonOp, +>( mut tensor: JitTensor, scalar: E, ) -> JitTensor { @@ -224,9 +231,9 @@ pub(crate) fn launch_scalar_cmp let cube_count = calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); + let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && tensor.can_mut() { - kernel_scalar_cmp::launch::( + kernel_scalar_cmp::launch::( &client, cube_count, cube_dim, @@ -241,70 +248,94 @@ pub(crate) fn launch_scalar_cmp tensor.shape, tensor.device, tensor.strides, - DType::U32, + BT::dtype(), ) } else { - let output = empty_device::( + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), ); - kernel_scalar_cmp::launch::( + kernel_scalar_cmp::launch::( &client, cube_count, CubeDim::default(), tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), - output.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), ); output } } -pub fn equal(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn equal( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn greater(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn greater( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn greater_equal( +pub fn greater_equal( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - launch_cmp::(lhs, rhs) + launch_cmp::(lhs, rhs) } -pub fn lower(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn lower( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn lower_equal( +pub fn lower_equal( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - launch_cmp::(lhs, rhs) + launch_cmp::(lhs, rhs) } -pub fn equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn greater_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn greater_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn lower_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn lower_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn greater_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn greater_equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn lower_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn lower_equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 1796389157..9f07d36c55 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -69,7 +69,7 @@ impl Default for ConvTranspose2dStrategy { /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// -pub fn conv2d( +pub fn conv2d( input: JitTensor, weight: JitTensor, bias: Option>, @@ -77,13 +77,11 @@ pub fn conv2d( strategy: Conv2dStrategy, ) -> JitTensor { match strategy { - Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), + Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] - Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), - Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), - Conv2dStrategy::ImplicitGemm => { - conv2d_implicit_gemm::(input, weight, bias, options) - } + Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), + Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), } } @@ -104,14 +102,14 @@ pub fn conv_transpose2d( ) -> JitTensor { match strategy { ConvTranspose2dStrategy::Direct => { - conv_transpose2d_direct::(input, weight, bias, options) + conv_transpose2d_direct::(input, weight, bias, options) } #[cfg(feature = "autotune")] ConvTranspose2dStrategy::Autotune => { - conv_transpose2d_autotune::(input, weight, bias, options) + conv_transpose2d_autotune::(input, weight, bias, options) } ConvTranspose2dStrategy::Gemm => { - conv_transpose2d_col2im::(input, weight, bias, options) + conv_transpose2d_col2im::(input, weight, bias, options) } } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 846aa3d8dd..0659561805 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -1,14 +1,18 @@ use burn_tensor::{ - ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions}, Shape, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{ + into_contiguous, + matmul::{matmul, MatmulStrategy}, + slice, + }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitElement, JitRuntime, }; use super::batches_per_run; @@ -20,7 +24,7 @@ use super::batches_per_run; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv_transpose2d_col2im( +pub fn conv_transpose2d_col2im( input: JitTensor, weight: JitTensor, bias: Option>, @@ -77,12 +81,12 @@ pub fn conv_transpose2d_col2im( let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]); for run in 0..runs { - let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = index::(input.clone(), run); let input = reshape(input, input_shape_run.clone()); let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); - let image_slice = JitBackend::::float_narrow(image.clone(), 0, run, 1); + let image_slice = index::(image.clone(), run); let image_slice = reshape(image_slice, im_shape); - execute::( + execute::( input, weight.clone(), bias.clone(), @@ -96,7 +100,7 @@ pub fn conv_transpose2d_col2im( } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); - execute::( + execute::( input, weight, bias, @@ -109,8 +113,21 @@ pub fn conv_transpose2d_col2im( } } +pub(crate) fn index(tensor: JitTensor, i: usize) -> JitTensor { + #[allow(clippy::single_range_in_vec_init)] + let mut indices = vec![i..i + 1]; + for dim in tensor.shape.dims[1..].iter() { + indices.push(0..*dim); + } + let new_shape = Shape { + dims: tensor.shape.dims[1..].to_vec(), + }; + let tensor = slice::(tensor, &indices); + reshape(tensor, new_shape) +} + #[allow(clippy::too_many_arguments)] -fn execute( +fn execute( input: JitTensor, weight: JitTensor, bias: Option>, @@ -128,7 +145,7 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = JitBackend::::float_matmul(weight, input); + let columns = matmul::(weight, input, MatmulStrategy::default()); let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index 9a65b6ae51..d5154ecc4b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -11,7 +11,7 @@ use crate::{ reshape, }, tensor::JitTensor, - FloatElement, IntElement, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -120,8 +120,7 @@ fn direct_conv2d_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv2d_direct( +pub fn conv2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 88125f0463..abcb8488fb 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -1,14 +1,16 @@ use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{ + conv::index, into_contiguous, launch_binop, matmul::matmul, matmul::MatmulStrategy, AddOp, + }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -178,7 +180,7 @@ fn im2col( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv2d_im2col( +pub fn conv2d_im2col( input: JitTensor, weight: JitTensor, bias: Option>, @@ -206,7 +208,7 @@ pub fn conv2d_im2col( if kernel_h == 1 && kernel_w == 1 && in_height == out_h && in_width == out_w { // Special case for 1x1 kernels (sometimes used to scale the image by a set of weights) - return execute_1x1_kernel::(input, weight, bias, options); + return execute_1x1_kernel::(input, weight, bias, options); } let batches_per_run = batches_per_run(batch_size, out_h, out_w) @@ -221,9 +223,9 @@ pub fn conv2d_im2col( let input = reshape(input, in_shape); let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]); for run in 0..runs { - let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = index::(input.clone(), run); let input = reshape(input, in_shape_run.clone()); - let out_slice = JitBackend::::float_narrow(out.clone(), 0, run, 1); + let out_slice = index::(out.clone(), run); let out_slice = reshape(out_slice, matmul_shape.clone()); execute::( input, @@ -245,12 +247,12 @@ pub fn conv2d_im2col( if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - out = JitBackend::::float_add(out, bias) + out = launch_binop::(out, bias) } out } -fn execute_1x1_kernel( +fn execute_1x1_kernel( input: JitTensor, weight: JitTensor, bias: Option>, @@ -266,12 +268,12 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = JitBackend::::float_matmul(weight, input); + let out = matmul::(weight, input, MatmulStrategy::default()); let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([out_channels, 1, 1, 1])); - out = JitBackend::::float_add(out, bias) + out = launch_binop::(out, bias) } swap_dims(out, 0, 1) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 49a639ef43..6771f2c5e2 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -18,7 +18,7 @@ use crate::{ permute, }, tensor::JitTensor, - FloatElement, IntElement, JitRuntime, + FloatElement, JitRuntime, }; use super::nchw_to_nhwc; @@ -30,8 +30,7 @@ use super::nchw_to_nhwc; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv2d_implicit_gemm( +pub fn conv2d_implicit_gemm( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 1062241d75..6a97ab8759 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -8,7 +8,7 @@ use crate::{ reshape, }, tensor::JitTensor, - IntElement, JitRuntime, + JitRuntime, }; use burn_tensor::{ops::ConvTransposeOptions, Shape}; @@ -121,8 +121,7 @@ fn conv_transpose2d_direct_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv_transpose2d_direct( +pub fn conv_transpose2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 05ec7fd960..4a8122a478 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -16,13 +16,13 @@ use crate::{ prng::random_uniform, }, tensor::JitTensor, - FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, + FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, }; use super::Conv2dAutotuneKey; /// Executes autotune on conv2d operations -pub fn conv2d_autotune( +pub fn conv2d_autotune( input: JitTensor, weights: JitTensor, bias: Option>, @@ -35,9 +35,7 @@ pub fn conv2d_autotune( TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(Conv2dOperations::::new( - input, weights, bias, options, - )), + Box::new(Conv2dOperations::::new(input, weights, bias, options)), ) } @@ -46,7 +44,7 @@ pub fn conv2d_autotune( create_key = create_key::, should_run = should_run )] -pub fn conv2d_operations( +pub fn conv2d_operations( key: JitAutotuneKey, input: JitTensor, weights: JitTensor, @@ -74,8 +72,8 @@ pub fn conv2d_operations( tune_with!(input, weights, bias, options) } -fn should_run( - op: &Conv2dOperations, +fn should_run( + op: &Conv2dOperations, key: &JitAutotuneKey, index: usize, ) -> bool { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index 3a8c1d04f2..c2d546151a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -10,13 +10,13 @@ use crate::{ prng::random_uniform, }, tensor::JitTensor, - FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, + FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, }; use super::ConvTranspose2dAutotuneKey; /// Executes autotune on conv2d operations -pub fn conv_transpose2d_autotune( +pub fn conv_transpose2d_autotune( input: JitTensor, weights: JitTensor, bias: Option>, @@ -29,14 +29,14 @@ pub fn conv_transpose2d_autotune( TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(ConvTranspose2dOperations::::new( + Box::new(ConvTranspose2dOperations::::new( input, weights, bias, options, )), ) } #[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key::, should_run = should_run)] -pub fn conv_transpose2d_operations( +pub fn conv_transpose2d_operations( key: JitAutotuneKey, input: JitTensor, weights: JitTensor, @@ -95,8 +95,8 @@ fn create_key( )) } -fn should_run( - _op: &ConvTranspose2dOperations, +fn should_run( + _op: &ConvTranspose2dOperations, key: &JitAutotuneKey, index: usize, ) -> bool { diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index b005a2384c..438850fe72 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -1,18 +1,22 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use burn_tensor::{ - ops::{conv::calculate_conv_output_size, DeformConvOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_output_size, DeformConvOptions}, Shape, }; use crate::{ - kernel::into_contiguous, + kernel::{ + into_contiguous, launch_binop, + matmul::{matmul, MatmulStrategy}, + AddOp, + }, ops::{ numeric::{ones_device, zeros_device}, reshape, swap_dims, }, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -251,7 +255,7 @@ pub(crate) fn deform_im2col( output } -pub(crate) fn deform_conv2d( +pub(crate) fn deform_conv2d( input: JitTensor, offset: JitTensor, weight: JitTensor, @@ -294,24 +298,15 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = JitBackend::::float_matmul(weight, columns); + let out = matmul::(weight, columns, MatmulStrategy::default()); let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - JitBackend::::float_add(out, bias) + launch_binop::(out, bias) } else { out } } - -pub(crate) fn index( - tensor: JitTensor, - index: usize, -) -> JitTensor { - let [_, shape_0, shape_1] = tensor.shape.dims(); - let tensor = JitBackend::::float_narrow(tensor, 0, index, 1); - reshape(tensor, Shape::new([shape_0, shape_1])) -} diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 4022a0bbe2..907b5ef344 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -5,7 +5,12 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; use crate::{ - kernel::{cast, into_contiguous}, + element::BoolElement, + kernel::{ + cast, into_contiguous, + matmul::{matmul, MatmulStrategy}, + slice_assign, + }, ops::{ numeric::{empty_device, ones_device, zeros_device}, reshape, swap_dims, @@ -18,7 +23,12 @@ use super::{bilinear_interpolate, deform_im2col, index}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] -pub(crate) fn deform_conv2d_backward( +pub(crate) fn deform_conv2d_backward< + R: JitRuntime, + E: FloatElement, + I: IntElement, + BT: BoolElement, +>( input: JitTensor, offset: JitTensor, weight: JitTensor, @@ -26,14 +36,14 @@ pub(crate) fn deform_conv2d_backward>, out_grad: JitTensor, options: DeformConvOptions<2>, -) -> DeformConv2dBackward> { +) -> DeformConv2dBackward> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); let gradient_bias = bias.map(|bias| { - let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); - let grad = JitBackend::::float_sum_dim(grad, 2); - let grad = JitBackend::::float_sum_dim(grad, 3); + let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); + let grad = JitBackend::::float_sum_dim(grad, 2); + let grad = JitBackend::::float_sum_dim(grad, 3); reshape(grad, bias.shape) }); @@ -42,7 +52,7 @@ pub(crate) fn deform_conv2d_backward( + let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs::( input.clone(), weight.clone(), offset.clone(), @@ -52,7 +62,7 @@ pub(crate) fn deform_conv2d_backward( + let weight_grad = compute_weight_grad::( input, offset, mask, @@ -71,7 +81,7 @@ pub(crate) fn deform_conv2d_backward( +fn compute_weight_grad( input: JitTensor, offset: JitTensor, mask: Option>, @@ -98,9 +108,9 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = JitBackend::::float_matmul(out_grad, columns); + let grad_weight = matmul::(out_grad, columns, MatmulStrategy::default()); - JitBackend::::float_reshape( + reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), ) @@ -108,7 +118,7 @@ fn compute_weight_grad( type InputGradients = (JitTensor, JitTensor, Option>); -fn backward_gradient_inputs( +fn backward_gradient_inputs( image: JitTensor, weight: JitTensor, offset: JitTensor, @@ -138,11 +148,11 @@ fn backward_gradient_inputs( let out_grad = reshape(out_grad, out_grad_shape); for group in 0..groups { - let weight = swap_dims(index::(weight.clone(), group), 0, 1); - let out_grad = index::(out_grad.clone(), group); - let values = JitBackend::::float_matmul(weight, out_grad); + let weight = swap_dims(index::(weight.clone(), group), 0, 1); + let out_grad = index::(out_grad.clone(), group); + let values = matmul::(weight, out_grad, MatmulStrategy::default()); let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); - columns = JitBackend::::float_slice_assign( + columns = slice_assign::( columns, &[group..group + 1, 0..col_shape_0, 0..col_shape_1], values, diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index e35cac8b2c..583e0346d3 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -1,4 +1,6 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, +}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] @@ -31,7 +33,7 @@ fn flip_kernel( output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn flip( +pub(crate) fn flip( tensor: JitTensor, indices: &[usize], ) -> JitTensor { @@ -40,26 +42,26 @@ pub(crate) fn flip( tensor.device.clone(), tensor.shape.clone(), ); - flip_on_output::(tensor, output, indices) + flip_on_output::(tensor, output, indices) } -pub(crate) fn flip_on_output( +pub(crate) fn flip_on_output( tensor: JitTensor, output: JitTensor, indices: &[usize], ) -> JitTensor { let ndims = tensor.shape.num_dims(); - let mut indices_sequence = SequenceArg::<'_, R, u32>::new(); + let mut indices_sequence = SequenceArg::<'_, R, BT>::new(); for i in 0..ndims { - indices_sequence.push(ScalarArg::new(indices.contains(&i) as u32)); + indices_sequence.push(ScalarArg::new(BT::new_bool(indices.contains(&i)))); } let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); unsafe { - flip_kernel::launch_unchecked::( + flip_kernel::launch_unchecked::( &tensor.client, cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/mask/base.rs b/crates/burn-jit/src/kernel/mask/base.rs index 2140972326..d37c6e05bb 100644 --- a/crates/burn-jit/src/kernel/mask/base.rs +++ b/crates/burn-jit/src/kernel/mask/base.rs @@ -1,8 +1,8 @@ use super::{mask_where::MaskWhereStrategy, MaskFillStrategy}; -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::JitElement, tensor::JitTensor, BoolElement, JitRuntime}; /// Execute the mask fill kernel. -pub(crate) fn mask_fill_auto( +pub(crate) fn mask_fill_auto( tensor: JitTensor, mask: JitTensor, value: E, @@ -13,11 +13,11 @@ pub(crate) fn mask_fill_auto( MaskFillStrategy::Readonly }; - super::mask_fill(tensor, mask, value, strategy) + super::mask_fill::(tensor, mask, value, strategy) } /// Execute the mask where kernel. -pub(crate) fn mask_where_auto( +pub(crate) fn mask_where_auto( tensor: JitTensor, mask: JitTensor, value: JitTensor, @@ -30,5 +30,5 @@ pub(crate) fn mask_where_auto( MaskWhereStrategy::Readonly }; - super::mask_where::(tensor, mask, value, strategy) + super::mask_where::(tensor, mask, value, strategy) } diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index e8b3f814d9..386e7a5039 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -1,11 +1,16 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + ops::{max_vectorization, numeric::empty_device}, + tensor::JitTensor, + BoolElement, JitRuntime, +}; #[cube(launch)] -fn mask_fill_readonly_kernel( +fn mask_fill_readonly_kernel( input: &Tensor>, - mask: &Tensor>, + mask: &Tensor>, output: &mut Tensor>, value: T, #[comptime] rank: u32, @@ -17,17 +22,15 @@ fn mask_fill_readonly_kernel( let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true); - if mask[index_mask] >= Line::new(1) { - output[ABSOLUTE_POS] = Line::new(value); - } else { - output[ABSOLUTE_POS] = input[index_input]; - } + let mask = Line::cast_from(mask[index_mask]); + + output[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[index_input]); } #[cube(launch)] -fn mask_fill_inplace_kernel( +fn mask_fill_inplace_kernel( input: &mut Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: T, #[comptime] rank: u32, ) { @@ -36,10 +39,9 @@ fn mask_fill_inplace_kernel( } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); + let mask = Line::cast_from(mask[index_mask]); - if mask[index_mask] >= Line::new(1) { - input[ABSOLUTE_POS] = Line::new(value); - } + input[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[ABSOLUTE_POS]); } #[derive(Clone, Copy, Debug)] @@ -56,19 +58,19 @@ pub enum MaskFillStrategy { } /// Execute the mask fill kernel with the given strategy. -pub fn mask_fill( +pub fn mask_fill( input: JitTensor, mask: JitTensor, value: E, strategy: MaskFillStrategy, ) -> JitTensor { match strategy { - MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), - MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), + MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), + MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), } } -fn mask_fill_readonly( +fn mask_fill_readonly( input: JitTensor, mask: JitTensor, value: EI, @@ -82,14 +84,15 @@ fn mask_fill_readonly( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_fill_readonly_kernel::launch::( + mask_fill_readonly_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - output.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + output.as_tensor_arg::(vectorization), ScalarArg::new(value), ndims as u32, ); @@ -97,7 +100,7 @@ fn mask_fill_readonly( output } -fn mask_fill_inplace( +fn mask_fill_inplace( input: JitTensor, mask: JitTensor, value: EI, @@ -105,13 +108,14 @@ fn mask_fill_inplace( let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_fill_inplace_kernel::launch::( + mask_fill_inplace_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), ScalarArg::new(value), ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 73c7c8fcf1..5518e9648b 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -1,11 +1,16 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + ops::{max_vectorization, numeric::empty_device}, + tensor::JitTensor, + BoolElement, JitRuntime, +}; #[cube(launch)] -fn mask_where_readonly_kernel( +fn mask_where_readonly_kernel( input: &Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: &Tensor>, output: &mut Tensor>, #[comptime] rank: u32, @@ -17,20 +22,17 @@ fn mask_where_readonly_kernel( let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true); let index_value = index_offset_with_layout(value, output, ABSOLUTE_POS, 0, rank, true); + let mask = Line::cast_from(mask[index_mask]); - if mask[index_mask] >= Line::new(1) { - output[ABSOLUTE_POS] = value[index_value]; - } else { - output[ABSOLUTE_POS] = input[index_input]; - } + output[ABSOLUTE_POS] = select_many(mask, value[index_value], input[index_input]); } #[cube(launch)] -fn mask_where_inplace_kernel( +fn mask_where_inplace_kernel( input: &mut Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: &Tensor>, - reverse: u32, + reverse: B, #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { @@ -40,9 +42,11 @@ fn mask_where_inplace_kernel( let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); let index_value = index_offset_with_layout(value, input, ABSOLUTE_POS, 0, rank, true); - if mask[index_mask] != Line::new(reverse) { - input[ABSOLUTE_POS] = value[index_value]; - } + input[ABSOLUTE_POS] = select( + mask[index_mask] != Line::new(reverse), + value[index_value], + input[ABSOLUTE_POS], + ); } #[derive(Clone, Copy, Debug)] @@ -61,20 +65,20 @@ pub enum MaskWhereStrategy { } /// Execute the mask where kernel with the given strategy. -pub fn mask_where( +pub fn mask_where( input: JitTensor, mask: JitTensor, value: JitTensor, strategy: MaskWhereStrategy, ) -> JitTensor { match strategy { - MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), - MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), - MaskWhereStrategy::InplaceRhs => mask_where_inplace::(value, mask, input, true), + MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), + MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), + MaskWhereStrategy::InplaceRhs => mask_where_inplace::(value, mask, input, true), } } -fn mask_where_readonly( +fn mask_where_readonly( input: JitTensor, mask: JitTensor, value: JitTensor, @@ -88,22 +92,23 @@ fn mask_where_readonly( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_where_readonly_kernel::launch::( + mask_where_readonly_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - value.as_tensor_arg::(1), - output.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + value.as_tensor_arg::(vectorization), + output.as_tensor_arg::(vectorization), ndims as u32, ); output } -fn mask_where_inplace( +fn mask_where_inplace( input: JitTensor, mask: JitTensor, value: JitTensor, @@ -112,15 +117,16 @@ fn mask_where_inplace( let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_where_inplace_kernel::launch::( + mask_where_inplace_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - value.as_tensor_arg::(1), - ScalarArg::new(reverse as u32), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + value.as_tensor_arg::(vectorization), + ScalarArg::new(EM::new_bool(reverse)), ndims as u32, ); diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 77a67df37a..ba953ae0d0 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -21,7 +21,7 @@ pub mod element; use burn_tensor::backend::{DeviceId, DeviceOps}; use cubecl::{compute::CubeTask, Feature, Runtime}; -pub use element::{FloatElement, IntElement, JitElement}; +pub use element::{BoolElement, FloatElement, IntElement, JitElement}; mod backend; diff --git a/crates/burn-jit/src/ops/activation_ops.rs b/crates/burn-jit/src/ops/activation_ops.rs index 7f6b921d16..eecd6849c8 100644 --- a/crates/burn-jit/src/ops/activation_ops.rs +++ b/crates/burn-jit/src/ops/activation_ops.rs @@ -1,10 +1,11 @@ -use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::ActivationOps; -impl ActivationOps for JitBackend +impl ActivationOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 58e3b25c0c..bce600604e 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,6 +1,6 @@ -use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; +use crate::{element::JitElement, kernel, tensor::JitTensor, BoolElement, JitRuntime}; use burn_tensor::{Shape, TensorData}; -use cubecl::{tensor_vectorization_factor, CubeElement}; +use cubecl::tensor_vectorization_factor; pub(crate) fn from_data( data: TensorData, @@ -29,11 +29,16 @@ pub fn into_data_sync(tensor: JitTensor) -> Ten TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } -pub(crate) async fn bool_into_data(tensor: JitTensor) -> TensorData { +pub(crate) async fn bool_into_data( + tensor: JitTensor, +) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; TensorData::new( - u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), + BT::from_bytes(&bytes) + .iter() + .map(|i| *i != BT::false_val()) + .collect(), tensor.shape, ) } diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index 036913e88d..017e76f2c4 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -1,31 +1,32 @@ -use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; use super::{expand, permute}; -impl BoolTensorOps for JitBackend +impl BoolTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::empty::(shape, device) + super::empty::(shape, device) } async fn bool_into_data(tensor: BoolTensor) -> TensorData { - super::bool_into_data(tensor).await + super::bool_into_data::(tensor).await } fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { - let data: TensorData = TensorData::new(data.iter::().collect(), data.shape); - super::from_data::(data, device) + let data: TensorData = TensorData::new(data.iter::().collect(), data.shape); + super::from_data::(data, device) } fn bool_into_int(tensor: BoolTensor) -> IntTensor { - kernel::bool_cast::(tensor) + kernel::bool_cast::(tensor) } fn bool_device(tensor: &BoolTensor) -> Device { @@ -41,7 +42,7 @@ where } fn bool_slice(tensor: BoolTensor, ranges: &[Range]) -> BoolTensor { - kernel::slice::(tensor, ranges) + kernel::slice::(tensor, ranges) } fn bool_slice_assign( @@ -49,19 +50,19 @@ where ranges: &[Range], value: BoolTensor, ) -> BoolTensor { - kernel::slice_assign::(tensor, ranges, value) + kernel::slice_assign::(tensor, ranges, value) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) } fn bool_not(tensor: BoolTensor) -> BoolTensor { - kernel::equal_elem::(tensor, 0) + kernel::equal_elem::(tensor, BT::false_val()) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - kernel::bool_cast::(tensor) + kernel::bool_cast::(tensor) } fn bool_swap_dims(mut tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { @@ -72,7 +73,7 @@ where } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { - kernel::repeat_dim::(tensor, dim, times) + kernel::repeat_dim::(tensor, dim, times) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { @@ -84,6 +85,6 @@ where } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { - kernel::flip::(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 52b013ec0e..f97b1609ff 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -1,7 +1,10 @@ use super::{expand, numeric, permute}; -use crate::kernel::matmul::{matmul, MatmulStrategy}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; +use crate::{ + element::BoolElement, + kernel::matmul::{matmul, MatmulStrategy}, +}; use crate::{execute_with_dtype, JitBackend}; use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; @@ -11,11 +14,12 @@ use cubecl::prelude::*; use half::{bf16, f16}; use std::ops::Range; -impl FloatTensorOps for JitBackend +impl FloatTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { super::from_data::(data, device) @@ -248,7 +252,7 @@ where execute_with_dtype!( float(tensor.dtype, value.dtype), E, - kernel::mask_where_auto::(tensor, mask, value) + kernel::mask_where_auto::(tensor, mask, value) ) } @@ -260,7 +264,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - kernel::mask_fill_auto::(tensor, mask, value.elem()) + kernel::mask_fill_auto::(tensor, mask, value.elem()) ) } @@ -268,7 +272,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) ) } @@ -276,7 +280,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::equal_elem::(lhs, rhs.elem()) + kernel::equal_elem::(lhs, rhs.elem()) ) } @@ -284,7 +288,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::greater::(lhs, rhs) + kernel::greater::(lhs, rhs) ) } @@ -292,7 +296,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::greater_elem::(lhs, rhs.elem()) + kernel::greater_elem::(lhs, rhs.elem()) ) } @@ -300,7 +304,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::greater_equal::(lhs, rhs) + kernel::greater_equal::(lhs, rhs) ) } @@ -308,7 +312,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::greater_equal_elem::(lhs, rhs.elem()) + kernel::greater_equal_elem::(lhs, rhs.elem()) ) } @@ -316,7 +320,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::lower::(lhs, rhs) + kernel::lower::(lhs, rhs) ) } @@ -324,7 +328,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::lower_elem::(lhs, rhs.elem()) + kernel::lower_elem::(lhs, rhs.elem()) ) } @@ -332,7 +336,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::lower_equal::(lhs, rhs) + kernel::lower_equal::(lhs, rhs) ) } @@ -340,7 +344,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::lower_equal_elem::(lhs, rhs.elem()) + kernel::lower_equal_elem::(lhs, rhs.elem()) ) } @@ -633,7 +637,11 @@ where } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { - execute_with_dtype!(float(tensor.dtype), E, kernel::flip::(tensor, axes)) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::flip::(tensor, axes) + ) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index cb6603bf80..25bb92521f 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,6 +1,9 @@ use super::{expand, numeric, permute}; -use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::{ + element::BoolElement, + kernel::prng::{random_bernoulli, random_normal, random_uniform}, +}; use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData}; @@ -8,11 +11,12 @@ use cubecl::frontend::Numeric; use cubecl::prelude::*; use std::ops::Range; -impl IntTensorOps for JitBackend +impl IntTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn int_empty(shape: Shape, device: &Device) -> IntTensor { super::empty::(shape, device) @@ -55,7 +59,7 @@ where mask: BoolTensor, value: IntTensor, ) -> IntTensor { - kernel::mask_where_auto::(tensor, mask, value) + kernel::mask_where_auto::(tensor, mask, value) } fn int_mask_fill( @@ -63,7 +67,7 @@ where mask: BoolTensor, value: IntElem, ) -> IntTensor { - kernel::mask_fill_auto(tensor, mask, value) + kernel::mask_fill_auto::(tensor, mask, value) } fn int_gather( @@ -101,43 +105,43 @@ where } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) } fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::equal_elem::(lhs, rhs) + kernel::equal_elem::(lhs, rhs) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater::(lhs, rhs) + kernel::greater::(lhs, rhs) } fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_elem::(lhs, rhs) + kernel::greater_elem::(lhs, rhs) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater_equal::(lhs, rhs) + kernel::greater_equal::(lhs, rhs) } fn int_greater_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_equal_elem::(lhs, rhs) + kernel::greater_equal_elem::(lhs, rhs) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower::(lhs, rhs) + kernel::lower::(lhs, rhs) } fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_elem::(lhs, rhs) + kernel::lower_elem::(lhs, rhs) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower_equal::(lhs, rhs) + kernel::lower_equal::(lhs, rhs) } fn int_lower_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_equal_elem::(lhs, rhs) + kernel::lower_equal_elem::(lhs, rhs) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { @@ -277,6 +281,6 @@ where } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { - kernel::flip::(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index 5539dfc9f2..b5c96058f9 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -1,4 +1,5 @@ use crate::{ + element::BoolElement, kernel::{ self, conv::{Conv2dStrategy, ConvTranspose2dStrategy}, @@ -11,11 +12,12 @@ use burn_tensor::ops::{ }; use burn_tensor::ops::{FloatTensor, IntTensor}; -impl ModuleOps for JitBackend +impl ModuleOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn conv2d( x: FloatTensor, @@ -23,7 +25,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) } fn deform_conv2d( @@ -34,7 +36,7 @@ where bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { - kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) } fn deform_conv2d_backward( @@ -46,7 +48,7 @@ where output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { - kernel::conv::deform_conv2d_backward::( + kernel::conv::deform_conv2d_backward::( x, offset, weight, diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index e5eb4005a6..94b1a6f2ee 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -9,6 +9,7 @@ use burn_tensor::{ }; use crate::{ + element::BoolElement, kernel, tensor::{JitQuantizationParameters, JitTensor, QJitTensor}, FloatElement, IntElement, JitBackend, JitRuntime, @@ -27,11 +28,12 @@ fn packed_tensor>( JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer, DType::U32) } -impl QTensorOps for JitBackend +impl QTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index 62477d3ce1..7320186570 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -3,13 +3,14 @@ use burn_tensor::{ DType, TensorData, }; -use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; -impl TransactionOps for JitBackend +impl TransactionOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn tr_execute( transaction: burn_tensor::ops::TransactionPrimitive, @@ -51,7 +52,7 @@ where client = Some(t.client.clone()); } - kinds.push(Kind::Bool(num_bindings, t.shape.into(), DType::U32)); + kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype())); num_bindings += 1; bindings.push(t.handle.binding()) }); @@ -64,7 +65,7 @@ where .await .into_iter() .map(Some) - .collect::>(); + .collect::>>(); let mut result = TransactionPrimitiveResult::default(); diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 112260c3bf..3eb44b3e02 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -162,9 +162,9 @@ macro_rules! execute_with_dtype { type $element = i8; $op } - // NOTE: bool and qfloat dtypes are actually represented as u32 + // NOTE: bool and qfloat dtypes are actually represented as u32/u8 // burn_tensor::DType::Bool => { - // type $element = u32; + // type $element = u32/u8; // $op // } // burn_tensor::DType::QFloat(_) => { diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs index fdf7068e1a..4ef5f77589 100644 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ b/crates/burn-jit/src/tensor/qtensor.rs @@ -6,7 +6,9 @@ use burn_tensor::{ read_sync, DType, TensorData, TensorMetadata, }; -use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{ + element::BoolElement, ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime, +}; use super::JitTensor; @@ -96,10 +98,11 @@ impl Clone for JitQuantizationParameters { } } -impl - From>> for JitQuantizationParameters +impl + From>> + for JitQuantizationParameters { - fn from(value: QuantizationParametersPrimitive>) -> Self { + fn from(value: QuantizationParametersPrimitive>) -> Self { JitQuantizationParameters { scale: value.scale, offset: value.offset, diff --git a/crates/burn-jit/src/tests/mask_fill.rs b/crates/burn-jit/src/tests/mask_fill.rs index 4542bbe3f1..c768373d13 100644 --- a/crates/burn-jit/src/tests/mask_fill.rs +++ b/crates/burn-jit/src/tests/mask_fill.rs @@ -11,6 +11,7 @@ mod tests { let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), @@ -31,6 +32,7 @@ mod tests { let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), diff --git a/crates/burn-jit/src/tests/mask_where.rs b/crates/burn-jit/src/tests/mask_where.rs index befdb76af6..a14993995c 100644 --- a/crates/burn-jit/src/tests/mask_where.rs +++ b/crates/burn-jit/src/tests/mask_where.rs @@ -23,6 +23,7 @@ mod tests { Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), @@ -44,6 +45,7 @@ mod tests { Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index b1ee4ce26d..f60edc2a1b 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -38,12 +38,12 @@ pub use serial_test; #[macro_export] macro_rules! testgen_all { () => { - use burn_tensor::{Float, Int}; - $crate::testgen_all!([Float], [Int]); + use burn_tensor::{Float, Int, Bool}; + $crate::testgen_all!([Float], [Int], [Bool]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { mod jit { - burn_jit::testgen_jit!([$($float),*], [$($int),*]); + burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]); mod kernel { use super::*; @@ -84,7 +84,7 @@ macro_rules! testgen_all { } } mod jit_fusion { - burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*]); + burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*], [$($bool),*]); } }; } @@ -92,31 +92,31 @@ macro_rules! testgen_all { #[macro_export] macro_rules! testgen_jit { () => { - use burn_tensor::{Float, Int}; - $crate::testgen_jit!([Float], [Int]); + use burn_tensor::{Float, Int, Bool}; + $crate::testgen_jit!([Float], [Int], [Bool]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { pub use super::*; use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test}; - pub type TestBackend = JitBackend; - pub type TestBackend2 = JitBackend; + pub type TestBackend = JitBackend; + pub type TestBackend2 = JitBackend; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; - pub type TestTensor2 = burn_tensor::Tensor, D>; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; - pub type TestTensorInt2 = - burn_tensor::Tensor, D, burn_tensor::Int>; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; - pub type TestTensorBool2 = - burn_tensor::Tensor, D, burn_tensor::Bool>; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_tensor::testgen_all!([$($float),*], [$($int),*], [$($bool),*]); burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: @@ -135,28 +135,28 @@ macro_rules! testgen_jit_fusion { use burn_tensor::{Float, Int}; $crate::testgen_jit_fusion!([Float], [Int]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { use super::*; use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; - pub type TestBackend = burn_fusion::Fusion>; - pub type TestBackend2 = burn_fusion::Fusion>; + pub type TestBackend = burn_fusion::Fusion>; + pub type TestBackend2 = burn_fusion::Fusion>; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; - pub type TestTensor2 = burn_tensor::Tensor, D>; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; - pub type TestTensorInt2 = - burn_tensor::Tensor, D, burn_tensor::Int>; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; - pub type TestTensorBool2 = - burn_tensor::Tensor, D, burn_tensor::Bool>; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_tensor::testgen_all!([$($float),*], [$($int),*], [$($bool),*]); burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 74957f5f1f..060899b979 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -53,6 +53,7 @@ impl Backend for type IntElem = I; type BoolTensorPrimitive = NdArrayTensor; + type BoolElem = bool; type QuantizedTensorPrimitive = NdArrayQTensor; type QuantizedEncoding = Q; diff --git a/crates/burn-remote/src/client/channel.rs b/crates/burn-remote/src/client/channel.rs index 6c431702af..d7102dd97a 100644 --- a/crates/burn-remote/src/client/channel.rs +++ b/crates/burn-remote/src/client/channel.rs @@ -19,6 +19,8 @@ impl RunnerChannel for WsChannel { type IntElem = i32; + type BoolElem = u32; + fn name() -> String { "remote".into() } diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 6fcf80ce3d..a5ada5e5fd 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -55,6 +55,8 @@ impl Backend for BackendRouter { type BoolTensorPrimitive = RouterTensor; + type BoolElem = R::BoolElem; + type QuantizedTensorPrimitive = RouterTensor; type QuantizedEncoding = u32; diff --git a/crates/burn-router/src/channel/base.rs b/crates/burn-router/src/channel/base.rs index 876d273f62..887190ecfa 100644 --- a/crates/burn-router/src/channel/base.rs +++ b/crates/burn-router/src/channel/base.rs @@ -18,6 +18,8 @@ pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized { type FloatElem: Element; /// Int element type. type IntElem: Element; + /// Bool element type. + type BoolElem: Element; /// Name of the channel. fn name() -> String; diff --git a/crates/burn-router/src/types.rs b/crates/burn-router/src/types.rs index 3b694d8779..f36e436638 100644 --- a/crates/burn-router/src/types.rs +++ b/crates/burn-router/src/types.rs @@ -206,6 +206,7 @@ macro_rules! impl_multi_backend_types { type FloatElem = $DefaultBackend::FloatElem; type IntElem = $DefaultBackend::IntElem; + type BoolElem = $DefaultBackend::BoolElem; type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>; diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index cef9a8d586..c294ae0025 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -101,6 +101,7 @@ impl Backend for LibTorch { type IntElem = i64; type BoolTensorPrimitive = TchTensor; + type BoolElem = bool; type QuantizedTensorPrimitive = TchQTensor; type QuantizedEncoding = Q; diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index f951a9b6f3..973f4ede65 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -83,6 +83,8 @@ pub trait Backend: /// Tensor primitive to be used for all bool operations. type BoolTensorPrimitive: TensorMetadata + 'static; + /// Tensor primitive to be used for all bool operations. + type BoolElem: Element; /// Tensor primitive to be used for all quantized operations. type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static; diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 5b15a45591..8aa41ee24d 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -17,28 +17,28 @@ macro_rules! testgen_all { pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; $crate::testgen_with_float_param!(); $crate::testgen_no_param!(); } }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { pub mod tensor { pub use super::*; pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; ::paste::paste! { $(mod [<$float _ty>] { pub use super::*; - pub type TestBackend = TestBackend2<$float, IntType>; - pub type TestTensor = TestTensor2<$float, IntType, D>; - pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; - pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + pub type TestBackend = TestBackend2<$float, IntType, BoolType>; + pub type TestTensor = TestTensor2<$float, IntType, BoolType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, BoolType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, BoolType, D>; pub type FloatType = $float; @@ -47,13 +47,25 @@ macro_rules! testgen_all { $(mod [<$int _ty>] { pub use super::*; - pub type TestBackend = TestBackend2; - pub type TestTensor = TestTensor2; - pub type TestTensorInt = TestTensorInt2; - pub type TestTensorBool = TestTensorBool2; + pub type TestBackend = TestBackend2; + pub type TestTensor = TestTensor2; + pub type TestTensorInt = TestTensorInt2; + pub type TestTensorBool = TestTensorBool2; pub type IntType = $int; + $crate::testgen_with_int_param!(); + })* + $(mod [<$bool _bool_ty>] { + pub use super::*; + + pub type TestBackend = TestBackend2; + pub type TestTensor = TestTensor2; + pub type TestTensorInt = TestTensorInt2; + pub type TestTensorBool = TestTensorBool2; + + pub type BoolType = $bool; + $crate::testgen_with_int_param!(); })* } @@ -307,6 +319,29 @@ macro_rules! testgen_with_int_param { }; } +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_with_bool_param { + () => { + burn_tensor::testgen_all_op!(); + burn_tensor::testgen_any_op!(); + burn_tensor::testgen_argwhere_nonzero!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_expand!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_nan!(); + burn_tensor::testgen_repeat_dim!(); + burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_stack!(); + burn_tensor::testgen_transpose!(); + burn_tensor::tri_mask!(); + }; +} + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_no_param { diff --git a/crates/burn-tensor/src/tests/ops/remainder.rs b/crates/burn-tensor/src/tests/ops/remainder.rs index fa75630fe8..996c71a7b7 100644 --- a/crates/burn-tensor/src/tests/ops/remainder.rs +++ b/crates/burn-tensor/src/tests/ops/remainder.rs @@ -67,7 +67,7 @@ mod tests { fn should_be_zero() { let device = Default::default(); let lhs = Tensor::::from_data(TensorData::from([0.0, 0.0, 0.0]), &device); - let rhs = Tensor::::from_data(TensorData::from([3.5, -2.1, 1e-5]), &device); + let rhs = Tensor::::from_data(TensorData::from([3.5, -2.1, 1e-4]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([0.0, 0.0, 0.0]); diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 7c26dcc31b..0751ad9f41 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -10,7 +10,7 @@ pub use burn_jit::{ }; pub use burn_jit::{tensor::JitTensor, JitBackend}; -pub use burn_jit::{FloatElement, IntElement}; +pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; pub use cubecl::ir::CubeDim; pub use cubecl::wgpu::*; @@ -21,8 +21,12 @@ pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; #[cfg(feature = "spirv")] type Compiler = SpirV; +#[cfg(feature = "spirv")] +type Byte = u8; #[cfg(not(feature = "spirv"))] type Compiler = Wgsl; +#[cfg(not(feature = "spirv"))] +type Bool = u32; #[cfg(feature = "fusion")] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -56,8 +60,8 @@ type Compiler = Wgsl; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = - burn_fusion::Fusion, F, I>>; +pub type Wgpu = + burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -91,7 +95,8 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend, F, I>; +pub type Wgpu = + JitBackend, F, I, B>; #[cfg(test)] mod tests { @@ -103,7 +108,7 @@ mod tests { // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues #[cfg(feature = "spirv")] - burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64]); + burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); #[cfg(not(feature = "spirv"))] - burn_jit::testgen_all!([f32], [i32]); + burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs index 6cb9e6a6ae..de6bfcc7d4 100644 --- a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs +++ b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::JitBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-cubecl-kernel/src/backward.rs b/examples/custom-cubecl-kernel/src/backward.rs index 3c66ae8e0e..a894f4e446 100644 --- a/examples/custom-cubecl-kernel/src/backward.rs +++ b/examples/custom-cubecl-kernel/src/backward.rs @@ -10,10 +10,10 @@ use burn::{ }, tensor::{Shape, TensorMetadata}, }; -use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_jit::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; -impl AutodiffBackend - for Autodiff> +impl AutodiffBackend + for Autodiff> { } diff --git a/examples/custom-cubecl-kernel/src/forward.rs b/examples/custom-cubecl-kernel/src/forward.rs index a8bf17fcd7..0e180e231a 100644 --- a/examples/custom-cubecl-kernel/src/forward.rs +++ b/examples/custom-cubecl-kernel/src/forward.rs @@ -3,12 +3,15 @@ use crate::{kernel::fused_matmul_add_relu_kernel, FloatTensor}; use super::Backend; use burn::tensor::Shape; use burn_jit::{ - kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, + element::BoolElement, kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, + JitBackend, JitRuntime, }; use cubecl::{CubeCount, CubeDim}; /// Implement our custom backend trait for the generic `JitBackend`. -impl Backend for JitBackend { +impl Backend + for JitBackend +{ fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index a309ea5716..0c2201080e 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::JitBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index b9032413bc..eb374d6c10 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -9,12 +9,15 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime}, + wgpu::{BoolElement, FloatElement, IntElement, JitBackend, WgpuRuntime}, }, tensor::{Shape, TensorMetadata}, }; -impl AutodiffBackend for Autodiff> {} +impl AutodiffBackend + for Autodiff> +{ +} // Implement our custom backend trait for any backend that also implements our custom backend trait. // diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index c8476230d2..e257e13bf0 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -3,8 +3,8 @@ use crate::FloatTensor; use super::Backend; use burn::{ backend::wgpu::{ - build_info, into_contiguous, kernel_source, FloatElement, IntElement, JitBackend, - JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, + build_info, into_contiguous, kernel_source, BoolElement, FloatElement, IntElement, + JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, }, tensor::Shape, }; @@ -41,7 +41,9 @@ impl KernelSource for FusedMatmulAddRelu { } /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend { +impl Backend + for JitBackend +{ fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor,