Skip to content

Commit

Permalink
Add aten::batch_norm_gather_stats (#932)
Browse files Browse the repository at this point in the history
- [x] batch_norm_gather_stats
- [x] batch_norm_gather_stats_with_counts
  • Loading branch information
xytintel authored Oct 15, 2024
1 parent bf4307b commit 76ee14f
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/ATen/native/xpu/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,5 +336,31 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_xpu(
grad_input_mask);
}

std::tuple<Tensor, Tensor> batch_norm_gather_stats_xpu(
const Tensor& input,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
double momentum,
double eps,
int64_t count) {
return xpu::batch_norm_gather_stats_kernel(
input, mean, invstd, running_mean, running_var, momentum, eps, count);
}

std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_xpu(
const Tensor& input,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
double momentum,
double eps,
const Tensor& counts) {
return xpu::batch_norm_gather_stats_with_counts_kernel(
input, mean, invstd, running_mean, running_var, momentum, eps, counts);
}

} // namespace native
} // namespace at
255 changes: 254 additions & 1 deletion src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <comm/SYCLContext.h>
#include <comm/XPUMathCompat.h>
#include <comm/xpu_aten.h>
#include <ATen/ops/from_blob.h>

#include <ATen/native/xpu/sycl/BatchNormKernels.h>

Expand Down Expand Up @@ -1170,7 +1171,7 @@ struct BatchNormTransformInputKernelFunctor {
} else {
invstd =
static_cast<stat_accscalar_t>(1) /
device_sqrt(
std::sqrt(
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
}

Expand Down Expand Up @@ -4017,6 +4018,258 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_kernel(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

template <typename scalar_t, typename accscalar_t, typename index_t>
struct BatchNormReduceStatisticsKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
int feature_size = vec_mean_.size(1);
int world_size = vec_mean_.size(0);

int bid = item.get_group(0);
int tid = item.get_local_id(0);
int group_size_x = item.get_local_range(0);

auto mean = mean_;
auto invstd = invstd_;
auto running_mean = running_mean_;
auto running_var = running_var_;

// first the reductions each thread does separately
for (int i = bid * group_size_x + tid; i < feature_size;
i += item.get_group_range(0) * group_size_x) {
accscalar_t avg = 0;
accscalar_t var_n = 0;
index_t n = 0;
for (int j = 0; j < world_size; j++) {
scalar_t count = counts_[j];
accscalar_t m = vec_mean_[j][i];
accscalar_t v = accscalar_t(1.0) / (vec_invstd_[j][i]);
v = (v * v - epsilon_) * count;
accscalar_t factor = 1.0 / (n + count);
var_n += v + (avg - m) * (avg - m) * n * count * factor;
avg = n * factor * avg + count * factor * m;
n += count;
}
mean[i] = avg;
invstd[i] = static_cast<accscalar_t>(1) / std::sqrt(var_n / n + epsilon_);
if (running_mean.data() != NULL) {
running_mean[i] = static_cast<scalar_t>(
(1 - momentum_) * running_mean[i] + momentum_ * avg);
}
accscalar_t unbiasedVar = var_n / (n - 1);
if (running_var.data() != NULL) {
running_var[i] = static_cast<scalar_t>(
(1 - momentum_) * running_var[i] + momentum_ * unbiasedVar);
}
}
}
BatchNormReduceStatisticsKernelFunctor(
const GenericPackedTensorAccessor<
accscalar_t,
2,
RestrictPtrTraits,
index_t> vec_mean,
const GenericPackedTensorAccessor<
accscalar_t,
2,
RestrictPtrTraits,
index_t> vec_invstd,
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t>
mean,
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t>
invstd,
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
running_mean,
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
running_var,
const accscalar_t epsilon,
const accscalar_t momentum,
const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
counts)
: vec_mean_(vec_mean),
vec_invstd_(vec_invstd),
mean_(mean),
invstd_(invstd),
running_mean_(running_mean),
running_var_(running_var),
epsilon_(epsilon),
momentum_(momentum),
counts_(counts) {}

private:
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t>
vec_mean_;
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t>
vec_invstd_;
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean_;
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t>
invstd_;
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
running_mean_;
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
running_var_;
const accscalar_t epsilon_;
const accscalar_t momentum_;
const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t>
counts_;
};

template <typename scalar_t, typename accscalar_t, typename index_t>
std::tuple<Tensor, Tensor> batch_norm_gather_stats_kernel_template(
const Tensor& mean_,
const Tensor& invstd_,
const Tensor& running_mean_,
const Tensor& running_var_,
double momentum,
double epsilon,
const Tensor& counts_) {
Tensor save_mean_;
Tensor save_invstd_;

auto features = mean_.size(1);
auto input_options = mean_.options();
if (mean_.scalar_type() == at::ScalarType::Half ||
mean_.scalar_type() == at::ScalarType::BFloat16) {
input_options = input_options.dtype(ScalarType::Float);
}
save_mean_ = at::empty({features}, input_options);
save_invstd_ = at::empty({features}, input_options);

auto mean =
packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(
mean_, "mean");
auto invstd =
packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(
invstd_, "invstd");
auto running_mean =
packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(
running_mean_, "running_mean");
auto running_var =
packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(
running_var_, "running_mean");
auto counts =
packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(
counts_, "counts");

auto save_mean =
get_packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>(
save_mean_, "save_mean");
auto save_invstd =
get_packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>(
save_invstd_, "save_invstd");

using KernelClass =
BatchNormReduceStatisticsKernelFunctor<scalar_t, accscalar_t, index_t>;

int group_size_x = get_num_threads<KernelClass>(features);
sycl::range<1> local_range(group_size_x);
sycl::range<1> global_range(
group_size_x * std::max<int>(1, features / group_size_x));

auto caller = KernelClass(
mean,
invstd,
save_mean,
save_invstd,
running_mean,
running_var,
epsilon,
momentum,
counts);
sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), caller);

return std::make_tuple(save_mean_, save_invstd_);
}

std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_kernel(
const Tensor& self,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean_opt /* optional */,
const std::optional<Tensor>& running_var_opt /* optional */,
double momentum,
double epsilon,
const Tensor& counts) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> running_mean_maybe_owned =
at::borrow_from_optional_tensor(running_mean_opt);
const Tensor& running_mean = *running_mean_maybe_owned;
const Tensor& running_var =
c10::value_or_else(running_var_opt, [] { return Tensor(); });

auto scalar_type =
running_mean.defined() ? running_mean.scalar_type() : self.scalar_type();
return AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"batch_norm_update_stats_kernel",
[&] {
using accscalar_t = acc_type_device<scalar_t, kXPU>;
if (canUse32BitIndexMath(self)) {
return batch_norm_gather_stats_kernel_template<
scalar_t,
accscalar_t,
int32_t>(
mean,
invstd,
running_mean,
running_var,
momentum,
epsilon,
counts);
} else {
return batch_norm_gather_stats_kernel_template<
scalar_t,
accscalar_t,
int64_t>(
mean,
invstd,
running_mean,
running_var,
momentum,
epsilon,
counts);
}
});
}

// accepting input(self) here to determine template data types, since
// running_mean/running_var are optional
std::tuple<Tensor, Tensor> batch_norm_gather_stats_kernel(
const Tensor& self,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
double momentum,
double epsilon,
int64_t count) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> running_mean_maybe_owned =
at::borrow_from_optional_tensor(running_mean_opt);
const Tensor& running_mean = *running_mean_maybe_owned;
const Tensor& running_var =
c10::value_or_else(running_var_opt, [] { return Tensor(); });

std::vector<int64_t> counts(mean.size(0), count);
Tensor counts_ = at::from_blob(
(void*)counts.data(),
{(int64_t)counts.size()},
self.options().dtype(at::kLong).device(at::kCPU));
counts_ =
counts_.to(self.device())
.to(running_mean.defined() ? running_mean.dtype() : self.dtype());
return batch_norm_gather_stats_with_counts_kernel(
self,
mean,
invstd,
running_mean,
running_var,
momentum,
epsilon,
counts_);
}

} // namespace xpu
} // namespace native
} // namespace at
21 changes: 21 additions & 0 deletions src/ATen/native/xpu/sycl/BatchNormKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,27 @@ TORCH_XPU_API std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_kernel(
double epsilon,
std::array<bool, 3> grad_input_mask);

TORCH_XPU_API std::tuple<Tensor, Tensor>
batch_norm_gather_stats_with_counts_kernel(
const Tensor& self,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean_opt /* optional */,
const std::optional<Tensor>& running_var_opt /* optional */,
double momentum,
double epsilon,
const Tensor& counts);

TORCH_XPU_API std::tuple<Tensor, Tensor> batch_norm_gather_stats_kernel(
const Tensor& self,
const Tensor& mean,
const Tensor& invstd,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
double momentum,
double epsilon,
int64_t count);

} // namespace xpu
} // namespace native
} // namespace at
10 changes: 10 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3346,6 +3346,16 @@
dispatch:
XPU: _new_batch_norm_backward_xpu

- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)
dispatch:
XPU: batch_norm_gather_stats_xpu
autogen: batch_norm_gather_stats.out

- func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)
dispatch:
XPU: batch_norm_gather_stats_with_counts_xpu
autogen: batch_norm_gather_stats_with_counts.out

- func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
Expand Down

0 comments on commit 76ee14f

Please sign in to comment.