diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4a061c3977..17d64b1010 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1186,12 +1186,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> { const OP: &'static str = "conv_transpose2d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; - if p.dilation != 1 { - crate::bail!( - "dilation {} is not supported for conv-transpose2d", - p.dilation - ) - } let inp = &inp[inp_l.start_offset()..]; let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; let k = &k[k_l.start_offset()..]; @@ -1235,8 +1229,8 @@ impl<'a> Map2 for ConvTranspose2D<'a> { for b_idx in 0..p.b_size { for inp_y in 0..p.i_h { for inp_x in 0..p.i_w { - let out_x = inp_x * p.stride + k_x; - let out_y = inp_y * p.stride + k_y; + let out_x = inp_x * p.stride + k_x * p.dilation; + let out_y = inp_y * p.stride + k_y * p.dilation; if out_x < p.padding || out_y < p.padding { continue; } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 14a77b5201..663f231929 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1046,12 +1046,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> { // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; - if p.dilation != 1 { - crate::bail!( - "dilation {} is not supported for conv-transpose2d", - p.dilation - ) - } let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 8196a27efc..937ddf6761 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -85,6 +85,10 @@ print(res) res = torch.nn.functional.conv2d(t, w, dilation=2) print(res.shape) print(res[0]) + +res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2) +print(res.shape) +print(res) */ fn conv2d(dev: &Device) -> Result<()> { let t = Tensor::new( @@ -158,6 +162,37 @@ fn conv2d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.45, -2.3504], ); + + // Transpose and dilations. + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; + assert_eq!(res.dims(), [1, 2, 9, 9]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], + [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], + [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], + [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], + [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], + [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], + [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024], + [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], + [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] + ], + [ + [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], + [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], + [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], + [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], + [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], + [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], + [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], + [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], + [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171] + ] + ] + ); Ok(()) } diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 91f4c7b298..ba2fa1adbf 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -155,15 +155,15 @@ __device__ void conv_transpose2d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (int k_x = 0; k_x < (int)w_k; ++k_x) { - // let out_x = inp_x * p.stride + k_x - p.padding; - int inp_x_stride = (int)(out_x + padding) - k_x; + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; if (inp_x_stride < 0 || inp_x_stride % stride) { continue; } int inp_x = inp_x_stride / stride; if (inp_x >= w_in) continue; for (int k_y = 0; k_y < (int)h_k; ++k_y) { - int inp_y_stride = (int)(out_y + padding) - k_y; + int inp_y_stride = (int)(out_y + padding) - k_y * dilation; if (inp_y_stride < 0 || inp_y_stride % stride) { continue; }