From 017f19009dbda628411c38bdd4199cc099356e23 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 1 Feb 2024 12:26:20 -0500 Subject: [PATCH] Fix problems with batch norm on LibTorch backend (#1226) * Fix problems with batch norm * Add comment --- burn-core/src/nn/norm/batch.rs | 5 +++-- burn-tch/src/ops/base.rs | 17 ++++++++++++++++- burn-tch/src/ops/bool_tensor.rs | 2 +- burn-tch/src/ops/int_tensor.rs | 2 +- burn-tch/src/ops/tensor.rs | 2 +- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/burn-core/src/nn/norm/batch.rs b/burn-core/src/nn/norm/batch.rs index cf4977b338..42239d1e3d 100644 --- a/burn-core/src/nn/norm/batch.rs +++ b/burn-core/src/nn/norm/batch.rs @@ -106,6 +106,7 @@ impl BatchNorm { } fn forward_train(&self, input: Tensor) -> Tensor { + let device = input.device(); let dims = input.dims(); let batch_size = dims[0]; let channels = dims[1]; @@ -134,8 +135,8 @@ impl BatchNorm { .mean_dim(1) .reshape(shape_unsqueeze); - let running_mean = self.running_mean.value_sync().to_device(&mean.device()); - let running_var = self.running_var.value_sync().to_device(&var.device()); + let running_mean = self.running_mean.value_sync().to_device(&device); + let running_var = self.running_var.value_sync().to_device(&device); let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( mean.clone() diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index 67f8312c41..d22bd3084a 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -1,7 +1,7 @@ use burn_tensor::Shape; use tch::Scalar; -use crate::{TchShape, TchTensor}; +use crate::{LibTorchDevice, TchShape, TchTensor}; use std::{marker::PhantomData, ops::Range}; pub struct TchOps { @@ -9,6 +9,21 @@ pub struct TchOps { } impl TchOps { + pub fn to_device( + tensor: TchTensor, + device: &LibTorchDevice, + ) -> TchTensor { + let device = (*device).into(); + + // We have to manually check if the device is the same, since when it's the case, we need to keep + // the same storage reference and not create a new one. + if tensor.tensor.device() == device { + return tensor; + } + + TchTensor::new(tensor.tensor.to(device)) + } + pub fn reshape( tensor: TchTensor, shape: Shape, diff --git a/burn-tch/src/ops/bool_tensor.rs b/burn-tch/src/ops/bool_tensor.rs index d50b642bbf..13f4b9c882 100644 --- a/burn-tch/src/ops/bool_tensor.rs +++ b/burn-tch/src/ops/bool_tensor.rs @@ -35,7 +35,7 @@ impl BoolTensorOps for LibTorch { tensor: TchTensor, device: &LibTorchDevice, ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) + TchOps::to_device(tensor, device) } fn bool_reshape( diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index b9c9994636..9f251e451c 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -38,7 +38,7 @@ impl IntTensorOps for LibTorch { tensor: TchTensor, device: &LibTorchDevice, ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) + TchOps::to_device(tensor, device) } fn int_reshape( diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 88c3aec97a..12b8e60016 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -102,7 +102,7 @@ impl FloatTensorOps for LibTorch { tensor: TchTensor, device: &LibTorchDevice, ) -> TchTensor { - TchTensor::new(tensor.tensor.to((*device).into())) + TchOps::to_device(tensor, device) } fn float_empty(