From c59397292490defabd1c305cb9ae27cb316f9a94 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sun, 26 May 2024 19:53:33 +0000 Subject: [PATCH] Early progress on MPI result folding --- Cargo.lock | 24 +++++ necsim/partitioning/mpi/Cargo.toml | 1 + .../partitioning/mpi/src/partition/utils.rs | 90 +++++++++++++++++++ 3 files changed, 115 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 73ddd3aa5..d08ccb3e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,6 +379,12 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "cobs" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" + [[package]] name = "colorchoice" version = "1.0.1" @@ -641,6 +647,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +[[package]] +name = "embedded-io" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" + [[package]] name = "equivalent" version = "1.0.1" @@ -1165,6 +1177,7 @@ dependencies = [ "necsim-core-bond", "necsim-impls-std", "necsim-partitioning-core", + "postcard", "serde", "serde_derive_state", "serde_state", @@ -1368,6 +1381,17 @@ dependencies = [ "array-init-cursor", ] +[[package]] +name = "postcard" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55c51ee6c0db07e68448e336cf8ea4131a620edefebf9893e759b2d793420f8" +dependencies = [ + "cobs", + "embedded-io", + "serde", +] + [[package]] name = "ppv-lite86" version = "0.2.17" diff --git a/necsim/partitioning/mpi/Cargo.toml b/necsim/partitioning/mpi/Cargo.toml index fa2c13c2f..99fb836c5 100644 --- a/necsim/partitioning/mpi/Cargo.toml +++ b/necsim/partitioning/mpi/Cargo.toml @@ -21,6 +21,7 @@ serde = "1.0" serde_state = "0.4" serde_derive_state = "0.4" humantime-serde = "1.1" +postcard = { version = "1.0", default-features = false, features = ["use-std"] } [build-dependencies] build-probe-mpi = "0.1" diff --git a/necsim/partitioning/mpi/src/partition/utils.rs b/necsim/partitioning/mpi/src/partition/utils.rs index ef1400166..762e77e38 100644 --- a/necsim/partitioning/mpi/src/partition/utils.rs +++ b/necsim/partitioning/mpi/src/partition/utils.rs @@ -1,4 +1,5 @@ use std::{ + marker::PhantomData, mem::{offset_of, MaybeUninit}, os::raw::{c_int, c_void}, }; @@ -151,3 +152,92 @@ unsafe impl Equivalence for MpiMigratingLineage { ) } } + +pub fn reduce_partitioning_data< + T: serde::Serialize + serde::de::DeserializeOwned, + F: 'static + Copy + Fn(T, T) -> T, +>( + world: &SimpleCommunicator, + data: T, + fold: F, +) -> T { + let local_ser = postcard::to_stdvec(&data).expect("MPI data failed to serialize"); + let mut global_ser = Vec::with_capacity(local_ser.len()); + + let operation = + unsafe { UnsafeUserOperation::commutative(unsafe_reduce_partitioning_data_op::) }; + + world.all_reduce_into(local_ser.as_slice(), &mut global_ser, &operation); + + postcard::from_bytes(&global_ser).expect("MPI data failed to deserialize") +} + +#[cfg(not(all(msmpi, target_arch = "x86")))] +unsafe extern "C" fn unsafe_reduce_partitioning_data_op< + T: serde::Serialize + serde::de::DeserializeOwned, + F: 'static + Copy + Fn(T, T) -> T, +>( + invec: *mut c_void, + inoutvec: *mut c_void, + len: *mut c_int, + datatype: *mut MPI_Datatype, +) { + unsafe_reduce_partitioning_data_op_inner::(invec, inoutvec, len, datatype); +} + +#[cfg(all(msmpi, target_arch = "x86"))] +unsafe extern "stdcall" fn unsafe_reduce_partitioning_data_op< + T: serde::Serialize + serde::de::DeserializeOwned, + F: 'static + Copy + Fn(T, T) -> T, +>( + invec: *mut c_void, + inoutvec: *mut c_void, + len: *mut c_int, + datatype: *mut MPI_Datatype, +) { + unsafe_reduce_partitioning_data_op_inner::(invec, inoutvec, len, datatype); +} + +#[inline] +unsafe fn unsafe_reduce_partitioning_data_op_inner< + T: serde::Serialize + serde::de::DeserializeOwned, + F: 'static + Copy + Fn(T, T) -> T, +>( + invec: *mut c_void, + inoutvec: *mut c_void, + len: *mut c_int, + datatype: *mut MPI_Datatype, +) { + debug_assert!(*len == 1); + debug_assert!(*datatype == mpi::raw::AsRaw::as_raw(&TimeRank::equivalent_datatype())); + + reduce_partitioning_data_op_inner::(&*invec.cast(), &mut *inoutvec.cast()); +} + +#[inline] +fn reduce_partitioning_data_op_inner< + T: serde::Serialize + serde::de::DeserializeOwned, + F: 'static + Copy + Fn(T, T) -> T, +>( + local_ser: &[u8], + global_ser: &mut Vec, +) { + union Magic T> { + func: F, + unit: (), + marker: PhantomData, + } + + let local_de: T = postcard::from_bytes(local_ser).expect("MPI data failed to deserialize"); + let global_de: T = postcard::from_bytes(global_ser).expect("MPI data failed to deserialize"); + + const { assert!(std::mem::size_of::() == 0) }; + const { assert!(std::mem::align_of::() == 1) }; + let func: F = unsafe { Magic { unit: () }.func }; + + let folded = func(local_de, global_de); + + global_ser.clear(); + + postcard::to_io(&folded, global_ser).expect("MPI data failed to serialize"); +}