@@ -209,6 +209,8 @@ def to_numpy(t):
209
209
if isinstance(t, torch.Tensor):
210
210
if t.dtype is torch.bfloat16:
211
211
return t.detach().cpu().to(torch.float32).numpy()
212
+ if t.dtype is torch.chalf:
213
+ return t.detach().cpu().to(torch.cfloat).numpy()
212
214
return t.detach().cpu().numpy()
213
215
elif isinstance(t, torch.dtype):
214
216
return torch_to_numpy_dtype_dict[t]
@@ -9592,12 +9594,18 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
9592
9594
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
9593
9595
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
9594
9596
else np.add(input, np.multiply(alpha, other)),
9595
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
9597
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
9598
+ torch.float16, torch.chalf),
9596
9599
assert_autodiffed=True,
9597
9600
sample_inputs_func=sample_inputs_add_sub,
9598
9601
supports_fwgrad_bwgrad=True,
9599
9602
supports_forward_ad=True,
9600
9603
supports_two_python_scalars=True,
9604
+ decorators=(
9605
+ DecorateInfo(
9606
+ toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
9607
+ 'TestBinaryUfuncs', 'test_reference_numerics'),
9608
+ ),
9601
9609
skips=(
9602
9610
# boolean alpha not handled properly
9603
9611
DecorateInfo(unittest.expectedFailure,
@@ -9629,7 +9637,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
9629
9637
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
9630
9638
ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)),
9631
9639
aliases=('subtract',),
9632
- dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
9640
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf ),
9633
9641
assert_autodiffed=True,
9634
9642
supports_forward_ad=True,
9635
9643
supports_fwgrad_bwgrad=True,
@@ -9639,6 +9647,15 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
9639
9647
DecorateInfo(
9640
9648
toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}),
9641
9649
'TestBinaryUfuncs', 'test_reference_numerics'),
9650
+ DecorateInfo(
9651
+ toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
9652
+ 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
9653
+ DecorateInfo(
9654
+ toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
9655
+ 'TestDecomp', 'test_comprehensive', device_type='cpu'),
9656
+ DecorateInfo(
9657
+ toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
9658
+ 'TestDecomp', 'test_quick', device_type='cpu'),
9642
9659
),
9643
9660
skips=(
9644
9661
DecorateInfo(unittest.skip("Skipped!"),
@@ -10241,7 +10258,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
10241
10258
supports_out=False,
10242
10259
),
10243
10260
OpInfo('resolve_neg',
10244
- dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
10261
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf ),
10245
10262
sample_inputs_func=sample_inputs_view_as_real,
10246
10263
supports_forward_ad=True,
10247
10264
supports_fwgrad_bwgrad=True,
@@ -13812,13 +13829,29 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
13812
13829
UnaryUfuncInfo('neg',
13813
13830
aliases=('negative', ),
13814
13831
ref=np.negative,
13815
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
13832
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf ),
13816
13833
error_inputs_func=error_inputs_neg,
13817
13834
supports_forward_ad=True,
13818
13835
supports_fwgrad_bwgrad=True,
13819
13836
supports_sparse=True,
13820
13837
supports_sparse_csr=True,
13821
- assert_autodiffed=True,),
13838
+ assert_autodiffed=True,
13839
+ skips=(
13840
+ # RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13841
+ DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
13842
+ dtypes=(torch.chalf,),),
13843
+ # RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13844
+ DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
13845
+ dtypes=(torch.chalf,),),
13846
+ # RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
13847
+ DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
13848
+ dtypes=(torch.chalf,),),
13849
+ # RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
13850
+ DecorateInfo(unittest.expectedFailure, 'TestSparseCSR',
13851
+ 'test_zero_to_zero_correspondence_unary',
13852
+ dtypes=(torch.chalf,),)
13853
+
13854
+ )),
13822
13855
OpInfo('dist',
13823
13856
op=torch.dist,
13824
13857
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -18104,6 +18137,18 @@ def __init__(
18104
18137
ElementwiseUnaryPythonRefInfo(
18105
18138
"_refs.neg",
18106
18139
torch_opinfo_name="neg",
18140
+ skips=(
18141
+ # On CPU
18142
+ # RuntimeError: unsupported Storage type: torch.complex32
18143
+ # https://github.com/pytorch/pytorch/issues/73502
18144
+ # On CUDA
18145
+ # RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
18146
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_consistency',
18147
+ dtypes=(torch.chalf,)),
18148
+ # Same reason as `test_python_reference_consistency`
18149
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions',
18150
+ dtypes=(torch.chalf,)),
18151
+ )
18107
18152
),
18108
18153
ElementwiseUnaryPythonRefInfo(
18109
18154
"_refs.reciprocal",
@@ -18321,6 +18366,7 @@ def __init__(
18321
18366
{
18322
18367
torch.bfloat16: tol(atol=1, rtol=0),
18323
18368
torch.float16: tol(atol=1e-2, rtol=0),
18369
+ torch.chalf: tol(atol=1e-2, rtol=0),
18324
18370
}
18325
18371
),
18326
18372
"TestCommon",
0 commit comments