Skip to content

Commit e75fb64

Browse files
committed
rename functor and use fetch_max
1 parent 549aaab commit e75fb64

File tree

4 files changed

+27
-108
lines changed

4 files changed

+27
-108
lines changed

src/ATen/native/xpu/Distributions.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ Tensor& XPUNativeFunctions::multinomial_out(
277277
return result;
278278
}
279279

280-
at::native::xpu::multinomial_with_replacement_kernel(
281-
result, self, n_sample, gen);
280+
at::native::xpu::multinomial_kernel(result, self, n_sample, gen);
282281
return result;
283282
}
284283

src/ATen/native/xpu/sycl/Atomics.h

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -115,92 +115,6 @@ struct AtomicIntegerImpl<T, 8> {
115115
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
116116
}
117117

118-
template <typename T, size_t n>
119-
struct AtomicIntegerImplLocal;
120-
121-
template <typename T>
122-
struct AtomicIntegerImplLocal<T, 1> {
123-
template <typename func_t>
124-
inline void operator()(T* address, T val, const func_t& func) {
125-
size_t offset = (size_t)address & 3;
126-
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
127-
uint32_t assumed = *address_as_ui;
128-
uint32_t shift = offset * 8;
129-
uint32_t newval;
130-
uint32_t newval_byte;
131-
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);
132-
133-
do {
134-
newval = assumed;
135-
newval_byte = (newval >> shift) & 0xff;
136-
// preserve size in initial cast. Casting directly to uint32_t pads
137-
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
138-
newval = static_cast<uint8_t>(func(val, static_cast<T>(newval_byte)));
139-
newval = (assumed & ~(0x000000ff << shift)) | (newval << shift);
140-
} while (!target.compare_exchange_strong(assumed, newval));
141-
}
142-
};
143-
144-
template <typename T>
145-
struct AtomicIntegerImplLocal<T, 2> {
146-
template <typename func_t>
147-
inline void operator()(T* address, T val, const func_t& func) {
148-
size_t offset = (size_t)address & 2;
149-
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
150-
bool is_32_align = offset;
151-
uint32_t assumed = *address_as_ui;
152-
uint32_t newval;
153-
uint32_t newval_bytes;
154-
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);
155-
156-
do {
157-
newval = assumed;
158-
newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff;
159-
// preserve size in initial cast. Casting directly to uint32_t pads
160-
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
161-
newval = static_cast<uint16_t>(func(val, static_cast<T>(newval_bytes)));
162-
newval = is_32_align ? (assumed & 0xffff) | (newval << 16)
163-
: (assumed & 0xffff0000) | newval;
164-
} while (!target.compare_exchange_strong(assumed, newval));
165-
}
166-
};
167-
168-
template <typename T>
169-
struct AtomicIntegerImplLocal<T, 4> {
170-
template <typename func_t>
171-
inline void operator()(T* address, T val, const func_t& func) {
172-
uint32_t* address_as_ui = (uint32_t*)(address);
173-
uint32_t assumed = *address_as_ui;
174-
uint32_t newval;
175-
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);
176-
177-
do {
178-
newval = static_cast<uint32_t>(func(val, static_cast<T>(assumed)));
179-
} while (!target.compare_exchange_strong(assumed, newval));
180-
}
181-
};
182-
183-
template <typename T>
184-
struct AtomicIntegerImplLocal<T, 8> {
185-
template <typename func_t>
186-
inline void operator()(T* address, T val, const func_t& func) {
187-
unsigned long long* address_as_ull = (unsigned long long*)(address);
188-
unsigned long long assumed = *address_as_ull;
189-
unsigned long long newval;
190-
sycl_atomic_ref_rlx_wg_local_t<unsigned long long> target(*address_as_ull);
191-
192-
do {
193-
newval = static_cast<uint64_t>(func(val, static_cast<T>(assumed)));
194-
} while (!target.compare_exchange_strong(assumed, newval));
195-
}
196-
};
197-
198-
#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \
199-
static inline void atomic##NAME( \
200-
const sycl_local_ptr<DTYPE>& address, DTYPE val) { \
201-
AtomicIntegerImplLocal<DTYPE, sizeof(DTYPE)>()( \
202-
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
203-
}
204118

205119
template <typename T>
206120
struct AtomicFPImpl;
@@ -361,11 +275,6 @@ SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
361275
SYCL_ATOMIC_INTEGER(Max, safe_max<int32_t>(a, b), int32_t)
362276
SYCL_ATOMIC_INTEGER(Max, safe_max<int64_t>(a, b), int64_t)
363277

364-
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<uint8_t>(a, b), uint8_t)
365-
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int8_t>(a, b), int8_t)
366-
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int16_t>(a, b), int16_t)
367-
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int32_t>(a, b), int32_t)
368-
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int64_t>(a, b), int64_t)
369278

370279
SYCL_ATOMIC_FP(Max, safe_max<float>(a, b), float)
371280
SYCL_ATOMIC_FP(Max, safe_max<double>(a, b), double)

src/ATen/native/xpu/sycl/MultinomialKernel.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
#pragma clang diagnostic push
2+
#pragma GCC diagnostic push
3+
// Avoid SYCL compiler return-type error
4+
#pragma clang diagnostic ignored "-Wreturn-type"
5+
#pragma GCC diagnostic ignored "-Wreturn-type"
6+
17
#include <ATen/AccumulateType.h>
28
#include <ATen/core/Tensor.h>
39
#include <ATen/native/xpu/sycl/Atomics.h>
@@ -83,8 +89,9 @@ inline void renormRows(Tensor& t) {
8389
TORCH_CHECK(t.dim() == 2);
8490
int64_t rows = t.size(0);
8591
int64_t cols = t.size(1);
86-
87-
int group_size = syclMaxWorkItemsPerEU();
92+
int subgroup_size = syclMaxSubGroupSize();
93+
int group_size =
94+
std::min(int(syclMaxWorkItemsPerEU()), subgroup_size* subgroup_size);
8895
int num_groups = (rows + group_size - 1) / group_size;
8996
int hw_max_groups = syclMaxWorkItemsPerTile() / group_size;
9097
num_groups = num_groups > hw_max_groups ? hw_max_groups : num_groups;
@@ -232,7 +239,7 @@ struct MultinomialWithReplacementKernelImplFunctor {
232239
scalar_t* normDist_ptr;
233240
};
234241
template <typename scalar_t, typename accscalar_t>
235-
struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
242+
struct SampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
236243
void operator()(sycl::nd_item<1> item) const {
237244
accscalar_t* smem = reinterpret_cast<accscalar_t*>(
238245
smem_.template get_multi_ptr<sycl::access::decorated::no>().get());
@@ -328,13 +335,15 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
328335
if (inBucket) {
329336
// We're done; we have the sample
330337
// Torch indices are 1-based
331-
atomicMax(
332-
sycl_local_ptr<int>(
333-
foundPos_
334-
.template get_multi_ptr<sycl::access::decorated::no>()
335-
.get()),
336-
cat);
337-
// foundPos_[0] = 1;
338+
339+
sycl::atomic_ref<
340+
int,
341+
sycl_mem_odr_rlx,
342+
sycl_mem_scp_wg,
343+
sycl_local_space>
344+
target(foundPos_[0]);
345+
target.fetch_max(cat, sycl_mem_odr_acq_rel, sycl_mem_scp_wg);
346+
338347
found_[0] = true;
339348
}
340349

@@ -375,7 +384,7 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
375384
foundPos_ = sycl_local_acc_t<int>(1, cgh);
376385
}
377386

378-
sampleMultinomialOnceFunctor(
387+
SampleMultinomialOnceFunctor(
379388
int64_t* dest,
380389
int64_t distributions,
381390
int categories,
@@ -407,7 +416,7 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
407416
sycl_local_acc_t<int> foundPos_;
408417
};
409418

410-
void multinomial_with_replacement_kernel(
419+
void multinomial_kernel(
411420
Tensor& result,
412421
const Tensor& self,
413422
const int64_t n_sample,
@@ -446,7 +455,7 @@ void multinomial_with_replacement_kernel(
446455
at::native::uniform_(sampled, 0.0, 1.0, generator);
447456
int group_size = requiredThreads;
448457
int group_range = numDist;
449-
auto kfn = sampleMultinomialOnceFunctor<scalar_t, accscalar_t>(
458+
auto kfn = SampleMultinomialOnceFunctor<scalar_t, accscalar_t>(
450459
result.mutable_data_ptr<int64_t>(),
451460
numDist,
452461
numCategories,
@@ -530,4 +539,6 @@ void multinomial_with_replacement_kernel(
530539
}
531540
}
532541

533-
} // namespace at::native::xpu
542+
} // namespace at::native::xpu
543+
#pragma GCC diagnostic pop
544+
#pragma clang diagnostic pop

src/ATen/native/xpu/sycl/MultinomialKernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace at::native::xpu {
55

6-
void multinomial_with_replacement_kernel(
6+
void multinomial_kernel(
77
Tensor& result,
88
const Tensor& self,
99
const int64_t n_sample,

0 commit comments

Comments
 (0)