Skip to content

Commit 22c2bcb

Browse files
committed
Support local_accessor, group::get_local_range() inside hierarchical PF
1 parent cdd637e commit 22c2bcb

File tree

7 files changed

+137
-93
lines changed

7 files changed

+137
-93
lines changed

include/simsycl/detail/group_operation_impl.hh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ struct concurrent_group {
163163
std::vector<concurrent_nd_item *> concurrent_nd_items;
164164
std::vector<allocation> local_memory_allocations;
165165
group_instance instance;
166-
size_t cur_hier_local_size = 0;
167166
};
168167

169168
template<int Dimensions>

include/simsycl/detail/schedule.hh

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#pragma once
22

3-
#include <cstddef>
4-
#include <cstring>
5-
#include <memory>
6-
#include <vector>
3+
#include "allocation.hh"
74

85
#include "../sycl/device.hh"
96
#include "../sycl/forward.hh"
@@ -14,6 +11,11 @@
1411
#include "../sycl/nd_range.hh"
1512
#include "../sycl/range.hh"
1613

14+
#include <cstddef>
15+
#include <cstring>
16+
#include <memory>
17+
#include <vector>
18+
1719

1820
namespace simsycl::detail {
1921

@@ -85,9 +87,15 @@ template<typename WorkgroupFunctionType>
8587
void sequential_for_work_group(sycl::range<1> num_work_groups, std::optional<sycl::range<1>> work_group_size,
8688
const WorkgroupFunctionType &kernel_func) {
8789
sycl::id<1> group_id;
90+
const auto type
91+
= work_group_size.has_value() ? group_type::hierarchical_explicit_size : group_type::hierarchical_implicit_size;
8892
for(group_id[0] = 0; group_id[0] < num_work_groups[0]; ++group_id[0]) {
8993
concurrent_group impl;
90-
sycl::group<1> group = make_hierarchical_group(make_item(group_id, num_work_groups), work_group_size, &impl);
94+
const auto group_item = make_item(group_id, num_work_groups);
95+
const auto local_item = make_item(sycl::id(0), work_group_size.value_or(sycl::range(1)));
96+
const auto global_item = make_item(
97+
group_id * sycl::id(local_item.get_range()), local_item.get_range() * group_item.get_range(), sycl::id(0));
98+
sycl::group<1> group = make_group(type, local_item, global_item, group_item, &impl);
9199
kernel_func(group);
92100
}
93101
}
@@ -96,11 +104,16 @@ template<typename WorkgroupFunctionType>
96104
void sequential_for_work_group(sycl::range<2> num_work_groups, std::optional<sycl::range<2>> work_group_size,
97105
const WorkgroupFunctionType &kernel_func) {
98106
sycl::id<2> group_id;
107+
const auto type
108+
= work_group_size.has_value() ? group_type::hierarchical_explicit_size : group_type::hierarchical_implicit_size;
99109
for(group_id[0] = 0; group_id[0] < num_work_groups[0]; ++group_id[0]) {
100110
for(group_id[1] = 0; group_id[1] < num_work_groups[1]; ++group_id[1]) {
101111
concurrent_group impl;
102-
sycl::group<2> group
103-
= make_hierarchical_group(make_item(group_id, num_work_groups), work_group_size, &impl);
112+
const auto group_item = make_item(group_id, num_work_groups);
113+
const auto local_item = make_item(sycl::id(0, 0), work_group_size.value_or(sycl::range(1, 1)));
114+
const auto global_item = make_item(group_id * sycl::id(local_item.get_range()),
115+
local_item.get_range() * group_item.get_range(), sycl::id(0, 0));
116+
sycl::group<2> group = make_group(type, local_item, global_item, group_item, &impl);
104117
kernel_func(group);
105118
}
106119
}
@@ -110,12 +123,17 @@ template<typename WorkgroupFunctionType>
110123
void sequential_for_work_group(sycl::range<3> num_work_groups, std::optional<sycl::range<3>> work_group_size,
111124
const WorkgroupFunctionType &kernel_func) {
112125
sycl::id<3> group_id;
126+
const auto type
127+
= work_group_size.has_value() ? group_type::hierarchical_explicit_size : group_type::hierarchical_implicit_size;
113128
for(group_id[0] = 0; group_id[0] < num_work_groups[0]; ++group_id[0]) {
114129
for(group_id[1] = 0; group_id[1] < num_work_groups[1]; ++group_id[1]) {
115130
for(group_id[2] = 0; group_id[2] < num_work_groups[2]; ++group_id[2]) {
116131
concurrent_group impl;
117-
sycl::group<3> group
118-
= make_hierarchical_group(make_item(group_id, num_work_groups), work_group_size, &impl);
132+
const auto group_item = make_item(group_id, num_work_groups);
133+
const auto local_item = make_item(sycl::id(0, 0, 0), work_group_size.value_or(sycl::range(1, 1, 1)));
134+
const auto global_item = make_item(group_id * sycl::id(local_item.get_range()),
135+
local_item.get_range() * group_item.get_range(), sycl::id(0, 0, 0));
136+
sycl::group<3> group = make_group(type, local_item, global_item, group_item, &impl);
119137
kernel_func(group);
120138
}
121139
}
@@ -197,10 +215,16 @@ void parallel_for(const sycl::device &device, sycl::nd_range<Dimensions> executi
197215
std::index_sequence<sizeof...(Rest) - 1>());
198216
}
199217

218+
template<int Dimensions>
219+
[[nodiscard]] std::vector<allocation> prepare_hierarchical_parallel_for(const sycl::device &device,
220+
std::optional<sycl::range<Dimensions>> work_group_size, const std::vector<local_memory_requirement> &local_memory);
221+
200222
template<typename KernelName, int Dimensions, typename WorkgroupFunctionType>
201-
void parallel_for_work_group(sycl::range<Dimensions> num_work_groups,
202-
std::optional<sycl::range<Dimensions>> work_group_size, const WorkgroupFunctionType &kernel_func) {
223+
void parallel_for_work_group(const sycl::device &device, sycl::range<Dimensions> num_work_groups,
224+
std::optional<sycl::range<Dimensions>> work_group_size, const std::vector<local_memory_requirement> &local_memory,
225+
const WorkgroupFunctionType &kernel_func) {
203226
register_kernel_on_static_construction<KernelName, WorkgroupFunctionType>();
227+
const auto local_allocations = prepare_hierarchical_parallel_for(device, work_group_size, local_memory);
204228
sequential_for_work_group(num_work_groups, work_group_size, kernel_func);
205229
}
206230

include/simsycl/sycl/group.hh

Lines changed: 53 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,39 @@
1414

1515
namespace simsycl::detail {
1616

17+
enum class group_type { nd_range, hierarchical_implicit_size, hierarchical_explicit_size };
18+
1719
template<int Dimensions>
18-
sycl::group<Dimensions> make_group(const sycl::item<Dimensions, false> &local_item,
20+
sycl::group<Dimensions> make_group(const group_type type, const sycl::item<Dimensions, false> &local_item,
1921
const sycl::item<Dimensions, true> &global_item, const sycl::item<Dimensions, false> &group_item,
2022
detail::concurrent_group *impl) {
21-
return sycl::group<Dimensions>(local_item, global_item, group_item, impl);
22-
}
23-
24-
template<int Dimensions>
25-
sycl::group<Dimensions> make_hierarchical_group(const sycl::item<Dimensions, false> &group_item,
26-
const std::optional<sycl::range<Dimensions>> &hier_local_range, detail::concurrent_group *impl) {
27-
return sycl::group<Dimensions>(group_item, hier_local_range, impl);
23+
return sycl::group<Dimensions>(type, local_item, global_item, group_item, impl);
2824
}
2925

3026
template<int Dimensions>
31-
bool is_hierarchical_group(const sycl::group<Dimensions> &g) {
32-
return g.m_hierarchical;
27+
group_type get_group_type(const sycl::group<Dimensions> &g) {
28+
return g.m_type;
3329
}
3430

35-
template<typename G>
31+
template<typename G, int Dimensions>
3632
class hierarchical_group_size_setter {
3733
public:
38-
hierarchical_group_size_setter(G &g, size_t size) : m_g(g) {
39-
m_old_size = get_concurrent_group(m_g).cur_hier_local_size;
40-
get_concurrent_group(m_g).cur_hier_local_size = size;
34+
hierarchical_group_size_setter(G &g, sycl::range<Dimensions> flexible_size)
35+
: m_g(g), m_old_local_item(g.m_local_item), m_old_global_item(g.m_global_item) {
36+
g.m_local_item = simsycl::detail::make_item(sycl::id<Dimensions>(), flexible_size);
37+
g.m_global_item
38+
= simsycl::detail::make_item(sycl::id<Dimensions>(), g.m_group_item.get_range() * flexible_size);
4139
}
4240

43-
~hierarchical_group_size_setter() { get_concurrent_group(m_g).cur_hier_local_size = m_old_size; }
41+
~hierarchical_group_size_setter() {
42+
m_g.m_local_item = m_old_local_item;
43+
m_g.m_global_item = m_old_global_item;
44+
}
4445

4546
private:
4647
G &m_g;
47-
size_t m_old_size;
48+
sycl::item<Dimensions, false> m_old_local_item;
49+
sycl::item<Dimensions, true> m_old_global_item;
4850
};
4951

5052
} // namespace simsycl::detail
@@ -79,53 +81,35 @@ class group {
7981

8082
size_t get_group_id(int dimension) const { return m_group_item.get_id()[dimension]; }
8183

82-
SIMSYCL_DETAIL_DEPRECATED_IN_SYCL range<Dimensions> get_global_range() const {
84+
[[deprecated("non-standard")]] range<Dimensions> get_global_range() const {
8385
SIMSYCL_CHECK(
84-
!m_hierarchical && "get_global_range is not supported for from within a parallel_for_work_item context");
86+
m_global_item.get_range().size() != 0 && "get_global_range called from hierarchical group scope?");
8587
return m_global_item.get_range();
8688
}
8789

88-
size_t get_global_range(int dimension) const {
89-
SIMSYCL_CHECK(
90-
!m_hierarchical && "get_global_range is not supported for from within a parallel_for_work_item context");
90+
[[deprecated("non-standard")]] size_t get_global_range(int dimension) const {
9191
return get_global_range()[dimension];
9292
}
9393

9494
id_type get_local_id() const {
95-
SIMSYCL_CHECK(
96-
!m_hierarchical && "get_local_id is not supported for from within a parallel_for_work_item context");
95+
SIMSYCL_CHECK(m_type == detail::group_type::nd_range
96+
&& "get_local_id is not supported for from within a parallel_for_work_item context");
9797
return m_local_item.get_id();
9898
}
9999

100-
size_t get_local_id(int dimension) const {
101-
SIMSYCL_CHECK(
102-
!m_hierarchical && "get_local_id is not supported for from within a parallel_for_work_item context");
103-
return get_local_id()[dimension];
104-
}
100+
size_t get_local_id(int dimension) const { return get_local_id()[dimension]; }
105101

106102
size_t get_local_linear_id() const {
107-
SIMSYCL_CHECK(
108-
!m_hierarchical && "get_local_linear_id is not supported for from within a parallel_for_work_item context");
103+
SIMSYCL_CHECK(m_type == detail::group_type::nd_range
104+
&& "get_local_linear_id is not supported for from within a parallel_for_work_item context");
109105
return m_local_item.get_linear_id();
110106
}
111107

112-
range_type get_local_range() const {
113-
SIMSYCL_CHECK(
114-
!m_hierarchical && "get_local_range is not supported for from within a parallel_for_work_item context");
115-
return m_local_item.get_range();
116-
}
108+
range_type get_local_range() const { return m_local_item.get_range(); }
117109

118-
size_t get_local_range(int dimension) const {
119-
SIMSYCL_CHECK(
120-
!m_hierarchical && "get_local_range is not supported for from within a parallel_for_work_item context");
121-
return get_local_range()[dimension];
122-
}
110+
size_t get_local_range(int dimension) const { return get_local_range()[dimension]; }
123111

124-
size_t get_local_linear_range() const {
125-
SIMSYCL_CHECK(
126-
!m_hierarchical && "get_local_range is not supported for from within a parallel_for_work_item context");
127-
return get_local_range().size();
128-
}
112+
size_t get_local_linear_range() const { return get_local_range().size(); }
129113

130114
range_type get_group_range() const { return m_group_item.get_range(); }
131115

@@ -144,26 +128,27 @@ class group {
144128
size_t get_group_linear_id() const { return m_group_item.get_linear_id(); }
145129

146130
bool leader() const {
147-
SIMSYCL_CHECK(!m_hierarchical && "leader() is not supported for from within a parallel_for_work_item context");
131+
SIMSYCL_CHECK(m_type == detail::group_type::nd_range
132+
&& "leader() is not supported for from within a parallel_for_work_item context");
148133
return (get_local_linear_id() == 0);
149134
}
150135

151136
template<typename WorkItemFunctionT>
152137
void parallel_for_work_item(WorkItemFunctionT func) const {
153-
SIMSYCL_CHECK(m_hierarchical
138+
SIMSYCL_CHECK(m_type != detail::group_type::nd_range
154139
&& "parallel_for_work_item is only supported for from within a parallel_for_work_item context");
155-
SIMSYCL_CHECK(m_hier_local_range.has_value()
140+
SIMSYCL_CHECK(m_type != detail::group_type::hierarchical_implicit_size
156141
&& "parallel_for_work_item(func) without a range argument is only supported in a parallel_for_work_item "
157142
"context with a set local range");
158-
parallel_for_work_item(m_hier_local_range.value(), func);
143+
parallel_for_work_item(m_local_item.get_range(), func);
159144
}
160145

161146
// All parallel_for_work_item calls within a given parallel_for_work_group execution must have the same dimensions
162147
template<typename WorkItemFunctionT>
163148
void parallel_for_work_item(range<Dimensions> flexible_range, WorkItemFunctionT func) const {
164-
SIMSYCL_CHECK(m_hierarchical
149+
SIMSYCL_CHECK(m_type != detail::group_type::nd_range
165150
&& "parallel_for_work_item is only supported for from within a parallel_for_work_item context");
166-
detail::hierarchical_group_size_setter set(*this, flexible_range.size());
151+
detail::hierarchical_group_size_setter set(*this, flexible_range);
167152
if constexpr(Dimensions == 1) {
168153
for(size_t i = 0; i < flexible_range[0]; ++i) {
169154
const auto global_id = m_group_item.get_id() * flexible_range[0] + i;
@@ -292,33 +277,27 @@ class group {
292277
friend bool operator!=(const group<Dimensions> &lhs, const group<Dimensions> &rhs) { return !(lhs == rhs); }
293278

294279
private:
295-
item<Dimensions, false /* WithOffset */> m_local_item;
296-
item<Dimensions, true /* WithOffset */> m_global_item;
297-
item<Dimensions, false /* WithOffset */> m_group_item;
298-
detail::concurrent_group *m_concurrent_group;
299-
300-
bool m_hierarchical = false;
301-
std::optional<range<Dimensions>> m_hier_local_range;
280+
template<typename G, int D>
281+
friend class detail::hierarchical_group_size_setter;
302282

303-
group(const item<Dimensions, false> &local_item, const item<Dimensions, true> &global_item,
304-
const item<Dimensions, false> &group_item, detail::concurrent_group *impl)
305-
: m_local_item(local_item), m_global_item(global_item), m_group_item(group_item), m_concurrent_group(impl) {}
283+
friend group<Dimensions> detail::make_group<Dimensions>(const detail::group_type type,
284+
const sycl::item<Dimensions, false> &local_item, const sycl::item<Dimensions, true> &global_item,
285+
const sycl::item<Dimensions, false> &group_item, detail::concurrent_group *impl);
306286

307-
group(const item<Dimensions, false> &group_item, const std::optional<range<Dimensions>> &hier_local_range,
308-
detail::concurrent_group *impl)
309-
: m_local_item(group_item), m_global_item(group_item), m_group_item(group_item), m_concurrent_group(impl),
310-
m_hierarchical(true), m_hier_local_range(hier_local_range) {}
311-
312-
friend group<Dimensions> detail::make_group<Dimensions>(const sycl::item<Dimensions, false> &local_item,
313-
const sycl::item<Dimensions, true> &global_item, const sycl::item<Dimensions, false> &group_item,
314-
detail::concurrent_group *impl);
287+
friend detail::group_type detail::get_group_type(const sycl::group<Dimensions> &g);
288+
friend detail::concurrent_group &detail::get_concurrent_group<Dimensions>(const sycl::group<Dimensions> &g);
315289

316-
friend group<Dimensions> detail::make_hierarchical_group<Dimensions>(
317-
const sycl::item<Dimensions, false> &group_item, const std::optional<sycl::range<Dimensions>> &hier_local_range,
318-
detail::concurrent_group *impl);
290+
detail::group_type m_type;
291+
mutable item<Dimensions, false /* WithOffset */> m_local_item; // mutable for hierarchical_group_size_setter
292+
mutable item<Dimensions, true /* WithOffset */> m_global_item; // mutable for hierarchical_group_size_setter
293+
item<Dimensions, false /* WithOffset */> m_group_item;
294+
detail::concurrent_group *m_concurrent_group;
319295

320-
friend bool detail::is_hierarchical_group<Dimensions>(const sycl::group<Dimensions> &g);
321-
friend detail::concurrent_group &detail::get_concurrent_group<Dimensions>(const sycl::group<Dimensions> &g);
296+
group(const detail::group_type type, const item<Dimensions, false> &local_item,
297+
const item<Dimensions, true> &global_item, const item<Dimensions, false> &group_item,
298+
detail::concurrent_group *impl)
299+
: m_type(type), m_local_item(local_item), m_global_item(global_item), m_group_item(group_item),
300+
m_concurrent_group(impl) {}
322301
};
323302

324303
template<int Dimensions>

include/simsycl/sycl/handler.hh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ class handler {
9494

9595
template<typename KernelName = simsycl::detail::unnamed_kernel, typename WorkgroupFunctionType, int Dimensions>
9696
void parallel_for_work_group(range<Dimensions> num_work_groups, const WorkgroupFunctionType &kernel_func) {
97-
detail::parallel_for_work_group<KernelName>(num_work_groups, {}, kernel_func);
97+
detail::parallel_for_work_group<KernelName>(m_device, num_work_groups, {}, m_local_memory, kernel_func);
9898
}
9999

100100
template<typename KernelName = simsycl::detail::unnamed_kernel, typename WorkgroupFunctionType, int Dimensions>
101101
void parallel_for_work_group(range<Dimensions> num_work_groups, range<Dimensions> work_group_size,
102102
const WorkgroupFunctionType &kernel_func) {
103-
detail::parallel_for_work_group<KernelName>(num_work_groups, {work_group_size}, kernel_func);
103+
detail::parallel_for_work_group<KernelName>(
104+
m_device, num_work_groups, {work_group_size}, m_local_memory, kernel_func);
104105
}
105106

106107
void single_task(const kernel &kernel_object);

include/simsycl/sycl/private_memory.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class private_memory {
1616
// Construct the storage if it has not yet been constructed
1717
T &operator()(const h_item<Dimensions> &id) {
1818
if(m_data.empty()) {
19-
size_t num_items = simsycl::detail::get_concurrent_group(m_group).cur_hier_local_size;
19+
size_t num_items = m_group.get_local_linear_range();
2020
m_data.resize(num_items);
2121
}
2222
return m_data[id.get_local().get_linear_id()];

0 commit comments

Comments
 (0)