Skip to content

Commit 75ee5bd

Browse files
committed
Resolve a number of thread-safety/scheduling issues
1 parent 9a5c4db commit 75ee5bd

File tree

6 files changed

+134
-45
lines changed

6 files changed

+134
-45
lines changed

ARCHITECTURE.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ That is, it doesn't get requeued on the thread pool for later execution.
1919

2020
The concurrent queue used to push work to worker threads is provided by [`moodycamel::ConcurrentQueue`](https://github.com/cameron314/concurrentqueue).
2121
Under the hood, the queue provides multiple-consumer multiple-producer usage, although in this case, only a single producer per queue
22-
exists.
22+
exists. The thread pool worker threads currently do *not* support work stealing, which is a slightly more complicated endeavor
23+
for job schedulers that support task affinity.
24+
25+
The granularity of your jobs shouldn't be too fine - maybe having jobs that are at least 100 us or more is a good idea, or you'll
26+
end up paying disproportionately for scheduling costs.
2327

2428
The Win32 event awaiter works by having a single IO thread which blocks in a single `WaitForMultipleObjects` call. One of the
2529
events it waits on is used to signal the available of more events to wait on. All the other events waited on are user awaited.

include/coop/detail/tracer.hpp

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
#pragma once
22

3+
#include <cstdint>
4+
#include <thread>
5+
6+
namespace coop
7+
{
8+
namespace detail
9+
{
10+
inline size_t thread_id() noexcept
11+
{
12+
return std::hash<std::thread::id>{}(std::this_thread::get_id());
13+
}
14+
} // namespace detail
15+
} // namespace coop
16+
317
#if defined(COOP_TRACE) && !defined(NDEBUG)
418
# include <cstdio>
519

include/coop/task.hpp

+82-20
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ using experimental::suspend_never;
1515
#endif
1616
#include <cstdlib>
1717
#include <limits>
18-
#include <semaphore>
19-
#include <unordered_map>
18+
#include <mutex>
2019

2120
namespace coop
2221
{
@@ -56,6 +55,9 @@ struct promise_base_t
5655
// resume point, which immediately following the suspend point.
5756
std::coroutine_handle<> continuation = nullptr;
5857

58+
std::mutex mutex;
59+
bool flag = false;
60+
5961
// Do not suspend immediately on entry of a coroutine
6062
std::suspend_never initial_suspend() const noexcept
6163
{
@@ -80,21 +82,84 @@ struct final_awaiter_t
8082
{
8183
}
8284

83-
std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coroutine) const noexcept
85+
std::coroutine_handle<>
86+
await_suspend(std::coroutine_handle<P> coroutine) const noexcept
8487
{
8588
// Check if this coroutine is being finalized from the
8689
// middle of a "continuation" coroutine and hop back there to
8790
// continue execution while *this* coroutine is suspended.
8891

89-
auto continuation = coroutine.promise().continuation;
90-
if (continuation)
92+
if constexpr (P::joinable_v)
93+
{
94+
// Joinable tasks are never awaited and so cannot have a
95+
// continuation by definition
96+
return std::noop_coroutine();
97+
}
98+
else
9199
{
92-
return continuation;
100+
COOP_LOG("Final await for coroutine %p on thread %zu\n",
101+
coroutine.address(),
102+
detail::thread_id());
103+
std::scoped_lock lock{coroutine.promise().mutex};
104+
if (coroutine.promise().flag)
105+
{
106+
// We're not the first to reach here
107+
auto continuation = coroutine.promise().continuation;
108+
if (continuation)
109+
{
110+
COOP_LOG("Resuming continuation %p on %p on thread %zu\n",
111+
continuation.address(),
112+
coroutine.address(),
113+
detail::thread_id());
114+
return continuation;
115+
}
116+
else
117+
{
118+
COOP_LOG(
119+
"Coroutine %p on thread %zu missing continuation\n",
120+
coroutine.address(),
121+
detail::thread_id());
122+
}
123+
}
124+
coroutine.promise().flag = true;
125+
return std::noop_coroutine();
93126
}
94-
return std::noop_coroutine();
95127
}
96128
};
97129

130+
namespace detail
131+
{
132+
// Helper function for awaiting on a task. The next resume point is
133+
// installed as a continuation of the task being awaited.
134+
template <typename P>
135+
std::coroutine_handle<>
136+
await_suspend(std::coroutine_handle<P> base, std::coroutine_handle<> next)
137+
{
138+
if constexpr (P::joinable_v)
139+
{
140+
// Joinable tasks are never awaited and so cannot have a
141+
// continuation by definition
142+
return std::noop_coroutine();
143+
}
144+
else
145+
{
146+
std::scoped_lock lock{base.promise().mutex};
147+
if (!base.promise().flag)
148+
{
149+
// We're the first to reach here
150+
base.promise().flag = true;
151+
COOP_LOG("Installing continuation %p for %p on thread %zu\n",
152+
next.address(),
153+
base.address(),
154+
detail::thread_id());
155+
base.promise().continuation = next;
156+
return std::noop_coroutine();
157+
}
158+
return base;
159+
}
160+
}
161+
} // namespace detail
162+
98163
template <bool Joinable>
99164
class task_base_t
100165
{
@@ -108,6 +173,8 @@ template <>
108173
class task_base_t<true>
109174
{
110175
public:
176+
constexpr static bool joinable_v = true;
177+
111178
void join()
112179
{
113180
join_sem_->acquire();
@@ -133,6 +200,8 @@ class task_t final : public task_base_t<Joinable>
133200
public:
134201
struct promise_type : public task_base_t<Joinable>::promise_type
135202
{
203+
constexpr static bool joinable_v = Joinable;
204+
136205
T data;
137206

138207
static void* operator new(size_t size)
@@ -190,7 +259,6 @@ class task_t final : public task_base_t<Joinable>
190259
task_t(std::coroutine_handle<promise_type> coroutine) noexcept
191260
: coroutine_{coroutine}
192261
{
193-
COOP_LOG("task %p born\n", coroutine_.address());
194262
}
195263
task_t(task_t const&) = delete;
196264
task_t& operator=(task_t const&) = delete;
@@ -238,7 +306,6 @@ class task_t final : public task_base_t<Joinable>
238306
{
239307
if (coroutine_)
240308
{
241-
COOP_LOG("task %p dying\n", coroutine_.address());
242309
coroutine_.destroy();
243310
}
244311
}
@@ -275,10 +342,9 @@ class task_t final : public task_base_t<Joinable>
275342

276343
// When suspending from a coroutine *within* a task's coroutine, save the
277344
// resume point (to be resumed when the inner coroutine finalizes)
278-
void
279-
await_suspend(std::coroutine_handle<> coroutine) noexcept
345+
std::coroutine_handle<> await_suspend(std::coroutine_handle<> coroutine) noexcept
280346
{
281-
coroutine_.promise().continuation = coroutine;
347+
return detail::await_suspend(coroutine_, coroutine);
282348
}
283349

284350
private:
@@ -293,6 +359,8 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
293359
public:
294360
struct promise_type : public task_base_t<Joinable>::promise_type
295361
{
362+
constexpr static bool joinable_v = Joinable;
363+
296364
static void* operator new(size_t size)
297365
{
298366
return C::alloc(size);
@@ -334,7 +402,6 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
334402
task_t(std::coroutine_handle<promise_type> coroutine) noexcept
335403
: coroutine_{coroutine}
336404
{
337-
COOP_LOG("task %p born\n", coroutine_.address());
338405
}
339406
task_t(task_t const&) = delete;
340407
task_t& operator=(task_t const&) = delete;
@@ -382,7 +449,6 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
382449
{
383450
if (coroutine_)
384451
{
385-
COOP_LOG("task %p dying\n", coroutine_.address());
386452
coroutine_.destroy();
387453
}
388454
}
@@ -404,10 +470,9 @@ class task_t<void, Joinable, C> : public task_base_t<Joinable>
404470

405471
// When suspending from a coroutine *within* this task's coroutine, save the
406472
// resume point (to be resumed when the inner coroutine finalizes)
407-
void
408-
await_suspend(std::coroutine_handle<> coroutine) noexcept
473+
std::coroutine_handle<> await_suspend(std::coroutine_handle<> coroutine) noexcept
409474
{
410-
coroutine_.promise().continuation = coroutine;
475+
return detail::await_suspend(coroutine_, coroutine);
411476
}
412477

413478
private:
@@ -452,9 +517,6 @@ inline auto suspend(S& scheduler = S::instance(),
452517
}
453518
};
454519

455-
COOP_LOG("Suspending coroutine from thread %zu\n",
456-
std::hash<std::thread::id>{}(std::this_thread::get_id()));
457-
458520
return awaiter_t{scheduler, cpu_mask, priority, source_location};
459521
}
460522

src/scheduler.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,16 @@ void scheduler_t::schedule(std::coroutine_handle<> coroutine,
181181
// discrepancy (Kronecker recurrence sequence)
182182
uint32_t index = static_cast<uint32_t>(update_++ * std::numbers::phi_v<float>)
183183
% std::popcount(cpu_affinity);
184-
queues_[index].enqueue(coroutine, priority, source_location);
184+
185+
// Iteratively unset bits to determine the nth set bit
186+
for (uint32_t i = 0; i != index; ++i)
187+
{
188+
cpu_affinity &= ~(1 << (std::countr_zero(cpu_affinity) + 1));
189+
}
190+
uint32_t queue = std::countr_zero(cpu_affinity);
191+
COOP_LOG("Work queue %i identified\n", queue);
192+
193+
queues_[queue].enqueue(coroutine, priority, source_location);
185194
}
186195

187196
void scheduler_t::schedule(std::coroutine_handle<> coroutine,

src/work_queue.cpp

+9-10
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ work_queue_t::work_queue_t(scheduler_t& scheduler, uint32_t id)
2626
thread_ = std::thread([this] {
2727
#if defined(_WIN32)
2828
SetThreadAffinityMask(
29-
thread_.native_handle(), static_cast<uint32_t>(1ull << id_));
29+
GetCurrentThread(), static_cast<uint32_t>(1ull << id_));
3030
#elif defined(__linux__)
3131
// TODO: Android implementation
3232
pthread_t thread = pthread_self();
@@ -64,12 +64,10 @@ work_queue_t::work_queue_t(scheduler_t& scheduler, uint32_t id)
6464
std::coroutine_handle<> coroutine;
6565
if (queues_[i].try_dequeue(coroutine))
6666
{
67-
COOP_LOG(
68-
"Dequeueing coroutine on CPU %i thread %zu %i\n",
69-
id_,
70-
std::hash<std::thread::id>{}(
71-
std::this_thread::get_id()),
72-
coroutine.done());
67+
COOP_LOG("Dequeueing coroutine %p on thread %zu (%i)\n",
68+
coroutine.address(),
69+
detail::thread_id(),
70+
id_);
7371
did_dequeue = true;
7472
coroutine.resume();
7573
break;
@@ -93,9 +91,10 @@ void work_queue_t::enqueue(std::coroutine_handle<> coroutine,
9391
uint32_t priority,
9492
source_location_t source_location)
9593
{
96-
priority = std::clamp<uint32_t>(priority, 0, COOP_PRIORITY_COUNT);
97-
COOP_LOG("Enqueueing coroutine on CPU %i (%s:%zu)\n",
98-
id_,
94+
priority = std::clamp<uint32_t>(priority, 0, COOP_PRIORITY_COUNT - 1);
95+
COOP_LOG("Enqueueing coroutine %p on thread %zu (%s:%zu)\n",
96+
coroutine.address(),
97+
detail::thread_id(),
9998
source_location.file,
10099
source_location.line);
101100
queues_[priority].enqueue(coroutine);

test/test.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@
55
#include <coop/task.hpp>
66
#include <thread>
77

8-
coop::task_t<> suspend_test1()
9-
{
10-
COOP_SUSPEND();
11-
}
12-
13-
coop::task_t<void, true> suspend_test2()
8+
coop::task_t<void, true> suspend_time()
149
{
10+
// std::printf("%zu start thread\n", coop::detail::thread_id());
1511
auto t1 = std::chrono::system_clock::now();
1612
COOP_SUSPEND();
1713
auto t2 = std::chrono::system_clock::now();
14+
// std::printf("%zu end thread\n", coop::detail::thread_id());
1815
size_t us
1916
= std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
2017
std::printf("Duration for suspend test: %zu us\n", us);
@@ -23,7 +20,9 @@ coop::task_t<void, true> suspend_test2()
2320
TEST_CASE("suspend overhead")
2421
{
2522
std::printf("Calling suspend_test2 coroutine\n");
26-
suspend_test2().join();
23+
// auto task = suspend_time();
24+
// task.join();
25+
suspend_time().join();
2726
std::printf("suspend_test2 joined\n");
2827
}
2928

@@ -46,25 +45,27 @@ TEST_CASE("test suspend")
4645
CHECK(id != next);
4746
}
4847

49-
coop::task_t<int> chain1()
48+
coop::task_t<int> chain1(int core)
5049
{
5150
std::printf("chain1 suspending\n");
52-
COOP_SUSPEND();
51+
COOP_SUSPEND4(1 << core);
5352
std::printf("chain1 resumed\n");
5453
co_return 1;
5554
}
5655

5756
coop::task_t<int> chain2()
5857
{
5958
std::printf("chain2\n");
60-
COOP_SUSPEND();
61-
co_return co_await chain1();
59+
COOP_SUSPEND4(1 << 3);
60+
auto t1 = chain1(5);
61+
auto t2 = chain1(6);
62+
co_return co_await t1 + co_await t2;
6263
}
6364

6465
coop::task_t<void, true> chain3(int& result)
6566
{
6667
std::printf("chain3 suspending\n");
67-
co_await coop::suspend(coop::scheduler_t::instance(), 0x2);
68+
COOP_SUSPEND4(1 << 4);
6869
std::printf("chain3 resumed\n");
6970
result = co_await chain2();
7071
}
@@ -76,7 +77,7 @@ TEST_CASE("chained continuation")
7677
std::printf("Joining chained continuation task\n");
7778
task.join();
7879
std::printf("Task chained continuation joined\n");
79-
CHECK(x == 1);
80+
CHECK(x == 2);
8081
}
8182

8283
coop::task_t<> in_flight1()

0 commit comments

Comments
 (0)