Skip to content

Commit

Permalink
Delete TransportCallbacks and use RequestHandler trait instead (p…
Browse files Browse the repository at this point in the history
…rivate-attribution#992)

* Delete `TransportCallbacks` and use `RequestHandler` trait instead

See private-attribution#987 for motivation. I had to decide whether I want to use dynamic
dispatch vs clunky HTTP interfaces with another generic parameter propagated
through the entire stack. I don't have a conslusive answer which way is better,
both have significant downsides.

Problems with DD approach that is proposed in this change:

* Hard to keep `RequestHandler` trait object safe. No generics for `handle` method,
use of `async_trait` etc. That removes the opportunity for some optimizations, namely
using a trait to pass data down to the handler. It could be better if HTTP layer just
passes the same structs it gets from HTTP layer without an extra conversion that must occur
if dynamic dispatch is used.
* Non zero-cost abstraction. To get data back from the handler, we have to use the
same format, right now it is JSON but I doubt we can do better than binary serialization,
which means more work to get the data out.
* `Box<dyn Trait<....` is everywhere now.

Problems with static dispatch (I will link a commit) is more code that requires a change.
It is also not clear whether we can make it a zero-cost abstraction.

It is mentioned in private-attribution#987 but I will reiterate it here that the reason for the intermediate layer
data representation (betweeen HTTP and transport) is to support various delivery channels for IPA,
that could potentially include something like CF workers. We don't seem to have an opportunity
to rely on our network layer being HTTP in the long term.

* Fix the memory leak inside TestApp

* Fix one FIXME

* Clean up code

* Feedback
  • Loading branch information
akoshelev authored Apr 5, 2024
1 parent 8894929 commit fa79cee
Show file tree
Hide file tree
Showing 28 changed files with 1,002 additions and 704 deletions.
184 changes: 107 additions & 77 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,106 +1,86 @@
use std::sync::Weak;

use async_trait::async_trait;

use crate::{
helpers::{
query::{QueryConfig, QueryInput},
Transport, TransportCallbacks, TransportImpl,
query::{PrepareQuery, QueryConfig, QueryInput},
routing::{Addr, RouteId},
ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse,
RequestHandler, Transport, TransportImpl,
},
hpke::{KeyPair, KeyRegistry},
protocol::QueryId,
query::{
NewQueryError, QueryCompletionError, QueryInputError, QueryProcessor, QueryStatus,
QueryStatusError,
},
query::{NewQueryError, QueryProcessor, QueryStatus},
sync::Arc,
};

pub struct Setup {
query_processor: Arc<QueryProcessor>,
query_processor: QueryProcessor,
handler: HandlerRef,
}

/// The API layer to interact with a helper.
#[must_use]
pub struct HelperApp {
query_processor: Arc<QueryProcessor>,
inner: Arc<Inner>,
}

struct Inner {
query_processor: QueryProcessor,
/// For HTTP implementation this transport is also behind an [`Arc`] which causes double indirection
/// on top of atomics and all fun stuff associated with it. I don't see an easy way to avoid that
/// if we want to keep the implementation leak-free, but one may be aware if this shows up on
/// the flamegraph
transport: TransportImpl,
}

impl Setup {
#[must_use]
pub fn new() -> (Self, TransportCallbacks<TransportImpl>) {
pub fn new() -> (Self, HandlerRef) {
Self::with_key_registry(KeyRegistry::empty())
}

#[must_use]
pub fn with_key_registry(
key_registry: KeyRegistry<KeyPair>,
) -> (Self, TransportCallbacks<TransportImpl>) {
let query_processor = Arc::new(QueryProcessor::new(key_registry));
pub fn with_key_registry(key_registry: KeyRegistry<KeyPair>) -> (Self, HandlerRef) {
let query_processor = QueryProcessor::new(key_registry);
let handler = HandlerBox::empty();
let this = Self {
query_processor: Arc::clone(&query_processor),
query_processor,
handler: handler.clone(),
};

// TODO: weak reference to query processor to prevent mem leak
(this, Self::callbacks(&query_processor))
(this, handler)
}

/// Instantiate [`HelperApp`] by connecting it to the provided transport implementation
pub fn connect(self, transport: TransportImpl) -> HelperApp {
HelperApp::new(transport, self.query_processor)
}
let app = Arc::new(Inner {
query_processor: self.query_processor,
transport,
});
self.handler.set_handler(
Arc::downgrade(&app) as Weak<dyn RequestHandler<Identity = HelperIdentity>>
);

/// Create callbacks that tie up query processor and transport.
fn callbacks(query_processor: &Arc<QueryProcessor>) -> TransportCallbacks<TransportImpl> {
let rqp = Arc::clone(query_processor);
let pqp = Arc::clone(query_processor);
let iqp = Arc::clone(query_processor);
let sqp = Arc::clone(query_processor);
let cqp = Arc::clone(query_processor);

TransportCallbacks {
receive_query: Box::new(move |transport: TransportImpl, receive_query| {
let processor = Arc::clone(&rqp);
Box::pin(async move {
let r = processor.new_query(transport, receive_query).await?;

Ok(r.query_id)
})
}),
prepare_query: Box::new(move |transport: TransportImpl, prepare_query| {
let processor = Arc::clone(&pqp);
Box::pin(async move { processor.prepare(&transport, prepare_query) })
}),
query_input: Box::new(move |transport: TransportImpl, query_input| {
let processor = Arc::clone(&iqp);
Box::pin(async move { processor.receive_inputs(transport, query_input) })
}),
query_status: Box::new(move |_transport: TransportImpl, query_id| {
let processor = Arc::clone(&sqp);
Box::pin(async move { processor.query_status(query_id) })
}),
complete_query: Box::new(move |_transport: TransportImpl, query_id| {
let processor = Arc::clone(&cqp);
Box::pin(async move { processor.complete(query_id).await })
}),
}
// Handler must be kept inside the app instance. When app is dropped, handler, transport and
// query processor are destroyed.
HelperApp { inner: app }
}
}

impl HelperApp {
pub fn new(transport: TransportImpl, query_processor: Arc<QueryProcessor>) -> Self {
Self {
query_processor,
transport,
}
}

/// Initiates a new query on this helper. In case if query is accepted, the unique [`QueryId`]
/// identifier is returned, otherwise an error indicating what went wrong is reported back.
///
/// ## Errors
/// If query is rejected for any reason.
pub async fn start_query(&self, query_config: QueryConfig) -> Result<QueryId, NewQueryError> {
Ok(self
.inner
.query_processor
.new_query(Transport::clone_ref(&self.transport), query_config)
.new_query(Transport::clone_ref(&self.inner.transport), query_config)
.await?
.query_id)
}
Expand All @@ -109,38 +89,88 @@ impl HelperApp {
///
/// ## Errors
/// Propagates errors from the helper.
pub fn execute_query(&self, input: QueryInput) -> Result<(), Error> {
let transport = <TransportImpl as Clone>::clone(&self.transport);
self.query_processor.receive_inputs(transport, input)?;
pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> {
let transport = <TransportImpl as Clone>::clone(&self.inner.transport);
self.inner
.query_processor
.receive_inputs(transport, input)?;
Ok(())
}

/// Retrieves the status of a query.
///
/// ## Errors
/// Propagates errors from the helper.
pub fn query_status(&self, query_id: QueryId) -> Result<QueryStatus, Error> {
Ok(self.query_processor.query_status(query_id)?)
pub fn query_status(&self, query_id: QueryId) -> Result<QueryStatus, ApiError> {
Ok(self.inner.query_processor.query_status(query_id)?)
}

/// Waits for a query to complete and returns the result.
///
/// ## Errors
/// Propagates errors from the helper.
pub async fn complete_query(&self, query_id: QueryId) -> Result<Vec<u8>, Error> {
Ok(self.query_processor.complete(query_id).await?.into_bytes())
pub async fn complete_query(&self, query_id: QueryId) -> Result<Vec<u8>, ApiError> {
Ok(self
.inner
.query_processor
.complete(query_id)
.await?
.to_bytes())
}
}

/// Union of error types returned by API operations.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
NewQuery(#[from] NewQueryError),
#[error(transparent)]
QueryInput(#[from] QueryInputError),
#[error(transparent)]
QueryCompletion(#[from] QueryCompletionError),
#[error(transparent)]
QueryStatus(#[from] QueryStatusError),
#[async_trait]
impl RequestHandler for Inner {
type Identity = HelperIdentity;

async fn handle(
&self,
req: Addr<Self::Identity>,
data: BodyStream,
) -> Result<HelperResponse, ApiError> {
fn ext_query_id(req: &Addr<HelperIdentity>) -> Result<QueryId, ApiError> {
req.query_id.ok_or_else(|| {
ApiError::BadRequest("Query input is missing query_id argument".into())
})
}

let qp = &self.query_processor;

Ok(match req.route {
r @ RouteId::Records => {
return Err(ApiError::BadRequest(
format!("{r:?} request must not be handled by query processing flow").into(),
))
}
RouteId::ReceiveQuery => {
let req = req.into::<QueryConfig>()?;
HelperResponse::from(
qp.new_query(Transport::clone_ref(&self.transport), req)
.await?,
)
}
RouteId::PrepareQuery => {
let req = req.into::<PrepareQuery>()?;
HelperResponse::from(qp.prepare(&self.transport, req)?)
}
RouteId::QueryInput => {
let query_id = ext_query_id(&req)?;
HelperResponse::from(qp.receive_inputs(
Transport::clone_ref(&self.transport),
QueryInput {
query_id,
input_stream: data,
},
)?)
}
RouteId::QueryStatus => {
let query_id = ext_query_id(&req)?;
HelperResponse::from(qp.query_status(query_id)?)
}
RouteId::CompleteQuery => {
let query_id = ext_query_id(&req)?;
HelperResponse::from(qp.complete(query_id).await?)
}
})
}
}
4 changes: 2 additions & 2 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> {
});

let key_registry = hpke_registry(mk_encryption.as_ref()).await?;
let (setup, callbacks) = AppSetup::with_key_registry(key_registry);
let (setup, handler) = AppSetup::with_key_registry(key_registry);

let server_config = ServerConfig {
port: args.port,
Expand All @@ -155,7 +155,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> {
server_config,
network_config,
clients,
callbacks,
Some(handler),
);

let _app = setup.connect(transport.clone());
Expand Down
3 changes: 2 additions & 1 deletion ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use crate::{
gateway::{
receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport,
},
HelperChannelId, Message, Role, RoleAssignment, RouteId, TotalRecords, Transport,
transport::routing::RouteId,
HelperChannelId, Message, Role, RoleAssignment, TotalRecords, Transport,
},
protocol::QueryId,
};
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/helpers/gateway/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use futures::Stream;

use crate::{
helpers::{
NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteId, RouteParams,
StepBinding, Transport, TransportImpl,
transport::routing::RouteId, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment,
RouteParams, StepBinding, Transport, TransportImpl,
},
protocol::{step::Gate, QueryId},
};
Expand Down
8 changes: 4 additions & 4 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ pub use prss_protocol::negotiate as negotiate_prss;
#[cfg(feature = "web-app")]
pub use transport::WrappedAxumBodyStream;
pub use transport::{
callbacks::*, query, BodyStream, BytesStream, Identity as TransportIdentity,
LengthDelimitedStream, LogErrors, NoResourceIdentifier, QueryIdBinding, ReceiveRecords,
RecordsStream, RouteId, RouteParams, StepBinding, StreamCollection, StreamKey, Transport,
WrappedBoxBodyStream,
make_owned_handler, query, routing, ApiError, BodyStream, BytesStream, HandlerBox, HandlerRef,
HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId,
NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RecordsStream, RequestHandler,
RouteParams, StepBinding, StreamCollection, StreamKey, Transport, WrappedBoxBodyStream,
};
#[cfg(feature = "in-memory-infra")]
pub use transport::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport};
Expand Down
Loading

0 comments on commit fa79cee

Please sign in to comment.