From 4540a8d2563bd8573d402285a14f0695ba3b2853 Mon Sep 17 00:00:00 2001 From: Fabian Knorr Date: Mon, 25 Nov 2024 17:45:27 +0100 Subject: [PATCH] Move scheduler / runtime testspy definitions to separate files --- CMakeLists.txt | 9 +- ci/find-unformatted-files.sh | 2 +- include/runtime.h | 100 ++++-------- include/runtime_impl.h | 128 --------------- include/scheduler.h | 18 +-- src/runtime.cc | 261 +++++++++++++++++++++++------- src/scheduler.cc | 24 +-- src/testspy/runtime_testspy.h | 55 +++++++ src/testspy/runtime_testspy.inl | 84 ++++++++++ src/testspy/scheduler_testspy.h | 46 ++++++ src/testspy/scheduler_testspy.inl | 27 ++++ test/CMakeLists.txt | 1 + test/dag_benchmarks.cc | 11 +- test/runtime_tests.cc | 9 +- test/test_main.cc | 4 +- test/test_utils.cc | 6 +- test/test_utils.h | 81 +--------- 17 files changed, 486 insertions(+), 380 deletions(-) delete mode 100644 include/runtime_impl.h create mode 100644 src/testspy/runtime_testspy.h create mode 100644 src/testspy/runtime_testspy.inl create mode 100644 src/testspy/scheduler_testspy.h create mode 100644 src/testspy/scheduler_testspy.inl diff --git a/CMakeLists.txt b/CMakeLists.txt index 04376d98e..38b62180e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -207,7 +207,10 @@ set(CELERITY_FEATURE_LOCAL_ACCESSOR ON) set(CELERITY_FEATURE_UNNAMED_KERNELS ON) # Add header files to library so they show up in IDEs -file(GLOB_RECURSE ALL_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE ALL_INCLUDES + "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/src/*.inl") if(CMAKE_GENERATOR STREQUAL "Ninja") # Force colored warnings in Ninja's output, if the compiler has -fdiagnostics-color support. @@ -304,14 +307,14 @@ elseif(CELERITY_SYCL_IMPL STREQUAL "SimSYCL") endif() configure_file(include/version.h.in include/version.h @ONLY) -list(APPEND ALL_HEADERS "${CMAKE_CURRENT_BINARY_DIR}/include/version.h") +list(APPEND ALL_INCLUDES "${CMAKE_CURRENT_BINARY_DIR}/include/version.h") list(APPEND PUBLIC_HEADERS "${CMAKE_CURRENT_BINARY_DIR}/include/version.h") add_library( celerity_runtime STATIC ${SOURCES} - ${ALL_HEADERS} + ${ALL_INCLUDES} ) set_property(TARGET celerity_runtime PROPERTY CXX_STANDARD "${CELERITY_CXX_STANDARD}") diff --git a/ci/find-unformatted-files.sh b/ci/find-unformatted-files.sh index c9ed957b5..1450be530 100755 --- a/ci/find-unformatted-files.sh +++ b/ci/find-unformatted-files.sh @@ -11,7 +11,7 @@ if [[ ! -x "$(which clang-format)" ]]; then exit 1 fi -SOURCES=$(find examples include src test \( -name "*.h" -o -name "*.cc" \) ! -name "stb*") +SOURCES=$(find examples include src test \( -name "*.h" -o -name "*.cc" -o -name "*.inl" \) ! -name "stb*") for s in $SOURCES; do # Since clang-format does not provide an option to check whether formatting is required, diff --git a/include/runtime.h b/include/runtime.h index a2f5aa88d..cffdbfbe3 100644 --- a/include/runtime.h +++ b/include/runtime.h @@ -24,7 +24,7 @@ namespace detail { */ static void init(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector = auto_select_devices{}); - static bool has_instance() { return s_instance != nullptr; } + static bool has_instance() { return s_instance.m_impl != nullptr; } static void shutdown(); @@ -34,100 +34,58 @@ namespace detail { runtime(runtime&&) = delete; runtime& operator=(const runtime&) = delete; runtime& operator=(runtime&&) = delete; + ~runtime() = default; - virtual ~runtime() = default; + task_id submit(raw_command_group&& cg); - virtual task_id submit(raw_command_group&& cg) = 0; + task_id fence(buffer_access access, std::unique_ptr fence_promise); - virtual task_id fence(buffer_access access, std::unique_ptr fence_promise) = 0; + task_id fence(host_object_effect effect, std::unique_ptr fence_promise); - virtual task_id fence(host_object_effect effect, std::unique_ptr fence_promise) = 0; + task_id sync(detail::epoch_action action); - virtual task_id sync(detail::epoch_action action) = 0; + void create_queue(); - virtual void create_queue() = 0; + void destroy_queue(); - virtual void destroy_queue() = 0; + allocation_id create_user_allocation(void* ptr); - virtual allocation_id create_user_allocation(void* ptr) = 0; + buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid); - virtual buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid) = 0; + void set_buffer_debug_name(buffer_id bid, const std::string& debug_name); - virtual void set_buffer_debug_name(buffer_id bid, const std::string& debug_name) = 0; + void destroy_buffer(buffer_id bid); - virtual void destroy_buffer(buffer_id bid) = 0; + host_object_id create_host_object(std::unique_ptr instance /* optional */); - virtual host_object_id create_host_object(std::unique_ptr instance /* optional */) = 0; + void destroy_host_object(host_object_id hoid); - virtual void destroy_host_object(host_object_id hoid) = 0; + reduction_id create_reduction(std::unique_ptr reducer); - virtual reduction_id create_reduction(std::unique_ptr reducer) = 0; + bool is_dry_run() const; - virtual bool is_dry_run() const = 0; + void set_scheduler_lookahead(experimental::lookahead lookahead); - virtual void set_scheduler_lookahead(experimental::lookahead lookahead) = 0; + void flush_scheduler(); - virtual void flush_scheduler() = 0; + private: + class impl; - protected: - inline static bool s_mpi_initialized = false; - inline static bool s_mpi_finalized = false; + static bool s_mpi_initialized; + static bool s_mpi_finalized; + + static bool s_test_mode; + static bool s_test_active; + static bool s_test_runtime_was_instantiated; static void mpi_initialize_once(int* argc, char*** argv); static void mpi_finalize_once(); - static std::unique_ptr s_instance; - - runtime() = default; + static runtime s_instance; - // ------------------------------------------ TESTING UTILS ------------------------------------------ - // We have to jump through some hoops to be able to re-initialize the runtime for unit testing. - // MPI does not like being initialized more than once per process, so we have to skip that part for - // re-initialization. - // --------------------------------------------------------------------------------------------------- + std::unique_ptr m_impl; - public: - // Switches to test mode, where MPI will be initialized through test_case_enter() instead of runtime::runtime(). Called on Catch2 startup. - static void test_mode_enter() { - assert(!s_mpi_initialized); - s_test_mode = true; - } - - // Finalizes MPI if it was ever initialized in test mode. Called on Catch2 shutdown. - static void test_mode_exit() { - assert(s_test_mode && !s_test_active && !s_mpi_finalized); - if(s_mpi_initialized) mpi_finalize_once(); - } - - // Initializes MPI for tests, if it was not initialized before - static void test_require_mpi() { - assert(s_test_mode && !s_test_active); - if(!s_mpi_initialized) mpi_initialize_once(nullptr, nullptr); - } - - // Allows the runtime to be transitively instantiated in tests. Called from runtime_fixture. - static void test_case_enter() { - assert(s_test_mode && !s_test_active && s_mpi_initialized && s_instance == nullptr); - s_test_active = true; - s_test_runtime_was_instantiated = false; - } - - static bool test_runtime_was_instantiated() { - assert(s_test_mode); - return s_test_runtime_was_instantiated; - } - - // Deletes the runtime instance, which happens only in tests. Called from runtime_fixture. - static void test_case_exit() { - assert(s_test_mode && s_test_active); - s_instance.reset(); // for when the test case explicitly initialized the runtime but did not successfully construct a queue / buffer / ... - s_test_active = false; - } - - protected: - inline static bool s_test_mode = false; - inline static bool s_test_active = false; - inline static bool s_test_runtime_was_instantiated = false; + runtime() = default; }; /// Returns the combined command graph of all nodes on node 0, an empty string on other nodes diff --git a/include/runtime_impl.h b/include/runtime_impl.h deleted file mode 100644 index e76a2bdbb..000000000 --- a/include/runtime_impl.h +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include "runtime.h" - -#include "affinity.h" -#include "cgf.h" -#include "device.h" -#include "executor.h" -#include "host_object.h" -#include "instruction_graph_generator.h" -#include "reduction.h" -#include "scheduler.h" -#include "task.h" -#include "task_manager.h" -#include "types.h" - -#include -#include - - -namespace celerity::detail { - -class config; - -class runtime_impl final : public runtime, private task_manager::delegate, private scheduler::delegate, private executor::delegate { - public: - runtime_impl(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector); - - runtime_impl(const runtime_impl&) = delete; - runtime_impl(runtime_impl&&) = delete; - runtime_impl& operator=(const runtime_impl&) = delete; - runtime_impl& operator=(runtime_impl&&) = delete; - - ~runtime_impl() override; - - task_id submit(raw_command_group&& cg) override; - - task_id fence(buffer_access access, std::unique_ptr fence_promise) override; - - task_id fence(host_object_effect effect, std::unique_ptr fence_promise) override; - - task_id sync(detail::epoch_action action) override; - - void create_queue() override; - - void destroy_queue() override; - - allocation_id create_user_allocation(void* ptr) override; - - buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid) override; - - void set_buffer_debug_name(buffer_id bid, const std::string& debug_name) override; - - void destroy_buffer(buffer_id bid) override; - - host_object_id create_host_object(std::unique_ptr instance /* optional */) override; - - void destroy_host_object(host_object_id hoid) override; - - reduction_id create_reduction(std::unique_ptr reducer) override; - - bool is_dry_run() const override; - - void set_scheduler_lookahead(experimental::lookahead lookahead) override; - - void flush_scheduler() override; - - private: - friend struct runtime_testspy; - - // `runtime` is not thread safe except for its delegate implementations, so we store the id of the thread where it was instantiated (the application - // thread) in order to throw if the user attempts to issue a runtime operation from any other thread. One case where this may happen unintentionally - // is capturing a buffer into a host-task by value, where this capture is the last reference to the buffer: The runtime would attempt to destroy itself - // from a thread that it also needs to await, which would at least cause a deadlock. This variable is immutable, so reading it from a different thread - // for the purpose of the check is safe. - std::thread::id m_application_thread; - - std::unique_ptr m_cfg; - size_t m_num_nodes = 0; - node_id m_local_nid = 0; - size_t m_num_local_devices = 0; - - // track all instances of celerity::queue, celerity::buffer and celerity::host_object to sanity-check runtime destruction - size_t m_num_live_queues = 0; - std::unordered_set m_live_buffers; - std::unordered_set m_live_host_objects; - - buffer_id m_next_buffer_id = 0; - raw_allocation_id m_next_user_allocation_id = 1; - host_object_id m_next_host_object_id = 0; - reduction_id m_next_reduction_id = no_reduction_id + 1; - - task_graph m_tdag; - std::unique_ptr m_task_mngr; - std::unique_ptr m_schdlr; - std::unique_ptr m_exec; - - std::optional m_latest_horizon_reached; // only accessed by executor thread - std::atomic m_latest_epoch_reached; // task_id, but cast to size_t to work with std::atomic - task_id m_last_epoch_pruned_before = 0; - - std::unique_ptr m_task_recorder; // accessed by task manager (application thread) - std::unique_ptr m_command_recorder; // accessed only by scheduler thread (until shutdown) - std::unique_ptr m_instruction_recorder; // accessed only by scheduler thread (until shutdown) - - std::unique_ptr m_thread_pinner; // thread safe, manages lifetime of thread pinning machinery - - /// Panic when not called from m_application_thread (see that variable for more info on the matter). Since there are thread-safe and non thread-safe - /// member functions, we call this check at the beginning of all the non-safe ones. - void require_call_from_application_thread() const; - - void maybe_prune_task_graph(); - - // task_manager::delegate - void task_created(const task* tsk) override; - - // scheduler::delegate - void flush(std::vector instructions, std::vector pilot) override; - - // executor::delegate - void horizon_reached(task_id horizon_tid) override; - void epoch_reached(task_id epoch_tid) override; - - /// True when no buffers, host objects or queues are live that keep the runtime alive. - bool is_unreferenced() const; -}; - -} // namespace celerity::detail diff --git a/include/scheduler.h b/include/scheduler.h index 107a20f4d..48d293b3a 100644 --- a/include/scheduler.h +++ b/include/scheduler.h @@ -1,27 +1,17 @@ #pragma once -#include "command_graph.h" #include "command_graph_generator.h" #include "instruction_graph_generator.h" #include "ranges.h" #include "types.h" #include -#include #include #include namespace celerity::detail::scheduler_detail { -/// executed inside scheduler thread, making it safe to access scheduler members -struct test_state { - const command_graph* cdag = nullptr; - const instruction_graph* idag = nullptr; - experimental::lookahead lookahead = experimental::lookahead::automatic; -}; -using test_inspector = std::function; - struct scheduler_impl; } // namespace celerity::detail::scheduler_detail @@ -77,15 +67,9 @@ class scheduler { void flush_commands(); private: - struct test_threadless_tag {}; + scheduler() = default; // used by scheduler_testspy std::unique_ptr m_impl; - - // used in scheduler_testspy - scheduler(test_threadless_tag, size_t num_nodes, node_id local_node_id, const system_info& system_info, scheduler::delegate* delegate, - command_recorder* crec, instruction_recorder* irec, const policy_set& policy = {}); - void test_scheduling_loop(); - void test_inspect(scheduler_detail::test_inspector inspector); }; } // namespace celerity::detail diff --git a/src/runtime.cc b/src/runtime.cc index 43378c615..6f429156d 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -1,4 +1,4 @@ -#include "runtime_impl.h" +#include "runtime.h" #include "affinity.h" #include "backend/sycl_backend.h" @@ -21,6 +21,7 @@ #include "system_info.h" #include "task.h" #include "task_manager.h" +#include "testspy/runtime_testspy.h" #include "tracy.h" #include "types.h" #include "utils.h" @@ -84,40 +85,108 @@ namespace detail { std::promise m_promise; }; - std::unique_ptr runtime::s_instance = nullptr; + class runtime::impl final : public runtime, private task_manager::delegate, private scheduler::delegate, private executor::delegate { + public: + impl(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector); - void runtime::mpi_initialize_once(int* argc, char*** argv) { -#if CELERITY_ENABLE_MPI - CELERITY_DETAIL_TRACY_ZONE_SCOPED_V("mpi::init", LightSkyBlue, "MPI_Init"); - assert(!s_mpi_initialized); - int provided; - MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &provided); - assert(provided == MPI_THREAD_MULTIPLE); -#endif // CELERITY_ENABLE_MPI - s_mpi_initialized = true; - } + impl(const runtime::impl&) = delete; + impl(runtime::impl&&) = delete; + impl& operator=(const runtime::impl&) = delete; + impl& operator=(runtime::impl&&) = delete; - void runtime::mpi_finalize_once() { -#if CELERITY_ENABLE_MPI - CELERITY_DETAIL_TRACY_ZONE_SCOPED_V("mpi::finalize", LightSkyBlue, "MPI_Finalize"); - assert(s_mpi_initialized && !s_mpi_finalized && (!s_test_mode || !s_instance)); - MPI_Finalize(); -#endif // CELERITY_ENABLE_MPI - s_mpi_finalized = true; - } + ~impl(); - void runtime::init(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector) { - assert(!s_instance); - s_instance = std::make_unique(argc, argv, user_devices_or_selector); - if(!s_test_mode) { atexit(shutdown); } - } + task_id submit(raw_command_group&& cg); - runtime& runtime::get_instance() { - if(s_instance == nullptr) { throw std::runtime_error("Runtime has not been initialized"); } - return *s_instance; - } + task_id fence(buffer_access access, std::unique_ptr fence_promise); + + task_id fence(host_object_effect effect, std::unique_ptr fence_promise); + + task_id sync(detail::epoch_action action); + + void create_queue(); + + void destroy_queue(); + + allocation_id create_user_allocation(void* ptr); + + buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid); + + void set_buffer_debug_name(buffer_id bid, const std::string& debug_name); + + void destroy_buffer(buffer_id bid); + + host_object_id create_host_object(std::unique_ptr instance /* optional */); + + void destroy_host_object(host_object_id hoid); + + reduction_id create_reduction(std::unique_ptr reducer); + + bool is_dry_run() const; + + void set_scheduler_lookahead(experimental::lookahead lookahead); + + void flush_scheduler(); + + private: + friend struct runtime_testspy; + + // `runtime` is not thread safe except for its delegate implementations, so we store the id of the thread where it was instantiated (the application + // thread) in order to throw if the user attempts to issue a runtime operation from any other thread. One case where this may happen unintentionally + // is capturing a buffer into a host-task by value, where this capture is the last reference to the buffer: The runtime would attempt to destroy itself + // from a thread that it also needs to await, which would at least cause a deadlock. This variable is immutable, so reading it from a different thread + // for the purpose of the check is safe. + std::thread::id m_application_thread; + + std::unique_ptr m_cfg; + size_t m_num_nodes = 0; + node_id m_local_nid = 0; + size_t m_num_local_devices = 0; + + // track all instances of celerity::queue, celerity::buffer and celerity::host_object to sanity-check runtime destruction + size_t m_num_live_queues = 0; + std::unordered_set m_live_buffers; + std::unordered_set m_live_host_objects; + + buffer_id m_next_buffer_id = 0; + raw_allocation_id m_next_user_allocation_id = 1; + host_object_id m_next_host_object_id = 0; + reduction_id m_next_reduction_id = no_reduction_id + 1; + + task_graph m_tdag; + std::unique_ptr m_task_mngr; + std::unique_ptr m_schdlr; + std::unique_ptr m_exec; + + std::optional m_latest_horizon_reached; // only accessed by executor thread + std::atomic m_latest_epoch_reached; // task_id, but cast to size_t to work with std::atomic + task_id m_last_epoch_pruned_before = 0; + + std::unique_ptr m_task_recorder; // accessed by task manager (application thread) + std::unique_ptr m_command_recorder; // accessed only by scheduler thread (until shutdown) + std::unique_ptr m_instruction_recorder; // accessed only by scheduler thread (until shutdown) + + std::unique_ptr m_thread_pinner; // thread safe, manages lifetime of thread pinning machinery + + /// Panic when not called from m_application_thread (see that variable for more info on the matter). Since there are thread-safe and non thread-safe + /// member functions, we call this check at the beginning of all the non-safe ones. + void require_call_from_application_thread() const; + + void maybe_prune_task_graph(); + + // task_manager::delegate + void task_created(const task* tsk) override; + + // scheduler::delegate + void flush(std::vector instructions, std::vector pilot) override; - void runtime::shutdown() { s_instance.reset(); } + // executor::delegate + void horizon_reached(task_id horizon_tid) override; + void epoch_reached(task_id epoch_tid) override; + + /// True when no buffers, host objects or queues are live that keep the runtime alive. + bool is_unreferenced() const; + }; static auto get_pid() { #ifdef _MSC_VER @@ -206,7 +275,7 @@ namespace detail { #endif // CELERITY_ENABLE_MPI } - runtime_impl::runtime_impl(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector) { + runtime::impl::impl(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector) { m_application_thread = std::this_thread::get_id(); m_cfg = std::make_unique(argc, argv); @@ -334,16 +403,16 @@ namespace detail { m_num_local_devices = system.devices.size(); } - void runtime_impl::require_call_from_application_thread() const { + void runtime::impl::require_call_from_application_thread() const { if(std::this_thread::get_id() != m_application_thread) { utils::panic("Celerity runtime, queue, handler, buffer and host_object types must only be constructed, used, and destroyed from the " "application thread. Make sure that you did not accidentally capture one of these types in a host_task."); } } - runtime_impl::~runtime_impl() { + runtime::impl::~impl() { // LCOV_EXCL_START - if(!is_unreferenced()) { + if(m_num_live_queues != 0 || !m_live_buffers.empty() || !m_live_host_objects.empty()) { // this call might originate from static destruction - we cannot assume spdlog to still be around utils::panic("Detected an attempt to destroy runtime while at least one queue, buffer or host_object was still alive. This likely means " "that one of these objects was leaked, or at least its lifetime extended beyond the scope of main(). This is undefined."); @@ -403,25 +472,25 @@ namespace detail { if(!s_test_mode) { mpi_finalize_once(); } } - task_id runtime_impl::submit(raw_command_group&& cg) { + task_id runtime::impl::submit(raw_command_group&& cg) { require_call_from_application_thread(); maybe_prune_task_graph(); return m_task_mngr->generate_command_group_task(std::move(cg)); } - task_id runtime_impl::fence(buffer_access access, std::unique_ptr fence_promise) { + task_id runtime::impl::fence(buffer_access access, std::unique_ptr fence_promise) { require_call_from_application_thread(); maybe_prune_task_graph(); return m_task_mngr->generate_fence_task(std::move(access), std::move(fence_promise)); } - task_id runtime_impl::fence(host_object_effect effect, std::unique_ptr fence_promise) { + task_id runtime::impl::fence(host_object_effect effect, std::unique_ptr fence_promise) { require_call_from_application_thread(); maybe_prune_task_graph(); return m_task_mngr->generate_fence_task(effect, std::move(fence_promise)); } - task_id runtime_impl::sync(epoch_action action) { + task_id runtime::impl::sync(epoch_action action) { require_call_from_application_thread(); maybe_prune_task_graph(); @@ -432,7 +501,7 @@ namespace detail { return epoch; } - void runtime_impl::maybe_prune_task_graph() { + void runtime::impl::maybe_prune_task_graph() { require_call_from_application_thread(); const auto current_epoch = m_latest_epoch_reached.load(std::memory_order_relaxed); @@ -478,14 +547,14 @@ namespace detail { // task_manager::delegate - void runtime_impl::task_created(const task* tsk) { + void runtime::impl::task_created(const task* tsk) { assert(m_schdlr != nullptr); m_schdlr->notify_task_created(tsk); } // scheduler::delegate - void runtime_impl::flush(std::vector instructions, std::vector pilots) { + void runtime::impl::flush(std::vector instructions, std::vector pilots) { // thread-safe assert(m_exec != nullptr); m_exec->submit(std::move(instructions), std::move(pilots)); @@ -493,7 +562,7 @@ namespace detail { // executor::delegate - void runtime_impl::horizon_reached(const task_id horizon_tid) { + void runtime::impl::horizon_reached(const task_id horizon_tid) { assert(!m_latest_horizon_reached || *m_latest_horizon_reached < horizon_tid); assert(m_latest_epoch_reached.load(std::memory_order::relaxed) < horizon_tid); // relaxed: written only by this thread @@ -504,7 +573,7 @@ namespace detail { m_latest_horizon_reached = horizon_tid; } - void runtime_impl::epoch_reached(const task_id epoch_tid) { + void runtime::impl::epoch_reached(const task_id epoch_tid) { // m_latest_horizon_reached does not need synchronization (see definition), all other accesses are implicitly synchronized. assert(!m_latest_horizon_reached || *m_latest_horizon_reached < epoch_tid); assert(epoch_tid == 0 || m_latest_epoch_reached.load(std::memory_order_relaxed) < epoch_tid); @@ -514,28 +583,28 @@ namespace detail { m_latest_horizon_reached = std::nullopt; // Any non-applied horizon is now behind the epoch and will therefore never become an epoch itself } - void runtime_impl::create_queue() { + void runtime::impl::create_queue() { require_call_from_application_thread(); ++m_num_live_queues; } - void runtime_impl::destroy_queue() { + void runtime::impl::destroy_queue() { require_call_from_application_thread(); assert(m_num_live_queues > 0); --m_num_live_queues; } - bool runtime_impl::is_dry_run() const { return m_cfg->is_dry_run(); } + bool runtime::impl::is_dry_run() const { return m_cfg->is_dry_run(); } - allocation_id runtime_impl::create_user_allocation(void* const ptr) { + allocation_id runtime::impl::create_user_allocation(void* const ptr) { require_call_from_application_thread(); const auto aid = allocation_id(user_memory_id, m_next_user_allocation_id++); m_exec->track_user_allocation(aid, ptr); return aid; } - buffer_id runtime_impl::create_buffer(const range<3>& range, const size_t elem_size, const size_t elem_align, const allocation_id user_aid) { + buffer_id runtime::impl::create_buffer(const range<3>& range, const size_t elem_size, const size_t elem_align, const allocation_id user_aid) { require_call_from_application_thread(); const auto bid = m_next_buffer_id++; @@ -545,7 +614,7 @@ namespace detail { return bid; } - void runtime_impl::set_buffer_debug_name(const buffer_id bid, const std::string& debug_name) { + void runtime::impl::set_buffer_debug_name(const buffer_id bid, const std::string& debug_name) { require_call_from_application_thread(); assert(utils::contains(m_live_buffers, bid)); @@ -553,7 +622,7 @@ namespace detail { m_schdlr->notify_buffer_debug_name_changed(bid, debug_name); } - void runtime_impl::destroy_buffer(const buffer_id bid) { + void runtime::impl::destroy_buffer(const buffer_id bid) { require_call_from_application_thread(); assert(utils::contains(m_live_buffers, bid)); @@ -562,7 +631,7 @@ namespace detail { m_live_buffers.erase(bid); } - host_object_id runtime_impl::create_host_object(std::unique_ptr instance) { + host_object_id runtime::impl::create_host_object(std::unique_ptr instance) { require_call_from_application_thread(); const auto hoid = m_next_host_object_id++; @@ -574,7 +643,7 @@ namespace detail { return hoid; } - void runtime_impl::destroy_host_object(const host_object_id hoid) { + void runtime::impl::destroy_host_object(const host_object_id hoid) { require_call_from_application_thread(); assert(utils::contains(m_live_host_objects, hoid)); @@ -583,7 +652,7 @@ namespace detail { m_live_host_objects.erase(hoid); } - reduction_id runtime_impl::create_reduction(std::unique_ptr reducer) { + reduction_id runtime::impl::create_reduction(std::unique_ptr reducer) { require_call_from_application_thread(); const auto rid = m_next_reduction_id++; @@ -591,17 +660,97 @@ namespace detail { return rid; } - void runtime_impl::set_scheduler_lookahead(const experimental::lookahead lookahead) { + void runtime::impl::set_scheduler_lookahead(const experimental::lookahead lookahead) { require_call_from_application_thread(); m_schdlr->set_lookahead(lookahead); } - void runtime_impl::flush_scheduler() { + void runtime::impl::flush_scheduler() { require_call_from_application_thread(); m_schdlr->flush_commands(); } - bool runtime_impl::is_unreferenced() const { return m_num_live_queues == 0 && m_live_buffers.empty() && m_live_host_objects.empty(); } + bool runtime::s_mpi_initialized = false; + bool runtime::s_mpi_finalized = false; + + runtime runtime::s_instance; // definition of static member + + void runtime::mpi_initialize_once(int* argc, char*** argv) { +#if CELERITY_ENABLE_MPI + CELERITY_DETAIL_TRACY_ZONE_SCOPED_V("mpi::init", LightSkyBlue, "MPI_Init"); + assert(!s_mpi_initialized); + int provided = -1; + MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &provided); + assert(provided == MPI_THREAD_MULTIPLE); +#endif // CELERITY_ENABLE_MPI + s_mpi_initialized = true; + } + + void runtime::mpi_finalize_once() { +#if CELERITY_ENABLE_MPI + CELERITY_DETAIL_TRACY_ZONE_SCOPED_V("mpi::finalize", LightSkyBlue, "MPI_Finalize"); + assert(s_mpi_initialized && !s_mpi_finalized && (!s_test_mode || !has_instance())); + MPI_Finalize(); +#endif // CELERITY_ENABLE_MPI + s_mpi_finalized = true; + } + + void runtime::init(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector) { + assert(!has_instance()); + s_instance.m_impl = std::make_unique(argc, argv, user_devices_or_selector); + if(!s_test_mode) { atexit(shutdown); } + } + + runtime& runtime::get_instance() { + if(!has_instance()) { throw std::runtime_error("Runtime has not been initialized"); } + return s_instance; + } + + void runtime::shutdown() { s_instance.m_impl.reset(); } + + task_id runtime::submit(raw_command_group&& cg) { return m_impl->submit(std::move(cg)); } + + task_id runtime::fence(buffer_access access, std::unique_ptr fence_promise) { + return m_impl->fence(std::move(access), std::move(fence_promise)); + } + + task_id runtime::fence(host_object_effect effect, std::unique_ptr fence_promise) { return m_impl->fence(effect, std::move(fence_promise)); } + + task_id runtime::sync(detail::epoch_action action) { return m_impl->sync(action); } + + void runtime::create_queue() { m_impl->create_queue(); } + + void runtime::destroy_queue() { m_impl->destroy_queue(); } + + allocation_id runtime::create_user_allocation(void* const ptr) { return m_impl->create_user_allocation(ptr); } + + buffer_id runtime::create_buffer(const range<3>& range, const size_t elem_size, const size_t elem_align, const allocation_id user_aid) { + return m_impl->create_buffer(range, elem_size, elem_align, user_aid); + } + + void runtime::set_buffer_debug_name(const buffer_id bid, const std::string& debug_name) { m_impl->set_buffer_debug_name(bid, debug_name); } + + void runtime::destroy_buffer(const buffer_id bid) { m_impl->destroy_buffer(bid); } + + host_object_id runtime::create_host_object(std::unique_ptr instance) { return m_impl->create_host_object(std::move(instance)); } + + void runtime::destroy_host_object(const host_object_id hoid) { m_impl->destroy_host_object(hoid); } + + reduction_id runtime::create_reduction(std::unique_ptr reducer) { return m_impl->create_reduction(std::move(reducer)); } + + bool runtime::is_dry_run() const { return m_impl->is_dry_run(); } + + void runtime::set_scheduler_lookahead(const experimental::lookahead lookahead) { m_impl->set_scheduler_lookahead(lookahead); } + + void runtime::flush_scheduler() { m_impl->flush_scheduler(); } + + bool runtime::s_test_mode = false; + bool runtime::s_test_active = false; + bool runtime::s_test_runtime_was_instantiated = false; } // namespace detail } // namespace celerity + + +#define CELERITY_DETAIL_TAIL_INCLUDE +#include "testspy/runtime_testspy.inl" diff --git a/src/scheduler.cc b/src/scheduler.cc index c5358ef9c..de8f3d7a4 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -9,6 +9,7 @@ #include "print_utils_internal.h" #include "ranges.h" #include "recorders.h" +#include "testspy/scheduler_testspy.h" #include "tracy.h" #include "types.h" #include "utils.h" @@ -64,13 +65,10 @@ struct event_set_lookahead { experimental::lookahead lookahead; }; struct event_flush_commands {}; -struct test_event_inspect { - scheduler_detail::test_inspector inspect; -}; /// An event passed from task_manager or runtime through the public scheduler interface. using task_event = std::variant; + event_host_object_destroyed, event_epoch_reached, event_set_lookahead, event_flush_commands, scheduler_testspy::event_inspect>; class task_queue { public: @@ -278,8 +276,8 @@ void scheduler_impl::process_task_queue_event(const task_event& evt) { [&](const event_flush_commands& e) { // command_queue.push(e); }, - [&](const test_event_inspect& e) { // - e.inspect({.cdag = &cdag, .idag = &idag, .lookahead = lookahead}); + [&](const scheduler_testspy::event_inspect& e) { // + e.inspector({.cdag = &cdag, .idag = &idag, .lookahead = lookahead}); }); } @@ -365,16 +363,8 @@ void scheduler::set_lookahead(const experimental::lookahead lookahead) { m_impl- void scheduler::flush_commands() { m_impl->task_queue.push(event_flush_commands{}); } -// LCOV_EXCL_START - this is test instrumentation used only in benchmarks and not covered in unit tests - -scheduler::scheduler(const test_threadless_tag /* tag */, const size_t num_nodes, const node_id local_node_id, const system_info& system, - delegate* const delegate, command_recorder* const crec, instruction_recorder* const irec, const policy_set& policy) - : m_impl(std::make_unique(false /* start_thread */, num_nodes, local_node_id, system, delegate, crec, irec, policy)) {} - -void scheduler::test_scheduling_loop() { m_impl->scheduling_loop(); } - -// LCOV_EXCL_STOP +} // namespace celerity::detail -void scheduler::test_inspect(scheduler_detail::test_inspector inspector) { m_impl->task_queue.push(test_event_inspect{std::move(inspector)}); } -} // namespace celerity::detail +#define CELERITY_DETAIL_TAIL_INCLUDE +#include "testspy/scheduler_testspy.inl" diff --git a/src/testspy/runtime_testspy.h b/src/testspy/runtime_testspy.h new file mode 100644 index 000000000..693e28269 --- /dev/null +++ b/src/testspy/runtime_testspy.h @@ -0,0 +1,55 @@ +#pragma once + +#include "runtime.h" +#include "types.h" + +#include +#include + + +namespace celerity::detail { + +class executor; +class scheduler; +class task_graph; +class task_manager; + +struct runtime_testspy { + static node_id get_local_nid(const runtime& rt); + static size_t get_num_nodes(const runtime& rt); + static size_t get_num_local_devices(const runtime& rt); + + static task_graph& get_task_graph(runtime& rt); + static task_manager& get_task_manager(runtime& rt); + static scheduler& get_schdlr(runtime& rt); + static executor& get_exec(runtime& rt); + + static task_id get_latest_epoch_reached(const runtime& rt); + + static std::string print_task_graph(runtime& rt); + static std::string print_command_graph(const node_id local_nid, runtime& rt); + static std::string print_instruction_graph(runtime& rt); + + // We have to jump through some hoops to be able to re-initialize the runtime for unit testing. + // MPI does not like being initialized more than once per process, so we have to skip that part for + // re-initialization. + + /// Switches to test mode, where MPI will be initialized through test_case_enter() instead of runtime::runtime(). Called on Catch2 startup. + static void test_mode_enter(); + + /// Finalizes MPI if it was ever initialized in test mode. Called on Catch2 shutdown. + static void test_mode_exit(); + + /// Initializes MPI for tests, if it was not initialized before + static void test_require_mpi(); + + /// Allows the runtime to be transitively instantiated in tests. Called from runtime_fixture. + static void test_case_enter(); + + static bool test_runtime_was_instantiated(); + + /// Deletes the runtime instance, which happens only in tests. Called from runtime_fixture. + static void test_case_exit(); +}; + +} // namespace celerity::detail diff --git a/src/testspy/runtime_testspy.inl b/src/testspy/runtime_testspy.inl new file mode 100644 index 000000000..720965285 --- /dev/null +++ b/src/testspy/runtime_testspy.inl @@ -0,0 +1,84 @@ +#pragma once + +#include "print_graph.h" +#include "runtime.h" +#include "runtime_testspy.h" +#include "scheduler_testspy.h" +#include "types.h" + +#include + +// This file is tail-included by runtime.cc; also make it parsable as a standalone file for clangd and clang-tidy +#ifndef CELERITY_DETAIL_TAIL_INCLUDE +#include "../runtime.cc" // NOLINT(bugprone-suspicious-include) +#endif + + +namespace celerity::detail { + +node_id runtime_testspy::get_local_nid(const runtime& rt) { return rt.m_impl->m_local_nid; } + +size_t runtime_testspy::get_num_nodes(const runtime& rt) { return rt.m_impl->m_num_nodes; } + +size_t runtime_testspy::get_num_local_devices(const runtime& rt) { return rt.m_impl->m_num_local_devices; } + +task_graph& runtime_testspy::get_task_graph(runtime& rt) { return rt.m_impl->m_tdag; } + +task_manager& runtime_testspy::get_task_manager(runtime& rt) { return *rt.m_impl->m_task_mngr; } + +scheduler& runtime_testspy::get_schdlr(runtime& rt) { return *rt.m_impl->m_schdlr; } + +executor& runtime_testspy::get_exec(runtime& rt) { return *rt.m_impl->m_exec; } + +task_id runtime_testspy::get_latest_epoch_reached(const runtime& rt) { return rt.m_impl->m_latest_epoch_reached.load(std::memory_order_relaxed); } + +std::string runtime_testspy::print_task_graph(runtime& rt) { + return detail::print_task_graph(*rt.m_impl->m_task_recorder); // task recorder is mutated by task manager (application / test thread) +} + +std::string runtime_testspy::print_command_graph(const node_id local_nid, runtime& rt) { + // command_recorder is mutated by scheduler thread + return scheduler_testspy::inspect_thread( + get_schdlr(rt), [&](const auto&) { return detail::print_command_graph(local_nid, *rt.m_impl->m_command_recorder); }); +} + +std::string runtime_testspy::print_instruction_graph(runtime& rt) { + // instruction recorder is mutated by scheduler thread + return scheduler_testspy::inspect_thread(get_schdlr(rt), [&](const auto&) { + return detail::print_instruction_graph(*rt.m_impl->m_instruction_recorder, *rt.m_impl->m_command_recorder, *rt.m_impl->m_task_recorder); + }); +} + +void runtime_testspy::test_mode_enter() { + assert(!runtime::s_mpi_initialized); + runtime::s_test_mode = true; +} + +void runtime_testspy::test_mode_exit() { + assert(runtime::s_test_mode && !runtime::s_test_active && !runtime::s_mpi_finalized); + if(runtime::s_mpi_initialized) { runtime::mpi_finalize_once(); } +} + +void runtime_testspy::test_require_mpi() { + assert(runtime::s_test_mode && !runtime::s_test_active); + if(!runtime::s_mpi_initialized) { runtime::mpi_initialize_once(nullptr, nullptr); } +} + +void runtime_testspy::test_case_enter() { + assert(runtime::s_test_mode && !runtime::s_test_active && runtime::s_mpi_initialized && !runtime::has_instance()); + runtime::s_test_active = true; + runtime::s_test_runtime_was_instantiated = false; +} + +bool runtime_testspy::test_runtime_was_instantiated() { + assert(runtime::s_test_mode); + return runtime::s_test_runtime_was_instantiated; +} + +void runtime_testspy::test_case_exit() { + assert(runtime::s_test_mode && runtime::s_test_active); + runtime::shutdown(); + runtime::s_test_active = false; +} + +} // namespace celerity::detail diff --git a/src/testspy/scheduler_testspy.h b/src/testspy/scheduler_testspy.h new file mode 100644 index 000000000..19d8fbc21 --- /dev/null +++ b/src/testspy/scheduler_testspy.h @@ -0,0 +1,46 @@ +#pragma once + +#include "scheduler.h" + +#include +#include +#include + + +namespace celerity::detail { + +struct scheduler_testspy { + struct scheduler_state { + const command_graph* cdag = nullptr; + const instruction_graph* idag = nullptr; + experimental::lookahead lookahead = experimental::lookahead::automatic; + }; + + struct event_inspect { + /// executed inside scheduler thread, making it safe to access scheduler members + std::function inspector; + }; + + static scheduler make_threadless_scheduler(size_t num_nodes, node_id local_node_id, const system_info& system_info, scheduler::delegate* delegate, + command_recorder* crec, instruction_recorder* irec, const scheduler::policy_set& policy = {}); + + static void run_scheduling_loop(scheduler& schdlr); + + static void begin_inspect_thread(scheduler& schdlr, event_inspect inspector); + + template + static auto inspect_thread(scheduler& schdlr, F&& f) { + using return_t = std::invoke_result_t; + std::promise channel; + begin_inspect_thread(schdlr, event_inspect{[&](const scheduler_state& state) { + if constexpr(std::is_void_v) { + f(state), channel.set_value(); + } else { + channel.set_value(f(state)); + } + }}); + return channel.get_future().get(); + } +}; + +} // namespace celerity::detail diff --git a/src/testspy/scheduler_testspy.inl b/src/testspy/scheduler_testspy.inl new file mode 100644 index 000000000..ee8471cec --- /dev/null +++ b/src/testspy/scheduler_testspy.inl @@ -0,0 +1,27 @@ +#pragma once + +#include "scheduler.h" +#include "scheduler_testspy.h" + +#include + +// This file is tail-included by scheduler.cc; also make it parsable as a standalone file for clangd and clang-tidy +#ifndef CELERITY_DETAIL_TAIL_INCLUDE +#include "../scheduler.cc" // NOLINT(bugprone-suspicious-include) +#endif + +namespace celerity::detail { + +scheduler scheduler_testspy::make_threadless_scheduler(size_t num_nodes, node_id local_node_id, const system_info& system_info, scheduler::delegate* delegate, + command_recorder* crec, instruction_recorder* irec, const scheduler::policy_set& policy) // +{ + scheduler schdlr; // default-constructible by testspy, keeps m_impl == nullptr + schdlr.m_impl = std::make_unique(false /* start_thread */, num_nodes, local_node_id, system_info, delegate, crec, irec, policy); + return schdlr; +} + +void scheduler_testspy::run_scheduling_loop(scheduler& schdlr) { schdlr.m_impl->scheduling_loop(); } + +void scheduler_testspy::begin_inspect_thread(scheduler& schdlr, event_inspect inspect) { schdlr.m_impl->task_queue.push(std::move(inspect)); } + +} // namespace celerity::detail diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2794dfc11..9d788506d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,6 +6,7 @@ include(ParseAndAddCatchTests) # Function for setting all relevant test parameters function(set_test_target_parameters TARGET SOURCE) + target_include_directories(${TARGET} PRIVATE "${CMAKE_SOURCE_DIR}/src") target_link_libraries(${TARGET} PUBLIC Catch2::Catch2) set_property(TARGET ${TARGET} PROPERTY CXX_STANDARD ${CELERITY_CXX_STANDARD}) set_property(TARGET ${TARGET} PROPERTY FOLDER "tests") diff --git a/test/dag_benchmarks.cc b/test/dag_benchmarks.cc index 7e1fd8733..566b27ad9 100644 --- a/test/dag_benchmarks.cc +++ b/test/dag_benchmarks.cc @@ -269,16 +269,17 @@ struct scheduler_benchmark_context : private task_manager::delegate { // NOLINT( task_graph tdag; task_manager tm{num_nodes, tdag, nullptr, this, benchmark_task_manager_policy}; restartable_scheduler_thread* thread; - scheduler_testspy::threadless_scheduler schdlr; + scheduler schdlr; test_utils::mock_buffer_factory mbf; explicit scheduler_benchmark_context(restartable_scheduler_thread& thrd, const size_t num_nodes, const size_t num_devices_per_node) - : num_nodes(num_nodes), thread(&thrd), - schdlr(num_nodes, 0 /* local_nid */, test_utils::make_system_info(num_devices_per_node, true /* supports d2d copies */), nullptr /* delegate */, - nullptr /* crec */, nullptr /* irec */), + : num_nodes(num_nodes), thread(&thrd), // + schdlr(scheduler_testspy::make_threadless_scheduler(num_nodes, 0 /* local_nid */, + test_utils::make_system_info(num_devices_per_node, true /* supports d2d copies */), nullptr /* delegate */, nullptr /* crec */, + nullptr /* irec */)), mbf(tm, schdlr) // { - thread->start([this] { schdlr.scheduling_loop(); }); + thread->start([this] { scheduler_testspy::run_scheduling_loop(schdlr); }); tm.generate_epoch_task(epoch_action::init); } diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index eee398ac7..8bf52515c 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -704,9 +704,12 @@ namespace detail { cgh.host_task(range<1>{num_nodes * 2}, [=](partition<1>) { (void)acc; }); }); + const auto live_command_count = scheduler_testspy::inspect_thread(runtime_testspy::get_schdlr(rt), // + [](const scheduler_testspy::scheduler_state& state) { return graph_testspy::get_live_node_count(*state.cdag); }); + // intial epoch + master-node task + push + host task + 1 horizon // (dry runs currently always simulate node 0, hence the master-node task) - CHECK(scheduler_testspy::get_live_command_count(runtime_testspy::get_schdlr(rt)) == 5); + CHECK(live_command_count == 5); } TEST_CASE_METHOD(test_utils::runtime_fixture, "dry run proceeds on fences", "[dryrun]") { @@ -1050,7 +1053,9 @@ namespace detail { env::scoped_test_environment ste("CELERITY_LOOKAHEAD", str); runtime::init(nullptr, nullptr); auto& schdlr = runtime_testspy::get_schdlr(detail::runtime::get_instance()); - CHECK(scheduler_testspy::get_lookahead(schdlr) == lookahead); + const auto actual_lookahead = + scheduler_testspy::inspect_thread(schdlr, [](const scheduler_testspy::scheduler_state& state) { return state.lookahead; }); + CHECK(actual_lookahead == lookahead); } TEST_CASE_METHOD(test_utils::runtime_fixture, "lookahead ensures that a single allocation is used for a growing access pattern", "[runtime][lookahead]") { diff --git a/test/test_main.cc b/test/test_main.cc index 40a6d68a5..dc2df1490 100644 --- a/test/test_main.cc +++ b/test/test_main.cc @@ -20,8 +20,8 @@ int main(int argc, char* argv[]) { // allow unit tests to catch and recover from panics celerity::detail::utils::set_panic_solution(celerity::detail::utils::panic_solution::throw_logic_error); - celerity::detail::runtime::test_mode_enter(); + celerity::detail::runtime_testspy::test_mode_enter(); return_code = session.run(); - celerity::detail::runtime::test_mode_exit(); + celerity::detail::runtime_testspy::test_mode_exit(); return return_code; } diff --git a/test/test_utils.cc b/test/test_utils.cc index f2de5696f..37057ffaf 100644 --- a/test/test_utils.cc +++ b/test/test_utils.cc @@ -279,15 +279,15 @@ detail::system_info make_system_info(const size_t num_devices, const bool suppor } runtime_fixture::runtime_fixture() { - detail::runtime::test_case_enter(); + detail::runtime_testspy::test_case_enter(); allow_higher_level_log_messages(spdlog::level::warn, test_utils_detail::expected_runtime_init_warnings_regex); allow_higher_level_log_messages(spdlog::level::warn, test_utils_detail::expected_device_enumeration_warnings_regex); allow_higher_level_log_messages(spdlog::level::warn, test_utils_detail::expected_backend_fallback_warnings_regex); } runtime_fixture::~runtime_fixture() { - if(!detail::runtime::test_runtime_was_instantiated()) { WARN("Test specified a runtime_fixture, but did not end up instantiating the runtime"); } - detail::runtime::test_case_exit(); + if(!detail::runtime_testspy::test_runtime_was_instantiated()) { WARN("Test specified a runtime_fixture, but did not end up instantiating the runtime"); } + detail::runtime_testspy::test_case_exit(); } void allow_backend_fallback_warnings() { allow_higher_level_log_messages(spdlog::level::warn, test_utils_detail::expected_backend_fallback_warnings_regex); } diff --git a/test/test_utils.h b/test/test_utils.h index c5b69a59c..202772e4e 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -1,9 +1,8 @@ #pragma once -#include #include #include -#include +#include #include #ifdef _WIN32 @@ -17,6 +16,7 @@ #include #include +#include "affinity.h" #include "async_event.h" #include "backend/sycl_backend.h" #include "command_graph.h" @@ -27,11 +27,11 @@ #include "print_utils_internal.h" #include "range_mapper.h" #include "region_map.h" -#include "runtime.h" -#include "runtime_impl.h" #include "scheduler.h" #include "system_info.h" #include "task_manager.h" +#include "testspy/runtime_testspy.h" +#include "testspy/scheduler_testspy.h" #include "types.h" // To avoid having to come up with tons of unique kernel names, we simply use the CPP counter. @@ -87,75 +87,6 @@ namespace detail { } }; - struct scheduler_testspy { - using test_state = scheduler_detail::test_state; - - class threadless_scheduler : public scheduler { - public: - threadless_scheduler(const auto&... params) : scheduler(test_threadless_tag(), params...) {} - void scheduling_loop() { test_scheduling_loop(); } - }; - - template - static auto inspect_thread(scheduler& schdlr, F&& f) { - using return_t = std::invoke_result_t; - std::promise channel; - schdlr.test_inspect([&](const scheduler_detail::test_state& state) { - if constexpr(std::is_void_v) { - f(state), channel.set_value(); - } else { - channel.set_value(f(state)); - } - }); - return channel.get_future().get(); - } - - static size_t get_live_command_count(scheduler& schdlr) { - return inspect_thread(schdlr, [](const test_state& state) { return graph_testspy::get_live_node_count(*state.cdag); }); - } - - static size_t get_live_instruction_count(scheduler& schdlr) { - return inspect_thread(schdlr, [](const test_state& state) { return graph_testspy::get_live_node_count(*state.idag); }); - } - - static experimental::lookahead get_lookahead(scheduler& schdlr) { - return inspect_thread(schdlr, [](const test_state& state) { return state.lookahead; }); - } - }; - - struct runtime_testspy { - static const runtime_impl& impl(const runtime& rt) { return dynamic_cast(rt); } - static runtime_impl& impl(runtime& rt) { return dynamic_cast(rt); } - - static node_id get_local_nid(const runtime& rt) { return impl(rt).m_local_nid; } - static size_t get_num_nodes(const runtime& rt) { return impl(rt).m_num_nodes; } - static size_t get_num_local_devices(const runtime& rt) { return impl(rt).m_num_local_devices; } - - static task_graph& get_task_graph(runtime& rt) { return impl(rt).m_tdag; } - static task_manager& get_task_manager(runtime& rt) { return *impl(rt).m_task_mngr; } - static scheduler& get_schdlr(runtime& rt) { return *impl(rt).m_schdlr; } - static executor& get_exec(runtime& rt) { return *impl(rt).m_exec; } - - static task_id get_latest_epoch_reached(const runtime& rt) { return impl(rt).m_latest_epoch_reached.load(std::memory_order_relaxed); } - - static std::string print_task_graph(runtime& rt) { - return detail::print_task_graph(*impl(rt).m_task_recorder); // task recorder is mutated by task manager (application / test thread) - } - - static std::string print_command_graph(const node_id local_nid, runtime& rt) { - // command_recorder is mutated by scheduler thread - return scheduler_testspy::inspect_thread( - get_schdlr(rt), [&](const auto&) { return detail::print_command_graph(local_nid, *impl(rt).m_command_recorder); }); - } - - static std::string print_instruction_graph(runtime& rt) { - // instruction recorder is mutated by scheduler thread - return scheduler_testspy::inspect_thread(get_schdlr(rt), [&](const auto&) { - return detail::print_instruction_graph(*impl(rt).m_instruction_recorder, *impl(rt).m_command_recorder, *impl(rt).m_task_recorder); - }); - } - }; - struct task_manager_testspy { inline static constexpr task_id initial_epoch_task = task_manager::initial_epoch_task; @@ -189,7 +120,7 @@ namespace test_utils { .use_backend_device_submission_threads = false, }; m_thread_pinner.emplace(cfg); - name_and_pin_and_order_this_thread(detail::named_threads::thread_type::application); + detail::thread_pinning::pin_this_thread(detail::named_threads::thread_type::application); } std::optional m_thread_pinner; @@ -455,7 +386,7 @@ namespace test_utils { // This fixture (or a subclass) must be used by all tests that transitively use MPI. class mpi_fixture { public: - mpi_fixture() { detail::runtime::test_require_mpi(); } + mpi_fixture() { detail::runtime_testspy::test_require_mpi(); } mpi_fixture(const mpi_fixture&) = delete; mpi_fixture(mpi_fixture&&) = delete; mpi_fixture& operator=(const mpi_fixture&) = delete;