forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request huggingface#1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
- Loading branch information
Showing
5 changed files
with
399 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
template <typename T> | ||
METAL_FUNC void im2col( | ||
constant size_t &dst_numel, | ||
constant size_t &h_out, | ||
constant size_t &w_out, | ||
constant size_t &h_k, | ||
constant size_t &w_k, | ||
constant size_t &stride, | ||
constant size_t &padding, | ||
constant size_t &dilation, | ||
constant size_t *src_dims, | ||
constant size_t *src_strides, | ||
device const T *src, | ||
device T *dst, | ||
uint tid [[ thread_position_in_grid ]] | ||
) { | ||
// dst: (b_size, h_out, w_out, c_in, h_k, w_k) | ||
// src: (b_size, c_in, h_in, w_in) | ||
if (tid >= dst_numel) { | ||
return; | ||
} | ||
const size_t b_in = src_dims[0]; | ||
const size_t c_in = src_dims[1]; | ||
const size_t h_in = src_dims[2]; | ||
const size_t w_in = src_dims[3]; | ||
|
||
const size_t dst_s4 = w_k; | ||
const size_t dst_s3 = h_k * dst_s4; | ||
const size_t dst_s2 = c_in * dst_s3; | ||
const size_t dst_s1 = w_out * dst_s2; | ||
const size_t dst_s0 = h_out * dst_s1; | ||
|
||
size_t tmp_tid = tid; | ||
const size_t b_idx = tmp_tid / dst_s0; | ||
tmp_tid -= b_idx * dst_s0; | ||
const size_t h_idx = tmp_tid / dst_s1; | ||
tmp_tid -= h_idx * dst_s1; | ||
const size_t w_idx = tmp_tid / dst_s2; | ||
tmp_tid -= w_idx * dst_s2; | ||
const size_t c_idx = tmp_tid / dst_s3; | ||
tmp_tid -= c_idx * dst_s3; | ||
const size_t h_k_idx = tmp_tid / dst_s4; | ||
tmp_tid -= h_k_idx * dst_s4; | ||
const size_t w_k_idx = tmp_tid; | ||
size_t src_h_idx = h_idx * stride + h_k_idx * dilation; | ||
size_t src_w_idx = w_idx * stride + w_k_idx * dilation; | ||
if (src_h_idx < padding || src_h_idx >= h_in + padding) { | ||
dst[tid] = static_cast<T>(0); | ||
} | ||
else if (src_w_idx < padding || src_w_idx >= w_in + padding) { | ||
dst[tid] = static_cast<T>(0); | ||
} | ||
else { | ||
src_h_idx -= padding; | ||
src_w_idx -= padding; | ||
const size_t src_i = | ||
b_idx * src_strides[0] | ||
+ c_idx * src_strides[1] | ||
+ src_h_idx * src_strides[2] | ||
+ src_w_idx * src_strides[3]; | ||
dst[tid] = src[src_i]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
METAL_FUNC void im2col1d( | ||
constant size_t &dst_numel, | ||
constant size_t &l_out, | ||
constant size_t &l_k, | ||
constant size_t &stride, | ||
constant size_t &padding, | ||
constant size_t &dilation, | ||
constant size_t *src_dims, | ||
constant size_t *src_strides, | ||
device const T *src, | ||
device T *dst, | ||
uint tid [[ thread_position_in_grid ]] | ||
) { | ||
// dst: (b_size, l_out, c_in, l_k) | ||
// src: (b_size, c_in, l_in) | ||
if (tid >= dst_numel) { | ||
return; | ||
} | ||
const size_t b_in = src_dims[0]; | ||
const size_t c_in = src_dims[1]; | ||
const size_t l_in = src_dims[2]; | ||
|
||
const size_t dst_s2 = l_k; | ||
const size_t dst_s1 = c_in * dst_s2; | ||
const size_t dst_s0 = l_out * dst_s1; | ||
|
||
size_t tmp_dst_i = tid; | ||
const size_t b_idx = tmp_dst_i / dst_s0; | ||
tmp_dst_i -= b_idx * dst_s0; | ||
const size_t l_idx = tmp_dst_i / dst_s1; | ||
tmp_dst_i -= l_idx * dst_s1; | ||
const size_t c_idx = tmp_dst_i / dst_s2; | ||
tmp_dst_i -= c_idx * dst_s2; | ||
const size_t l_k_idx = tmp_dst_i; | ||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation; | ||
if (src_l_idx < padding || src_l_idx >= l_in + padding) { | ||
dst[tid] = static_cast<T>(0); | ||
} | ||
else { | ||
src_l_idx -= padding; | ||
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2]; | ||
dst[tid] = src[src_i]; | ||
} | ||
} | ||
|
||
#define IM2COL_OP(T, FN_NAME) \ | ||
kernel void FN_NAME( \ | ||
constant size_t &dst_numel, \ | ||
constant size_t &h_out, \ | ||
constant size_t &w_out, \ | ||
constant size_t &h_k, \ | ||
constant size_t &w_k, \ | ||
constant size_t &stride, \ | ||
constant size_t &padding, \ | ||
constant size_t &dilation, \ | ||
constant size_t *src_dims, \ | ||
constant size_t *src_strides, \ | ||
device const T *src, \ | ||
device T *dst, \ | ||
uint tid [[ thread_position_in_grid ]] \ | ||
) { \ | ||
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ | ||
} \ | ||
|
||
#define IM2COL1D_OP(T, FN_NAME) \ | ||
kernel void FN_NAME( \ | ||
constant size_t &dst_numel, \ | ||
constant size_t &l_out, \ | ||
constant size_t &l_k, \ | ||
constant size_t &stride, \ | ||
constant size_t &padding, \ | ||
constant size_t &dilation, \ | ||
constant size_t *src_dims, \ | ||
constant size_t *src_strides, \ | ||
device const T *src, \ | ||
device T *dst, \ | ||
uint tid [[ thread_position_in_grid ]] \ | ||
) { \ | ||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ | ||
} \ | ||
|
||
IM2COL_OP(float, im2col_f32) | ||
IM2COL_OP(uint8_t, im2col_u8) | ||
IM2COL_OP(uint32_t, im2col_u32) | ||
|
||
IM2COL1D_OP(float, im2col1d_f32) | ||
IM2COL1D_OP(uint8_t, im2col1d_u8) | ||
IM2COL1D_OP(uint32_t, im2col1d_u32) |
Oops, something went wrong.