Skip to content

Commit

Permalink
Add aten::polar and its variants (#606)
Browse files Browse the repository at this point in the history
Co-authored-by: yucai <[email protected]>
Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
3 people authored Jul 30, 2024
1 parent e210c5c commit 36dfe23
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 3 deletions.
15 changes: 15 additions & 0 deletions src/ATen/native/xpu/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,21 @@ Tensor& XPUNativeFunctions::complex_out(
return result;
}

Tensor& XPUNativeFunctions::polar_out(
const Tensor& abs,
const Tensor& angle,
Tensor& result) {
complex_check_dtype(result, abs, angle);
auto iter = TensorIteratorConfig()
.add_output(result)
.add_const_input(abs)
.add_const_input(angle)
.check_all_same_dtype(false)
.build();
native::xpu::polar_kernel(iter);
return result;
}

Tensor& XPUNativeFunctions::randperm_out(
int64_t n,
c10::optional<Generator> generator,
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"ormqr",
"_pdist_backward",
"_pdist_forward",
"polar.out",
"_prelu_kernel",
"_prelu_kernel_backward",
"prod",
Expand Down
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/ComplexKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,18 @@ void complex_kernel(TensorIterator& iter) {
});
}

template <typename scalar_t>
struct PolarFunctor {
c10::complex<scalar_t> operator()(scalar_t a, scalar_t b) const {
return c10::complex<scalar_t>(a * std::cos(b), a * std::sin(b));
}
};

void polar_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(0), "polar_xpu", [&]() {
PolarFunctor<scalar_t> f;
gpu_kernel(iter, f);
});
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/ComplexKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ namespace at::native::xpu {

void complex_kernel(TensorIterator& iter);

void polar_kernel(TensorIterator& iter);

} // namespace at::native::xpu
4 changes: 4 additions & 0 deletions test/xpu/extended/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@
# Greatest relative difference: 0.00396728515625 at index (610,) (up to 0.001 allowed)
"test_compare_cpu_hypot_xpu_bfloat16",

# RuntimeError: Expected both inputs to be Half, Float or Double tensors but got BFloat16 and BFloat16.
# Polar's backward is calculated using complex(), which does not support bfloat16. CUDA fails with same error.
"test_compare_cpu_polar_xpu_bfloat16",

# Regressions due to PyTorch uplift (Numeric difference in float and bfloat)
# https://github.com/intel/torch-xpu-ops/issues/549
# Example fail log
Expand Down
12 changes: 10 additions & 2 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,10 @@ def launch_test(test_case, skip_list=None, exe_list=None):
# torch.complex32 - "sinh_cpu" not implemented for 'ComplexHalf'
"test_dtypes_cosh_xpu",

# RuntimeError: Expected both inputs to be Half, Float or Double tensors but got BFloat16 and BFloat16.
# Polar's backward is calculated using complex(), which does not support bfloat16. CUDA fails with same error.
"test_dtypes_polar_xpu",

# implemented aten::histogram to align MPS operators coverage, CUDA doesn't support
# but test_dtypes infrastructure leverage CUDA supported datatypes
"test_dtypes_histogram_xpu",
Expand Down Expand Up @@ -3016,8 +3020,12 @@ def launch_test(test_case, skip_list=None, exe_list=None):
res += launch_test("nn/test_load_state_dict_xpu.py")

# test_module_hooks

res += launch_test("nn/test_module_hooks_xpu.py")
skip_list = (
# TypeError: TestStateDictHooks.test_register_state_dict_post_hook() missing 1 required positional argument: 'private'
# https://github.com/intel/torch-xpu-ops/issues/658
"test_register_state_dict_post_hook",
)
res += launch_test("nn/test_module_hooks_xpu.py", skip_list)

# test_parametrization

Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
"unique",
"multinomial",
"lerp",
"polar",
"frac",
"aminmax",
"argmin",
Expand Down
1 change: 1 addition & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ supported:
- eye.m_out
- _efficientzerotensor
- complex.out
- polar.out
- clone
- fill_.Scalar
- fill_.Tensor
Expand Down

0 comments on commit 36dfe23

Please sign in to comment.