Skip to content

Commit

Permalink
Support dilation in conv-transpose2d. (huggingface#671)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Aug 30, 2023
1 parent 9b25113 commit 3936903
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
10 changes: 2 additions & 8 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1186,12 +1186,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
if p.dilation != 1 {
crate::bail!(
"dilation {} is not supported for conv-transpose2d",
p.dilation
)
}
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
Expand Down Expand Up @@ -1235,8 +1229,8 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
for b_idx in 0..p.b_size {
for inp_y in 0..p.i_h {
for inp_x in 0..p.i_w {
let out_x = inp_x * p.stride + k_x;
let out_y = inp_y * p.stride + k_y;
let out_x = inp_x * p.stride + k_x * p.dilation;
let out_y = inp_y * p.stride + k_y * p.dilation;
if out_x < p.padding || out_y < p.padding {
continue;
}
Expand Down
6 changes: 0 additions & 6 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1046,12 +1046,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
if p.dilation != 1 {
crate::bail!(
"dilation {} is not supported for conv-transpose2d",
p.dilation
)
}
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
Expand Down
35 changes: 35 additions & 0 deletions candle-core/tests/conv_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ print(res)
res = torch.nn.functional.conv2d(t, w, dilation=2)
print(res.shape)
print(res[0])
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
print(res.shape)
print(res)
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
Expand Down Expand Up @@ -158,6 +162,37 @@ fn conv2d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.45, -2.3504],
);

// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
]
]
);
Ok(())
}

Expand Down
6 changes: 3 additions & 3 deletions candle-kernels/src/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@ __device__ void conv_transpose2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x;
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= w_in) continue;
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
int inp_y_stride = (int)(out_y + padding) - k_y;
int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
if (inp_y_stride < 0 || inp_y_stride % stride) {
continue;
}
Expand Down

0 comments on commit 3936903

Please sign in to comment.