diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index bf29ee68a63db..4e27a5e047c98 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -485,13 +485,36 @@ std::tuple batch_norm_backward_cpu_template( return std::make_tuple(grad_input, grad_weight, grad_bias); } +template +bool checkAllTypeEq(const c10::ScalarType dtype, const Args&... args) +{ + return ((dtype == ((Tensor)args).scalar_type()) || ...); +} +template +bool checkAnyTypeEq(const Tensor& t, const Args&... args) +{ + return ((t.scalar_type() == ((c10::ScalarType)args)) || ...); +} +static bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false); BatchNormBackend _select_batch_norm_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps) { auto& ctx = at::globalContext(); bool cudnn_enabled = ctx.userEnabledCuDNN(); + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _select_batch_norm_backend" + << " cudnn_enabled=" << cudnn_enabled + << " input.dtype=" << input.scalar_type() << ":" << input.suggest_memory_format() + << " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type() + << " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type() + << " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " training=" << training + << " input.dim=" << input.dim() + << std::endl; if ( input.is_cuda() && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 @@ -512,19 +535,27 @@ BatchNormBackend _select_batch_norm_backend( } if ( - input.is_cuda() + cudnn_enabled + && detail::getCUDAHooks().compiledWithMIOpen() + && input.is_cuda() + && (input.dim() >= 3) && input.dim() <= MIOPEN_DIM_MAX - && input.scalar_type() != at::kDouble - && (weight.scalar_type() != at::kHalf) - && (weight.scalar_type() != at::kBFloat16) && weight.defined() && bias.defined() - && ((running_mean.defined() && running_var.defined()) - || (!running_mean.defined() && !running_var.defined() && training)) - && (input.dim() >= 3) - && detail::getCUDAHooks().compiledWithMIOpen() - && cudnn_enabled - && input.suggest_memory_format() != MemoryFormat::ChannelsLast - && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d + && ( + (running_mean.defined() && running_var.defined()) + || (!running_mean.defined() && !running_var.defined() && training) + ) + && ( + checkAllTypeEq(at::kFloat, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous // fp32 ocl + // || checkAllTypeEq(at::kHalf, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous // fp16 ocl + // || checkAllTypeEq(at::kBFloat16, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous// bf16 ocl + || checkAnyTypeEq(input, at::kHalf, at::kBFloat16) && checkAllTypeEq(at::kFloat, weight, bias) // && input.suggest_memory_format() == MemoryFormat::Contiguous // mixed + ) + // && input.scalar_type() != at::kDouble + // && (weight.scalar_type() != at::kHalf) + // && (weight.scalar_type() != at::kBFloat16) + // && input.suggest_memory_format() != MemoryFormat::ChannelsLast + // && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d ) { return BatchNormBackend::Miopen; } @@ -532,8 +563,6 @@ BatchNormBackend _select_batch_norm_backend( return BatchNormBackend::Native; } -bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false); - // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection // of backends, while enabling it to keep the information about the used backend, so that it can // use its corresponding backward implementation. @@ -546,7 +575,7 @@ std::tuple _batch_norm_impl_index( if (PYTORCH_MIOPEN_EXTRA_LOGGING) std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index" - << " input=" << input.scalar_type() + << " input=" << input.scalar_type() << ":" << input.suggest_memory_format() << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) @@ -623,7 +652,7 @@ std::tuple _batch_norm_impl_index( << " cudnn_enabled=" << cudnn_enabled << " dim=" << input.dim() << " memory_format=" << input.suggest_memory_format() - << " input.dtype=" << input.scalar_type() + << " input.dtype=" << input.scalar_type() << ":" << input.suggest_memory_format() << " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type() << " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type() << " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() @@ -636,7 +665,8 @@ std::tuple _batch_norm_impl_index( std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl; return std::tuple_cat( at::miopen_batch_norm( - input.contiguous(), weight.contiguous(), bias.contiguous(), + input.contiguous(input.suggest_memory_format()), + weight.contiguous(), bias.contiguous(), running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps), @@ -711,7 +741,7 @@ Tensor batch_norm( if (PYTORCH_MIOPEN_EXTRA_LOGGING) std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm" - << " input=" << input.scalar_type() + << " input=" << input.scalar_type() << ":" << input.suggest_memory_format() << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index b3d8b1dd86a2c..5935c7797c76a 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -100,7 +100,7 @@ std::tuple miopen_batch_norm( mode = miopenBNSpatial; } - auto output_t = at::empty(input->sizes(), input->options()); + auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); TensorArg output{ output_t, "output", 0 }; auto handle = getMiopenHandle(); @@ -177,8 +177,10 @@ std::tuple miopen_batch_norm_backward( const Tensor& save_var_t = c10::value_or_else(save_var_t_opt, [] { return Tensor(); }); + auto grad_output_contig = + grad_output_t.contiguous(input_t.suggest_memory_format()); TensorArg input{ input_t, "input", 1 }, - grad_output{ grad_output_t, "grad_output", 2 }, + grad_output{ grad_output_contig, "grad_output", 2 }, weight{ weight_t, "weight", 3 }, save_mean{ save_mean_t, "save_mean", 4 }, save_var{ save_var_t, "save_var", 5 }; @@ -193,7 +195,9 @@ std::tuple miopen_batch_norm_backward( } checkAllSameType(c, {input, grad_output}); checkAllSameType(c, {weight, save_mean, save_var}); - checkAllContiguous(c, {input, grad_output, save_mean, save_var}); + checkAllContiguous(c, {save_mean, save_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); checkDimRange(c, input, 2, 6 /* exclusive */); checkSameSize(c, input, grad_output); auto num_features = input->size(1); @@ -208,7 +212,8 @@ std::tuple miopen_batch_norm_backward( mode = miopenBNSpatial; } - auto grad_input_t = at::empty(input->sizes(), input->options()); + auto grad_input_t = at::empty( + input->sizes(), input->options(), input->suggest_memory_format()); auto grad_weight_t = at::empty(weight->sizes(), weight->options()); auto grad_bias_t = at::empty(weight->sizes(), weight->options()); diff --git a/test/test_nn.py b/test/test_nn.py index 03c1281b3b910..1d2d9084e81c8 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9,11 +9,13 @@ import warnings import pickle import re +import os from copy import deepcopy from itertools import product from functools import partial from collections import OrderedDict from unittest import SkipTest +import traceback import torch from torch import inf, nan @@ -8198,6 +8200,113 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + def wrap_assert(self, fn): + try: + fn() + except Exception as e: + print(">>>>>>>>> Failed with exception: ") + traceback.print_exception(e) + return False + return True + + def batchnorm2d_miopen(self, dtype, memory_format, mixed=False, use_cpu=False): + def run_test(input, grad_output, mixed, use_cpu=False, range_min=0.0, range_max=1.0): + c = input.size(1) + mod = nn.BatchNorm2d(c).cuda() + if not mixed: + mod = mod.to(dtype=input.dtype) + mod.weight.data.uniform_(range_min, range_max) + mod.bias.data.uniform_(range_min, range_max) + if use_cpu: + ref_mod = nn.BatchNorm2d(c) + ref_input = input.cpu().detach().clone(memory_format=torch.preserve_format).requires_grad_(True) + ref_grad = grad_output.cpu().detach().clone(memory_format=torch.preserve_format) + else: + ref_mod = nn.BatchNorm2d(c).cuda() + ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True) + ref_grad = grad_output.detach().clone(memory_format=torch.preserve_format) + + if not mixed: + ref_mod = ref_mod.to(dtype=input.dtype) + ref_mod.load_state_dict(mod.state_dict()) + out = mod(input) + out.backward(grad_output) + with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm + ref_out = ref_mod(ref_input) + ref_out.backward(ref_grad) + + success = self.wrap_assert(lambda: self.assertTrue(out.is_contiguous(memory_format=memory_format))) + success = success and self.wrap_assert(lambda: self.assertTrue(ref_out.is_contiguous(memory_format=memory_format))) + success = success and self.wrap_assert(lambda: self.assertEqual(out, ref_out)) + success = success and self.wrap_assert(lambda: self.assertEqual(mod.weight.grad, ref_mod.weight.grad)) + success = success and self.wrap_assert(lambda: self.assertEqual(mod.bias.grad, ref_mod.bias.grad)) + success = success and self.wrap_assert(lambda: self.assertEqual(mod.running_mean, ref_mod.running_mean)) + success = success and self.wrap_assert(lambda: self.assertEqual(mod.running_var, ref_mod.running_var)) + success = success and self.wrap_assert(lambda: self.assertEqual(input.grad, ref_input.grad)) + self.assertTrue(success) + + range_min = -2.0 + range_max = 2.0 + size = (4, 8, 2, 2) + # input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda") + input = torch.FloatTensor(size=size).uniform_(range_min, range_max).to(dtype=dtype, device="cuda").contiguous(memory_format=memory_format).detach().requires_grad_() + # input = input.contiguous(memory_format=memory_format).detach().requires_grad_() + # grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda") + grad = torch.FloatTensor(size=size).uniform_(range_min, range_max).to(dtype=dtype, device="cuda").contiguous(memory_format=memory_format).detach() + run_test(input, grad, mixed=mixed, use_cpu=use_cpu) + # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" + # not channels_last + input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") + input = input.contiguous(memory_format=memory_format).detach().requires_grad_() + grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") + grad = grad.permute(0, 2, 1, 3) + run_test(input, grad, mixed=mixed, use_cpu=use_cpu, range_min=range_min, range_max=range_max) + + + @onlyCUDA + @dtypes(torch.float) + def test_batchnorm_nhwc_miopen(self, dtype): + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) + try: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" + self.batchnorm2d_miopen(dtype, torch.channels_last) + finally: + if prev_val is None: + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] + else: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val + + @onlyCUDA + @dtypes(torch.half, torch.bfloat16) + def test_batchnorm_nhwc_miopen_mixed(self, dtype): + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) + try: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" + self.batchnorm2d_miopen(dtype, torch.channels_last, mixed=True) + finally: + if prev_val is None: + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] + else: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val + + @onlyCUDA + @dtypes(torch.half, torch.bfloat16, torch.float) + def test_batchnorm_nchw_miopen(self, dtype): + self.batchnorm2d_miopen(dtype, torch.contiguous_format) + + @onlyCUDA + @dtypes(torch.half, torch.bfloat16) + def test_batchnorm_nchw_miopen_mixed(self, dtype): + self.batchnorm2d_miopen(dtype, torch.contiguous_format, mixed=True) + + @onlyCUDA + @dtypes(torch.half, torch.bfloat16) + def test_batchnorm_nchw_miopen_mixed_vs_cpu(self, dtype): + self.batchnorm2d_miopen(dtype, torch.contiguous_format, mixed=True, use_cpu=True) @onlyCUDA @dtypes(torch.float, torch.half)