Skip to content

Commit edeae97

Browse files
committed
Add ability to get and set thread pool size from JIT-land
1 parent 1653c16 commit edeae97

File tree

7 files changed

+63
-16
lines changed

7 files changed

+63
-16
lines changed

src/JITModule.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,24 @@ void JITModule::reuse_device_allocations(bool b) const {
529529
}
530530
}
531531

532+
int JITModule::get_num_threads() const {
533+
std::map<std::string, Symbol>::const_iterator f =
534+
exports().find("halide_get_num_threads");
535+
if (f != exports().end()) {
536+
return (reinterpret_bits<int (*)()>(f->second.address))();
537+
}
538+
return 1;
539+
}
540+
541+
int JITModule::set_num_threads(int n) const {
542+
std::map<std::string, Symbol>::const_iterator f =
543+
exports().find("halide_set_num_threads");
544+
if (f != exports().end()) {
545+
return (reinterpret_bits<int (*)(int)>(f->second.address))(n);
546+
}
547+
return 1;
548+
}
549+
532550
bool JITModule::compiled() const {
533551
return jit_module->JIT != nullptr;
534552
}
@@ -1075,6 +1093,16 @@ void JITSharedRuntime::reuse_device_allocations(bool b) {
10751093
shared_runtimes(MainShared).reuse_device_allocations(b);
10761094
}
10771095

1096+
int JITSharedRuntime::get_num_threads() {
1097+
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
1098+
return shared_runtimes(MainShared).get_num_threads();
1099+
}
1100+
1101+
int JITSharedRuntime::set_num_threads(int n) {
1102+
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
1103+
return shared_runtimes(MainShared).set_num_threads(n);
1104+
}
1105+
10781106
JITCache::JITCache(Target jit_target,
10791107
std::vector<Argument> arguments,
10801108
std::map<std::string, JITExtern> jit_externs,

src/JITModule.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ struct JITModule {
247247
/** See JITSharedRuntime::reuse_device_allocations */
248248
void reuse_device_allocations(bool) const;
249249

250+
/** See JITSharedRuntime::get_num_threads */
251+
int get_num_threads() const;
252+
253+
/** See JITSharedRuntime::set_num_threads */
254+
int set_num_threads(int) const;
255+
250256
/** Return true if compile_module has been called on this module. */
251257
bool compiled() const;
252258
};
@@ -279,6 +285,18 @@ class JITSharedRuntime {
279285
static void reuse_device_allocations(bool);
280286

281287
static void release_all();
288+
289+
/** Get the number of threads in the Halide thread pool. Includes the
290+
* calling thread. Meaningless if a custom do par for has been set. */
291+
static int get_num_threads();
292+
293+
/** Set the number of threads to use in the Halide thread pool, inclusive of
294+
* the calling thread. Pass zero to use a reasonable default (typically the
295+
* number of CPUs online). Calling this is meaningless if custom_do_par_for
296+
* has been set. Halide may launch more threads than this if necessary to
297+
* avoid deadlock when using the async scheduling directive. Returns the old
298+
* number. */
299+
static int set_num_threads(int);
282300
};
283301

284302
void *get_symbol_address(const char *s);

src/runtime/HalideRuntime.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ extern struct halide_thread *halide_spawn_thread(void (*f)(void *), void *closur
390390
/** Join a thread. */
391391
extern void halide_join_thread(struct halide_thread *);
392392

393-
/** Set the number of threads used by Halide's thread pool. Returns
393+
/** Get or set the number of threads used by Halide's thread pool. Set returns
394394
* the old number.
395395
*
396396
* n < 0 : error condition
@@ -402,7 +402,10 @@ extern void halide_join_thread(struct halide_thread *);
402402
* of halide_do_par_for(); custom implementations may completely ignore values
403403
* passed to halide_set_num_threads().)
404404
*/
405+
// @{
406+
extern int halide_get_num_threads();
405407
extern int halide_set_num_threads(int n);
408+
// @}
406409

407410
/** Halide calls these functions to allocate and free memory. To
408411
* replace in AOT code, use the halide_set_custom_malloc and

src/runtime/runtime_api.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
9696
(void *)&halide_get_cpu_features,
9797
(void *)&halide_get_gpu_device,
9898
(void *)&halide_get_library_symbol,
99+
(void *)&halide_get_num_threads,
99100
(void *)&halide_get_symbol,
100101
(void *)&halide_get_trace_file,
101102
(void *)&halide_hexagon_detach_device_handle,

src/runtime/thread_pool_common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,13 @@ WEAK int halide_set_num_threads(int n) {
693693
return old;
694694
}
695695

696+
WEAK int halide_get_num_threads() {
697+
halide_mutex_lock(&work_queue.mutex);
698+
int n = work_queue.desired_threads_working;
699+
halide_mutex_unlock(&work_queue.mutex);
700+
return n;
701+
}
702+
696703
WEAK void halide_shutdown_thread_pool() {
697704
if (work_queue.initialized) {
698705
// Wake everyone up and tell them the party's over and it's time

test/correctness/atomics.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,15 +1196,10 @@ int main(int argc, char **argv) {
11961196
}
11971197

11981198
Target target = get_jit_target_from_environment();
1199-
// Most of the schedules used in this test are terrible for large
1200-
// thread count machines, due to massive amounts of
1201-
// contention. We'll just set the thread count to 4. Unfortunately
1202-
// there's no JIT api for this yet.
1203-
#ifdef _WIN32
1204-
_putenv_s("HL_NUM_THREADS", "4");
1205-
#else
1206-
setenv("HL_NUM_THREADS", "4", 1);
1207-
#endif
1199+
// Most of the schedules used in this test are terrible for large
1200+
// thread count machines, due to massive amounts of
1201+
// contention. We'll just set the thread count to 4.
1202+
Halide::Internal::JITSharedModule::set_num_threads(4);
12081203
test_all<uint8_t>(Backend::CPU);
12091204
test_all<uint8_t>(Backend::CPUVectorize);
12101205
test_all<int8_t>(Backend::CPU);

test/performance/inner_loop_parallel.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,7 @@ int main(int argc, char **argv) {
2525
for (int t = 2; t <= 64; t *= 2) {
2626
std::ostringstream ss;
2727
ss << "HL_NUM_THREADS=" << t;
28-
std::string str = ss.str();
29-
char buf[32] = {0};
30-
memcpy(buf, str.c_str(), str.size());
31-
putenv(buf);
32-
p.invalidate_cache();
33-
Halide::Internal::JITSharedRuntime::release_all();
28+
Halide::Internal::JITSharedRuntime::set_num_threads(t);
3429

3530
p.compile_jit();
3631
// Start the thread pool without giving any hints as to the

0 commit comments

Comments
 (0)