Skip to content

Commit 46ba9f2

Browse files
Revert "Remove conj kernels for real dtypes (pytorch#80374)"
This reverts commit ad44079. Reverted pytorch#80374 on behalf of https://github.com/atalman due to Breaks internal build UnaryOpsKernel.cpp:208:5: error: unused type alias 'scalar_t' [-Werror,-Wunused-local-typedef]
1 parent d11d3dd commit 46ba9f2

File tree

6 files changed

+35
-53
lines changed

6 files changed

+35
-53
lines changed

aten/src/ATen/native/cpu/CopyKernel.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ namespace native {
1313
inline namespace CPU_CAPABILITY {
1414
void neg_kernel(TensorIteratorBase &iter);
1515
void conj_kernel(TensorIteratorBase &iter);
16+
} // namespace CPU_CAPABILITY
17+
18+
namespace {
1619

1720
void float_bfloat16_copy_kernel(TensorIteratorBase &iter, bool requires_neg) {
1821
auto strides_out = iter.strides(0);
@@ -269,7 +272,7 @@ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
269272
}
270273
}
271274

272-
} // namespace CPU_CAPABILITY
275+
} // anonymous namespace
273276

274277
REGISTER_DISPATCH(copy_stub, &copy_kernel);
275278

aten/src/ATen/native/cpu/CopyKernel.h

-12
This file was deleted.

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+7-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <ATen/cpu/vec/vec.h>
1414
#include <ATen/cpu/vml.h>
1515
#include <ATen/native/TensorIterator.h>
16-
#include <ATen/native/cpu/CopyKernel.h>
1716
#include <ATen/native/cpu/Loops.h>
1817
#include <ATen/native/cpu/zmath.h>
1918
#include <ATen/OpMathType.h>
@@ -204,18 +203,13 @@ static void angle_kernel(TensorIteratorBase& iter) {
204203

205204
// NB: Ignores the negative bit on tensors
206205
void conj_kernel(TensorIteratorBase& iter) {
207-
AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cpu",
208-
AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
209-
// conj is a no-op for non-complex types
210-
direct_copy_kernel(iter);
211-
})
212-
AT_DISPATCH_CASE_COMPLEX_TYPES_AND(kComplexHalf, [&] {
213-
cpu_kernel_vec(
214-
iter,
215-
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
216-
[=](Vectorized<scalar_t> a) { return a.conj(); });
217-
})
218-
);
206+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
207+
kBool, kBFloat16, kHalf, kComplexHalf, iter.common_dtype(), "conj_cpu", [&]() {
208+
cpu_kernel_vec(
209+
iter,
210+
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
211+
[=](Vectorized<scalar_t> a) { return a.conj(); });
212+
});
219213
}
220214

221215
static void bitwise_not_kernel(TensorIteratorBase& iter) {

aten/src/ATen/native/cuda/Copy.cu

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace native {
2323
void neg_kernel_cuda(TensorIteratorBase &iter);
2424
void conj_kernel_cuda(TensorIteratorBase &iter);
2525

26+
namespace {
2627
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
2728
ScalarType dtype = iter.dtype(0);
2829
if (isQIntType(dtype)) {
@@ -42,6 +43,7 @@ void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
4243
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); });
4344
});
4445
}
46+
} // namespace (anonymous)
4547

4648
using namespace at::cuda;
4749

aten/src/ATen/native/cuda/Copy.h

-10
This file was deleted.

aten/src/ATen/native/cuda/UnaryComplexKernels.cu

+22-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define TORCH_ASSERT_NO_OPERATORS
22
#include <limits>
33
#include <ATen/native/UnaryOps.h>
4-
#include <ATen/native/cuda/Copy.h>
54
#include <ATen/native/cuda/Loops.cuh>
65
#include <ATen/native/cuda/JitLoops.cuh>
76
#include <ATen/Dispatch.h>
@@ -59,10 +58,22 @@ void angle_kernel_cuda(TensorIteratorBase& iter) {
5958
}
6059
}
6160

61+
// We manually overload conj because std::conj does not work types other than c10::complex.
62+
template<typename scalar_t>
63+
__host__ __device__ static inline scalar_t conj_wrapper(scalar_t v) {
64+
return v;
65+
}
66+
67+
template<typename T>
68+
__host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v) {
69+
return std::conj(v);
70+
}
71+
6272
// NB: Ignores the negative bit on tensors
6373
const char conj_name[] = "conj_kernel";
6474
void conj_kernel_cuda(TensorIteratorBase& iter) {
65-
auto conj_chalf = [&] {
75+
auto common_dtype = iter.common_dtype();
76+
if (common_dtype == kComplexHalf) {
6677
using scalar_t = c10::complex<at::Half>;
6778
#if AT_USE_JITERATOR()
6879
static const auto conj_string = jiterator_stringify(
@@ -74,23 +85,17 @@ void conj_kernel_cuda(TensorIteratorBase& iter) {
7485
jitted_gpu_kernel<conj_name, scalar_t, scalar_t, 1>(iter, conj_string);
7586
#else
7687
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
77-
return std::conj(a);
88+
return conj_wrapper(a);
7889
});
7990
#endif
80-
};
81-
82-
AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cuda",
83-
AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
84-
// Conj is a no-op for non-complex types
85-
direct_copy_kernel_cuda(iter);
86-
})
87-
AT_DISPATCH_CASE_COMPLEX_TYPES([&] {
88-
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
89-
return std::conj(a);
90-
});
91-
})
92-
AT_DISPATCH_CASE(kComplexHalf, conj_chalf)
93-
);
91+
} else {
92+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
93+
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
94+
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
95+
return conj_wrapper(a);
96+
});
97+
});
98+
}
9499
}
95100

96101
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);

0 commit comments

Comments
 (0)