From 10bc8f25b2845d21f90f0ca32b994e451584590e Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 13 Feb 2025 20:58:02 -0800 Subject: [PATCH] [MPS][BE] Migrate polar to use functor (#147184) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147184 Approved by: https://github.com/dcci ghstack dependencies: #147182, #147183 --- .../native/mps/kernels/BinaryKernel.metal | 28 ++++++++----------- .../native/mps/operations/BinaryKernel.mm | 2 +- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 02a9f3ccec73b..e5a34827e7fe3 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -1,4 +1,5 @@ #include +#include #include using namespace metal; @@ -74,6 +75,15 @@ struct nextafter_functor { } }; +struct polar_functor { + template + using ret_type = c10::metal::vec2type_t; + template + inline ret_type operator()(const T a, const T b) { + return ret_type(a * cos(b), a * sin(b)); + } +}; + // Future BinaryTensorIterator template using result_of = decltype(::metal::declval()( @@ -153,22 +163,8 @@ REGISTER_BINARY_INDEXING_OP(zeta, bfloat); #endif // Complex binary functions -template -kernel void polar( - constant void* abs_ [[buffer(0)]], - constant void* angle_ [[buffer(1)]], - device void* out_ [[buffer(2)]], - constant uint3* offsets [[buffer(3)]], - uint tid [[thread_position_in_grid]]) { - device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x); - constant T* angle = (constant T*)((constant uint8_t*)angle_ + offsets[tid].z); - constant T* abs = (constant T*)((constant uint8_t*)abs_ + offsets[tid].y); - out[0] = abs[0] * cos(angle[0]); - out[1] = abs[0] * sin(angle[0]); -} - -REGISTER_BINARY_OP(polar, float); -REGISTER_BINARY_OP(polar, half); +REGISTER_BINARY_INDEXING_OP(polar, float); +REGISTER_BINARY_INDEXING_OP(polar, half); template kernel void complex_mul( diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 2f220c8140975..f190f9bdd9ebf 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -141,7 +141,7 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) { auto output_as_real = at::view_as_real(output).select(output.dim(), 0); auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(abs).add_input(angle).build(); - mps::binary_mps_impl(iter, "polar", false); + mps::binary_mps_impl(iter, "polar"); return output; }