Skip to content

Commit

Permalink
Add ability to get and set thread pool size from JIT-land
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams committed Oct 28, 2024
1 parent 1653c16 commit edeae97
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 16 deletions.
28 changes: 28 additions & 0 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ void JITModule::reuse_device_allocations(bool b) const {
}
}

int JITModule::get_num_threads() const {
std::map<std::string, Symbol>::const_iterator f =
exports().find("halide_get_num_threads");
if (f != exports().end()) {
return (reinterpret_bits<int (*)()>(f->second.address))();
}
return 1;
}

int JITModule::set_num_threads(int n) const {
std::map<std::string, Symbol>::const_iterator f =
exports().find("halide_set_num_threads");
if (f != exports().end()) {
return (reinterpret_bits<int (*)(int)>(f->second.address))(n);
}
return 1;
}

bool JITModule::compiled() const {
return jit_module->JIT != nullptr;
}
Expand Down Expand Up @@ -1075,6 +1093,16 @@ void JITSharedRuntime::reuse_device_allocations(bool b) {
shared_runtimes(MainShared).reuse_device_allocations(b);
}

int JITSharedRuntime::get_num_threads() {
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
return shared_runtimes(MainShared).get_num_threads();
}

int JITSharedRuntime::set_num_threads(int n) {
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
return shared_runtimes(MainShared).set_num_threads(n);
}

JITCache::JITCache(Target jit_target,
std::vector<Argument> arguments,
std::map<std::string, JITExtern> jit_externs,
Expand Down
18 changes: 18 additions & 0 deletions src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ struct JITModule {
/** See JITSharedRuntime::reuse_device_allocations */
void reuse_device_allocations(bool) const;

/** See JITSharedRuntime::get_num_threads */
int get_num_threads() const;

/** See JITSharedRuntime::set_num_threads */
int set_num_threads(int) const;

/** Return true if compile_module has been called on this module. */
bool compiled() const;
};
Expand Down Expand Up @@ -279,6 +285,18 @@ class JITSharedRuntime {
static void reuse_device_allocations(bool);

static void release_all();

/** Get the number of threads in the Halide thread pool. Includes the
* calling thread. Meaningless if a custom do par for has been set. */
static int get_num_threads();

/** Set the number of threads to use in the Halide thread pool, inclusive of
* the calling thread. Pass zero to use a reasonable default (typically the
* number of CPUs online). Calling this is meaningless if custom_do_par_for
* has been set. Halide may launch more threads than this if necessary to
* avoid deadlock when using the async scheduling directive. Returns the old
* number. */
static int set_num_threads(int);
};

void *get_symbol_address(const char *s);
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ extern struct halide_thread *halide_spawn_thread(void (*f)(void *), void *closur
/** Join a thread. */
extern void halide_join_thread(struct halide_thread *);

/** Set the number of threads used by Halide's thread pool. Returns
/** Get or set the number of threads used by Halide's thread pool. Set returns
* the old number.
*
* n < 0 : error condition
Expand All @@ -402,7 +402,10 @@ extern void halide_join_thread(struct halide_thread *);
* of halide_do_par_for(); custom implementations may completely ignore values
* passed to halide_set_num_threads().)
*/
// @{
extern int halide_get_num_threads();
extern int halide_set_num_threads(int n);
// @}

/** Halide calls these functions to allocate and free memory. To
* replace in AOT code, use the halide_set_custom_malloc and
Expand Down
1 change: 1 addition & 0 deletions src/runtime/runtime_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
(void *)&halide_get_cpu_features,
(void *)&halide_get_gpu_device,
(void *)&halide_get_library_symbol,
(void *)&halide_get_num_threads,
(void *)&halide_get_symbol,
(void *)&halide_get_trace_file,
(void *)&halide_hexagon_detach_device_handle,
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/thread_pool_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,13 @@ WEAK int halide_set_num_threads(int n) {
return old;
}

WEAK int halide_get_num_threads() {
halide_mutex_lock(&work_queue.mutex);
int n = work_queue.desired_threads_working;
halide_mutex_unlock(&work_queue.mutex);
return n;
}

WEAK void halide_shutdown_thread_pool() {
if (work_queue.initialized) {
// Wake everyone up and tell them the party's over and it's time
Expand Down
13 changes: 4 additions & 9 deletions test/correctness/atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,15 +1196,10 @@ int main(int argc, char **argv) {
}

Target target = get_jit_target_from_environment();
// Most of the schedules used in this test are terrible for large
// thread count machines, due to massive amounts of
// contention. We'll just set the thread count to 4. Unfortunately
// there's no JIT api for this yet.
#ifdef _WIN32
_putenv_s("HL_NUM_THREADS", "4");
#else
setenv("HL_NUM_THREADS", "4", 1);
#endif
// Most of the schedules used in this test are terrible for large
// thread count machines, due to massive amounts of
// contention. We'll just set the thread count to 4.
Halide::Internal::JITSharedModule::set_num_threads(4);
test_all<uint8_t>(Backend::CPU);
test_all<uint8_t>(Backend::CPUVectorize);
test_all<int8_t>(Backend::CPU);
Expand Down
7 changes: 1 addition & 6 deletions test/performance/inner_loop_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ int main(int argc, char **argv) {
for (int t = 2; t <= 64; t *= 2) {
std::ostringstream ss;
ss << "HL_NUM_THREADS=" << t;
std::string str = ss.str();
char buf[32] = {0};
memcpy(buf, str.c_str(), str.size());
putenv(buf);
p.invalidate_cache();
Halide::Internal::JITSharedRuntime::release_all();
Halide::Internal::JITSharedRuntime::set_num_threads(t);

p.compile_jit();
// Start the thread pool without giving any hints as to the
Expand Down

0 comments on commit edeae97

Please sign in to comment.