Skip to content

Commit

Permalink
Merge pull request huggingface#1479 from huggingface/upsample_metal
Browse files Browse the repository at this point in the history
Adding upsample_nearest_2d.
  • Loading branch information
Narsil authored Dec 25, 2023
2 parents 1505d85 + 13a5d15 commit eae3a20
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
35 changes: 33 additions & 2 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest1d metal")
}

fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
crate::bail!("upsample_nearest2d metal")
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
// let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let strides = inp_l.stride();
if dims.len() != 4 {
crate::bail!("unexpected input shape for upsample {dims:?}")
}
let name = match self.dtype {
DType::F32 => "upsample_nearest2d_f32",
dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
};

let dst_el = out_w * out_h * dims[0] * dims[1];
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
dims,
strides,
out_w,
out_h,
&self.buffer,
inp_l.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), self.dtype))
}

fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
Expand Down
60 changes: 60 additions & 0 deletions candle-metal-kernels/src/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,47 @@ METAL_FUNC void im2col1d(
}
}

template <typename T>
METAL_FUNC void upsample_nearest2d(
constant size_t &w_out,
constant size_t &h_out,
constant float &w_scale,
constant float &h_scale,
constant size_t *src_dims,
constant size_t *src_s,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// src: (b_size, c_in, w_in, h_in)

const size_t c = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];

if (tid >= src_dims[0] * c * w_out * h_out) {
return;
}

// TODO: Improve this.
const size_t b_idx = tid / (w_out * h_out * c);
const size_t c_idx = (tid / (w_out * h_out)) % c;
const size_t dst_w = (tid / h_out) % w_out;
const size_t dst_h = tid % h_out;

size_t src_w = static_cast<size_t>(dst_w * w_scale);
size_t src_h = static_cast<size_t>(dst_h * h_scale);
if (src_w >= w_in) {
src_w = w_in - 1;
}
if (src_h >= h_in) {
src_h = h_in - 1;
}

const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
dst[tid] = src[src_i];
}

#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
Expand Down Expand Up @@ -143,6 +184,21 @@ kernel void FN_NAME( \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \

#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
constant size_t &h_out, \
constant float &w_scale, \
constant float &h_scale, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
} \

IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
Expand All @@ -151,3 +207,7 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)

UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
44 changes: 44 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
out_h: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let dst_el = out_w * out_h * shape[0] * shape[1];
let scale_w = shape[2] as f32 / out_w as f32;
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();

Ok(())
}

fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
Expand Down

0 comments on commit eae3a20

Please sign in to comment.