Skip to content

Commit

Permalink
Backprop for conv2d. (huggingface#638)
Browse files Browse the repository at this point in the history
* Start adding backprop for conv2d.

* Backprop for conv2d.

* Bugfix + start adding a conv2d test.

* Conv2d backprop testing.

* More conv fixes.
  • Loading branch information
LaurentMazare committed Aug 28, 2023
1 parent 9137c63 commit b292047
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 13 deletions.
25 changes: 24 additions & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,30 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::Conv2D {
arg,
kernel,
padding,
stride,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg =
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;

let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Expand Down
12 changes: 4 additions & 8 deletions candle-core/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D {
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h - 1) * self.stride - 2 * self.padding
+ dilation * (self.k_h - 1)
+ self.output_padding
+ 1
(self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}

pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w - 1) * self.stride - 2 * self.padding
+ dilation * (self.k_w - 1)
+ self.output_padding
+ 1
(self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}

pub(crate) fn out_dims(&self) -> Vec<usize> {
Expand Down
10 changes: 6 additions & 4 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let inp_x = out_x * p.stride as i32 - p.padding as i32;
let inp_y = out_y * p.stride as i32 - p.padding as i32;
for k_y in 0..p.k_h as i32 {
for k_x in 0..p.k_h as i32 {
for k_x in 0..p.k_w as i32 {
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
let inp_y = inp_y + k_y;
let inp_x = inp_x + k_x;
Expand All @@ -1215,9 +1215,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let inp_y = inp_y as usize;
if inp_x < p.i_w && inp_y < p.i_h {
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3;
for c_out in 0..k_s0 {
for c_in in 0..k_s1 {
let dst_index = b_idx * dst_s0
+ out_y as usize * dst_s2
+ out_x as usize * dst_s3;
for c_out in 0..p.c_out {
for c_in in 0..p.c_in {
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
let dst_index = dst_index + c_out * dst_s1;
let inp_index = inp_index + c_in * inp_s1;
Expand Down
72 changes: 72 additions & 0 deletions candle-core/tests/conv_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,78 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
Ok(())
}

#[test]
fn conv2d_grad() -> Result<()> {
use candle_core::Var;
let dev = &Device::Cpu;
let t = Var::from_slice(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
)?;
let w = Var::from_slice(
&[
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
0.5583, 0.4623, 0.6026,
],
(2, 4, 3, 3),
dev,
)?;
let res = t.conv2d(&w, 0, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
// THIS IS WRONG AT THE MOMENT
[
1.7442, -10.1747, -9.9426, 0.0, 0.0, -1.7046, -21.2248, 30.8435, 0.0, 0.0, -18.713,
-1.0547, -7.8746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 169.3047,
46.0812, 40.6937, 0.0, 0.0, -85.8156, 4.537, 53.2871, 0.0, 0.0, -59.632, -35.9725,
-7.1689, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.823, 8.9794,
42.3011, 0.0, 0.0, -58.9268, 32.907, -50.6863, 0.0, 0.0, -0.9706, -3.9175, -4.2594,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 72.8229, 25.8492, 8.9871,
0.0, 0.0, -136.2584, 40.1739, 88.9583, 0.0, 0.0, -53.465, -40.7102, -24.9406, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
]
);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?,
[
-28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217,
-13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259,
18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977,
9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476,
20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654,
22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597,
44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195,
-33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042
]
);
Ok(())
}

test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
Expand Down

0 comments on commit b292047

Please sign in to comment.