From 13a5d15ebcfd924b55c4b1f2860abf238117580f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 25 Dec 2023 14:25:19 +0100 Subject: [PATCH] Adding upsample_nearest_2d. --- candle-core/src/metal_backend.rs | 35 ++++++++++++++++- candle-metal-kernels/src/conv.metal | 60 +++++++++++++++++++++++++++++ candle-metal-kernels/src/lib.rs | 44 +++++++++++++++++++++ 3 files changed, 137 insertions(+), 2 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 1813f27681..6d8afab192 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage { crate::bail!("upsample_nearest1d metal") } - fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { - crate::bail!("upsample_nearest2d metal") + fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result { + // let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let strides = inp_l.stride(); + if dims.len() != 4 { + crate::bail!("unexpected input shape for upsample {dims:?}") + } + let name = match self.dtype { + DType::F32 => "upsample_nearest2d_f32", + dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"), + }; + + let dst_el = out_w * out_h * dims[0] * dims[1]; + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_upsample_nearest_2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + dims, + strides, + out_w, + out_h, + &self.buffer, + inp_l.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), self.dtype)) } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index 49141771ce..dca531619a 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -108,6 +108,47 @@ METAL_FUNC void im2col1d( } } +template +METAL_FUNC void upsample_nearest2d( + constant size_t &w_out, + constant size_t &h_out, + constant float &w_scale, + constant float &h_scale, + constant size_t *src_dims, + constant size_t *src_s, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, w_in, h_in) + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + size_t src_w = static_cast(dst_w * w_scale); + size_t src_h = static_cast(dst_h * h_scale); + if (src_w >= w_in) { + src_w = w_in - 1; + } + if (src_h >= h_in) { + src_h = h_in - 1; + } + + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + dst[tid] = src[src_i]; +} + #define IM2COL_OP(T, FN_NAME) \ kernel void FN_NAME( \ constant size_t &dst_numel, \ @@ -143,6 +184,21 @@ kernel void FN_NAME( \ ) { \ im2col1d(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ } \ + +#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_out, \ + constant size_t &h_out, \ + constant float &w_scale, \ + constant float &h_scale, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + upsample_nearest2d(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ +} \ IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) @@ -151,3 +207,7 @@ IM2COL_OP(uint32_t, im2col_u32) IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint32_t, im2col1d_u32) + +UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) +UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d126aa429b..dd97a86d69 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1518,6 +1518,50 @@ pub fn call_im2col_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_upsample_nearest_2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let scale_w = shape[2] as f32 / out_w as f32; + let scale_h = shape[3] as f32 / out_h as f32; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + out_w, + out_h, + scale_w, + scale_h, + shape, + strides, + (input, input_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger }