Skip to content

Commit

Permalink
_foreach_norm: Align with PyTorch operator semantics on allocation sc…
Browse files Browse the repository at this point in the history
…heme of return tensors (#709)

PyTorch requires separate copies returned in foreach_norm. The existing
XPU implementation follows an out-of-date allocation scheme, to share
storage among returned tensor. In latest PyTorch unit test, the behavior
is not allowed.
related case:
- test_dispatch_meta_outplace__foreach_norm_xpu_bfloat16
- test_dispatch_meta_outplace__foreach_norm_xpu_float
- test_dispatch_symbolic_meta_outplace__foreach_norm_xpu_bfloat16
- test_dispatch_symbolic_meta_outplace__foreach_norm_xpu_float
-
test_dispatch_symbolic_meta_outplace_all_strides__foreach_norm_xpu_float32
- test_meta_outplace__foreach_norm_xpu_bfloat16
- test_meta_outplace__foreach_norm_xpu_float

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
chunhuanMeng and fengyuan14 authored Aug 19, 2024
1 parent 5a47831 commit 7eb5219
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions src/ATen/native/xpu/sycl/ForeachReduceKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ struct lpnormChunkReduceKernelFunctor {
auto sum_val = sycl::reduce_over_group(
item_id.get_group(), val, sycl::plus<opmath_t>());
if (lid == 0) {
ret_per_tensor_[group_id] =
*(ret_per_tensor_[group_id]) =
norm_type == NormType::L1 ? sum_val : std::sqrt((opmath_t)sum_val);
}
}
lpnormChunkReduceKernelFunctor(
const opmath_t* output_per_tensor,
out_t* ret_per_tensor,
out_t** ret_per_tensor,
int max_chunks_per_tensor,
int wg_size)
: output_per_tensor_(output_per_tensor),
Expand All @@ -122,15 +122,15 @@ struct lpnormChunkReduceKernelFunctor {

private:
const opmath_t* output_per_tensor_;
out_t* ret_per_tensor_;
out_t** ret_per_tensor_;
int max_chunks_per_tensor_;
int wg_size_;
};

template <typename out_t, NormType norm_type, typename out_opmath_t>
void launch_lpnorm_chunk_reduce_kernel(
const out_opmath_t* output_per_tensor,
out_t* ret_per_tensor,
out_t** ret_per_tensor,
int wg_size,
int max_chunks_per_tensor,
int n_tensor) {
Expand Down Expand Up @@ -194,8 +194,22 @@ std::vector<Tensor> foreach_norm_kernel(
dtype.has_value() ? dtype.value() : tensors[0].scalar_type();
const auto options = tensors[0].options();
auto output_per_tensor_option = options.dtype(toOpMathType(output_dtype));
auto res_option = options.dtype(output_dtype);
auto ret_per_tensor = at::empty({ntensors}, res_option);
std::vector<at::Tensor> ret_per_tensor;
ret_per_tensor.reserve(ntensors);
const auto res_option = options.dtype(output_dtype);
for (int i = 0; i < ntensors; i++) {
ret_per_tensor.push_back(at::empty({}, res_option));
}
auto& q = getCurrentSYCLQueue();
auto addressStorage =
at::empty({(int)(sizeof(void*) * ntensors)}, options.dtype(at::kByte));
auto metaAddress = static_cast<void**>(addressStorage.mutable_data_ptr());
void** tensor_list_addresses = nullptr;

auto tensor_list_addresses_dptr =
at::xpu::HostAlloc(sizeof(void*) * ntensors);
tensor_list_addresses = (void**)tensor_list_addresses_dptr.get();

auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};

int64_t wg_size;
Expand Down Expand Up @@ -228,14 +242,25 @@ std::vector<Tensor> foreach_norm_kernel(
LpNormFunctor<scalar_t, NormType::L1, out_opmath_t>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
for (int i = 0; i < ntensors; i++) {
tensor_list_addresses[i] =
ret_per_tensor[i].mutable_data_ptr<out_t>();
}
q.memcpy(
(void*)metaAddress,
(void*)tensor_list_addresses,
sizeof(void*) * ntensors);

// sum final val for all chunks
at::xpu::CachingHostAllocator_recordEvent(
(void*)tensor_list_addresses,
tensor_list_addresses_dptr.get_context(),
at::xpu::getCurrentXPUStream());
launch_lpnorm_chunk_reduce_kernel<
out_t,
NormType::L1,
out_opmath_t>(
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
ret_per_tensor.mutable_data_ptr<out_t>(),
(out_t**)(metaAddress),
wg_size,
max_chunks_per_tensor,
ntensors);
Expand Down Expand Up @@ -267,13 +292,25 @@ std::vector<Tensor> foreach_norm_kernel(
LpNormFunctor<scalar_t, NormType::L2, out_opmath_t>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
for (int i = 0; i < ntensors; i++) {
tensor_list_addresses[i] =
ret_per_tensor[i].mutable_data_ptr<out_t>();
}
q.memcpy(
(void*)metaAddress,
(void*)tensor_list_addresses,
sizeof(void*) * ntensors);

at::xpu::CachingHostAllocator_recordEvent(
(void*)tensor_list_addresses,
tensor_list_addresses_dptr.get_context(),
at::xpu::getCurrentXPUStream());
launch_lpnorm_chunk_reduce_kernel<
out_t,
NormType::L2,
out_opmath_t>(
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
ret_per_tensor.mutable_data_ptr<out_t>(),
(out_t**)(metaAddress),
wg_size,
max_chunks_per_tensor,
ntensors);
Expand All @@ -288,7 +325,6 @@ std::vector<Tensor> foreach_norm_kernel(
for (const auto& i : c10::irange(ntensors)) {
result.emplace_back(ret_per_tensor[i]);
}

return result;
}

Expand Down

0 comments on commit 7eb5219

Please sign in to comment.