diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index cc0ac25bf2..ca6904fe94 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -130,7 +130,7 @@ async fn run(args: Args) -> Result<(), Error> { args.query_size, ) }; - let mut raw_data = EventGenerator::with_config( + let raw_data = EventGenerator::with_config( rng, EventGeneratorConfig { user_count, @@ -143,9 +143,6 @@ async fn run(args: Args) -> Result<(), Error> { ) .take(query_size) .collect::>(); - // EventGenerator produces events in random order, but IPA requires them to be sorted by - // timestamp. - raw_data.sort_by_key(|e| e.timestamp); let order = CappingOrder::CapMostRecentFirst; diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index d0e7deb13f..496c737167 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -155,10 +155,9 @@ fn gen_inputs( let rng = seed .map(StdRng::seed_from_u64) .unwrap_or_else(StdRng::from_entropy); - let mut event_gen = EventGenerator::with_config(rng, args) + let event_gen = EventGenerator::with_config(rng, args) .take(count as usize) .collect::>(); - event_gen.sort_by_key(|e| e.timestamp); let mut writer: Box = if let Some(path) = output_file { Box::new(OpenOptions::new().write(true).create_new(true).open(path)?) } else { diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index aa5628c708..6616283d87 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -313,6 +313,23 @@ pub mod tests { test_executor::run, test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld}, }; + use rand::{seq::SliceRandom, thread_rng}; + + fn test_input( + timestamp: u64, + user_id: u64, + is_trigger_report: bool, + breakdown_key: u32, + trigger_value: u32, + ) -> TestRawDataRecord { + TestRawDataRecord { + timestamp, + user_id, + is_trigger_report, + breakdown_key, + trigger_value, + } + } #[test] fn semi_honest() { @@ -322,41 +339,11 @@ pub mod tests { let world = TestWorld::default(); let records: Vec = vec![ - TestRawDataRecord { - timestamp: 0, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 5, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 2, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 10, - user_id: 12345, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 5, - }, - TestRawDataRecord { - timestamp: 0, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 20, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 2, - }, + test_input(0, 12345, false, 1, 0), + test_input(5, 12345, false, 2, 0), + test_input(10, 12345, true, 0, 5), + test_input(0, 68362, false, 1, 0), + test_input(20, 68362, true, 0, 2), ]; let mut result: Vec<_> = world @@ -377,4 +364,46 @@ pub mod tests { ); }); } + + // Test that IPA tolerates duplicate timestamps among a user's records. The end-to-end test + // harness does not generate data like this because the attribution result is non-deterministic. + // To make the output deterministic for this case, all of the duplicate timestamp records are + // identical. + #[test] + fn duplicate_timestamps() { + const EXPECTED: &[u128] = &[0, 2, 10, 0, 0, 0, 0, 0]; + + run(|| async { + let world = TestWorld::default(); + + let mut records: Vec = vec![ + test_input(0, 12345, false, 1, 0), + test_input(5, 12345, false, 2, 0), + test_input(5, 12345, false, 2, 0), + test_input(10, 12345, true, 0, 5), + test_input(10, 12345, true, 0, 5), + test_input(0, 68362, false, 1, 0), + test_input(20, 68362, true, 0, 2), + ]; + + records.shuffle(&mut thread_rng()); + + let mut result: Vec<_> = world + .semi_honest(records.into_iter(), |ctx, input_rows| async move { + oprf_ipa::<_, BA8, BA3, BA20, BA5, Fp31>(ctx, input_rows, None) + .await + .unwrap() + }) + .await + .reconstruct(); + result.truncate(EXPECTED.len()); + assert_eq!( + result, + EXPECTED + .iter() + .map(|i| Fp31::try_from(*i).unwrap()) + .collect::>() + ); + }); + } } diff --git a/ipa-core/src/test_fixture/event_gen.rs b/ipa-core/src/test_fixture/event_gen.rs index 33afabf32d..0b22a37b6b 100644 --- a/ipa-core/src/test_fixture/event_gen.rs +++ b/ipa-core/src/test_fixture/event_gen.rs @@ -1,3 +1,10 @@ +use std::{ + collections::HashSet, + num::{NonZeroU32, NonZeroU64}, +}; + +use crate::{rand::Rng, test_fixture::ipa::TestRawDataRecord}; + #[derive(Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq)] struct UserId(u64); @@ -26,6 +33,10 @@ impl UserId { pub const FIRST: Self = Self(1); } +// 7 days = 604800 seconds fits in 20 bits +pub type Timestamp = u32; +pub type NonZeroTimestamp = NonZeroU32; + #[derive(Debug, Copy, Clone)] #[cfg_attr(feature = "clap", derive(clap::ValueEnum))] pub enum ReportFilter { @@ -46,8 +57,7 @@ pub struct Config { #[cfg_attr(feature = "clap", arg(long, default_value = "20"))] pub max_breakdown_key: NonZeroU32, #[cfg_attr(feature = "clap", arg(long, hide = true, default_value = "604800"))] - // 7 days < 20 bits - pub max_timestamp: NonZeroU32, + pub max_timestamp: NonZeroTimestamp, #[cfg_attr(feature = "clap", arg(long, default_value = "10"))] pub max_events_per_user: NonZeroU32, #[cfg_attr(feature = "clap", arg(long, default_value = "1"))] @@ -89,14 +99,14 @@ impl Config { max_breakdown_key: u32, min_events_per_user: u32, max_events_per_user: u32, - max_timestamp: u32, + max_timestamp: Timestamp, ) -> Self { assert!(min_events_per_user < max_events_per_user); Self { user_count: NonZeroU64::try_from(user_count).unwrap(), max_trigger_value: NonZeroU32::try_from(max_trigger_value).unwrap(), max_breakdown_key: NonZeroU32::try_from(max_breakdown_key).unwrap(), - max_timestamp: NonZeroU32::try_from(max_timestamp).unwrap(), + max_timestamp: NonZeroTimestamp::try_from(max_timestamp).unwrap(), min_events_per_user: NonZeroU32::try_from(min_events_per_user).unwrap(), max_events_per_user: NonZeroU32::try_from(max_events_per_user).unwrap(), report_filter: ReportFilter::All, @@ -111,17 +121,11 @@ impl Config { } } -use std::{ - collections::HashSet, - num::{NonZeroU32, NonZeroU64}, -}; - -use crate::{rand::Rng, test_fixture::ipa::TestRawDataRecord}; - struct UserStats { user_id: UserId, generated: u32, max: u32, + used_timestamps: HashSet, } impl UserStats { @@ -130,6 +134,7 @@ impl UserStats { user_id, generated: 0, max: max_events, + used_timestamps: HashSet::new(), } } @@ -152,8 +157,7 @@ pub struct EventGenerator { config: Config, rng: R, users: Vec, - // even bit vector takes too long to initialize. Need a sparse structure here - used: HashSet, + used_ids: HashSet, } impl EventGenerator { @@ -166,15 +170,21 @@ impl EventGenerator { config, rng, users: vec![], - used: HashSet::new(), + used_ids: HashSet::new(), } } - fn gen_event(&mut self, user_id: UserId) -> TestRawDataRecord { - // Generate a new random timestamp between [0..`max_timestamp`). - // This means the generated events must be sorted by timestamp before being - // fed into the IPA protocols. - let current_ts = self.rng.gen_range(0..self.config.max_timestamp.get()); + fn gen_event(&mut self, idx: usize) -> TestRawDataRecord { + let user_id = self.users[idx].user_id; + + // Generate a new random timestamp between [0..`max_timestamp`) and distinct from + // already-used timestamps. + let current_ts = loop { + let ts = self.rng.gen_range(0..self.config.max_timestamp.get()); + if self.users[idx].used_timestamps.insert(ts) { + break ts; + } + }; match self.config.report_filter { ReportFilter::All => { @@ -198,7 +208,7 @@ impl EventGenerator { } } - fn gen_trigger(&mut self, user_id: UserId, timestamp: u32) -> TestRawDataRecord { + fn gen_trigger(&mut self, user_id: UserId, timestamp: Timestamp) -> TestRawDataRecord { let trigger_value = self.rng.gen_range(1..self.config.max_trigger_value.get()); TestRawDataRecord { @@ -210,7 +220,7 @@ impl EventGenerator { } } - fn gen_source(&mut self, user_id: UserId, timestamp: u32) -> TestRawDataRecord { + fn gen_source(&mut self, user_id: UserId, timestamp: Timestamp) -> TestRawDataRecord { let breakdown_key = self.rng.gen_range(0..self.config.max_breakdown_key.get()); TestRawDataRecord { @@ -223,28 +233,28 @@ impl EventGenerator { } fn sample_user(&mut self) -> Option { - if self.used.len() == self.config.user_count() { + if self.used_ids.len() == self.config.user_count() { return None; } - let valid = |user_id| -> bool { !self.used.contains(&user_id) }; - - Some(loop { - let next = UserId::from( + loop { + let user_id = UserId::from( self.rng .gen_range(UserId::FIRST.into()..=self.config.user_count.get()), ); - if valid(next) { - self.used.insert(next); - break UserStats::new( - next, - self.rng.gen_range( - self.config.min_events_per_user.get() - ..=self.config.max_events_per_user.get(), - ), - ); + if self.used_ids.contains(&user_id) { + continue; } - }) + self.used_ids.insert(user_id); + + break Some(UserStats::new( + user_id, + self.rng.gen_range( + self.config.min_events_per_user.get() + ..=self.config.max_events_per_user.get(), + ), + )); + } } } @@ -266,12 +276,13 @@ impl Iterator for EventGenerator { } let idx = self.rng.gen_range(0..self.users.len()); - let user_id = self.users[idx].user_id; + let event = self.gen_event(idx); + if self.users[idx].add_one() { self.users.swap_remove(idx); } - Some(self.gen_event(user_id)) + Some(event) } } @@ -366,6 +377,12 @@ mod tests { "Found source report with trigger value set" ); } + + assert!( + event.timestamp < u64::from(self.max_timestamp.get()), + "Timestamp should not exceed configured maximum", + ) + } } @@ -392,7 +409,7 @@ mod tests { user_count: NonZeroU64::new(10_000).unwrap(), max_trigger_value: NonZeroU32::new(max_trigger_value).unwrap(), max_breakdown_key: NonZeroU32::new(max_breakdown_key).unwrap(), - max_timestamp: NonZeroU32::new(604_800).unwrap(), + max_timestamp: NonZeroTimestamp::new(604_800).unwrap(), min_events_per_user: NonZeroU32::new(min_events_per_user).unwrap(), max_events_per_user: NonZeroU32::new(max_events_per_user).unwrap(), report_filter, @@ -423,9 +440,7 @@ mod tests { "Generated breakdown key greater than {max_breakdown}" ); - // Basic correctness checks. timestamps are not checked as the order of events - // is not guaranteed. The caller must sort the events by timestamp before - // feeding them into IPA. + // Basic correctness checks. config.is_valid(&event); } } diff --git a/ipa-core/src/test_fixture/ipa.rs b/ipa-core/src/test_fixture/ipa.rs index b5d608edd3..89a08b72c0 100644 --- a/ipa-core/src/test_fixture/ipa.rs +++ b/ipa-core/src/test_fixture/ipa.rs @@ -1,4 +1,6 @@ -use std::{collections::HashMap, num::NonZeroU32, ops::Deref}; +use std::{collections::HashMap, num::NonZeroU32}; + +use rand::{thread_rng, Rng}; use crate::protocol::ipa_prf::prf_sharding::GroupingKey; #[cfg(feature = "in-memory-infra")] @@ -41,13 +43,31 @@ impl GroupingKey for TestRawDataRecord { } } +/// Insert `record` into `user_records`, maintaining timestamp order. +/// +/// If there are existing records with the same timestamp, inserts the new record +/// randomly in any position that maintains timestamp order. +fn insert_sorted(user_records: &mut Vec, record: TestRawDataRecord) { + let upper = user_records.partition_point(|rec| rec.timestamp <= record.timestamp); + if upper > 0 && user_records[upper - 1].timestamp == record.timestamp { + let lower = user_records[0..upper - 1] + .iter() + .rposition(|rec| rec.timestamp < record.timestamp) + .map_or(0, |lower| lower + 1); + user_records.insert(thread_rng().gen_range(lower..=upper), record); + } else { + user_records.insert(upper, record); + } +} + /// Executes IPA protocol in the clear, that is without any MPC helpers involved in the computation. /// Useful to validate that MPC output makes sense by comparing the breakdowns produced by MPC IPA /// with this function's results. Note that MPC version of IPA may apply DP noise to the aggregates, /// so strict equality may not work. /// -/// This function requires input to be sorted by the timestamp and returns a vector of contributions -/// sorted by the breakdown key. +/// Just like the MPC implementation, if the input contains records with duplicate timestamps, the +/// order those records are considered by the attribution algorithm is undefined, and the output +/// may be non-deterministic. /// /// ## Panics /// Will panic if you run in on Intel 80286 or any other 16 bit hardware. @@ -58,33 +78,20 @@ pub fn ipa_in_the_clear( max_breakdown: u32, order: &CappingOrder, ) -> Vec { - // build a view that is convenient for attribution. match key -> events sorted by timestamp in reverse + // build a view that is convenient for attribution. match key -> events sorted by timestamp // that is more memory intensive, but should be faster to compute. We can always opt-out and // execute IPA in place let mut user_events = HashMap::new(); - let mut last_ts = 0; for row in input { - if cfg!(debug_assertions) { - assert!( - last_ts <= row.timestamp, - "Input is not sorted: last row had timestamp {last_ts} that is greater than \ - {this_ts} timestamp of the current row", - this_ts = row.timestamp - ); - last_ts = row.timestamp; - } - - user_events - .entry(row.user_id) - .or_insert_with(Vec::new) - .push(row); + insert_sorted( + user_events.entry(row.user_id).or_insert_with(Vec::new), + row.clone(), + ); } let mut breakdowns = vec![0u32; usize::try_from(max_breakdown).unwrap()]; for records_per_user in user_events.values() { - // it works because input is sorted and vectors preserve the insertion order - // so records in `rev` are returned in reverse chronological order - let rev_records = records_per_user.iter().rev().map(Deref::deref); + let rev_records = records_per_user.iter().rev(); update_expected_output_for_user( rev_records, &mut breakdowns, @@ -233,3 +240,82 @@ pub async fn test_oprf_ipa( let _ = result.split_off(expected_results.len()); assert_eq!(result, expected_results); } + +#[cfg(all(test, unit_test))] +mod tests { + use super::*; + + fn insert_sorted_test>(iter: I) -> Vec { + fn test_record(timestamp: u64, breakdown_key: u32) -> TestRawDataRecord { + TestRawDataRecord { + timestamp, + user_id: 0, + is_trigger_report: false, + breakdown_key, + trigger_value: 0, + } + } + + let mut expected = Vec::new(); + let mut actual = Vec::new(); + for (i, v) in iter.into_iter().enumerate() { + expected.push(v); + super::insert_sorted(&mut actual, test_record(v, u32::try_from(i).unwrap())); + } + expected.sort(); + assert_eq!(expected, actual.iter().map(|rec| rec.timestamp).collect::>()); + + actual + } + + #[test] + fn insert_sorted() { + insert_sorted_test([1, 2, 3, 4]); + insert_sorted_test([4, 3, 2, 1]); + insert_sorted_test([2, 3, 1, 4]); + + let mut counts1 = [0, 0, 0]; + let mut counts5 = [0, 0, 0]; + let mut counts6 = [0, 0, 0]; + // The three twos (initially in positions 1, 5, and 6), should be placed in positions 2, 3, + // and 4 in the output in random order. After 128 trials, each of these possibilities should + // have occurred at least once. + for _ in 0..128 { + let result = insert_sorted_test([1, 2, 0, 3, 4, 2, 2]); + + let i1 = result.iter().position(|r| r.breakdown_key == 1).unwrap(); + counts1[i1 - 2] += 1; + let i5 = result.iter().position(|r| r.breakdown_key == 5).unwrap(); + counts5[i5 - 2] += 1; + let i6 = result.iter().position(|r| r.breakdown_key == 6).unwrap(); + counts6[i6 - 2] += 1; + } + for i in 0..3 { + assert_ne!(counts1[i], 0); + assert_ne!(counts5[i], 0); + assert_ne!(counts6[i], 0); + } + + let mut counts2 = [0, 0, 0]; + let mut counts5 = [0, 0, 0]; + let mut counts6 = [0, 0, 0]; + // The three zeros (initially in positions 2, 5, and 6), should be placed in positions 0, 1, + // and 2 in the output in random order. After 128 trials, each of these possibilities should + // have occurred at least once. + for _ in 0..128 { + let result = insert_sorted_test([1, 2, 0, 3, 4, 0, 0]); + + let i2 = result.iter().position(|r| r.breakdown_key == 2).unwrap(); + counts2[i2] += 1; + let i5 = result.iter().position(|r| r.breakdown_key == 5).unwrap(); + counts5[i5] += 1; + let i6 = result.iter().position(|r| r.breakdown_key == 6).unwrap(); + counts6[i6] += 1; + } + for i in 0..3 { + assert_ne!(counts2[i], 0); + assert_ne!(counts5[i], 0); + assert_ne!(counts6[i], 0); + } + } +}