diff --git a/src/madness/world/worldgop.cc b/src/madness/world/worldgop.cc index 019c0aade12..c5c1f6b0c99 100644 --- a/src/madness/world/worldgop.cc +++ b/src/madness/world/worldgop.cc @@ -47,12 +47,13 @@ namespace madness { /// constant over two traversals. We are then we are sure /// that all tasks and AM are processed and there no AM in /// flight. + /// \post `this->world_.taskq.size()==0` void WorldGopInterface::fence_impl(std::function epilogue, bool pause_during_epilogue, bool debug) { PROFILE_MEMBER_FUNC(WorldGopInterface); MADNESS_CHECK(not forbid_fence_); - unsigned long nsent_prev=0, nrecv_prev=1; // invalid initial condition + unsigned long nsent_prev=0, nrecv_prev=1; SafeMPI::Request req0, req1; ProcessID parent, child0, child1; world_.mpi.binary_tree_info(0, parent, child0, child1); @@ -60,96 +61,138 @@ namespace madness { Tag bcast_tag = world_.mpi.unique_tag(); int npass = 0; - //double start = wall_time(); - - if (debug) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: entering fence loop, gfence_tag=", gfence_tag, " bcast_tag=", bcast_tag); - - while (1) { - uint64_t sum0[2]={0,0}, sum1[2]={0,0}, sum[2]; - if (child0 != -1) req0 = world_.mpi.Irecv((void*) &sum0, sizeof(sum0), MPI_BYTE, child0, gfence_tag); - if (child1 != -1) req1 = world_.mpi.Irecv((void*) &sum1, sizeof(sum1), MPI_BYTE, child1, gfence_tag); + // fence ensures that all ranks agree that all sent AMs (nsent) have been + // processed (nrecv) and that no tasks (ntask) are running. We ensure this by + // observing the global sums of these local observables, and then ensuring + // that the termination conditions have been met twice and + // over two rounds of observations no messages have been met. + // N.B. Epilogue and deferred cleanup can also generate messages, so + // need to do another round of global synchronization after these + // actions ... hence the lambda + + auto termdet = [&]() { + if (debug) + madness::print( + world_.rank(), + ": WORLD.GOP.FENCE: entering termdet, gfence_tag=", + gfence_tag, " bcast_tag=", bcast_tag); + + while (1) { + uint64_t sum0[2] = {0, 0}, sum1[2] = {0, 0}, sum[2]; + if (child0 != -1) + req0 = world_.mpi.Irecv((void *)&sum0, sizeof(sum0), MPI_BYTE, + child0, gfence_tag); + if (child1 != -1) + req1 = world_.mpi.Irecv((void *)&sum1, sizeof(sum1), MPI_BYTE, + child1, gfence_tag); world_.taskq.fence(); - if (child0 != -1) World::await(req0); - if (child1 != -1) World::await(req1); + if (child0 != -1) + World::await(req0); + if (child1 != -1) + World::await(req1); if (debug && (child0 != -1 || child1 != -1)) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: npass=", npass, " received messages from children={", child0, ",", child1, "} gfence_tag=", gfence_tag); + madness::print(world_.rank(), + ": WORLD.GOP.FENCE: npass=", npass, + " received messages from children={", child0, + ",", child1, "} gfence_tag=", gfence_tag); bool finished; uint64_t ntask1, nsent1, nrecv1, ntask2, nsent2, nrecv2; do { - world_.taskq.fence(); + world_.taskq.fence(); - // Since the number of outstanding tasks and number of AM sent/recv - // don't share a critical section read each twice and ensure they - // are unchanged to ensure that are consistent ... they don't have - // to be current. + // Since the number of outstanding tasks and number of AM sent/recv + // don't share a critical section there is no good way to obtain + // their "current" values (i.e. their values at the same clock), + // so read each twice and ensure they are unchanged to ensure + // that are consistent ... - ntask1 = world_.taskq.size(); - nsent1 = world_.am.nsent; - nrecv1 = world_.am.nrecv; + nsent1 = world_.am.nsent; // # of sent AM + nrecv1 = world_.am.nrecv; // # of processed incoming AM + ntask1 = world_.taskq.size(); // current # of tasks; N.B. this was zero after the fence above but may be non-zero now + // processing each incoming AMs may bump this up, so read it AFTER nrecv (albeit task completion will drop this again) - __asm__ __volatile__ (" " : : : "memory"); + __asm__ __volatile__(" " : : : "memory"); - ntask2 = world_.taskq.size(); - nsent2 = world_.am.nsent; - nrecv2 = world_.am.nrecv; + nsent2 = world_.am.nsent; + nrecv2 = world_.am.nrecv; + ntask2 = world_.taskq.size(); - __asm__ __volatile__ (" " : : : "memory"); + __asm__ __volatile__(" " : : : "memory"); - finished = (ntask2==0) && (ntask1==0) && (nsent1==nsent2) && (nrecv1==nrecv2); - } - while (!finished); + finished = (ntask2 == 0) && (ntask1 == 0) && + (nsent1 == nsent2) && (nrecv1 == nrecv2); + } while (!finished); - sum[0] = sum0[0] + sum1[0] + nsent2; // Must use values read above + sum[0] = + sum0[0] + sum1[0] + nsent2; sum[1] = sum0[1] + sum1[1] + nrecv2; if (parent != -1) { - req0 = world_.mpi.Isend(&sum, sizeof(sum), MPI_BYTE, parent, gfence_tag); - if (debug) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: npass=", npass, " sent message to parent=", parent, " gfence_tag=", gfence_tag); - World::await(req0); - if (debug) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: npass=", npass, " parent=", parent, ", confirmed receipt"); + req0 = world_.mpi.Isend(&sum, sizeof(sum), MPI_BYTE, parent, + gfence_tag); + if (debug) + madness::print(world_.rank(), + ": WORLD.GOP.FENCE: npass=", npass, + " sent message to parent=", parent, + " gfence_tag=", gfence_tag); + World::await(req0); + if (debug) + madness::print(world_.rank(), + ": WORLD.GOP.FENCE: npass=", npass, + " parent=", parent, ", confirmed receipt"); } - // While we are probably idle free unused communication buffers - //world_.am.free_managed_buffers(); - - //bool dowork = (npass==0) || (ThreadPool::size()==0); + // bool dowork = (npass==0) || (ThreadPool::size()==0); bool dowork = true; broadcast(&sum, sizeof(sum), 0, dowork, bcast_tag); ++npass; if (debug) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: npass=", npass, " sum0=", sum[0], " nsent_prev=", nsent_prev, " sum1=", sum[1], " nrecv_prev=", nrecv_prev); + madness::print(world_.rank(), + ": WORLD.GOP.FENCE: npass=", npass, + " sum0=", sum[0], " nsent_prev=", nsent_prev, + " sum1=", sum[1], " nrecv_prev=", nrecv_prev); - if (sum[0]==sum[1] && sum[0]==nsent_prev && sum[1]==nrecv_prev) { + if (sum[0] == sum[1] && sum[0] == nsent_prev && + sum[1] == nrecv_prev) { if (debug) - madness::print(world_.rank(), ": WORLD.GOP.FENCE: npass=", npass, " exiting fence loop"); + madness::print(world_.rank(), + ": WORLD.GOP.FENCE: npass=", npass, + " exiting fence loop"); break; } -// if (wall_time() - start > 1200.0) { -// std::cout << rank() << " FENCE " << nsent2 << " " -// << nsent_prev << " " << nrecv2 << " " << nrecv_prev -// << " " << sum[0] << " " << sum[1] << " " << npass -// << " " << taskq.size() << std::endl; -// std::cout.flush(); -// //myusleep(1000); -// MADNESS_ASSERT(0); -// } + // if (wall_time() - start > 1200.0) { + // std::cout << rank() << " FENCE " << nsent2 << " " + // << nsent_prev << " " << nrecv2 << " " << nrecv_prev + // << " " << sum[0] << " " << sum[1] << " " << npass + // << " " << taskq.size() << std::endl; + // std::cout.flush(); + // //myusleep(1000); + // MADNESS_ASSERT(0); + // } nsent_prev = sum[0]; nrecv_prev = sum[1]; + }; + }; // termdet + + termdet(); - }; // execute post-fence actions MADNESS_ASSERT(pause_during_epilogue == false); epilogue(); world_.am.free_managed_buffers(); // free up communication buffers deferred_->do_cleanup(); + + // repeat termination detection in case epilogue or cleanup produced tasks + termdet(); + + // ensure postcondition + world_.taskq.fence(); + #ifdef MADNESS_HAS_GOOGLE_PERF_TCMALLOC MallocExtension::instance()->ReleaseFreeMemory(); // print("clearing memory"); diff --git a/src/madness/world/worldgop.h b/src/madness/world/worldgop.h index e61a24fc880..2d7a16f21f8 100644 --- a/src/madness/world/worldgop.h +++ b/src/madness/world/worldgop.h @@ -706,14 +706,14 @@ namespace madness { /// Synchronizes all processes in communicator AND globally ensures no pending AM or tasks - /// \internal Runs Dykstra-like termination algorithm on binary tree by - /// locally ensuring ntask=0 and all am sent and processed, - /// and then participating in a global sum of nsent and nrecv. - /// Then globally checks that nsent=nrecv and that both are - /// constant over two traversals. We are then sure + /// \internal Runs Dykstra-like termination algorithm on binary tree + /// which stops when global sum of # of tasks in queue (`ntask`) is + /// zero and global sum of the # of sent/received AMs (`nsent`/`nrecv`) + /// are equal and unchanged over two traversals. We are then sure /// that all tasks and AM are processed and there no AM in /// flight. /// \param[in] debug set to true to print progress statistics using madness::print(); the default is false. + /// \post `this->gop.taskq.size()==0` void fence(bool debug = false); /// Executes an action on single (this) thread after ensuring all other work is done