Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][release/2.5] refactor condition to use miopen for batchnorm #1787

Draft
wants to merge 3 commits into
base: release/2.5
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,36 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

template<class... Args>
bool checkAllTypeEq(const c10::ScalarType dtype, const Args&... args)
{
return ((dtype == ((Tensor)args).scalar_type()) || ...);
}
template<class... Args>
bool checkAnyTypeEq(const Tensor& t, const Args&... args)
{
return ((t.scalar_type() == ((c10::ScalarType)args)) || ...);
}
static bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);
BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {

auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _select_batch_norm_backend"
<< " cudnn_enabled=" << cudnn_enabled
<< " input.dtype=" << input.scalar_type() << ":" << input.suggest_memory_format()
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " training=" << training
<< " input.dim=" << input.dim()

<< std::endl;
if (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
Expand All @@ -512,28 +535,34 @@ BatchNormBackend _select_batch_norm_backend(
}

if (
input.is_cuda()
cudnn_enabled
&& detail::getCUDAHooks().compiledWithMIOpen()
&& input.is_cuda()
&& (input.dim() >= 3)
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& (weight.scalar_type() != at::kHalf)
&& (weight.scalar_type() != at::kBFloat16)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
&& (
(running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training)
)
&& (
checkAllTypeEq(at::kFloat, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous // fp32 ocl
// || checkAllTypeEq(at::kHalf, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous // fp16 ocl
// || checkAllTypeEq(at::kBFloat16, input, weight, bias) && input.suggest_memory_format() == MemoryFormat::Contiguous// bf16 ocl
|| checkAnyTypeEq(input, at::kHalf, at::kBFloat16) && checkAllTypeEq(at::kFloat, weight, bias) // && input.suggest_memory_format() == MemoryFormat::Contiguous // mixed
)
// && input.scalar_type() != at::kDouble
// && (weight.scalar_type() != at::kHalf)
// && (weight.scalar_type() != at::kBFloat16)
// && input.suggest_memory_format() != MemoryFormat::ChannelsLast
// && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
) {
return BatchNormBackend::Miopen;
}

return BatchNormBackend::Native;
}

bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);

// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
Expand All @@ -546,7 +575,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index"
<< " input=" << input.scalar_type()
<< " input=" << input.scalar_type() << ":" << input.suggest_memory_format()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
Expand Down Expand Up @@ -623,7 +652,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
<< " cudnn_enabled=" << cudnn_enabled
<< " dim=" << input.dim()
<< " memory_format=" << input.suggest_memory_format()
<< " input.dtype=" << input.scalar_type()
<< " input.dtype=" << input.scalar_type() << ":" << input.suggest_memory_format()
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
Expand All @@ -636,7 +665,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl;
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
input.contiguous(input.suggest_memory_format()),
weight.contiguous(), bias.contiguous(),
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
Expand Down Expand Up @@ -711,7 +741,7 @@ Tensor batch_norm(
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm"
<< " input=" << input.scalar_type()
<< " input=" << input.scalar_type() << ":" << input.suggest_memory_format()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
Expand Down
13 changes: 9 additions & 4 deletions aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
mode = miopenBNSpatial;
}

auto output_t = at::empty(input->sizes(), input->options());
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
TensorArg output{ output_t, "output", 0 };

auto handle = getMiopenHandle();
Expand Down Expand Up @@ -177,8 +177,10 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
const Tensor& save_var_t =
c10::value_or_else(save_var_t_opt, [] { return Tensor(); });

auto grad_output_contig =
grad_output_t.contiguous(input_t.suggest_memory_format());
TensorArg input{ input_t, "input", 1 },
grad_output{ grad_output_t, "grad_output", 2 },
grad_output{ grad_output_contig, "grad_output", 2 },
weight{ weight_t, "weight", 3 },
save_mean{ save_mean_t, "save_mean", 4 },
save_var{ save_var_t, "save_var", 5 };
Expand All @@ -193,7 +195,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
}
checkAllSameType(c, {input, grad_output});
checkAllSameType(c, {weight, save_mean, save_var});
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
checkAllContiguous(c, {save_mean, save_var});
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
checkDimRange(c, input, 2, 6 /* exclusive */);
checkSameSize(c, input, grad_output);
auto num_features = input->size(1);
Expand All @@ -208,7 +212,8 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
mode = miopenBNSpatial;
}

auto grad_input_t = at::empty(input->sizes(), input->options());
auto grad_input_t = at::empty(
input->sizes(), input->options(), input->suggest_memory_format());
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
auto grad_bias_t = at::empty(weight->sizes(), weight->options());

Expand Down
109 changes: 109 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import warnings
import pickle
import re
import os
from copy import deepcopy
from itertools import product
from functools import partial
from collections import OrderedDict
from unittest import SkipTest
import traceback

import torch
from torch import inf, nan
Expand Down Expand Up @@ -8198,6 +8200,113 @@ def test_affine_3d_rotateRandom(self, device):

self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))

def wrap_assert(self, fn):
try:
fn()
except Exception as e:
print(">>>>>>>>> Failed with exception: ")
traceback.print_exception(e)
return False
return True

def batchnorm2d_miopen(self, dtype, memory_format, mixed=False, use_cpu=False):
def run_test(input, grad_output, mixed, use_cpu=False, range_min=0.0, range_max=1.0):
c = input.size(1)
mod = nn.BatchNorm2d(c).cuda()
if not mixed:
mod = mod.to(dtype=input.dtype)
mod.weight.data.uniform_(range_min, range_max)
mod.bias.data.uniform_(range_min, range_max)
if use_cpu:
ref_mod = nn.BatchNorm2d(c)
ref_input = input.cpu().detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
ref_grad = grad_output.cpu().detach().clone(memory_format=torch.preserve_format)
else:
ref_mod = nn.BatchNorm2d(c).cuda()
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
ref_grad = grad_output.detach().clone(memory_format=torch.preserve_format)

if not mixed:
ref_mod = ref_mod.to(dtype=input.dtype)
ref_mod.load_state_dict(mod.state_dict())
out = mod(input)
out.backward(grad_output)
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
ref_out = ref_mod(ref_input)
ref_out.backward(ref_grad)

success = self.wrap_assert(lambda: self.assertTrue(out.is_contiguous(memory_format=memory_format)))
success = success and self.wrap_assert(lambda: self.assertTrue(ref_out.is_contiguous(memory_format=memory_format)))
success = success and self.wrap_assert(lambda: self.assertEqual(out, ref_out))
success = success and self.wrap_assert(lambda: self.assertEqual(mod.weight.grad, ref_mod.weight.grad))
success = success and self.wrap_assert(lambda: self.assertEqual(mod.bias.grad, ref_mod.bias.grad))
success = success and self.wrap_assert(lambda: self.assertEqual(mod.running_mean, ref_mod.running_mean))
success = success and self.wrap_assert(lambda: self.assertEqual(mod.running_var, ref_mod.running_var))
success = success and self.wrap_assert(lambda: self.assertEqual(input.grad, ref_input.grad))
self.assertTrue(success)

range_min = -2.0
range_max = 2.0
size = (4, 8, 2, 2)
# input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
input = torch.FloatTensor(size=size).uniform_(range_min, range_max).to(dtype=dtype, device="cuda").contiguous(memory_format=memory_format).detach().requires_grad_()
# input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
# grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
grad = torch.FloatTensor(size=size).uniform_(range_min, range_max).to(dtype=dtype, device="cuda").contiguous(memory_format=memory_format).detach()
run_test(input, grad, mixed=mixed, use_cpu=use_cpu)
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
# not channels_last
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
grad = grad.permute(0, 2, 1, 3)
run_test(input, grad, mixed=mixed, use_cpu=use_cpu, range_min=range_min, range_max=range_max)


@onlyCUDA
@dtypes(torch.float)
def test_batchnorm_nhwc_miopen(self, dtype):
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
try:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
self.batchnorm2d_miopen(dtype, torch.channels_last)
finally:
if prev_val is None:
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
else:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
def test_batchnorm_nhwc_miopen_mixed(self, dtype):
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
try:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
self.batchnorm2d_miopen(dtype, torch.channels_last, mixed=True)
finally:
if prev_val is None:
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
else:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

@onlyCUDA
@dtypes(torch.half, torch.bfloat16, torch.float)
def test_batchnorm_nchw_miopen(self, dtype):
self.batchnorm2d_miopen(dtype, torch.contiguous_format)

@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
def test_batchnorm_nchw_miopen_mixed(self, dtype):
self.batchnorm2d_miopen(dtype, torch.contiguous_format, mixed=True)

@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
def test_batchnorm_nchw_miopen_mixed_vs_cpu(self, dtype):
self.batchnorm2d_miopen(dtype, torch.contiguous_format, mixed=True, use_cpu=True)

@onlyCUDA
@dtypes(torch.float, torch.half)
Expand Down