Skip to content

Commit

Permalink
Fix the dilated convolutions. (huggingface#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Aug 29, 2023
1 parent a044907 commit 7122155
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = (p.stride * dst_l + offset) * p.dilation;
let src_l = p.stride * dst_l + offset * p.dilation;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
Expand Down Expand Up @@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
let src_h = (p.stride * dst_h + offset_h) * p.dilation;
let src_h = p.stride * dst_h + offset_h * p.dilation;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
let src_w = (p.stride * dst_w + offset_w) * p.dilation;
let src_w = p.stride * dst_w + offset_w * p.dilation;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
Expand Down
24 changes: 12 additions & 12 deletions candle-core/tests/conv_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,24 +423,24 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[28.34, -45.75, 7.32],
[0.72, -35.28, 19.23],
[-28.29, 20.89, -5.18]
[28.34, -7.91, -45.75],
[21.03, 3.86, 29.86],
[0.72, -36.58, -35.28]
],
[
[-16.04, -16.38, 32.12],
[57.5, 25.81, 11.96],
[-18.66, 8.48, -9.92]
[-16.04, 11.53, -16.38],
[29.62, -16.32, -48.35],
[57.5, 28.29, 25.81]
],
[
[2.93, 1.57, -23.76],
[12.74, -26.2, -17.88],
[-14.98, -9.35, 12.2]
[2.93, -19.6, 1.57],
[27.15, 53.88, -24.64],
[12.74, -22.6, -26.2]
],
[
[-0.18, -6.82, 20.79],
[-2.54, 27.11, -10.11],
[-0.41, -3.18, -0.07]
[-0.18, -14.86, -6.82],
[-19.55, -2.72, 45.9],
[-2.54, 36.97, 27.11]
]
]
);
Expand Down
4 changes: 2 additions & 2 deletions candle-kernels/src/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = (stride * dst_w + w_offset) * dilation;
size_t src_w = stride * dst_w + w_offset * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = (stride * dst_h + h_offset) * dilation;
size_t src_h = stride * dst_h + h_offset * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
Expand Down

0 comments on commit 7122155

Please sign in to comment.