Skip to content

Commit 420b49c

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[complex32] add, sub, neg (pytorch#77179)
Ref: pytorch#74537 Pull Request resolved: pytorch#77179 Approved by: https://github.com/anjali411
1 parent f348b1b commit 420b49c

File tree

5 files changed

+77
-9
lines changed

5 files changed

+77
-9
lines changed

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void reciprocal_kernel(TensorIteratorBase& iter) {
257257

258258
// NB: Ignores the negative bit on tensors
259259
void neg_kernel(TensorIteratorBase& iter) {
260-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "neg_cpu", [&]() {
260+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.dtype(), "neg_cpu", [&]() {
261261
cpu_kernel_vec(
262262
iter,
263263
[=](scalar_t a) -> scalar_t { return -a; },

aten/src/ATen/native/cuda/UnarySignKernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ void neg_kernel_cuda(TensorIteratorBase& iter) {
3535
return -a;
3636
}
3737
); // neg_string
38-
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
38+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_cuda", [&]() {
3939
jitted_gpu_kernel<
4040
/*name=*/ neg_name,
4141
/*return_dtype=*/ scalar_t,
4242
/*common_dtype=*/ scalar_t,
4343
/*arity=*/ 1>(iter, neg_string);
4444
});
4545
#else
46-
AT_DISPATCH_COMPLEX_TYPES(dtype, "neg_cuda", [&]() {
46+
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_cuda", [&]() {
4747
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
4848
return -a;
4949
});

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@
469469
structured: True
470470
structured_inherits: TensorIteratorBase
471471
ufunc_inner_loop:
472-
Generic: add (AllAndComplex, BFloat16, Half)
472+
Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
473473
ScalarOnly: add (Bool)
474474
dispatch:
475475
SparseCPU: add_out_sparse_cpu

c10/util/Half.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,28 @@ struct alignas(4) complex<Half> {
426426
constexpr C10_HOST_DEVICE Half imag() const {
427427
return imag_;
428428
}
429+
430+
complex<Half>& operator+=(const complex<Half>& other) {
431+
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
432+
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
433+
return *this;
434+
}
435+
436+
complex<Half>& operator-=(const complex<Half>& other) {
437+
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
438+
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
439+
return *this;
440+
}
441+
442+
complex<Half>& operator*=(const complex<Half>& other) {
443+
auto a = static_cast<float>(real_);
444+
auto b = static_cast<float>(imag_);
445+
auto c = static_cast<float>(other.real());
446+
auto d = static_cast<float>(other.imag());
447+
real_ = a * c - b * d;
448+
imag_ = a * d + b * c;
449+
return *this;
450+
}
429451
};
430452

431453
// In some versions of MSVC, there will be a compiler error when building.

torch/testing/_internal/common_methods_invocations.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def to_numpy(t):
209209
if isinstance(t, torch.Tensor):
210210
if t.dtype is torch.bfloat16:
211211
return t.detach().cpu().to(torch.float32).numpy()
212+
if t.dtype is torch.chalf:
213+
return t.detach().cpu().to(torch.cfloat).numpy()
212214
return t.detach().cpu().numpy()
213215
elif isinstance(t, torch.dtype):
214216
return torch_to_numpy_dtype_dict[t]
@@ -9592,12 +9594,18 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
95929594
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
95939595
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
95949596
else np.add(input, np.multiply(alpha, other)),
9595-
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
9597+
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
9598+
torch.float16, torch.chalf),
95969599
assert_autodiffed=True,
95979600
sample_inputs_func=sample_inputs_add_sub,
95989601
supports_fwgrad_bwgrad=True,
95999602
supports_forward_ad=True,
96009603
supports_two_python_scalars=True,
9604+
decorators=(
9605+
DecorateInfo(
9606+
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
9607+
'TestBinaryUfuncs', 'test_reference_numerics'),
9608+
),
96019609
skips=(
96029610
# boolean alpha not handled properly
96039611
DecorateInfo(unittest.expectedFailure,
@@ -9629,7 +9637,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
96299637
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
96309638
ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)),
96319639
aliases=('subtract',),
9632-
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
9640+
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf),
96339641
assert_autodiffed=True,
96349642
supports_forward_ad=True,
96359643
supports_fwgrad_bwgrad=True,
@@ -9639,6 +9647,15 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
96399647
DecorateInfo(
96409648
toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}),
96419649
'TestBinaryUfuncs', 'test_reference_numerics'),
9650+
DecorateInfo(
9651+
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
9652+
'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
9653+
DecorateInfo(
9654+
toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
9655+
'TestDecomp', 'test_comprehensive', device_type='cpu'),
9656+
DecorateInfo(
9657+
toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
9658+
'TestDecomp', 'test_quick', device_type='cpu'),
96429659
),
96439660
skips=(
96449661
DecorateInfo(unittest.skip("Skipped!"),
@@ -10241,7 +10258,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1024110258
supports_out=False,
1024210259
),
1024310260
OpInfo('resolve_neg',
10244-
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
10261+
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
1024510262
sample_inputs_func=sample_inputs_view_as_real,
1024610263
supports_forward_ad=True,
1024710264
supports_fwgrad_bwgrad=True,
@@ -13812,13 +13829,29 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1381213829
UnaryUfuncInfo('neg',
1381313830
aliases=('negative', ),
1381413831
ref=np.negative,
13815-
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
13832+
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
1381613833
error_inputs_func=error_inputs_neg,
1381713834
supports_forward_ad=True,
1381813835
supports_fwgrad_bwgrad=True,
1381913836
supports_sparse=True,
1382013837
supports_sparse_csr=True,
13821-
assert_autodiffed=True,),
13838+
assert_autodiffed=True,
13839+
skips=(
13840+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13841+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
13842+
dtypes=(torch.chalf,),),
13843+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13844+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
13845+
dtypes=(torch.chalf,),),
13846+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13847+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
13848+
dtypes=(torch.chalf,),),
13849+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
13850+
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR',
13851+
'test_zero_to_zero_correspondence_unary',
13852+
dtypes=(torch.chalf,),)
13853+
13854+
)),
1382213855
OpInfo('dist',
1382313856
op=torch.dist,
1382413857
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -18104,6 +18137,18 @@ def __init__(
1810418137
ElementwiseUnaryPythonRefInfo(
1810518138
"_refs.neg",
1810618139
torch_opinfo_name="neg",
18140+
skips=(
18141+
# On CPU
18142+
# RuntimeError: unsupported Storage type: torch.complex32
18143+
# https://github.com/pytorch/pytorch/issues/73502
18144+
# On CUDA
18145+
# RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
18146+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_consistency',
18147+
dtypes=(torch.chalf,)),
18148+
# Same reason as `test_python_reference_consistency`
18149+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions',
18150+
dtypes=(torch.chalf,)),
18151+
)
1810718152
),
1810818153
ElementwiseUnaryPythonRefInfo(
1810918154
"_refs.reciprocal",
@@ -18321,6 +18366,7 @@ def __init__(
1832118366
{
1832218367
torch.bfloat16: tol(atol=1, rtol=0),
1832318368
torch.float16: tol(atol=1e-2, rtol=0),
18369+
torch.chalf: tol(atol=1e-2, rtol=0),
1832418370
}
1832518371
),
1832618372
"TestCommon",

0 commit comments

Comments
 (0)