Skip to content

Commit

Permalink
Fix problems with batch norm on LibTorch backend (#1226)
Browse files Browse the repository at this point in the history
* Fix problems with batch norm

* Add comment
  • Loading branch information
nathanielsimard authored and syl20bnr committed Feb 1, 2024
1 parent 0552b8b commit 017f190
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
5 changes: 3 additions & 2 deletions burn-core/src/nn/norm/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
}

fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
let device = input.device();
let dims = input.dims();
let batch_size = dims[0];
let channels = dims[1];
Expand Down Expand Up @@ -134,8 +135,8 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
.mean_dim(1)
.reshape(shape_unsqueeze);

let running_mean = self.running_mean.value_sync().to_device(&mean.device());
let running_var = self.running_var.value_sync().to_device(&var.device());
let running_mean = self.running_mean.value_sync().to_device(&device);
let running_var = self.running_var.value_sync().to_device(&device);

let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
mean.clone()
Expand Down
17 changes: 16 additions & 1 deletion burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
use burn_tensor::Shape;
use tch::Scalar;

use crate::{TchShape, TchTensor};
use crate::{LibTorchDevice, TchShape, TchTensor};
use std::{marker::PhantomData, ops::Range};

pub struct TchOps<E: tch::kind::Element + Copy + Default> {
e: PhantomData<E>,
}

impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn to_device<const D: usize>(
tensor: TchTensor<E, D>,
device: &LibTorchDevice,
) -> TchTensor<E, D> {
let device = (*device).into();

// We have to manually check if the device is the same, since when it's the case, we need to keep
// the same storage reference and not create a new one.
if tensor.tensor.device() == device {
return tensor;
}

TchTensor::new(tensor.tensor.to(device))
}

pub fn reshape<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
shape: Shape<D2>,
Expand Down
2 changes: 1 addition & 1 deletion burn-tch/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
tensor: TchTensor<bool, D>,
device: &LibTorchDevice,
) -> TchTensor<bool, D> {
TchTensor::new(tensor.tensor.to((*device).into()))
TchOps::to_device(tensor, device)
}

fn bool_reshape<const D1: usize, const D2: usize>(
Expand Down
2 changes: 1 addition & 1 deletion burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
tensor: TchTensor<i64, D>,
device: &LibTorchDevice,
) -> TchTensor<i64, D> {
TchTensor::new(tensor.tensor.to((*device).into()))
TchOps::to_device(tensor, device)
}

fn int_reshape<const D1: usize, const D2: usize>(
Expand Down
2 changes: 1 addition & 1 deletion burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
tensor: TchTensor<E, D>,
device: &LibTorchDevice,
) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.to((*device).into()))
TchOps::to_device(tensor, device)
}

fn float_empty<const D: usize>(
Expand Down

0 comments on commit 017f190

Please sign in to comment.