Skip to content

Commit b85ca1d

Browse files
committed
[WIP] late-bound stream scheduler algorithm customizations
1 parent de51303 commit b85ca1d

File tree

6 files changed

+221
-26
lines changed

6 files changed

+221
-26
lines changed

examples/nvexec/maxwell/snr.cuh

+6-4
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,12 @@ auto maxwell_eqs_snr(
346346
return ex::just()
347347
| exec::on(
348348
computer,
349-
repeat_n(
350-
n_iterations,
351-
ex::bulk(accessor.cells, update_h(accessor))
352-
| ex::bulk(accessor.cells, update_e(time, dt, accessor))))
349+
ex::bulk(accessor.cells, update_h(accessor))
350+
| ex::bulk(accessor.cells, update_e(time, dt, accessor)))
351+
// repeat_n(
352+
// n_iterations,
353+
// ex::bulk(accessor.cells, update_h(accessor))
354+
// | ex::bulk(accessor.cells, update_e(time, dt, accessor))))
353355
| ex::then(dump_vtk(write_results, accessor));
354356
}
355357

include/nvexec/stream/common.cuh

+23-4
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ namespace nvexec {
165165
}
166166
};
167167

168+
struct stream_scheduler;
169+
168170
struct context_state_t {
169171
std::pmr::memory_resource* pinned_resource_{nullptr};
170172
std::pmr::memory_resource* managed_resource_{nullptr};
@@ -195,14 +197,22 @@ namespace nvexec {
195197
void return_stream(cudaStream_t stream) {
196198
stream_pools_->return_stream(stream, priority_);
197199
}
200+
201+
stream_scheduler make_stream_scheduler() const noexcept;
198202
};
199203

200-
struct stream_scheduler;
204+
template <class = stream_scheduler>
205+
struct stream_domain;
201206

202207
struct stream_sender_base {
203208
using is_sender = void;
204209
};
205210

211+
// needed for subsumption purposes
212+
template <class Sender, class Env>
213+
concept _non_stream_sender = //
214+
!derived_from<__decay_t<Sender>, stream_sender_base>;
215+
206216
struct stream_receiver_base : __receiver_base {
207217
constexpr static std::size_t memory_allocation_size = 0;
208218
};
@@ -265,6 +275,10 @@ namespace nvexec {
265275
stream_provider_t* operator()(const Env& env) const noexcept {
266276
return tag_invoke(get_stream_provider_t{}, env);
267277
}
278+
279+
friend constexpr bool tag_invoke(forwarding_query_t, const get_stream_provider_t&) noexcept {
280+
return true;
281+
}
268282
};
269283

270284
template <class... Ts>
@@ -308,7 +322,10 @@ namespace nvexec {
308322
using variant_storage_t = //
309323
__minvoke< __minvoke<
310324
__mfold_right<
311-
__mbind_front_q<stream_storage_impl::variant, ::cuda::std::tuple<set_noop>>,
325+
__mbind_front_q<
326+
stream_storage_impl::variant,
327+
::cuda::std::tuple<set_noop>,
328+
::cuda::std::tuple<set_error_t, cudaError_t>>,
312329
__mbind_front_q<stream_storage_impl::__bind_completions_t, _Sender, _Env>>,
313330
set_value_t,
314331
set_error_t,
@@ -570,7 +587,8 @@ namespace nvexec {
570587

571588
template <__decays_to<cudaError_t> Error>
572589
void propagate_completion_signal(set_error_t, Error&& status) noexcept {
573-
if constexpr (stream_receiver<outer_receiver_t>) {
590+
using Domain = __env_domain_of_t<env_of_t<outer_receiver_t>>;
591+
if constexpr (stream_receiver<outer_receiver_t> || same_as<Domain, stream_domain<>>) {
574592
set_error((outer_receiver_t&&) rcvr_, (cudaError_t&&) status);
575593
} else {
576594
// pass a cudaError_t by value:
@@ -581,7 +599,8 @@ namespace nvexec {
581599

582600
template <class Tag, class... As>
583601
void propagate_completion_signal(Tag, As&&... as) noexcept {
584-
if constexpr (stream_receiver<outer_receiver_t>) {
602+
using Domain = __env_domain_of_t<env_of_t<outer_receiver_t>>;
603+
if constexpr (stream_receiver<outer_receiver_t> || same_as<Domain, stream_domain<>>) {
585604
Tag()((outer_receiver_t&&) rcvr_, (As&&) as...);
586605
} else {
587606
continuation_kernel<outer_receiver_t, As&&...> // by reference

include/nvexec/stream/wrap.cuh

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) 2022 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "../../stdexec/execution.hpp"
19+
#include <type_traits>
20+
21+
#include "common.cuh"
22+
23+
namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
24+
namespace _wrap {
25+
template <class SenderId>
26+
struct sender : stream_sender_base {
27+
using is_sender = void;
28+
using Sender = stdexec::__t<SenderId>;
29+
using __t = sender;
30+
using __id = sender;
31+
32+
sender(Sender sndr, context_state_t context_state)
33+
: sndr_(std::move(sndr))
34+
, env_{context_state} {
35+
}
36+
37+
struct environment {
38+
context_state_t context_state_;
39+
40+
template <same_as<environment> Self>
41+
friend auto tag_invoke(get_completion_scheduler_t<set_value_t>, const Self& env) noexcept {
42+
return env.context_state_.make_stream_scheduler();
43+
}
44+
};
45+
46+
// BUGBUG this doesn't handle the case where the sender has a nested
47+
// type alias named completion_signatures.
48+
template <class Self, class Env>
49+
using completions_t =
50+
tag_invoke_result_t<get_completion_signatures_t, __copy_cvref_t<Self, Sender>, Env>;
51+
52+
// test for tag_invocable instead of sender_to because the connect customization
53+
// point would convert the stdexec::just sender back into this nvexec::just sender,
54+
// causing recursion.
55+
template <__decays_to<sender> Self, receiver Receiver>
56+
requires receiver_of<Receiver, completions_t<Self, env_of_t<Receiver>>> &&
57+
tag_invocable<connect_t, __copy_cvref_t<Self, Sender>, Receiver>
58+
friend auto tag_invoke(connect_t, Self&& self, Receiver rcvr) //
59+
noexcept(nothrow_tag_invocable<connect_t, __copy_cvref_t<Self, Sender>, Receiver>)
60+
-> tag_invoke_result_t<connect_t, __copy_cvref_t<Self, Sender>, Receiver> {
61+
return tag_invoke(connect, ((Self&&) self).sndr_, (Receiver&&) rcvr);
62+
}
63+
64+
template <__decays_to<sender> Self, class Env>
65+
friend auto tag_invoke(get_completion_signatures_t, Self&& self, Env&& env) noexcept
66+
-> completions_t<Self, Env> {
67+
return {};
68+
}
69+
70+
template <same_as<sender> Self>
71+
friend const environment& tag_invoke(get_env_t, const Self& self) noexcept {
72+
return self.env_;
73+
}
74+
75+
Sender sndr_;
76+
environment env_;
77+
};
78+
} // namespace _wrap
79+
80+
template <class Env, class Sender>
81+
auto as_stream_sender(Sender sndr, const context_state_t&) -> Sender {
82+
return sndr;
83+
}
84+
85+
template <class Env, class Sender>
86+
requires _non_stream_sender<Sender, Env>
87+
auto as_stream_sender(Sender sndr, const context_state_t& context_state) //
88+
-> _wrap::sender<__id<Sender>> {
89+
return {std::move(sndr), context_state};
90+
}
91+
}
92+
93+
namespace stdexec::__detail {
94+
template <class SenderId>
95+
inline constexpr __mconst<
96+
nvexec::STDEXEC_STREAM_DETAIL_NS::_wrap::sender<__name_of<__t<SenderId>>>>
97+
__name_of_v<nvexec::STDEXEC_STREAM_DETAIL_NS::_wrap::sender<SenderId>>{};
98+
}

include/nvexec/stream_context.cuh

+88-16
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "stream/when_all.cuh"
3737
#include "stream/reduce.cuh"
3838
#include "stream/ensure_started.cuh"
39+
#include "stream/wrap.cuh"
3940

4041
#include "stream/common.cuh"
4142
#include "detail/queue.cuh"
@@ -87,41 +88,108 @@ namespace nvexec {
8788
template <sender Sender>
8889
using ensure_started_th = __t<ensure_started_sender_t<__id<Sender>>>;
8990

90-
// needed for subsumption purposes
91-
template <class Sender, class Env>
92-
concept _non_stream_sender = //
93-
!derived_from<__decay_t<Sender>, stream_sender_base>;
94-
9591
struct stream_scheduler;
9692

97-
template <class = stream_scheduler>
93+
// template <class = stream_scheduler>
94+
// struct stream_domain;
95+
96+
// template <class Tag>
97+
// struct _just_t : Tag {
98+
// static __prop<get_domain_t, stream_domain<>> get_env(auto&&) noexcept {
99+
// return __mkprop
100+
// }
101+
// }
102+
103+
template <class /*= stream_scheduler*/>
98104
struct stream_domain : private __default_domain<context_state_t> {
99105
using __default_domain::__default_domain;
100-
using __default_domain::transform_sender;
106+
//using __default_domain::transform_sender;
101107

102108
// Lazy algorithm customizations require a recursive tree transformation
103109
template <sender_expr Sender, class Env>
104110
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
105111
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
106-
return stdexec::apply_sender(
112+
//print<__name_of<Sender>>();
113+
auto result = stdexec::apply_sender(
107114
(Sender&&) sndr,
108115
[&]<class Tag, class Data, class... Children>(Tag, Data&& data, Children&&... children) {
109-
return make_sender_expr<Tag, stream_domain>(
110-
(Data&&) data, transform_sender((Children&&) children, env)...);
116+
return //as_stream_sender<Env>(
117+
make_sender_expr<Tag, stream_domain>(
118+
(Data&&) data,
119+
stdexec::transform_sender(*this, (Children&&) children, env)...); //,
120+
//base());
111121
});
122+
//print<__name_of<decltype(result)>>();
123+
return result;
112124
}
113125

114-
// reduce senders get a special transformation
115-
template <sender_expr_for<reduce_t> Sender, class Env>
126+
template <sender_expr_for<schedule_from_t> Sender, class Env>
116127
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
117128
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
118129
return stdexec::apply_sender(
119130
(Sender&&) sndr,
120131
[&]<class Tag, class Data, class Child>(Tag, Data&& data, Child&& child) {
121-
auto [init, fun] = (Data&&) data;
122-
auto next = transform_sender((Child&&) child, env);
123-
return reduce_sender_t<decltype(next), decltype(init), decltype(fun)>(
124-
std::move(next), init, fun);
132+
auto sched = get_scheduler(env);
133+
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
134+
return stdexec::__t<
135+
schedule_from_sender_t<stream_scheduler, stdexec::__id<decltype(next)>>>{
136+
sched.context_state_, std::move(next)};
137+
});
138+
}
139+
140+
// // reduce senders get a special transformation
141+
// template <sender_expr_for<reduce_t> Sender, class Env>
142+
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
143+
// auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
144+
// return stdexec::apply_sender(
145+
// (Sender&&) sndr,
146+
// [&]<class Tag, class Data, class Child>(Tag, Data&& data, Child&& child) {
147+
// auto [init, fun] = (Data&&) data;
148+
// auto next = stdexec::transform_sender(*this, (Child&&) child, env);
149+
// return reduce_sender_t<decltype(next), decltype(init), decltype(fun)>(
150+
// std::move(next), init, fun);
151+
// });
152+
// }
153+
154+
// transform senders get a special transformation
155+
template <sender_expr_for<transfer_t> Sender, class Env>
156+
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
157+
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
158+
return stdexec::apply_sender(
159+
(Sender&&) sndr, [&]<class Data, class Child>(__ignore, Data&& data, Child&& child) {
160+
auto from = get_scheduler(env);
161+
auto to = get_completion_scheduler<set_value_t>(data);
162+
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
163+
auto transfer = __t<transfer_sender_t<decltype(next)>>(
164+
from.context_state_, std::move(next));
165+
return schedule_from(to, std::move(transfer));
166+
});
167+
}
168+
169+
// template <sender_expr_for<just_t, just_error_t, just_stopped_t> Sender, class Env>
170+
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
171+
// auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
172+
// return stdexec::apply_sender(
173+
// (Sender&&) sndr, [&]<class Tag, class Data>(Tag, Data&& data) {
174+
// return make_sender_expr<Tag, stream_domain>(
175+
// (Data&&) data, get_completion_scheduler<Tag>(data));
176+
// });
177+
// }
178+
// template <sender_expr_for<just_t> Sender, class Env>
179+
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
180+
// auto transform_sender(Sender&& sndr, const Env&) const noexcept {
181+
// return just_sender<__decay_t<Sender>>{(Sender&&) sndr, base()};
182+
// }
183+
184+
template <sender_expr_for<bulk_t> Sender, class Env>
185+
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
186+
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
187+
return stdexec::apply_sender(
188+
(Sender&&) sndr, [&]<class Data, class Child>(__ignore, Data&& data, Child&& child) {
189+
auto&& [shape, fun] = (Data&&) data;
190+
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
191+
return bulk_sender_th<decltype(next), decltype(shape), decltype(fun)>{
192+
{}, std::move(next), shape, fun};
125193
});
126194
}
127195

@@ -338,6 +406,10 @@ namespace nvexec {
338406
return {base()};
339407
}
340408

409+
stream_scheduler context_state_t::make_stream_scheduler() const noexcept {
410+
return {*this};
411+
}
412+
341413
template <stream_completing_sender Sender>
342414
void tag_invoke(start_detached_t, Sender&& sndr) noexcept(false) {
343415
_submit::submit_t{}((Sender&&) sndr, _start_detached::detached_receiver_t{});

include/stdexec/__detail/__basic_sender.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ namespace stdexec {
265265
concept sender_expr = //
266266
__mvalid<__tag_of, _Sender>;
267267

268-
template <class _Sender, class _Tag>
268+
template <class _Sender, class... _Tags>
269269
concept sender_expr_for = //
270-
sender_expr<_Sender> && same_as<__tag_of<_Sender>, _Tag>;
270+
sender_expr<_Sender> && __one_of<__tag_of<_Sender>, _Tags...>;
271271

272272
// The __name_of utility defined below is used to pretty-print the type names of
273273
// senders in compiler diagnostics.

include/stdexec/execution.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -5181,6 +5181,7 @@ namespace stdexec {
51815181

51825182
inline void run_loop::finish() {
51835183
std::unique_lock __lock{__mutex_};
5184+
//kernel<<<1,0,0>>>();
51845185
__stop_ = true;
51855186
__cv_.notify_all();
51865187
}
@@ -6922,6 +6923,7 @@ namespace stdexec {
69226923
template <same_as<set_value_t> _Tag, class... _As>
69236924
requires constructible_from<std::tuple<_Values...>, _As...>
69246925
friend void tag_invoke(_Tag, __t&& __rcvr, _As&&... __as) noexcept {
6926+
//kernel<<<1,0,0>>>(__as...);
69256927
try {
69266928
__rcvr.__state_->__data_.template emplace<1>((_As&&) __as...);
69276929
__rcvr.__loop_->finish();
@@ -6932,10 +6934,12 @@ namespace stdexec {
69326934

69336935
template <same_as<set_error_t> _Tag, class _Error>
69346936
friend void tag_invoke(_Tag, __t&& __rcvr, _Error __err) noexcept {
6937+
//kernel<<<1,0,0>>>(__err);
69356938
__rcvr.__set_error((_Error&&) __err);
69366939
}
69376940

69386941
friend void tag_invoke(set_stopped_t __d, __t&& __rcvr) noexcept {
6942+
//kernel<<<1,0,0>>>();
69396943
__rcvr.__state_->__data_.template emplace<3>(__d);
69406944
__rcvr.__loop_->finish();
69416945
}

0 commit comments

Comments
 (0)