diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 27b2824fc7..1813f27681 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage { fn conv1d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConv1D, + layout: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConv1D, ) -> Result { - crate::bail!("conv1d metal") + let device = self.device().clone(); + let shape = layout.shape(); + let dims = shape.dims(); + let strides = layout.stride(); + + let stride = params.stride; + let dilation = params.dilation; + let padding = params.padding; + let k_size = params.k_size; + let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; + let dst_el = dims[0] * l_out * dims[1] * k_size; + let dst = self + .device + .new_buffer(dst_el, self.dtype, "conv1d_im2col")?; + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "im2col1d_f32", + dtype => crate::bail!("conv1d metal {dtype:?} not implemented"), + }; + candle_metal_kernels::call_im2col1d_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + layout.shape().dims(), + strides, + (k_size, stride, padding, dilation), + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &dst, + ) + .map_err(MetalError::from)?; + let col = Self { + buffer: dst, + device, + dtype: self.dtype, + }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose1d( @@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage { fn conv2d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConv2D, + layout: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConv2D, ) -> Result { - crate::bail!("conv2d metal") + let device = self.device().clone(); + let shape = layout.shape(); + let dims = shape.dims(); + + let stride = params.stride; + let dilation = params.dilation; + let padding = params.padding; + let h_k = params.k_h; + let w_k = params.k_w; + let h = dims[2]; + let w = dims[3]; + let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1; + let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1; + let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k; + + let dst = self + .device + .new_buffer(dst_el, self.dtype, "conv2d_im2col")?; + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "im2col_f32", + dtype => crate::bail!("conv1d metal {dtype:?} not implemented"), + }; + candle_metal_kernels::call_im2col_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + layout.shape().dims(), + layout.stride(), + (h_k, w_k, stride, padding, dilation), + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &dst, + ) + .map_err(MetalError::from)?; + let col = Self { + buffer: dst, + device, + dtype: self.dtype, + }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d( diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal new file mode 100644 index 0000000000..49141771ce --- /dev/null +++ b/candle-metal-kernels/src/conv.metal @@ -0,0 +1,153 @@ +template +METAL_FUNC void im2col( + constant size_t &dst_numel, + constant size_t &h_out, + constant size_t &w_out, + constant size_t &h_k, + constant size_t &w_k, + constant size_t &stride, + constant size_t &padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // dst: (b_size, h_out, w_out, c_in, h_k, w_k) + // src: (b_size, c_in, h_in, w_in) + if (tid >= dst_numel) { + return; + } + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + + const size_t dst_s4 = w_k; + const size_t dst_s3 = h_k * dst_s4; + const size_t dst_s2 = c_in * dst_s3; + const size_t dst_s1 = w_out * dst_s2; + const size_t dst_s0 = h_out * dst_s1; + + size_t tmp_tid = tid; + const size_t b_idx = tmp_tid / dst_s0; + tmp_tid -= b_idx * dst_s0; + const size_t h_idx = tmp_tid / dst_s1; + tmp_tid -= h_idx * dst_s1; + const size_t w_idx = tmp_tid / dst_s2; + tmp_tid -= w_idx * dst_s2; + const size_t c_idx = tmp_tid / dst_s3; + tmp_tid -= c_idx * dst_s3; + const size_t h_k_idx = tmp_tid / dst_s4; + tmp_tid -= h_k_idx * dst_s4; + const size_t w_k_idx = tmp_tid; + size_t src_h_idx = h_idx * stride + h_k_idx * dilation; + size_t src_w_idx = w_idx * stride + w_k_idx * dilation; + if (src_h_idx < padding || src_h_idx >= h_in + padding) { + dst[tid] = static_cast(0); + } + else if (src_w_idx < padding || src_w_idx >= w_in + padding) { + dst[tid] = static_cast(0); + } + else { + src_h_idx -= padding; + src_w_idx -= padding; + const size_t src_i = + b_idx * src_strides[0] + + c_idx * src_strides[1] + + src_h_idx * src_strides[2] + + src_w_idx * src_strides[3]; + dst[tid] = src[src_i]; + } +} + +template +METAL_FUNC void im2col1d( + constant size_t &dst_numel, + constant size_t &l_out, + constant size_t &l_k, + constant size_t &stride, + constant size_t &padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // dst: (b_size, l_out, c_in, l_k) + // src: (b_size, c_in, l_in) + if (tid >= dst_numel) { + return; + } + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + + const size_t dst_s2 = l_k; + const size_t dst_s1 = c_in * dst_s2; + const size_t dst_s0 = l_out * dst_s1; + + size_t tmp_dst_i = tid; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t l_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= l_idx * dst_s1; + const size_t c_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= c_idx * dst_s2; + const size_t l_k_idx = tmp_dst_i; + size_t src_l_idx = l_idx * stride + l_k_idx * dilation; + if (src_l_idx < padding || src_l_idx >= l_in + padding) { + dst[tid] = static_cast(0); + } + else { + src_l_idx -= padding; + const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2]; + dst[tid] = src[src_i]; + } +} + +#define IM2COL_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_numel, \ + constant size_t &h_out, \ + constant size_t &w_out, \ + constant size_t &h_k, \ + constant size_t &w_k, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + im2col(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ +} \ + +#define IM2COL1D_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_numel, \ + constant size_t &l_out, \ + constant size_t &l_k, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + im2col1d(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ +} \ + +IM2COL_OP(float, im2col_f32) +IM2COL_OP(uint8_t, im2col_u8) +IM2COL_OP(uint32_t, im2col_u32) + +IM2COL1D_OP(float, im2col1d_f32) +IM2COL1D_OP(uint8_t, im2col1d_u8) +IM2COL1D_OP(uint32_t, im2col1d_u32) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0418c96cca..d126aa429b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const CONV: &str = include_str!("conv.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); /// Most kernels apply similarly across the tensors @@ -115,6 +116,7 @@ pub enum Source { Cast, Reduce, Mfa, + Conv, } macro_rules! ops{ @@ -225,6 +227,7 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Conv => CONV, Source::Mfa => panic!("Invalid lib"), } } @@ -1298,7 +1301,7 @@ pub fn call_gemm( let fused_activation = false; let fused_bias = false; let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { - let m_simd = 16; + let m_simd = 8; let n_simd = 8; let k_simd = 64; let m_splits = 1; @@ -1307,7 +1310,7 @@ pub fn call_gemm( } else { let m_simd = 40; let n_simd = 40; - let k_simd = 8; + let k_simd = 32; let m_splits = 1; let n_splits = 1; (m_simd, n_simd, k_simd, m_splits, n_splits) @@ -1418,6 +1421,103 @@ pub fn call_gemm( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_im2col1d_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + (k_size, stride, padding, dilation): (usize, usize, usize, usize), + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; + let dst_el = shape[0] * l_out * shape[1] * k_size; + + let encoder = command_buffer.new_compute_command_encoder(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + dst_el, + l_out, + k_size, + stride, + padding, + dilation, + shape, + strides, + (input, input_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_im2col_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + + let h = shape[2]; + let w = shape[3]; + let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1; + let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1; + + let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; + + let encoder = command_buffer.new_compute_command_encoder(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + dst_el, + h_out, + w_out, + h_k, + w_k, + stride, + padding, + dilation, + shape, + strides, + (input, input_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 1b3153b19f..c955abca21 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; +use metal::{Device, MTLResourceOptions}; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; @@ -485,73 +485,6 @@ fn run_index_select( read_to_vec(&dst_buffer, dst_el) } -#[test] -fn index_add() { - let device = Device::system_default().expect("no device found"); - - let options = CompileOptions::new(); - let library = device.new_library_with_source(INDEXING, &options).unwrap(); - - let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let right = [1.0f32; 15]; - let index = [0u32, 4, 2]; - let ids_dim_size = index.len() as u32; - let dst_dim_size: u32 = 15; - let left_size: u32 = 3; - let right_size: u32 = 3; - - let function = library.get_function("ia_u32_f32", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); - - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); - - encoder.set_compute_pipeline_state(&pipeline); - - let index_buffer = new_buffer(&device, &index); - let inputs_buffer = new_buffer(&device, &left); - let outputs_buffer = new_buffer(&device, &right); - - set_params!( - encoder, - ( - &index_buffer, - &inputs_buffer, - &outputs_buffer, - ids_dim_size, - left_size, - dst_dim_size, - right_size - ) - ); - - let grid_size = MTLSize { - width: right.len() as NSUInteger, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - let expected = vec![ - 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, - ]; - let result: Vec = read_to_vec(&outputs_buffer, right.len()); - assert_eq!(result, expected); -} - #[test] fn cos_f16() { let v: Vec = [1.0f32, 2.0, 3.0] diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 553bc50677..04fa37a98d 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -64,12 +64,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ + output[tid] = TYPENAME(FN(float(input[tid]))); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -78,12 +78,12 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ + output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ } #define UNARY_OP(NAME) \