Skip to content

Commit b0782f0

Browse files
mingfeimafacebook-github-bot
authored andcommitted
add BFloat16 support for bernoulli and Dropout on CPU (pytorch#56372)
Summary: Pull Request resolved: pytorch#56372 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D28836792 Pulled By: VitalyFedyunin fbshipit-source-id: ede951d172a59276e11383fd767778ab959b5a6b
1 parent 7299565 commit b0782f0

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

aten/src/ATen/native/cpu/DistributionTemplates.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ struct ExponentialKernel {
308308

309309
template<typename RNG>
310310
void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
311-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
311+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
312312
// See Note [Acquire lock when using random generators]
313313
std::lock_guard<std::mutex> lock(generator->mutex_);
314314
using self_t = scalar_t;
@@ -325,7 +325,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
325325
return static_cast<self_t>(bernoulli(generator));
326326
});
327327
} else {
328-
AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
328+
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
329329
using p_t = scalar_t;
330330
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
331331
at::bernoulli_distribution<float> bernoulli(p_val);
@@ -338,7 +338,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
338338

339339
template<typename RNG>
340340
void bernoulli_kernel(Tensor& self, double p, RNG generator) {
341-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
341+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
342342
// See Note [Acquire lock when using random generators]
343343
std::lock_guard<std::mutex> lock(generator->mutex_);
344344
auto iter = TensorIterator::borrowing_nullary_op(self);

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ void bernoulli_scalar_kernel(Tensor &self, double p, c10::optional<Generator> ge
488488
int64_t n = self.numel();
489489
bool contig = self.is_contiguous();
490490

491-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
491+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
492492
at::Tensor tmp_int_tensor;
493493
if (std::is_same<scalar_t, int>::value && contig) {
494494
tmp_int_tensor = self;

test/test_nn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12984,7 +12984,7 @@ def test_Dropout(self, device):
1298412984

1298512985
self._test_dropout_stride_mean_preserve(nn.Dropout, device)
1298612986

12987-
if self.device_type == 'cuda':
12987+
if self.device_type == 'cuda' or self.device_type == 'cpu':
1298812988
input = input.bfloat16()
1298912989
self._test_dropout(nn.Dropout, device, input)
1299012990

test/test_torch.py

+1
Original file line numberDiff line numberDiff line change
@@ -4324,6 +4324,7 @@ def test_repeat_interleave(self, device):
43244324
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
43254325

43264326
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
4327+
@dtypesIfCPU(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
43274328
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
43284329
def test_bernoulli_p(self, device, dtype):
43294330
for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):

0 commit comments

Comments
 (0)