Skip to content

Commit

Permalink
Early progress on MPI result folding
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed May 26, 2024
1 parent 84d133c commit c593972
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions necsim/partitioning/mpi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde = "1.0"
serde_state = "0.4"
serde_derive_state = "0.4"
humantime-serde = "1.1"
postcard = { version = "1.0", default-features = false, features = ["use-std"] }

[build-dependencies]
build-probe-mpi = "0.1"
90 changes: 90 additions & 0 deletions necsim/partitioning/mpi/src/partition/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
marker::PhantomData,
mem::{offset_of, MaybeUninit},
os::raw::{c_int, c_void},
};
Expand Down Expand Up @@ -151,3 +152,92 @@ unsafe impl Equivalence for MpiMigratingLineage {
)
}
}

pub fn reduce_partitioning_data<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
world: &SimpleCommunicator,
data: T,
fold: F,
) -> T {
let local_ser = postcard::to_stdvec(&data).expect("MPI data failed to serialize");
let mut global_ser = Vec::with_capacity(local_ser.len());

let operation =
unsafe { UnsafeUserOperation::commutative(unsafe_reduce_partitioning_data_op::<T, F>) };

world.all_reduce_into(local_ser.as_slice(), &mut global_ser, &operation);

postcard::from_bytes(&global_ser).expect("MPI data failed to deserialize")
}

#[cfg(not(all(msmpi, target_arch = "x86")))]
unsafe extern "C" fn unsafe_reduce_partitioning_data_op<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
unsafe_reduce_partitioning_data_op_inner::<T, F>(invec, inoutvec, len, datatype);
}

#[cfg(all(msmpi, target_arch = "x86"))]
unsafe extern "stdcall" fn unsafe_reduce_partitioning_data_op<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
unsafe_reduce_partitioning_data_op_inner::<T, F>(invec, inoutvec, len, datatype);
}

#[inline]
unsafe fn unsafe_reduce_partitioning_data_op_inner<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
datatype: *mut MPI_Datatype,
) {
debug_assert!(*len == 1);
debug_assert!(*datatype == mpi::raw::AsRaw::as_raw(&TimeRank::equivalent_datatype()));

reduce_partitioning_data_op_inner::<T, F>(&*invec.cast(), &mut *inoutvec.cast());
}

#[inline]
fn reduce_partitioning_data_op_inner<
T: serde::Serialize + serde::de::DeserializeOwned,
F: 'static + Copy + Fn(T, T) -> T,
>(
local_ser: &[u8],
global_ser: &mut Vec<u8>,
) {
union Magic<T, F: 'static + Copy + Fn(T, T) -> T> {
func: F,
unit: (),
marker: PhantomData<T>,
}

let local_de: T = postcard::from_bytes(local_ser).expect("MPI data failed to deserialize");
let global_de: T = postcard::from_bytes(global_ser).expect("MPI data failed to deserialize");

const { assert!(std::mem::size_of::<F>() == 0) };
const { assert!(std::mem::align_of::<F>() == 1) };
let func: F = unsafe { Magic { unit: () }.func };

let folded = func(local_de, global_de);

global_ser.clear();

postcard::to_io(&folded, global_ser).expect("MPI data failed to serialize");
}

0 comments on commit c593972

Please sign in to comment.