Skip to content

Commit

Permalink
Implement MPI partitioning result reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed May 27, 2024
1 parent c593972 commit 9ac90c4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 95 deletions.
64 changes: 60 additions & 4 deletions necsim/partitioning/mpi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ use std::{fmt, mem::ManuallyDrop, num::NonZeroU32, time::Duration};
use anyhow::Context;
use humantime_serde::re::humantime::format_duration;
use mpi::{
datatype::PartitionMut,
environment::Universe,
topology::{Communicator, Rank, SimpleCommunicator},
Tag,
traits::CommunicatorCollectives,
Count, Tag,
};
use serde::{ser::SerializeStruct, Deserialize, Serialize, Serializer};
use serde_derive_state::DeserializeState;
Expand Down Expand Up @@ -182,8 +184,7 @@ impl Partitioning for MpiPartitioning {
event_log: Self::Auxiliary,
args: A,
inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q,
// TODO: use fold to return the same result in all partitions, then deprecate
_fold: fn(Q, Q) -> Q,
fold: fn(Q, Q) -> Q,
) -> anyhow::Result<Q> {
let Some(event_log) = event_log else {
anyhow::bail!(MpiLocalPartitionError::MissingEventLog)
Expand Down Expand Up @@ -239,7 +240,9 @@ impl Partitioning for MpiPartitioning {
)))
};

Ok(inner(local_partition, args))
let local_result = inner(local_partition, args);

reduce_partitioning_data(&self.world, local_result, fold)
})
}
}
Expand Down Expand Up @@ -285,3 +288,56 @@ fn deserialize_state_mpi_world<'de, D: Deserializer<'de>>(
))),
}
}

fn reduce_partitioning_data<T: serde::Serialize + serde::de::DeserializeOwned>(
world: &SimpleCommunicator,
data: T,
fold: fn(T, T) -> T,
) -> anyhow::Result<T> {
let local_ser = postcard::to_stdvec(&data).context("MPI data failed to serialize")?;
std::mem::drop(data);

#[allow(clippy::cast_sign_loss)]
let mut counts = vec![0 as Count; world.size() as usize];
world.all_gather_into(&(Count::try_from(local_ser.len()).unwrap()), &mut counts);

let offsets = counts
.iter()
.scan(0 as Count, |acc, &x| {
let tmp = *acc;
*acc = (*acc).checked_add(x).unwrap();
Some(tmp)
})
.collect::<Vec<_>>();

#[allow(clippy::cast_sign_loss)]
let mut all_sers = vec![0_u8; counts.iter().copied().sum::<Count>() as usize];
world.all_gather_varcount_into(
local_ser.as_slice(),
&mut PartitionMut::new(all_sers.as_mut_slice(), counts.as_slice(), offsets),
);

let folded: Option<T> = counts
.iter()
.scan(0_usize, |acc, &x| {
let pre = *acc;
#[allow(clippy::cast_sign_loss)]
{
*acc += x as usize;
}
let post = *acc;

let de: anyhow::Result<T> = postcard::from_bytes(&all_sers[pre..post])
.context("MPI data failed to deserialize");

Some(de)
})
.try_fold(None, |acc, x| match (acc, x) {
(_, Err(err)) => Err(err),
(Some(acc), Ok(x)) => Ok(Some(fold(acc, x))),
(None, Ok(x)) => Ok(Some(x)),
})?;
let folded = folded.expect("at least one MPI partitioning result");

Ok(folded)
}
90 changes: 0 additions & 90 deletions necsim/partitioning/mpi/src/partition/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
marker::PhantomData,
mem::{offset_of, MaybeUninit},
os::raw::{c_int, c_void},
};
Expand Down Expand Up @@ -152,92 +151,3 @@ unsafe impl Equivalence for MpiMigratingLineage {
)
}
}

pub fn reduce_partitioning_data<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
world: &SimpleCommunicator,
data: T,
fold: F,
) -> T {
let local_ser = postcard::to_stdvec(&data).expect("MPI data failed to serialize");
let mut global_ser = Vec::with_capacity(local_ser.len());

let operation =
unsafe { UnsafeUserOperation::commutative(unsafe_reduce_partitioning_data_op::<T, F>) };

world.all_reduce_into(local_ser.as_slice(), &mut global_ser, &operation);

postcard::from_bytes(&global_ser).expect("MPI data failed to deserialize")
}

#[cfg(not(all(msmpi, target_arch = "x86")))]
unsafe extern "C" fn unsafe_reduce_partitioning_data_op<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
unsafe_reduce_partitioning_data_op_inner::<T, F>(invec, inoutvec, len, datatype);
}

#[cfg(all(msmpi, target_arch = "x86"))]
unsafe extern "stdcall" fn unsafe_reduce_partitioning_data_op<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
unsafe_reduce_partitioning_data_op_inner::<T, F>(invec, inoutvec, len, datatype);
}

#[inline]
unsafe fn unsafe_reduce_partitioning_data_op_inner<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
debug_assert!(*len == 1);
debug_assert!(*datatype == mpi::raw::AsRaw::as_raw(&TimeRank::equivalent_datatype()));

reduce_partitioning_data_op_inner::<T, F>(&*invec.cast(), &mut *inoutvec.cast());
}

#[inline]
fn reduce_partitioning_data_op_inner<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
local_ser: &[u8],
global_ser: &mut Vec<u8>,
) {
union Magic<T, F: 'static + Copy + Fn(T, T) -> T> {
func: F,
unit: (),
marker: PhantomData<T>,
}

let local_de: T = postcard::from_bytes(local_ser).expect("MPI data failed to deserialize");
let global_de: T = postcard::from_bytes(global_ser).expect("MPI data failed to deserialize");

const { assert!(std::mem::size_of::<F>() == 0) };
const { assert!(std::mem::align_of::<F>() == 1) };
let func: F = unsafe { Magic { unit: () }.func };

let folded = func(local_de, global_de);

global_ser.clear();

postcard::to_io(&folded, global_ser).expect("MPI data failed to serialize");
}
2 changes: 1 addition & 1 deletion necsim/partitioning/threads/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl Partitioning for ThreadsPartitioning {
None => result,
});
}
folded_result.expect("at least one thread partitioning result")
folded_result.expect("at least one threads partitioning result")
});

Ok(result)
Expand Down

0 comments on commit 9ac90c4

Please sign in to comment.