From cbb807d8f2f555371e3560e4951e206c31dc71db Mon Sep 17 00:00:00 2001 From: Momo Langenstein Date: Sat, 13 Nov 2021 21:40:50 +0000 Subject: [PATCH] (ml5717) Untested impl of more robust index sampling --- necsim/core/src/cogs/rng.rs | 130 +++++++++++++----- .../alias/dynamic/indexed/mod.rs | 5 +- .../alias/dynamic/stack/mod.rs | 5 +- 3 files changed, 104 insertions(+), 36 deletions(-) diff --git a/necsim/core/src/cogs/rng.rs b/necsim/core/src/cogs/rng.rs index 6498c8a1f..1e03f754a 100644 --- a/necsim/core/src/cogs/rng.rs +++ b/necsim/core/src/cogs/rng.rs @@ -1,7 +1,7 @@ use core::{ convert::AsMut, default::Default, - num::{NonZeroU128, NonZeroU32, NonZeroUsize}, + num::{NonZeroU128, NonZeroU32, NonZeroU64, NonZeroUsize}, ptr::copy_nonoverlapping, }; @@ -41,6 +41,7 @@ pub trait SeedableRng: RngCore { const INC: u64 = 11_634_580_027_462_260_723_u64; let mut seed = Self::Seed::default(); + for chunk in seed.as_mut().chunks_mut(4) { // We advance the state first (to get away from the input value, // in case it has low Hamming Weight). @@ -96,51 +97,112 @@ pub trait RngSampler: RngCore { #[inline] #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")] fn sample_index(&mut self, length: NonZeroUsize) -> usize { - // attributes on expressions are experimental - // see https://github.com/rust-lang/rust/issues/15701 - #[allow( - clippy::cast_precision_loss, - clippy::cast_possible_truncation, - clippy::cast_sign_loss - )] - let index = - M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as usize; - // Safety in case of f64 rounding errors - index.min(length.get() - 1) + #[cfg(target_pointer_width = "32")] + #[allow(clippy::cast_possible_truncation)] + { + self.sample_index_u32(unsafe { NonZeroU32::new_unchecked(length.get() as u32) }) + as usize + } + #[cfg(target_pointer_width = "64")] + #[allow(clippy::cast_possible_truncation)] + { + self.sample_index_u64(unsafe { NonZeroU64::new_unchecked(length.get() as u64) }) + as usize + } } #[must_use] #[inline] #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")] fn sample_index_u32(&mut self, length: NonZeroU32) -> u32 { - // attributes on expressions are experimental - // see https://github.com/rust-lang/rust/issues/15701 - #[allow( - clippy::cast_precision_loss, - clippy::cast_possible_truncation, - clippy::cast_sign_loss - )] - let index = - M::floor(self.sample_uniform_closed_open().get() * f64::from(length.get())) as u32; - // Safety in case of f64 rounding errors - index.min(length.get() - 1) + // TODO: Check if delegation to `sample_index_u64` is faster + + // Adapted from: + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single + + const LOWER_MASK: u64 = !0 >> 32; + + // Conservative approximation of the acceptance zone + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); + + loop { + let raw = self.sample_u64(); + + let sample_check_lo = (raw & LOWER_MASK) * u64::from(length.get()); + + #[allow(clippy::cast_possible_truncation)] + if (sample_check_lo as u32) <= acceptance_zone { + return (sample_check_lo >> 32) as u32; + } + + let sample_check_hi = (raw >> 32) * u64::from(length.get()); + + #[allow(clippy::cast_possible_truncation)] + if (sample_check_hi as u32) <= acceptance_zone { + return (sample_check_hi >> 32) as u32; + } + } + } + + #[must_use] + #[inline] + #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")] + fn sample_index_u64(&mut self, length: NonZeroU64) -> u64 { + // Adapted from: + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single + + // Conservative approximation of the acceptance zone + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); + + loop { + let raw = self.sample_u64(); + + let sample_check = u128::from(raw) * u128::from(length.get()); + + #[allow(clippy::cast_possible_truncation)] + if (sample_check as u64) <= acceptance_zone { + return (sample_check >> 64) as u64; + } + } } #[must_use] #[inline] #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")] fn sample_index_u128(&mut self, length: NonZeroU128) -> u128 { - // attributes on expressions are experimental - // see https://github.com/rust-lang/rust/issues/15701 - #[allow( - clippy::cast_precision_loss, - clippy::cast_possible_truncation, - clippy::cast_sign_loss - )] - let index = - M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as u128; - // Safety in case of f64 rounding errors - index.min(length.get() - 1) + // Adapted from: + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single + + const LOWER_MASK: u128 = !0 >> 64; + + // Conservative approximation of the acceptance zone + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); + + loop { + let raw_hi = u128::from(self.sample_u64()); + let raw_lo = u128::from(self.sample_u64()); + + // 256-bit multiplication (hi, lo) = (raw_hi, raw_lo) * length + let mut low = raw_lo * (length.get() & LOWER_MASK); + let mut t = low >> 64; + low &= LOWER_MASK; + t += raw_hi * (length.get() & LOWER_MASK); + low += (t & LOWER_MASK) << 64; + let mut high = t >> 64; + t = low >> 64; + low &= LOWER_MASK; + t += (length.get() >> 64) * raw_lo; + low += (t & LOWER_MASK) << 64; + high += t >> 64; + high += raw_hi * (length.get() >> 64); + + let sample = high; + let check = low; + + if check <= acceptance_zone { + return sample; + } + } } #[must_use] diff --git a/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/indexed/mod.rs b/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/indexed/mod.rs index c978f42a6..bb86dc39c 100644 --- a/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/indexed/mod.rs +++ b/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/indexed/mod.rs @@ -1,9 +1,10 @@ use alloc::{vec, vec::Vec}; use core::{ cmp::Ordering, + convert::TryFrom, fmt, hash::Hash, - num::{NonZeroU128, NonZeroUsize}, + num::{NonZeroU128, NonZeroU64, NonZeroUsize}, }; use fnv::FnvBuildHasher; @@ -191,6 +192,8 @@ impl DynamicAliasMethodIndexedSampler { if let Some(total_weight) = NonZeroU128::new(self.total_weight) { let cdf_sample = if let [_group] = &self.groups[..] { 0_u128 + } else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) { + u128::from(rng.sample_index_u64(total_weight)) } else { rng.sample_index_u128(total_weight) }; diff --git a/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/stack/mod.rs b/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/stack/mod.rs index bf52c1a07..01d24b002 100644 --- a/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/stack/mod.rs +++ b/necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/stack/mod.rs @@ -1,9 +1,10 @@ use alloc::{vec, vec::Vec}; use core::{ cmp::Ordering, + convert::TryFrom, fmt, hash::Hash, - num::{NonZeroU128, NonZeroUsize}, + num::{NonZeroU128, NonZeroU64, NonZeroUsize}, }; use necsim_core::cogs::{MathsCore, RngCore, RngSampler}; @@ -125,6 +126,8 @@ impl DynamicAliasMethodStackSampler { if let Some(total_weight) = NonZeroU128::new(self.total_weight) { let cdf_sample = if let [_group] = &self.groups[..] { 0_u128 + } else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) { + u128::from(rng.sample_index_u64(total_weight)) } else { rng.sample_index_u128(total_weight) };