Skip to content

Commit

Permalink
Resolve the memory format issue of GroupNorm (#677)
Browse files Browse the repository at this point in the history
GroupNorm XPU kernel supports channel last input, which differs from
CUDA's behavior. Therefore, NCHW check cannot be performed.
For details, see:
pytorch/pytorch@e9cabef

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Aug 4, 2024
1 parent d866d5f commit f6d0f77
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 7 deletions.
6 changes: 1 addition & 5 deletions src/ATen/native/xpu/GroupNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ std::tuple<Tensor, Tensor, Tensor> XPUNativeFunctions::native_group_norm(
// repeated check so expanded weights can call native_group_norm directly but
// save mean and variance from forward
check_group_norm_inputs(X, gamma, beta, C, group);
auto memory_format = X.device().is_cpu() ? X.suggest_memory_format()
: at::MemoryFormat::Contiguous;

TORCH_CHECK(X.is_contiguous(memory_format));

bool mixed_type = at::native::is_mixed_type(X, gamma, beta);
if (mixed_type) {
Expand All @@ -76,7 +72,7 @@ std::tuple<Tensor, Tensor, Tensor> XPUNativeFunctions::native_group_norm(
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
memory_format);
MemoryFormat::Contiguous);
const auto dtype = at::native::param_scalar_type(X, mixed_type);
Tensor mean = at::empty({N, group}, X.options().dtype(dtype));
Tensor rstd = at::empty({N, group}, X.options().dtype(dtype));
Expand Down
2 changes: 0 additions & 2 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,6 @@ def launch_test(test_case, skip_list=None, exe_list=None):
"test_rnn_retain_variables_xpu_float64",
"test_transformerencoderlayer_xpu_float64",
"test_variable_sequence_xpu_float64",
# native_group_norm : RuntimeError: Expected X.is_contiguous(memory_format) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
"test_GroupNorm_memory_format_xpu",
# AssertionError: Scalars are not close!
"test_InstanceNorm1d_general_xpu",
"test_InstanceNorm2d_general_xpu",
Expand Down

0 comments on commit f6d0f77

Please sign in to comment.