Skip to content

Commit da0820e

Browse files
jiayisunxfacebook-github-bot
authored andcommitted
add BFloat16 operators on CPU: range, sinh, cosh, frexp, nan_to_num (pytorch#61826)
Summary: Added BFloat16 support for range, sinh, cosh, frexp, and nan_to_num on CPU, and collected the benchmark data of these OPs(range, sinh, cosh, frexp, and nan_to_num) for BFloat16 and Float32 data type by using the operator_benchmark tool of PyTorch on the platform of Intel(R) Xeon(R) Platinum 8180 CPU @ 2.50GHz Number of cores: 1 core, 28 cores(1 socket) [cosh_sinh_benchmark.txt](https://github.com/pytorch/pytorch/files/6974313/cosh_sinh_benchmark.txt) [frexp_benchmark.txt](https://github.com/pytorch/pytorch/files/6974315/frexp_benchmark.txt) [nan_to_num_benchmark.txt](https://github.com/pytorch/pytorch/files/6974317/nan_to_num_benchmark.txt) [range_benchmark.txt](https://github.com/pytorch/pytorch/files/6974318/range_benchmark.txt) Pull Request resolved: pytorch#61826 Reviewed By: saketh-are Differential Revision: D30257259 Pulled By: VitalyFedyunin fbshipit-source-id: 394cd713e6394050a8c90b2160633beb675d71dd
1 parent a8de0d8 commit da0820e

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

aten/src/ATen/native/RangeFactories.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Tensor& logspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<i
113113
}
114114

115115
Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
116-
AT_DISPATCH_ALL_TYPES(result.scalar_type(), "range_cpu", [&]() {
116+
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "range_cpu", [&]() {
117117
using accscalar_t = at::acc_type<scalar_t, false>;
118118
auto xstart = start.to<accscalar_t>();
119119
auto xend = end.to<accscalar_t>();
@@ -133,7 +133,7 @@ Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step
133133
scalar_t *data_ptr = r.data_ptr<scalar_t>();
134134

135135
at::parallel_for(0, size, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
136-
scalar_t is = p_begin;
136+
accscalar_t is = p_begin;
137137
for (int64_t i = p_begin; i < p_end; ++i, ++is) {
138138
data_ptr[i] = xstart + is * xstep;
139139
}

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ static void sinc_kernel(TensorIteratorBase& iter) {
322322
}
323323

324324
static void sinh_kernel(TensorIteratorBase& iter) {
325-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sinh_cpu", [&]() {
325+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "sinh_cpu", [&]() {
326326
cpu_kernel_vec(
327327
iter,
328328
[=](scalar_t a) -> scalar_t { return std::sinh(a); },
@@ -331,7 +331,7 @@ static void sinh_kernel(TensorIteratorBase& iter) {
331331
}
332332

333333
static void cosh_kernel(TensorIteratorBase& iter) {
334-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "cosh_cpu", [&]() {
334+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "cosh_cpu", [&]() {
335335
cpu_kernel_vec(
336336
iter,
337337
[=](scalar_t a) -> scalar_t { return std::cosh(a); },
@@ -407,7 +407,7 @@ static void nan_to_num_kernel(
407407
c10::optional<double> nan,
408408
c10::optional<double> pos_inf,
409409
c10::optional<double> neg_inf) {
410-
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "nan_to_num", [&]() {
410+
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "nan_to_num", [&]() {
411411
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
412412
scalar_t pos_inf_replacement = pos_inf.has_value()
413413
? static_cast<scalar_t>(pos_inf.value())
@@ -586,7 +586,7 @@ static void entr_kernel(TensorIteratorBase& iter) {
586586
}
587587

588588
static void frexp_kernel(TensorIteratorBase& iter) {
589-
AT_DISPATCH_FLOATING_TYPES_AND(kHalf,
589+
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf,
590590
// The iter.dtype() here is the dtype of mantissa output.
591591
// It's a floating point type and must be the same as the input's dtype.
592592
iter.dtype(),

c10/util/BFloat16-math.h

+6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ inline c10::BFloat16 sin(c10::BFloat16 a) {
5757
inline c10::BFloat16 tan(c10::BFloat16 a) {
5858
return std::tan(float(a));
5959
}
60+
inline c10::BFloat16 sinh(c10::BFloat16 a) {
61+
return std::sinh(float(a));
62+
}
63+
inline c10::BFloat16 cosh(c10::BFloat16 a) {
64+
return std::cosh(float(a));
65+
}
6066
inline c10::BFloat16 tanh(c10::BFloat16 a) {
6167
return std::tanh(float(a));
6268
}

torch/testing/_internal/common_methods_invocations.py

+4
Original file line numberDiff line numberDiff line change
@@ -6028,6 +6028,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
60286028
UnaryUfuncInfo('cosh',
60296029
ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
60306030
dtypes=all_types_and_complex_and(torch.bool),
6031+
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
60316032
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
60326033
safe_casts_outputs=True,
60336034
assert_autodiffed=True,
@@ -6413,6 +6414,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
64136414
op=torch.frexp,
64146415
ref=np.frexp,
64156416
dtypes=floating_types_and(torch.half),
6417+
dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16),
64166418
# skip testing torch.frexp as it is not supported by ROCm platform yet
64176419
decorators=[skipCUDAIfRocm],
64186420
supports_out=False,
@@ -7432,6 +7434,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
74327434
UnaryUfuncInfo('sinh',
74337435
ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
74347436
dtypes=all_types_and_complex_and(torch.bool),
7437+
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
74357438
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
74367439
safe_casts_outputs=True,
74377440
assert_autodiffed=True,
@@ -7753,6 +7756,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
77537756
UnaryUfuncInfo('nan_to_num',
77547757
ref=np.nan_to_num,
77557758
dtypes=all_types_and(torch.half, torch.bool),
7759+
dtypesIfCPU=all_types_and(torch.half, torch.bool, torch.bfloat16),
77567760
dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16),
77577761
supports_forward_ad=True,
77587762
# Passing numpy_kwargs via sample_kwargs, as numpy does comparison

0 commit comments

Comments
 (0)