@@ -15,8 +15,7 @@ using experimental::suspend_never;
15
15
#endif
16
16
#include < cstdlib>
17
17
#include < limits>
18
- #include < semaphore>
19
- #include < unordered_map>
18
+ #include < mutex>
20
19
21
20
namespace coop
22
21
{
@@ -56,6 +55,9 @@ struct promise_base_t
56
55
// resume point, which immediately following the suspend point.
57
56
std::coroutine_handle<> continuation = nullptr ;
58
57
58
+ std::mutex mutex;
59
+ bool flag = false ;
60
+
59
61
// Do not suspend immediately on entry of a coroutine
60
62
std::suspend_never initial_suspend () const noexcept
61
63
{
@@ -80,21 +82,84 @@ struct final_awaiter_t
80
82
{
81
83
}
82
84
83
- std::coroutine_handle<> await_suspend (std::coroutine_handle<P> coroutine) const noexcept
85
+ std::coroutine_handle<>
86
+ await_suspend (std::coroutine_handle<P> coroutine) const noexcept
84
87
{
85
88
// Check if this coroutine is being finalized from the
86
89
// middle of a "continuation" coroutine and hop back there to
87
90
// continue execution while *this* coroutine is suspended.
88
91
89
- auto continuation = coroutine.promise ().continuation ;
90
- if (continuation)
92
+ if constexpr (P::joinable_v)
93
+ {
94
+ // Joinable tasks are never awaited and so cannot have a
95
+ // continuation by definition
96
+ return std::noop_coroutine ();
97
+ }
98
+ else
91
99
{
92
- return continuation;
100
+ COOP_LOG (" Final await for coroutine %p on thread %zu\n " ,
101
+ coroutine.address (),
102
+ detail::thread_id ());
103
+ std::scoped_lock lock{coroutine.promise ().mutex };
104
+ if (coroutine.promise ().flag )
105
+ {
106
+ // We're not the first to reach here
107
+ auto continuation = coroutine.promise ().continuation ;
108
+ if (continuation)
109
+ {
110
+ COOP_LOG (" Resuming continuation %p on %p on thread %zu\n " ,
111
+ continuation.address (),
112
+ coroutine.address (),
113
+ detail::thread_id ());
114
+ return continuation;
115
+ }
116
+ else
117
+ {
118
+ COOP_LOG (
119
+ " Coroutine %p on thread %zu missing continuation\n " ,
120
+ coroutine.address (),
121
+ detail::thread_id ());
122
+ }
123
+ }
124
+ coroutine.promise ().flag = true ;
125
+ return std::noop_coroutine ();
93
126
}
94
- return std::noop_coroutine ();
95
127
}
96
128
};
97
129
130
+ namespace detail
131
+ {
132
+ // Helper function for awaiting on a task. The next resume point is
133
+ // installed as a continuation of the task being awaited.
134
+ template <typename P>
135
+ std::coroutine_handle<>
136
+ await_suspend (std::coroutine_handle<P> base, std::coroutine_handle<> next)
137
+ {
138
+ if constexpr (P::joinable_v)
139
+ {
140
+ // Joinable tasks are never awaited and so cannot have a
141
+ // continuation by definition
142
+ return std::noop_coroutine ();
143
+ }
144
+ else
145
+ {
146
+ std::scoped_lock lock{base.promise ().mutex };
147
+ if (!base.promise ().flag )
148
+ {
149
+ // We're the first to reach here
150
+ base.promise ().flag = true ;
151
+ COOP_LOG (" Installing continuation %p for %p on thread %zu\n " ,
152
+ next.address (),
153
+ base.address (),
154
+ detail::thread_id ());
155
+ base.promise ().continuation = next;
156
+ return std::noop_coroutine ();
157
+ }
158
+ return base;
159
+ }
160
+ }
161
+ } // namespace detail
162
+
98
163
template <bool Joinable>
99
164
class task_base_t
100
165
{
@@ -108,6 +173,8 @@ template <>
108
173
class task_base_t <true >
109
174
{
110
175
public:
176
+ constexpr static bool joinable_v = true ;
177
+
111
178
void join ()
112
179
{
113
180
join_sem_->acquire ();
@@ -133,6 +200,8 @@ class task_t final : public task_base_t<Joinable>
133
200
public:
134
201
struct promise_type : public task_base_t <Joinable>::promise_type
135
202
{
203
+ constexpr static bool joinable_v = Joinable;
204
+
136
205
T data;
137
206
138
207
static void * operator new (size_t size)
@@ -190,7 +259,6 @@ class task_t final : public task_base_t<Joinable>
190
259
task_t (std::coroutine_handle<promise_type> coroutine) noexcept
191
260
: coroutine_{coroutine}
192
261
{
193
- COOP_LOG (" task %p born\n " , coroutine_.address ());
194
262
}
195
263
task_t (task_t const &) = delete ;
196
264
task_t & operator =(task_t const &) = delete ;
@@ -238,7 +306,6 @@ class task_t final : public task_base_t<Joinable>
238
306
{
239
307
if (coroutine_)
240
308
{
241
- COOP_LOG (" task %p dying\n " , coroutine_.address ());
242
309
coroutine_.destroy ();
243
310
}
244
311
}
@@ -275,10 +342,9 @@ class task_t final : public task_base_t<Joinable>
275
342
276
343
// When suspending from a coroutine *within* a task's coroutine, save the
277
344
// resume point (to be resumed when the inner coroutine finalizes)
278
- void
279
- await_suspend (std::coroutine_handle<> coroutine) noexcept
345
+ std::coroutine_handle<> await_suspend (std::coroutine_handle<> coroutine) noexcept
280
346
{
281
- coroutine_. promise (). continuation = coroutine;
347
+ return detail::await_suspend (coroutine_, coroutine) ;
282
348
}
283
349
284
350
private:
@@ -293,6 +359,8 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
293
359
public:
294
360
struct promise_type : public task_base_t <Joinable>::promise_type
295
361
{
362
+ constexpr static bool joinable_v = Joinable;
363
+
296
364
static void * operator new (size_t size)
297
365
{
298
366
return C::alloc (size);
@@ -334,7 +402,6 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
334
402
task_t (std::coroutine_handle<promise_type> coroutine) noexcept
335
403
: coroutine_{coroutine}
336
404
{
337
- COOP_LOG (" task %p born\n " , coroutine_.address ());
338
405
}
339
406
task_t (task_t const &) = delete ;
340
407
task_t & operator =(task_t const &) = delete ;
@@ -382,7 +449,6 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
382
449
{
383
450
if (coroutine_)
384
451
{
385
- COOP_LOG (" task %p dying\n " , coroutine_.address ());
386
452
coroutine_.destroy ();
387
453
}
388
454
}
@@ -404,10 +470,9 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
404
470
405
471
// When suspending from a coroutine *within* this task's coroutine, save the
406
472
// resume point (to be resumed when the inner coroutine finalizes)
407
- void
408
- await_suspend (std::coroutine_handle<> coroutine) noexcept
473
+ std::coroutine_handle<> await_suspend (std::coroutine_handle<> coroutine) noexcept
409
474
{
410
- coroutine_. promise (). continuation = coroutine;
475
+ return detail::await_suspend (coroutine_, coroutine) ;
411
476
}
412
477
413
478
private:
@@ -452,9 +517,6 @@ inline auto suspend(S& scheduler = S::instance(),
452
517
}
453
518
};
454
519
455
- COOP_LOG (" Suspending coroutine from thread %zu\n " ,
456
- std::hash<std::thread::id>{}(std::this_thread::get_id ()));
457
-
458
520
return awaiter_t {scheduler, cpu_mask, priority, source_location};
459
521
}
460
522
0 commit comments