From 9ac90c4cecc57191806ebcec2645e235cc431a98 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Mon, 27 May 2024 03:13:23 +0000 Subject: [PATCH] Implement MPI partitioning result reduction --- necsim/partitioning/mpi/src/lib.rs | 64 ++++++++++++- .../partitioning/mpi/src/partition/utils.rs | 90 ------------------- necsim/partitioning/threads/src/lib.rs | 2 +- 3 files changed, 61 insertions(+), 95 deletions(-) diff --git a/necsim/partitioning/mpi/src/lib.rs b/necsim/partitioning/mpi/src/lib.rs index f1e00c1b3..dba2c1980 100644 --- a/necsim/partitioning/mpi/src/lib.rs +++ b/necsim/partitioning/mpi/src/lib.rs @@ -8,9 +8,11 @@ use std::{fmt, mem::ManuallyDrop, num::NonZeroU32, time::Duration}; use anyhow::Context; use humantime_serde::re::humantime::format_duration; use mpi::{ + datatype::PartitionMut, environment::Universe, topology::{Communicator, Rank, SimpleCommunicator}, - Tag, + traits::CommunicatorCollectives, + Count, Tag, }; use serde::{ser::SerializeStruct, Deserialize, Serialize, Serializer}; use serde_derive_state::DeserializeState; @@ -182,8 +184,7 @@ impl Partitioning for MpiPartitioning { event_log: Self::Auxiliary, args: A, inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q, - // TODO: use fold to return the same result in all partitions, then deprecate - _fold: fn(Q, Q) -> Q, + fold: fn(Q, Q) -> Q, ) -> anyhow::Result { let Some(event_log) = event_log else { anyhow::bail!(MpiLocalPartitionError::MissingEventLog) @@ -239,7 +240,9 @@ impl Partitioning for MpiPartitioning { ))) }; - Ok(inner(local_partition, args)) + let local_result = inner(local_partition, args); + + reduce_partitioning_data(&self.world, local_result, fold) }) } } @@ -285,3 +288,56 @@ fn deserialize_state_mpi_world<'de, D: Deserializer<'de>>( ))), } } + +fn reduce_partitioning_data( + world: &SimpleCommunicator, + data: T, + fold: fn(T, T) -> T, +) -> anyhow::Result { + let local_ser = postcard::to_stdvec(&data).context("MPI data failed to serialize")?; + std::mem::drop(data); + + #[allow(clippy::cast_sign_loss)] + let mut counts = vec![0 as Count; world.size() as usize]; + world.all_gather_into(&(Count::try_from(local_ser.len()).unwrap()), &mut counts); + + let offsets = counts + .iter() + .scan(0 as Count, |acc, &x| { + let tmp = *acc; + *acc = (*acc).checked_add(x).unwrap(); + Some(tmp) + }) + .collect::>(); + + #[allow(clippy::cast_sign_loss)] + let mut all_sers = vec![0_u8; counts.iter().copied().sum::() as usize]; + world.all_gather_varcount_into( + local_ser.as_slice(), + &mut PartitionMut::new(all_sers.as_mut_slice(), counts.as_slice(), offsets), + ); + + let folded: Option = counts + .iter() + .scan(0_usize, |acc, &x| { + let pre = *acc; + #[allow(clippy::cast_sign_loss)] + { + *acc += x as usize; + } + let post = *acc; + + let de: anyhow::Result = postcard::from_bytes(&all_sers[pre..post]) + .context("MPI data failed to deserialize"); + + Some(de) + }) + .try_fold(None, |acc, x| match (acc, x) { + (_, Err(err)) => Err(err), + (Some(acc), Ok(x)) => Ok(Some(fold(acc, x))), + (None, Ok(x)) => Ok(Some(x)), + })?; + let folded = folded.expect("at least one MPI partitioning result"); + + Ok(folded) +} diff --git a/necsim/partitioning/mpi/src/partition/utils.rs b/necsim/partitioning/mpi/src/partition/utils.rs index 762e77e38..ef1400166 100644 --- a/necsim/partitioning/mpi/src/partition/utils.rs +++ b/necsim/partitioning/mpi/src/partition/utils.rs @@ -1,5 +1,4 @@ use std::{ - marker::PhantomData, mem::{offset_of, MaybeUninit}, os::raw::{c_int, c_void}, }; @@ -152,92 +151,3 @@ unsafe impl Equivalence for MpiMigratingLineage { ) } } - -pub fn reduce_partitioning_data< - T: serde::Serialize + serde::de::DeserializeOwned, - F: 'static + Copy + Fn(T, T) -> T, ->( - world: &SimpleCommunicator, - data: T, - fold: F, -) -> T { - let local_ser = postcard::to_stdvec(&data).expect("MPI data failed to serialize"); - let mut global_ser = Vec::with_capacity(local_ser.len()); - - let operation = - unsafe { UnsafeUserOperation::commutative(unsafe_reduce_partitioning_data_op::) }; - - world.all_reduce_into(local_ser.as_slice(), &mut global_ser, &operation); - - postcard::from_bytes(&global_ser).expect("MPI data failed to deserialize") -} - -#[cfg(not(all(msmpi, target_arch = "x86")))] -unsafe extern "C" fn unsafe_reduce_partitioning_data_op< - T: serde::Serialize + serde::de::DeserializeOwned, - F: 'static + Copy + Fn(T, T) -> T, ->( - invec: *mut c_void, - inoutvec: *mut c_void, - len: *mut c_int, - datatype: *mut MPI_Datatype, -) { - unsafe_reduce_partitioning_data_op_inner::(invec, inoutvec, len, datatype); -} - -#[cfg(all(msmpi, target_arch = "x86"))] -unsafe extern "stdcall" fn unsafe_reduce_partitioning_data_op< - T: serde::Serialize + serde::de::DeserializeOwned, - F: 'static + Copy + Fn(T, T) -> T, ->( - invec: *mut c_void, - inoutvec: *mut c_void, - len: *mut c_int, - datatype: *mut MPI_Datatype, -) { - unsafe_reduce_partitioning_data_op_inner::(invec, inoutvec, len, datatype); -} - -#[inline] -unsafe fn unsafe_reduce_partitioning_data_op_inner< - T: serde::Serialize + serde::de::DeserializeOwned, - F: 'static + Copy + Fn(T, T) -> T, ->( - invec: *mut c_void, - inoutvec: *mut c_void, - len: *mut c_int, - datatype: *mut MPI_Datatype, -) { - debug_assert!(*len == 1); - debug_assert!(*datatype == mpi::raw::AsRaw::as_raw(&TimeRank::equivalent_datatype())); - - reduce_partitioning_data_op_inner::(&*invec.cast(), &mut *inoutvec.cast()); -} - -#[inline] -fn reduce_partitioning_data_op_inner< - T: serde::Serialize + serde::de::DeserializeOwned, - F: 'static + Copy + Fn(T, T) -> T, ->( - local_ser: &[u8], - global_ser: &mut Vec, -) { - union Magic T> { - func: F, - unit: (), - marker: PhantomData, - } - - let local_de: T = postcard::from_bytes(local_ser).expect("MPI data failed to deserialize"); - let global_de: T = postcard::from_bytes(global_ser).expect("MPI data failed to deserialize"); - - const { assert!(std::mem::size_of::() == 0) }; - const { assert!(std::mem::align_of::() == 1) }; - let func: F = unsafe { Magic { unit: () }.func }; - - let folded = func(local_de, global_de); - - global_ser.clear(); - - postcard::to_io(&folded, global_ser).expect("MPI data failed to serialize"); -} diff --git a/necsim/partitioning/threads/src/lib.rs b/necsim/partitioning/threads/src/lib.rs index c8b67037e..435655576 100644 --- a/necsim/partitioning/threads/src/lib.rs +++ b/necsim/partitioning/threads/src/lib.rs @@ -260,7 +260,7 @@ impl Partitioning for ThreadsPartitioning { None => result, }); } - folded_result.expect("at least one thread partitioning result") + folded_result.expect("at least one threads partitioning result") }); Ok(result)