Skip to content

Commit

Permalink
Add aten::masked_scatter_ (#652)
Browse files Browse the repository at this point in the history
Add aten::masked_scatter_.

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Jul 30, 2024
1 parent 8226837 commit 67116b3
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 10 deletions.
30 changes: 30 additions & 0 deletions src/ATen/native/xpu/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ Tensor XPUNativeFunctions::index_select(
return index_select_out(self, dim, index, out);
}

Tensor& XPUNativeFunctions::masked_scatter_(
Tensor& self,
const Tensor& mask,
const Tensor& source) {
at::assert_no_internal_overlap(self);
TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
"masked_scatter_: expected self and source to have same dtypes but got ",
self.scalar_type(),
" and ",
source.scalar_type());
TORCH_CHECK(
mask.dtype() == ScalarType::Bool,
"masked_scatter_ only supports boolean masks, "
"but got mask with dtype ",
mask.dtype());

c10::MaybeOwned<Tensor> b_mask =
expand_inplace(self, mask, "masked_scatter_");

if (self.numel() == 0) {
return self;
}

auto maskPrefixSum = at::empty(self.sizes(), mask.options().dtype(kLong));
native::xpu::masked_scatter_kernel(self, *b_mask, maskPrefixSum, source);

return self;
}

static Tensor& masked_select_out_impl(
Tensor& result,
const Tensor& self,
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"log_normal_",
"logspace.out",
"lu_unpack.out",
"masked_scatter_",
"max_pool3d_with_indices",
"max_pool3d_with_indices_backward",
"max_unpool2d",
Expand Down
90 changes: 90 additions & 0 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ void index_put_deterministic_kernel(

if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
auto expanded_size = at::DimVector(expandedValue.sizes());

auto size1 = expandedValue.sizes();
auto size2 = linearIndex.sizes();
if (are_expandable(size1, size2)) {
Expand Down Expand Up @@ -667,6 +668,95 @@ void index_put_deterministic_kernel(
}
}

template <typename scalar_t>
struct MaskedScatterElementwiseFunctor {
scalar_t operator()(
const scalar_t a,
const bool mask,
const int64_t maskPrefixSum) const {
if (mask) {
return source_ptr_[maskPrefixSum];
}
return a;
}
MaskedScatterElementwiseFunctor(const scalar_t* source_ptr)
: source_ptr_(source_ptr) {}

private:
const scalar_t* source_ptr_;
};

struct MaskedScatterSizeCheckFunctor {
void operator()(sycl::nd_item<1> item) const {
const auto totalElements = *mask_exclusive_sum_ + *mask_;
SYCL_KERNEL_ASSERT(totalElements <= srcSize_);
}
MaskedScatterSizeCheckFunctor(
const int64_t* const mask_exclusive_sum,
const bool* const mask,
const int64_t srcSize)
: mask_exclusive_sum_(mask_exclusive_sum),
mask_(mask),
srcSize_(srcSize) {}

private:
const int64_t* const mask_exclusive_sum_;
const bool* const mask_;
const int64_t srcSize_;
};

void masked_scatter_kernel(
const TensorBase& self,
const TensorBase& mask,
const TensorBase& maskPrefixSum,
const TensorBase& source) {
const auto srcSize = source.numel();
const auto mask_cont = mask.contiguous();
const auto mask_numel = mask.numel();

// Use a prefix sum to determine the output locations of the masked elements
auto maskPrefixSum_data = maskPrefixSum.mutable_data_ptr<int64_t>();
auto mask_data = mask_cont.const_data_ptr<bool>();

pstl::exclusive_scan(
mask_data,
mask_data + mask_numel,
maskPrefixSum_data,
static_cast<int64_t>(0));

// Asynchronously check that the number of `1` elements present in the mask
// must be <= the number of elements available in `src`.
auto caller = MaskedScatterSizeCheckFunctor(
&maskPrefixSum_data[mask_numel - 1], &mask_data[mask_numel - 1], srcSize);
sycl_kernel_submit((size_t)1, (size_t)1, getCurrentSYCLQueue(), caller);

// We are getting elements from `src` based on an offset from
// `maskPrefixSum`, so that should be made contiguous too
auto source_contig = source.contiguous();

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self)
.add_input(self)
.add_const_input(mask_cont)
.add_input(maskPrefixSum)
.build();

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::Bool,
ScalarType::BFloat16,
ScalarType::Half,
self.scalar_type(),
"masked_scatter_",
[&]() {
auto source_ptr = source_contig.const_data_ptr<scalar_t>();
gpu_kernel(iter, MaskedScatterElementwiseFunctor<scalar_t>(source_ptr));
});
}

} // namespace at::native::xpu

#pragma GCC diagnostic pop
#pragma clang diagnostic pop
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/IndexingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ void index_put_deterministic_kernel(
bool accumulate,
bool unsafe);

void masked_scatter_kernel(
const TensorBase& self,
const TensorBase& mask,
const TensorBase& maskPrefixSum,
const TensorBase& source);

} // namespace at::native::xpu
13 changes: 6 additions & 7 deletions src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#pragma once

#include <ATen/ceil_div.h>
#include <ATen/record_function.h>

#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/MemoryAccessUtils.h>
#include <ATen/native/xpu/sycl/SortingKernels.h>
#include <ATen/record_function.h>
#include <comm/SYCLContext.h>
#include <comm/SYCLHelpers.h>
#include <comm/TensorOptions.h>
Expand All @@ -23,10 +22,10 @@ struct KSScanKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
// initialize local_input
auto cur_init = init_;
if (scan_type == 1) {
local_scan_[local_id] = first_[local_id];
local_scan_[local_id] = c10::load(&first_[local_id]);
} else {
if (local_id > 0)
local_scan_[local_id] = first_[local_id - 1];
local_scan_[local_id] = c10::load(&first_[local_id - 1]);
else
local_scan_[local_id] = 0;
}
Expand Down Expand Up @@ -72,17 +71,17 @@ struct KSScanWithCarrierKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto cur_init = (group_id == 0 ? init_ : 0);
if (global_id < N_) {
if (scan_type == 1) {
local_scan_[local_id] = first_[global_id];
local_scan_[local_id] = c10::load(&first_[global_id]);
} else {
if (local_id > 0)
local_scan_[local_id] = first_[global_id - 1];
local_scan_[local_id] = c10::load(&first_[global_id - 1]);
else
local_scan_[local_id] = 0;
}
if (local_id == 0)
local_scan_[local_id] += cur_init;
if (local_id == wgroup_size_ - 1) {
carry_ptr_[group_id] = first_[global_id];
carry_ptr_[group_id] = c10::load(&first_[global_id]);
}
}
item_id.barrier(sycl_local_fence);
Expand Down
4 changes: 2 additions & 2 deletions test/xpu/test_torch_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3995,11 +3995,11 @@ def test_masked_scatter(self, device, dtype):
dest_ones.masked_scatter_(mask, src_ones)
self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)

# Bound checking in CUDA is done inside a kernel
# Bound checking in GPU is done inside a kernel
# in order to avoid synchronization, but this means
# we can not clear the failures. So there is no way
# to test it then recover.
if self.device_type != 'cuda' or self.device_type != 'xpu':
if self.device_type != 'cuda' and self.device_type != 'xpu':
# make src smaller. this should fail
src = torch.zeros(num_copy - 1, dtype=dt, device=device)
with self.assertRaises(RuntimeError):
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"index_fill",
"index_put",
"index_select",
"masked_scatter",
"masked_select",
"isin",
"isnan",
Expand Down
1 change: 1 addition & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ supported:
- isnan.out
- masked_fill_.Tensor
- masked_fill_.Scalar
- masked_scatter_
- index_add.out
- index_add_
- index_add
Expand Down

0 comments on commit 67116b3

Please sign in to comment.