Skip to content

Commit

Permalink
[MPS][BE] Migrate polar to use functor (pytorch#147184)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#147184
Approved by: https://github.com/dcci
ghstack dependencies: pytorch#147182, pytorch#147183
  • Loading branch information
malfet authored and pytorchmergebot committed Feb 14, 2025
1 parent 278ffd8 commit 10bc8f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
28 changes: 12 additions & 16 deletions aten/src/ATen/native/mps/kernels/BinaryKernel.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <c10/metal/special_math.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;

Expand Down Expand Up @@ -74,6 +75,15 @@ struct nextafter_functor {
}
};

struct polar_functor {
template <typename U>
using ret_type = c10::metal::vec2type_t<U>;
template <typename T>
inline ret_type<T> operator()(const T a, const T b) {
return ret_type<T>(a * cos(b), a * sin(b));
}
};

// Future BinaryTensorIterator
template <typename T, typename F>
using result_of = decltype(::metal::declval<F>()(
Expand Down Expand Up @@ -153,22 +163,8 @@ REGISTER_BINARY_INDEXING_OP(zeta, bfloat);
#endif

// Complex binary functions
template <typename T>
kernel void polar(
constant void* abs_ [[buffer(0)]],
constant void* angle_ [[buffer(1)]],
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* angle = (constant T*)((constant uint8_t*)angle_ + offsets[tid].z);
constant T* abs = (constant T*)((constant uint8_t*)abs_ + offsets[tid].y);
out[0] = abs[0] * cos(angle[0]);
out[1] = abs[0] * sin(angle[0]);
}

REGISTER_BINARY_OP(polar, float);
REGISTER_BINARY_OP(polar, half);
REGISTER_BINARY_INDEXING_OP(polar, float);
REGISTER_BINARY_INDEXING_OP(polar, half);

template <typename T>
kernel void complex_mul(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(abs).add_input(angle).build();

mps::binary_mps_impl(iter, "polar", false);
mps::binary_mps_impl(iter, "polar");
return output;
}

Expand Down

0 comments on commit 10bc8f2

Please sign in to comment.