@@ -308,7 +308,7 @@ struct ExponentialKernel {
308
308
309
309
template <typename RNG>
310
310
void bernoulli_kernel (Tensor& self, const Tensor& p_, RNG generator) {
311
- AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Bool, self.scalar_type (), " bernoulli_tensor_cpu_self_" , [&] {
311
+ AT_DISPATCH_ALL_TYPES_AND2 (at::ScalarType::Bool, at::ScalarType::BFloat16 , self.scalar_type (), " bernoulli_tensor_cpu_self_" , [&] {
312
312
// See Note [Acquire lock when using random generators]
313
313
std::lock_guard<std::mutex> lock (generator->mutex_ );
314
314
using self_t = scalar_t ;
@@ -325,7 +325,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
325
325
return static_cast <self_t >(bernoulli (generator));
326
326
});
327
327
} else {
328
- AT_DISPATCH_FLOATING_TYPES ( p_.scalar_type (), " bernoulli_tensor_cpu_p_" , [&] {
328
+ AT_DISPATCH_FLOATING_TYPES_AND (at::ScalarType::BFloat16, p_.scalar_type (), " bernoulli_tensor_cpu_p_" , [&] {
329
329
using p_t = scalar_t ;
330
330
cpu_serial_kernel (iter, [&](const p_t p_val) -> self_t {
331
331
at::bernoulli_distribution<float > bernoulli (p_val);
@@ -338,7 +338,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
338
338
339
339
template <typename RNG>
340
340
void bernoulli_kernel (Tensor& self, double p, RNG generator) {
341
- AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Bool, self.scalar_type (), " bernoulli_scalar_cpu_" , [&] {
341
+ AT_DISPATCH_ALL_TYPES_AND2 (at::ScalarType::Bool, at::ScalarType::BFloat16 , self.scalar_type (), " bernoulli_scalar_cpu_" , [&] {
342
342
// See Note [Acquire lock when using random generators]
343
343
std::lock_guard<std::mutex> lock (generator->mutex_ );
344
344
auto iter = TensorIterator::borrowing_nullary_op (self);
0 commit comments