Skip to content

Commit

Permalink
Merge pull request huggingface#1461 from huggingface/metal-conv
Browse files Browse the repository at this point in the history
Adding the convolutions (1d + 2d) to candle on metal.
  • Loading branch information
Narsil authored Dec 25, 2023
2 parents 7135791 + 95e18ef commit 1505d85
Show file tree
Hide file tree
Showing 5 changed files with 399 additions and 86 deletions.
147 changes: 137 additions & 10 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {

fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv1D,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv1D,
) -> Result<Self> {
crate::bail!("conv1d metal")
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let strides = layout.stride();

let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let k_size = params.k_size;
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = dims[0] * l_out * dims[1] * k_size;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}

fn conv_transpose1d(
Expand All @@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {

fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv2D,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv2D,
) -> Result<Self> {
crate::bail!("conv2d metal")
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();

let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let h_k = params.k_h;
let w_k = params.k_w;
let h = dims[2];
let w = dims[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;

let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
dtype: self.dtype,
};
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}

fn conv_transpose2d(
Expand Down
153 changes: 153 additions & 0 deletions candle-metal-kernels/src/conv.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
template <typename T>
METAL_FUNC void im2col(
constant size_t &dst_numel,
constant size_t &h_out,
constant size_t &w_out,
constant size_t &h_k,
constant size_t &w_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];

const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;

size_t tmp_tid = tid;
const size_t b_idx = tmp_tid / dst_s0;
tmp_tid -= b_idx * dst_s0;
const size_t h_idx = tmp_tid / dst_s1;
tmp_tid -= h_idx * dst_s1;
const size_t w_idx = tmp_tid / dst_s2;
tmp_tid -= w_idx * dst_s2;
const size_t c_idx = tmp_tid / dst_s3;
tmp_tid -= c_idx * dst_s3;
const size_t h_k_idx = tmp_tid / dst_s4;
tmp_tid -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_tid;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[tid] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_strides[0]
+ c_idx * src_strides[1]
+ src_h_idx * src_strides[2]
+ src_w_idx * src_strides[3];
dst[tid] = src[src_i];
}
}

template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
constant size_t &l_out,
constant size_t &l_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];

const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;

size_t tmp_dst_i = tid;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
dst[tid] = src[src_i];
}
}

#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &h_out, \
constant size_t &w_out, \
constant size_t &h_k, \
constant size_t &w_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \

#define IM2COL1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &l_out, \
constant size_t &l_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \

IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)

IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
Loading

0 comments on commit 1505d85

Please sign in to comment.