Skip to content

Commit

Permalink
Vec-extraction futures
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Jan 3, 2024
1 parent dcb6a39 commit 5abc016
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 11 deletions.
125 changes: 125 additions & 0 deletions ipa-core/src/helpers/buffers/buffered_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::{task::{Context, Poll}, pin::Pin, iter::{repeat_with, repeat}};

use futures::{Future, FutureExt, future::Shared, ready};
use tokio::sync::oneshot::{Sender, self, Receiver};
use pin_project::pin_project;

pub fn vec_item_futures<Fut: Future<Output = Vec<T>>, T>(fut: Fut, len: usize) -> impl Iterator<Item = VecItemFuture<Fut, T>> {
let (tx, rx): (Vec<Sender<T>>, Vec<Receiver<T>>) = repeat_with(|| oneshot::channel()).take(len).unzip();
repeat(VecDistributionFuture { fut, len, tx }.shared()).zip(rx).map(|(fut, rx)| VecItemFuture { src: Some(fut), rx })
}

#[pin_project]
pub struct VecDistributionFuture<Fut: Future<Output = Vec<T>>, T> {
#[pin]
fut: Fut,
len: usize,
tx: Vec<Sender<T>>,
}

impl<Fut: Future<Output = Vec<T>>, T> Future for VecDistributionFuture<Fut, T> {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let proj = self.project();
let vec = ready!(proj.fut.poll(cx));

assert_eq!(vec.len(), *proj.len, "future used with vec_item_futures did not resolve with the correct length");
for (item, tx) in vec.into_iter().zip(proj.tx.drain(..)) {
assert!(tx.send(item).is_ok());
}

Poll::Ready(())
}
}

#[pin_project]
pub struct VecItemFuture<Fut: Future<Output = Vec<T>>, T> {
#[pin]
src: Option<Shared<VecDistributionFuture<Fut, T>>>,
#[pin]
rx: Receiver<T>,
}

impl<Fut: Future<Output = Vec<T>>, T> Future for VecItemFuture<Fut, T> {
type Output = T;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut proj = self.project();
if let Some(src) = proj.src.as_mut().as_pin_mut() {
match src.poll(cx) {
Poll::Ready(()) => *proj.src = None,
Poll::Pending => return Poll::Pending,
}
}
<Receiver<T> as Future>::poll(proj.rx, cx).map(|recv_res| recv_res.unwrap())
}
}

#[cfg(test)]
mod tests {
use std::{ptr::null, task::Waker, future};

use super::*;

fn fake_waker() -> Waker {
use std::task::{RawWaker, RawWakerVTable};
const fn fake_raw_waker() -> RawWaker {
const TABLE: RawWakerVTable =
RawWakerVTable::new(|_| fake_raw_waker(), |_| {}, |_| {}, |_| {});
RawWaker::new(null(), &TABLE)
}
unsafe { Waker::from_raw(fake_raw_waker()) }
}

#[tokio::test]
async fn test_vec_item_futures() {
let mut futs = vec_item_futures(future::ready(vec![1, 2]), 2);
let fut1 = futs.next().unwrap();
let fut2 = futs.next().unwrap();
assert!(futs.next().is_none());

assert_eq!(fut1.await, 1);
assert_eq!(fut2.await, 2);
}

#[tokio::test]
async fn test_vec_item_futures_reverse_order() {
let mut futs = vec_item_futures(future::ready(vec![1, 2]), 2);
let fut1 = futs.next().unwrap();
let fut2 = futs.next().unwrap();
assert!(futs.next().is_none());

assert_eq!(fut2.await, 2);
assert_eq!(fut1.await, 1);
}

#[tokio::test]
#[should_panic(expected = "future used with vec_item_futures did not resolve with the correct length")]
async fn test_vec_item_futures_incorrect_length() {
let mut futs = vec_item_futures(future::ready(vec![1, 2]), 3);
futs.next().unwrap().await;
}

#[tokio::test]
async fn test_vec_item_futures_pending() {
let (_tx, rx) = oneshot::channel::<Vec<i32>>();
let mut futs = vec_item_futures(rx.map(Result::unwrap), 3);
let mut fut = futs.next().unwrap();
let waker = fake_waker();
let mut cx = Context::from_waker(&waker);
assert_eq!(fut.poll_unpin(&mut cx), Poll::Pending);
}

#[tokio::test]
async fn test_vec_item_futures_broadcast() {
// We should be able to poll the second future, even if nobody looks at the first.
let (tx, rx) = oneshot::channel::<Vec<i32>>();
let mut futs = vec_item_futures(rx.map(Result::unwrap), 2);
let _fut1 = futs.next().unwrap();
let fut2 = futs.next().unwrap();
assert!(futs.next().is_none());
tx.send(vec![1, 2]).unwrap();
assert_eq!(fut2.await, 2);
}
}
2 changes: 2 additions & 0 deletions ipa-core/src/helpers/buffers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod buffered_stream;
mod ordering_mpsc;
mod ordering_sender;
mod unordered_receiver;

pub use buffered_stream::vec_item_futures;
pub use ordering_mpsc::{ordering_mpsc, OrderingMpscReceiver, OrderingMpscSender};
pub use ordering_sender::{OrderedStream, OrderingSender};
pub use unordered_receiver::UnorderedReceiver;
1 change: 1 addition & 0 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod transport;

use std::ops::{Index, IndexMut};

pub use buffers::vec_item_futures;
/// to validate that transport can actually send streams of this type
#[cfg(test)]
pub use buffers::OrderingSender;
Expand Down
30 changes: 19 additions & 11 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{
num::NonZeroU32,
num::{NonZeroU32, NonZeroUsize},
ops::{Not, Range},
};

use futures::{
future::{try_join, try_join3},
stream::{iter as stream_iter, unfold},
Stream, StreamExt, TryStreamExt,
FutureExt, Stream, StreamExt, TryStreamExt,
};
use ipa_macros::Step;

Expand All @@ -18,7 +18,7 @@ use crate::{
boolean_array::{BA32, BA7},
ArrayAccess, CustomArray, Expand, Field, PrimeField, Serializable,
},
helpers::Role,
helpers::{Role, vec_item_futures},
protocol::{
basics::{if_else, SecureMul, ShareKnownValue},
boolean::or::or,
Expand Down Expand Up @@ -506,24 +506,32 @@ where
let num_user_rows = rows_for_user.len();
let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned();

evaluate_per_user_attribution_circuit::<_, BK, TV, TS, SS>(
contexts,
RecordId::from(record_id),
rows_for_user,
attribution_window_seconds,
(
num_user_rows - 1,
FutureExt::map(evaluate_per_user_attribution_circuit::<_, BK, TV, TS, SS>(
contexts,
RecordId::from(record_id),
rows_for_user,
attribution_window_seconds,
), Result::unwrap),
)
});

// Execute all of the async futures (sequentially), and flatten the result
let flattened_stream = seq_join(sh_ctx.active_work(), stream_iter(per_user_results))
.flat_map(|x| stream_iter(x.unwrap()));
let flattened_results = per_user_results
.flat_map(|(len, vec_fut)| vec_item_futures(vec_fut, len));

let user_results_stream = seq_join(
sh_ctx.active_work().checked_mul(NonZeroUsize::new(4).unwrap()).unwrap(),
stream_iter(flattened_results),
);

// modulus convert breakdown keys and trigger values
let converted_bks_and_tvs = convert_bits(
prime_field_ctx
.narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues)
.set_total_records(num_outputs),
flattened_stream,
user_results_stream,
0..(<BK as SharedValue>::BITS + <TV as SharedValue>::BITS),
);

Expand Down

0 comments on commit 5abc016

Please sign in to comment.