From 1455aae865ebbaa802ed0890b9141f9bd34cd943 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 13 Dec 2024 17:10:26 -0800 Subject: [PATCH] Clear the FieldCache and any comms buffers when changing communicators; this fixes UB that caused non-reproducible hangs when testing split grid. Also adds comm_barrier_global(), a global barrier regardless of the present communicator scope --- include/comm_quda.h | 11 +++++++++++ lib/communicator_stack.cpp | 9 +++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/include/comm_quda.h b/include/comm_quda.h index 8f37bcd032..8496efab5a 100644 --- a/include/comm_quda.h +++ b/include/comm_quda.h @@ -432,7 +432,18 @@ namespace quda */ void comm_broadcast(void *data, size_t nbytes, int root = 0); + /** + @brief Multi-process barrier that applies to the present + communicator + */ void comm_barrier(void); + + /** + @brief Multi-process barrier that is global regardless of the + present communicator + */ + void comm_barrier_global(void); + void comm_abort(int status); void comm_abort_(int status); diff --git a/lib/communicator_stack.cpp b/lib/communicator_stack.cpp index adfd332f2b..ae36f5b2d0 100644 --- a/lib/communicator_stack.cpp +++ b/lib/communicator_stack.cpp @@ -60,14 +60,17 @@ namespace quda // used to store the size of the tunecache at the point of splitting static size_t tune_cache_size = 0; + // destroy any message handles associate with the prior communicator + LatticeField::freeGhostBuffer(); + ColorSpinorField::freeGhostBuffer(); + FieldTmp::destroy(); + auto search = communicator_stack.find(split_key); if (search == communicator_stack.end()) { communicator_stack.emplace(std::piecewise_construct, std::forward_as_tuple(split_key), std::forward_as_tuple(get_default_communicator(), split_key.data())); } - LatticeField::freeGhostBuffer(); // Destroy the (IPC) Comm buffers with the old communicator. - auto split_key_old = current_key; current_key = split_key; @@ -362,6 +365,8 @@ namespace quda void comm_barrier(void) { get_current_communicator().comm_barrier(); } + void comm_barrier_global(void) { get_default_communicator().comm_barrier(); } + void comm_abort_(int status) { Communicator::comm_abort_(status); }; int commDim(int dim) { return get_current_communicator().commDim(dim); }