Skip to content

Commit

Permalink
Update Normalization.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Aug 26, 2024
1 parent e73a8a9 commit df48a3e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/ATen/native/xpu/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ Tensor& renorm_impl(
reduce_dims.erase(reduce_dims.begin() + dim);

auto dtype = self.scalar_type();
auto acc_type = at::toAccumulateType(dtype, c10::DeviceType::XPU);

// This is a device-independent accumulate type, and we follow PyTorch's design.
auto acc_type = at::toAccumulateType(dtype, true);

Tensor norm;
if (acc_type != dtype) {
norm = at::linalg_vector_norm(
Expand Down

0 comments on commit df48a3e

Please sign in to comment.