diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 59dffaad8a..6970408a58 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -47,9 +47,10 @@ fn bench( token: Option<&str>, ) { let benchmarks = [ - // (3, 4096, 4096, 4096), + (2, 4096, 4096, 4096), (32, 2048, 2048, 2048), - // (2, 4096, 4096, 512), + (256, 1024, 1024, 1024), + (1024, 256, 256, 256), ] .into_iter() .map(|(b, m, n, k)| { diff --git a/crates/burn-autodiff/src/tests/select.rs b/crates/burn-autodiff/src/tests/select.rs index 26a7824115..04f3bd071a 100644 --- a/crates/burn-autodiff/src/tests/select.rs +++ b/crates/burn-autodiff/src/tests/select.rs @@ -17,7 +17,7 @@ mod tests { let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1.clone().select(0, indices); let tensor_4 = tensor_2.matmul(tensor_3); - // panic!("Tensor 4 {}", tensor_4); + let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 77a3a81a96..a5f154faa3 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -32,7 +32,7 @@ pub enum JitOptimization { pub enum JitOptimizationState { /// Element wise state. ElementWise(ElemwiseOptimizationState), - /// Matrix multiplication optimizatio state. + /// Matrix multiplication optimization state. Matmul(MatmulOptimizationState), } diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index d400f332bf..8aa8067632 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -6,52 +6,8 @@ use crate::fusion::on_write::{ kernel::fuse_on_write, }; -pub struct FusedMatmulState { - inputs: *const GlobalArgs, - outputs: *mut GlobalArgs, - config: ElemwiseConfig, - lhs: Arg, - rhs: Arg, - out: Arg, -} - -#[cube] -impl FusedMatmulState { - pub fn new( - inputs: &FusedMatmulInput, - outputs: &mut GlobalArgs, - #[comptime] config: &ElemwiseConfig, - ) -> FusedMatmulState { - FusedMatmulState { - inputs: &inputs.global, - outputs, - config: comptime![config.clone()], - lhs: comptime![inputs.lhs], - rhs: comptime![inputs.rhs], - out: comptime![inputs.out], - } - } -} - #[derive(Clone)] -pub struct FusedMatmulStateExpand { - inputs: GlobalArgsExpand, - outputs: GlobalArgsExpand, - config: ElemwiseConfig, - lhs: Arg, - rhs: Arg, - out: Arg, -} - -impl CubeType for FusedMatmulState { - type ExpandType = FusedMatmulStateExpand; -} - -impl Init for FusedMatmulStateExpand { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} +pub struct FusedMatmulArgs; #[derive(CubeLaunch)] pub struct FusedMatmulInput { @@ -66,9 +22,6 @@ pub struct FusedMatmulInput { out: Arg, } -#[derive(Clone)] -pub struct FusedMatmulArgs; - #[cube] impl MatmulArgs for FusedMatmulArgs { type Output = GlobalArgs; @@ -248,3 +201,49 @@ impl MatmulArgs for FusedMatmulArgs { } } } + +pub struct FusedMatmulState { + inputs: *const GlobalArgs, + outputs: *mut GlobalArgs, + config: ElemwiseConfig, + lhs: Arg, + rhs: Arg, + out: Arg, +} +#[cube] +impl FusedMatmulState { + pub fn new( + inputs: &FusedMatmulInput, + outputs: &mut GlobalArgs, + #[comptime] config: &ElemwiseConfig, + ) -> FusedMatmulState { + FusedMatmulState { + inputs: &inputs.global, + outputs, + config: comptime![config.clone()], + lhs: comptime![inputs.lhs], + rhs: comptime![inputs.rhs], + out: comptime![inputs.out], + } + } +} + +#[derive(Clone)] +pub struct FusedMatmulStateExpand { + inputs: GlobalArgsExpand, + outputs: GlobalArgsExpand, + config: ElemwiseConfig, + lhs: Arg, + rhs: Arg, + out: Arg, +} + +impl CubeType for FusedMatmulState { + type ExpandType = FusedMatmulStateExpand; +} + +impl Init for FusedMatmulStateExpand { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index b810f80c6c..66c372564a 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -83,7 +83,7 @@ impl MatmulOptimization { fn execute_fused( &mut self, context: &mut Context<'_, JitFusionHandle>, - ) -> Result<(), MatmulLaunchError> { + ) -> Result<(), FusedMatmulError> { self.trace .run::(&self.client, &self.device, context, &self.matmul) } @@ -141,8 +141,20 @@ pub struct FusedMatmul { op: BinaryOperationDescription, } +#[derive(Debug)] +pub enum FusedMatmulError { + LaunchError(MatmulLaunchError), + InvalidInput, +} + +impl From for FusedMatmulError { + fn from(value: MatmulLaunchError) -> Self { + Self::LaunchError(value) + } +} + impl TraceRunner for FusedMatmul { - type Error = MatmulLaunchError; + type Error = FusedMatmulError; fn run<'a>( &'a self, @@ -150,7 +162,7 @@ impl TraceRunner for FusedMatmul { inputs: GlobalArgsLaunch<'a, R>, outputs: GlobalArgsLaunch<'a, R>, config: &'a ElemwiseConfig, - ) -> Result<(), MatmulLaunchError> { + ) -> Result<(), FusedMatmulError> { match self.out.precision() { ElemwisePrecision::F32 => self.matmul_fused::(client, inputs, outputs, config), ElemwisePrecision::F16 => self.matmul_fused::(client, inputs, outputs, config), @@ -169,7 +181,7 @@ impl FusedMatmul { inputs: GlobalArgsLaunch<'a, R>, outputs: GlobalArgsLaunch<'a, R>, config: &'a ElemwiseConfig, - ) -> Result<(), MatmulLaunchError> { + ) -> Result<(), FusedMatmulError> { let lhs_shape = inputs.shape(&self.lhs); let rhs_shape = inputs.shape(&self.rhs); @@ -189,9 +201,7 @@ impl FusedMatmul { let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs_strides); if lhs_make_contiguous || rhs_make_contiguous { - return Err(MatmulLaunchError::Unavailable( - MatmulAvailabilityError::PlaneDimUnknown, - )); + return Err(FusedMatmulError::InvalidInput); } let rank = lhs_shape.len(); @@ -209,9 +219,7 @@ impl FusedMatmul { }; if out_line_size == 1 && (lhs_line_size > 1 || rhs_line_size > 1) { - return Err(MatmulLaunchError::Unavailable( - MatmulAvailabilityError::PlaneDimUnknown, - )); + return Err(FusedMatmulError::InvalidInput); } let problem = MatmulProblem { @@ -261,7 +269,9 @@ impl FusedMatmul { None => Err(MatmulLaunchError::Unavailable( MatmulAvailabilityError::PlaneDimUnknown, )), - } + }?; + + Ok(()) } } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 68e5cd9968..a01bec2e53 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -115,22 +115,37 @@ pub struct GlobalArgs { } impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> { + /// Get the shape of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. pub fn shape(&self, arg: &Arg) -> &[usize] { - match self.handle(arg) { + match self.resolve_arg(arg) { TensorArg::Handle { handle, .. } => handle.shape, TensorArg::Alias { .. } => panic!("Unsupported yet"), } } + /// Get the strides of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. pub fn strides(&self, arg: &Arg) -> &[usize] { - match self.handle(arg) { + match self.resolve_arg(arg) { TensorArg::Handle { handle, .. } => handle.strides, TensorArg::Alias { .. } => panic!("Unsupported yet"), } } + /// Get the line size of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. pub fn line_size(&self, arg: &Arg) -> u8 { - match self.handle(arg) { + match self.resolve_arg(arg) { TensorArg::Handle { vectorization_factor, .. @@ -139,19 +154,40 @@ impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> { } } - pub fn handle(&self, arg: &Arg) -> &TensorArg<'_, R> { + /// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg). + /// + /// # Panics + /// + /// If the argument isn't a global input or output tensor. + pub fn resolve_arg(&self, arg: &Arg) -> &TensorArg<'_, R> { match arg { Arg::Input(pos, precision, _) => match precision { ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], - _ => panic!(), + ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], + ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], + ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], + ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], + ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], + ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], + ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], + ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], + ElemwisePrecision::Bool => panic!("Unsupported yet"), }, Arg::Output(pos, precision, _) => match precision { ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], - _ => panic!(), + ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], + ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], + ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], + ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], + ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], + ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], + ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], + ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], + ElemwisePrecision::Bool => panic!("Unsupported yet"), }, _ => panic!("Only input & output can have a shape"), } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs index ca0ca83572..ad2069c505 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs @@ -1,8 +1,7 @@ use cubecl::prelude::Numeric; use std::marker::PhantomData; -/// Matrix multiplication spec definiting each element types used in the computation as well as -/// how the arguments are passed to the kernel. +/// Implicit convolution spec definiting each element types used in the computation. pub trait ConvSpec: Send + Sync + Clone + 'static { /// The plane size used by this kernel. const PLANE_DIM: u32; @@ -15,7 +14,7 @@ pub trait ConvSpec: Send + Sync + Clone + 'static { type EA: Numeric; } -/// Specification for a simple standard matmul using global tensor as inputs. +/// Specification for a single conv using global tensor as inputs. #[derive(Clone)] pub struct SingleConvSpec { _eg: PhantomData, diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 426f520818..d3cb280e90 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -67,10 +67,6 @@ mod cube_wgpu { WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), WgpuDevice::Cpu => DeviceId::new(3, 0), WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0), - // For an existing device, use the 64 bit wgpu device ID as the burn DeviceID. - // We're only storing 32 bits, so wrap the the 64 bit value to 32 bits. This - // might collide - but a 1 in 4 billion chance seems ok given there's only a few - // devices in flight at any time. WgpuDevice::Existing(id) => DeviceId::new(5, *id), } } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index e8b7253db2..3da083c0e6 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -56,7 +56,6 @@ impl HandleContainer { /// Register a handle for the given [tensor id](TensorId). pub fn register_handle(&mut self, id: TensorId, handle: H) { - // println!("Register handle {id:?}"); self.handles.insert(id, Handle::Existing(handle)); } @@ -68,7 +67,6 @@ impl HandleContainer { /// Make sure the status corresponds to the operation you want to execute the handle on, /// otherwise you might remove a tensor handle that will be required in the future. pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H { - // println!("Get handle {id:?}"); let (id, handle) = self .handles .remove_entry(id) @@ -148,7 +146,6 @@ impl HandleContainer { where B: ReprBackend, { - // println!("Register float tensor {id:?}"); let handle = B::float_tensor_handle(tensor); self.handles.insert(*id, Handle::Existing(handle)); }