diff --git a/src/ATen/native/xpu/Normalization.cpp b/src/ATen/native/xpu/Normalization.cpp index 29236e553..3bc170da6 100644 --- a/src/ATen/native/xpu/Normalization.cpp +++ b/src/ATen/native/xpu/Normalization.cpp @@ -47,7 +47,10 @@ Tensor& renorm_impl( reduce_dims.erase(reduce_dims.begin() + dim); auto dtype = self.scalar_type(); - auto acc_type = at::toAccumulateType(dtype, c10::DeviceType::XPU); + + // This is a device-independent accumulate type, and we follow PyTorch's design. + auto acc_type = at::toAccumulateType(dtype, true); + Tensor norm; if (acc_type != dtype) { norm = at::linalg_vector_norm(