From 65996fb9218bd2eca85b4b1d997aa50ae3a942f7 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 7 Apr 2024 11:53:58 -0700 Subject: [PATCH 1/4] Implement shard communication channels for MPC circuits This change adds the necessary functionality to `Gateway` to be able to communicate between shards. `get_shard_sender` and `get_shard_receiver` will allow circuits to send and request data from other shards, similarly to MPC send/recv. There is a difference in how receives are handled for MPC and shards. The former channels use `UnorderedReceiver` that lets them receive records in any order. Shard receivers return a stream that has a FIFO order. This is a requirement that came from analysing shuffle and other protocols that require cross-shard communication. Each shard does not know in advance how many records it expects to receive from any shard. This also adds a stub for HTTP shard transport, just to make the code compile. Actual HTTP implementation will come later. In terms of next steps, there remains a building block for resharding shares and sharded shuffle implementation. After these two are complete, in memory implementation will be fully functional --- ipa-core/src/app.rs | 31 ++- ipa-core/src/bin/helper.rs | 4 +- .../src/helpers/buffers/ordering_sender.rs | 4 +- ipa-core/src/helpers/gateway/mod.rs | 230 ++++++++++++++---- ipa-core/src/helpers/gateway/receive.rs | 73 +++++- ipa-core/src/helpers/gateway/send.rs | 168 +++++++------ .../src/helpers/gateway/stall_detection.rs | 158 +++++++++--- ipa-core/src/helpers/gateway/transport.rs | 17 +- ipa-core/src/helpers/mod.rs | 45 ++-- ipa-core/src/helpers/prss_protocol.rs | 8 +- .../helpers/transport/in_memory/sharding.rs | 33 +++ ipa-core/src/helpers/transport/mod.rs | 34 ++- ipa-core/src/helpers/transport/receive.rs | 2 +- .../src/helpers/transport/stream/input.rs | 7 +- ipa-core/src/helpers/transport/stream/mod.rs | 2 +- ipa-core/src/net/mod.rs | 2 +- ipa-core/src/net/transport.rs | 44 +++- ipa-core/src/protocol/context/malicious.rs | 40 ++- ipa-core/src/protocol/context/mod.rs | 48 +++- ipa-core/src/protocol/context/semi_honest.rs | 29 ++- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 4 +- ipa-core/src/query/processor.rs | 21 +- ipa-core/src/test_fixture/app.rs | 23 +- ipa-core/src/test_fixture/mod.rs | 2 +- ipa-core/src/test_fixture/world.rs | 28 ++- 25 files changed, 790 insertions(+), 267 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 0ca99287d..eee1e434a 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -7,7 +7,7 @@ use crate::{ query::{PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse, - RequestHandler, Transport, TransportImpl, + MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport, }, hpke::{KeyPair, KeyRegistry}, protocol::QueryId, @@ -32,7 +32,8 @@ struct Inner { /// on top of atomics and all fun stuff associated with it. I don't see an easy way to avoid that /// if we want to keep the implementation leak-free, but one may be aware if this shows up on /// the flamegraph - transport: TransportImpl, + mpc_transport: MpcTransportImpl, + shard_transport: ShardTransportImpl, } impl Setup { @@ -55,10 +56,15 @@ impl Setup { } /// Instantiate [`HelperApp`] by connecting it to the provided transport implementation - pub fn connect(self, transport: TransportImpl) -> HelperApp { + pub fn connect( + self, + mpc_transport: MpcTransportImpl, + shard_transport: ShardTransportImpl, + ) -> HelperApp { let app = Arc::new(Inner { query_processor: self.query_processor, - transport, + mpc_transport, + shard_transport, }); self.handler.set_handler( Arc::downgrade(&app) as Weak> @@ -80,7 +86,10 @@ impl HelperApp { Ok(self .inner .query_processor - .new_query(Transport::clone_ref(&self.inner.transport), query_config) + .new_query( + Transport::clone_ref(&self.inner.mpc_transport), + query_config, + ) .await? .query_id) } @@ -90,10 +99,11 @@ impl HelperApp { /// ## Errors /// Propagates errors from the helper. pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> { - let transport = ::clone(&self.inner.transport); + let mpc_transport = Transport::clone_ref(&self.inner.mpc_transport); + let shard_transport = Transport::clone_ref(&self.inner.shard_transport); self.inner .query_processor - .receive_inputs(transport, input)?; + .receive_inputs(mpc_transport, shard_transport, input)?; Ok(()) } @@ -145,18 +155,19 @@ impl RequestHandler for Inner { RouteId::ReceiveQuery => { let req = req.into::()?; HelperResponse::from( - qp.new_query(Transport::clone_ref(&self.transport), req) + qp.new_query(Transport::clone_ref(&self.mpc_transport), req) .await?, ) } RouteId::PrepareQuery => { let req = req.into::()?; - HelperResponse::from(qp.prepare(&self.transport, req)?) + HelperResponse::from(qp.prepare(&self.mpc_transport, req)?) } RouteId::QueryInput => { let query_id = ext_query_id(&req)?; HelperResponse::from(qp.receive_inputs( - Transport::clone_ref(&self.transport), + Transport::clone_ref(&self.mpc_transport), + Transport::clone_ref(&self.shard_transport), QueryInput { query_id, input_stream: data, diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 9ac13f670..81cdf214c 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -15,7 +15,7 @@ use ipa_core::{ config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig}, error::BoxError, helpers::HelperIdentity, - net::{ClientIdentity, HttpTransport, MpcHelperClient}, + net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, AppSetup, }; use tracing::{error, info}; @@ -158,7 +158,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { Some(handler), ); - let _app = setup.connect(transport.clone()); + let _app = setup.connect(transport.clone(), HttpShardTransport); let listener = args.server_socket_fd .map(|fd| { diff --git a/ipa-core/src/helpers/buffers/ordering_sender.rs b/ipa-core/src/helpers/buffers/ordering_sender.rs index b2a9e9ec7..531508227 100644 --- a/ipa-core/src/helpers/buffers/ordering_sender.rs +++ b/ipa-core/src/helpers/buffers/ordering_sender.rs @@ -524,7 +524,7 @@ mod test { use super::OrderingSender; use crate::{ ff::{Fp31, Fp32BitPrime, Gf20Bit, Gf9Bit, Serializable, U128Conversions}, - helpers::Message, + helpers::MpcMessage, rand::thread_rng, sync::Arc, test_executor::run, @@ -622,7 +622,7 @@ mod test { >; // Given a message, returns a closure that sends the message and increments an associated record index. - fn send_fn(m: M) -> BoxedSendFn { + fn send_fn(m: M) -> BoxedSendFn { Box::new(|s: &OrderingSender, i: &mut usize| { let fut = s.send(*i, m).boxed(); *i += 1; diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index d973d6c6f..740d4fefd 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -6,23 +6,26 @@ mod transport; use std::num::NonZeroUsize; -pub(super) use receive::ReceivingEnd; +pub(super) use receive::{MpcReceivingEnd, ShardReceivingEnd}; pub(super) use send::SendingEnd; -#[cfg(all(test, feature = "shuttle"))] -use shuttle::future as tokio; #[cfg(feature = "stall-detection")] pub(super) use stall_detection::InstrumentedGateway; +pub use transport::RoleResolvingTransport; use crate::{ helpers::{ buffers::UnorderedReceiver, gateway::{ - receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport, + receive::{GatewayReceivers, ShardReceiveStream, UR}, + send::GatewaySenders, + transport::Transports, }, - transport::routing::RouteId, - HelperChannelId, LogErrors, Message, Role, RoleAssignment, TotalRecords, Transport, + HelperChannelId, LogErrors, Message, MpcMessage, RecordsStream, Role, RoleAssignment, + ShardChannelId, TotalRecords, Transport, }, protocol::QueryId, + sharding::ShardIndex, + sync::{Arc, Mutex}, }; /// Alias for the currently configured transport. @@ -30,17 +33,25 @@ use crate::{ /// To avoid proliferation of type parameters, most code references this concrete type alias, rather /// than a type parameter `T: Transport`. #[cfg(feature = "in-memory-infra")] -pub type TransportImpl = super::transport::InMemoryTransport; +type TransportImpl = super::transport::InMemoryTransport; +#[cfg(feature = "in-memory-infra")] +pub type MpcTransportImpl = TransportImpl; +#[cfg(feature = "in-memory-infra")] +pub type ShardTransportImpl = TransportImpl; #[cfg(feature = "real-world-infra")] -pub type TransportImpl = crate::sync::Arc; +type TransportImpl = crate::sync::Arc; +#[cfg(feature = "real-world-infra")] +pub type MpcTransportImpl = TransportImpl; +#[cfg(feature = "real-world-infra")] +pub type ShardTransportImpl = crate::net::HttpShardTransport; -pub type TransportError = ::Error; +pub type MpcTransportError = ::Error; /// Gateway into IPA Network infrastructure. It allows helpers send and receive messages. pub struct Gateway { config: GatewayConfig, - transport: RoleResolvingTransport, + transports: Transports, query_id: QueryId, #[cfg(feature = "stall-detection")] inner: crate::sync::Arc, @@ -50,8 +61,10 @@ pub struct Gateway { #[derive(Default)] pub struct State { - senders: GatewaySenders, - receivers: GatewayReceivers, + mpc_senders: GatewaySenders, + mpc_receivers: GatewayReceivers, + shard_senders: GatewaySenders, + shard_receivers: GatewayReceivers, } #[derive(Clone, Copy, Debug)] @@ -73,15 +86,19 @@ impl Gateway { query_id: QueryId, config: GatewayConfig, roles: RoleAssignment, - transport: TransportImpl, + mpc_transport: MpcTransportImpl, + shard_transport: ShardTransportImpl, ) -> Self { #[allow(clippy::useless_conversion)] // not useless in stall-detection build Self { query_id, config, - transport: RoleResolvingTransport { - roles, - inner: transport, + transports: Transports { + mpc: RoleResolvingTransport { + roles, + inner: mpc_transport, + }, + shard: shard_transport, }, inner: State::default().into(), } @@ -89,7 +106,7 @@ impl Gateway { #[must_use] pub fn role(&self) -> Role { - self.transport.identity() + self.transports.mpc.identity() } #[must_use] @@ -97,52 +114,75 @@ impl Gateway { &self.config } + /// Returns a sender suitable for sending data between MPC helpers. The data must be approved + /// for sending by implementing [`MpcMessage`] trait. + /// + /// Do not remove the test below, it verifies that we don't allow raw sharings to be sent + /// between MPC helpers without using secure reveal. + /// + /// ```compile_fail + /// use ipa_core::helpers::Gateway; + /// use ipa_core::secret_sharing::replicated::semi_honest::AdditiveShare; + /// use ipa_core::ff::Fp32BitPrime; + /// + /// let gateway: Gateway = todo!(); + /// let mpc_channel = gateway.get_mpc_sender::>(todo!(), todo!()); + /// ``` /// /// ## Panics /// If there is a failure connecting via HTTP #[must_use] - pub fn get_sender( + pub fn get_mpc_sender( &self, channel_id: &HelperChannelId, total_records: TotalRecords, - ) -> send::SendingEnd { - let (tx, maybe_stream) = self.inner.senders.get_or_create::( + ) -> send::SendingEnd { + let transport = &self.transports.mpc; + let channel = self.inner.mpc_senders.get::( channel_id, + transport, self.config.active_work(), + self.query_id, total_records, ); - if let Some(stream) = maybe_stream { - tokio::spawn({ - let channel_id = channel_id.clone(); - let transport = self.transport.clone(); - let query_id = self.query_id; - async move { - // TODO(651): In the HTTP case we probably need more robust error handling here. - transport - .send( - channel_id.peer, - (RouteId::Records, query_id, channel_id.gate), - stream, - ) - .await - .expect("{channel_id:?} receiving end should be accepted by transport"); - } - }); - } - send::SendingEnd::new(tx, self.role(), channel_id) + send::SendingEnd::new(channel, transport.identity()) + } + + /// Returns a sender for shard-to-shard traffic. This sender is more relaxed compared to one + /// returned by [`Self::get_mpc_sender`] as it allows anything that can be serialized into bytes + /// to be sent out. MPC sender needs to be more careful about it and not to allow sending sensitive + /// information to be accidentally revealed. + /// An example of such sensitive data could be secret sharings - it is perfectly fine to send them + /// between shards as they are known to each helper anyway. Sending them across MPC helper boundary + /// could lead to information reveal. + pub fn get_shard_sender( + &self, + channel_id: &ShardChannelId, + total_records: TotalRecords, + ) -> send::SendingEnd { + let transport = &self.transports.shard; + let channel = self.inner.shard_senders.get::( + channel_id, + transport, + self.config.active_work(), + self.query_id, + total_records, + ); + + send::SendingEnd::new(channel, transport.identity()) } #[must_use] - pub fn get_receiver( + pub fn get_mpc_receiver( &self, channel_id: &HelperChannelId, - ) -> receive::ReceivingEnd { - receive::ReceivingEnd::new( + ) -> receive::MpcReceivingEnd { + receive::MpcReceivingEnd::new( channel_id.clone(), - self.inner.receivers.get_or_create(channel_id, || { + self.inner.mpc_receivers.get_or_create(channel_id, || { UnorderedReceiver::new( - Box::pin(LogErrors::new(self.transport.receive( + Box::pin(LogErrors::new(self.transports.mpc.receive( channel_id.peer, (self.query_id, channel_id.gate.clone()), ))), @@ -151,6 +191,33 @@ impl Gateway { }), ) } + + /// Requests a stream of records to be received from the given shard. In contrast with + /// [`Self::get_mpc_receiver`] stream, items in this stream are available in FIFO order only. + pub fn get_shard_receiver( + &self, + channel_id: &ShardChannelId, + ) -> receive::ShardReceivingEnd { + let mut called_before = true; + let rx = self.inner.shard_receivers.get_or_create(channel_id, || { + called_before = false; + ShardReceiveStream(Arc::new(Mutex::new( + self.transports + .shard + .receive(channel_id.peer, (self.query_id, channel_id.gate.clone())), + ))) + }); + + assert!( + !called_before, + "Shard receiver {channel_id:?} can only be created once" + ); + + receive::ShardReceivingEnd { + channel_id: channel_id.clone(), + rx: RecordsStream::new(rx), + } + } } impl Default for GatewayConfig { @@ -192,13 +259,19 @@ impl GatewayConfig { mod tests { use std::iter::{repeat, zip}; - use futures_util::future::{join, try_join, try_join_all}; + use futures::{ + future::{join, try_join, try_join_all}, + stream::StreamExt, + }; use crate::{ - ff::{Fp31, Fp32BitPrime, Gf2, U128Conversions}, - helpers::{Direction, GatewayConfig, Message, Role, SendingEnd}, + ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, + helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd}, protocol::{context::Context, RecordId}, - test_fixture::{Runner, TestWorld, TestWorldConfig}, + secret_sharing::replicated::semi_honest::AdditiveShare, + sharding::ShardConfiguration, + test_executor::run, + test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, }; /// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size. @@ -208,7 +281,7 @@ mod tests { /// Gateway must be able to deal with it. #[tokio::test] async fn can_handle_heterogeneous_channels() { - async fn send(channel: &SendingEnd, i: usize) { + async fn send(channel: &SendingEnd, i: usize) { channel .send(i.into(), V::truncate_from(u128::try_from(i).unwrap())) .await @@ -380,6 +453,65 @@ mod tests { let _world = unsafe { Box::from_raw(world_ptr) }; } + #[test] + fn shards() { + run(|| async move { + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + shard_comms_test(&world).await; + }); + } + + #[test] + #[should_panic( + expected = "Shard receiver channel[ShardIndex(1),\"protocol/iter0\"] can only be created once" + )] + fn shards_receive_twice() { + run(|| async move { + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::<()>::new().into_iter(), |ctx, _| async move { + let peer = ctx.peer_shards().next().unwrap(); + let recv1 = ctx.shard_recv_channel::(peer); + let recv2 = ctx.shard_recv_channel::(peer); + drop(recv1); + drop(recv2); + }) + .await; + }); + } + + async fn shard_comms_test(test_world: &TestWorld>) { + let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; + + let r = test_world + .semi_honest(input.clone().into_iter(), |ctx, input| async move { + let ctx = ctx.set_total_records(input.len()); + // Swap shares between shards, works only for 2 shards. + let peer = ctx.peer_shards().next().unwrap(); + for (record_id, item) in input.into_iter().enumerate() { + ctx.shard_send_channel(peer) + .send(record_id.into(), item) + .await + .unwrap(); + } + + let mut r = Vec::>::new(); + let mut recv_channel = ctx.shard_recv_channel(peer); + while let Some(v) = recv_channel.next().await { + r.push(v.unwrap()); + } + + r + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + let reverse_input = input.into_iter().rev().collect::>(); + assert_eq!(reverse_input, r); + } + fn make_world() -> (&'static TestWorld, *mut TestWorld) { let world = Box::leak(Box::::default()); let world_ptr = world as *mut _; diff --git a/ipa-core/src/helpers/gateway/receive.rs b/ipa-core/src/helpers/gateway/receive.rs index a98166e9f..ad2ccbf61 100644 --- a/ipa-core/src/helpers/gateway/receive.rs +++ b/ipa-core/src/helpers/gateway/receive.rs @@ -1,36 +1,61 @@ -use std::marker::PhantomData; +use std::{ + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; use bytes::Bytes; use dashmap::{mapref::entry::Entry, DashMap}; +use futures::Stream; +use pin_project::pin_project; use crate::{ error::BoxError, helpers::{ - buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, Error, - HelperChannelId, LogErrors, Message, Role, Transport, + buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, + transport::SingleRecordStream, ChannelId, Error, HelperChannelId, LogErrors, Message, + MpcMessage, Role, ShardChannelId, ShardTransportImpl, Transport, TransportIdentity, }, protocol::RecordId, + sync::{Arc, Mutex}, }; -/// Receiving end of the gateway channel. -pub struct ReceivingEnd { +/// Receiving end of the MPC gateway channel. +/// I tried to make it generic and work for both MPC and Shard connectors, but ran into +/// "implementation of `S` is not general enough" issue on the client side (reveal). It may be another +/// occurrence of [`gat`] issue +/// +/// [`gat`]: https://github.com/rust-lang/rust/issues/100013 +pub struct MpcReceivingEnd { channel_id: HelperChannelId, unordered_rx: UR, _phantom: PhantomData, } +#[pin_project] +pub struct ShardReceivingEnd { + pub(super) channel_id: ShardChannelId, + #[pin] + pub(super) rx: SingleRecordStream, +} + /// Receiving channels, indexed by (role, step). -#[derive(Default)] -pub(super) struct GatewayReceivers { - pub(super) inner: DashMap, +pub(super) struct GatewayReceivers { + pub(super) inner: DashMap, S>, } -pub(super) type UR = UnorderedReceiver< +pub type UR = UnorderedReceiver< LogErrors<::RecordsStream, Bytes, BoxError>, Vec, >; -impl ReceivingEnd { +/// Stream of records received from a peer shard. +#[derive(Clone)] +pub struct ShardReceiveStream( + pub(super) Arc::RecordsStream>>, +); + +impl MpcReceivingEnd { pub(super) fn new(channel_id: HelperChannelId, rx: UR) -> Self { Self { channel_id, @@ -61,8 +86,24 @@ impl ReceivingEnd { } } -impl GatewayReceivers { - pub fn get_or_create UR>(&self, channel_id: &HelperChannelId, ctr: F) -> UR { +impl Stream for ShardReceivingEnd { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_next(cx) + } +} + +impl Default for GatewayReceivers { + fn default() -> Self { + Self { + inner: DashMap::default(), + } + } +} + +impl GatewayReceivers { + pub fn get_or_create S>(&self, channel_id: &ChannelId, ctr: F) -> S { // TODO: raw entry API if it becomes available to avoid cloning the key match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => entry.get().clone(), @@ -75,3 +116,11 @@ impl GatewayReceivers { } } } + +impl Stream for ShardReceiveStream { + type Item = <::RecordsStream as Stream>::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.0.lock().unwrap()).as_mut().poll_next(cx) + } +} diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 473deb486..72e147c61 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -8,11 +8,16 @@ use std::{ use dashmap::{mapref::entry::Entry, DashMap}; use futures::Stream; +#[cfg(all(test, feature = "shuttle"))] +use shuttle::future as tokio; use typenum::Unsigned; use crate::{ - helpers::{buffers::OrderingSender, Error, HelperChannelId, Message, Role, TotalRecords}, - protocol::RecordId, + helpers::{ + buffers::OrderingSender, routing::RouteId, ChannelId, Error, Message, TotalRecords, + Transport, TransportIdentity, + }, + protocol::{QueryId, RecordId}, sync::Arc, telemetry::{ labels::{ROLE, STEP}, @@ -21,31 +26,37 @@ use crate::{ }; /// Sending end of the gateway channel. -pub struct SendingEnd { - sender_role: Role, - channel_id: HelperChannelId, - inner: Arc, +pub struct SendingEnd { + sender_id: I, + inner: Arc>, _phantom: PhantomData, } -/// Sending channels, indexed by (role, step). -#[derive(Default)] -pub(super) struct GatewaySenders { - pub(super) inner: DashMap>, +/// Sending channels, indexed by identity and gate. +pub(super) struct GatewaySenders { + pub(super) inner: DashMap, Arc>>, } -pub(super) struct GatewaySender { - channel_id: HelperChannelId, +pub(super) struct GatewaySender { + channel_id: ChannelId, ordering_tx: OrderingSender, total_records: TotalRecords, } -pub(super) struct GatewaySendStream { - inner: Arc, +struct GatewaySendStream { + inner: Arc>, +} + +impl Default for GatewaySenders { + fn default() -> Self { + Self { + inner: DashMap::default(), + } + } } -impl GatewaySender { - fn new(channel_id: HelperChannelId, tx: OrderingSender, total_records: TotalRecords) -> Self { +impl GatewaySender { + fn new(channel_id: ChannelId, tx: OrderingSender, total_records: TotalRecords) -> Self { Self { channel_id, ordering_tx: tx, @@ -57,7 +68,7 @@ impl GatewaySender { &self, record_id: RecordId, msg: B, - ) -> Result<(), Error> { + ) -> Result<(), Error> { debug_assert!( self.total_records.is_specified(), "total_records cannot be unspecified when sending" @@ -94,15 +105,10 @@ impl GatewaySender { } } -impl SendingEnd { - pub(super) fn new( - sender: Arc, - role: Role, - channel_id: &HelperChannelId, - ) -> Self { +impl SendingEnd { + pub(super) fn new(sender: Arc>, id: I) -> Self { Self { - sender_role: role, - channel_id: channel_id.clone(), + sender_id: id, inner: sender, _phantom: PhantomData, } @@ -117,32 +123,38 @@ impl SendingEnd { /// call. /// /// [`set_total_records`]: crate::protocol::context::Context::set_total_records - #[tracing::instrument(level = "trace", "send", skip_all, fields(i = %record_id, total = %self.inner.total_records, to = ?self.channel_id.peer, gate = ?self.channel_id.gate.as_ref()))] - pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error> { + #[tracing::instrument(level = "trace", "send", skip_all, fields( + i = %record_id, + total = %self.inner.total_records, + to = ?self.inner.channel_id.peer, + gate = ?self.inner.channel_id.gate.as_ref() + ))] + pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error> { let r = self.inner.send(record_id, msg).await; metrics::increment_counter!(RECORDS_SENT, - STEP => self.channel_id.gate.as_ref().to_string(), - ROLE => self.sender_role.as_static_str() + STEP => self.inner.channel_id.gate.as_ref().to_string(), + ROLE => self.sender_id.as_str(), ); metrics::counter!(BYTES_SENT, M::Size::U64, - STEP => self.channel_id.gate.as_ref().to_string(), - ROLE => self.sender_role.as_static_str() + STEP => self.inner.channel_id.gate.as_ref().to_string(), + ROLE => self.sender_id.as_str(), ); r } } -impl GatewaySenders { - /// Returns or creates a new communication channel. In case if channel is newly created, - /// returns the receiving end of it as well. It must be send over to the receiver in order for - /// messages to get through. - pub(crate) fn get_or_create( +impl GatewaySenders { + /// Returns a communication channel for the given [`ChannelId`]. If it does not exist, it will + /// be created using the provided [`Transport`] implementation. + pub fn get>( &self, - channel_id: &HelperChannelId, + channel_id: &ChannelId, + transport: &T, capacity: NonZeroUsize, + query_id: QueryId, total_records: TotalRecords, // TODO track children for indeterminate senders - ) -> (Arc, Option) { + ) -> Arc> { assert!( total_records.is_specified(), "unspecified total records for {channel_id:?}" @@ -150,44 +162,64 @@ impl GatewaySenders { // TODO: raw entry API would be nice to have here but it's not exposed yet match self.inner.entry(channel_id.clone()) { - Entry::Occupied(entry) => (Arc::clone(entry.get()), None), + Entry::Occupied(entry) => Arc::clone(entry.get()), Entry::Vacant(entry) => { - // Spare buffer is not required when messages have uniform size and buffer is a - // multiple of that size. - const SPARE: usize = 0; - // a little trick - if number of records is indeterminate, set the capacity to one - // message. Any send will wake the stream reader then, effectively disabling - // buffering. This mode is clearly inefficient, so avoid using this mode. - let write_size = if total_records.is_indeterminate() { - NonZeroUsize::new(M::Size::USIZE).unwrap() - } else { - // capacity is defined in terms of number of elements, while sender wants bytes - // so perform the conversion here - capacity - .checked_mul( - NonZeroUsize::new(M::Size::USIZE) - .expect("Message size should be greater than 0"), - ) - .expect("capacity should not overflow") - }; - - let sender = Arc::new(GatewaySender::new( - channel_id.clone(), - OrderingSender::new(write_size, SPARE), - total_records, - )); + let sender = Self::new_sender::(capacity, channel_id.clone(), total_records); entry.insert(Arc::clone(&sender)); - ( - Arc::clone(&sender), - Some(GatewaySendStream { inner: sender }), - ) + tokio::spawn({ + let ChannelId { peer, gate } = channel_id.clone(); + let transport = transport.clone(); + let stream = GatewaySendStream { + inner: Arc::clone(&sender), + }; + async move { + // TODO(651): In the HTTP case we probably need more robust error handling here. + transport + .send(peer, (RouteId::Records, query_id, gate), stream) + .await + .expect("{channel_id:?} receiving end should be accepted by transport"); + } + }); + + sender } } } + + fn new_sender( + capacity: NonZeroUsize, + channel_id: ChannelId, + total_records: TotalRecords, + ) -> Arc> { + // Spare buffer is not required when messages have uniform size and buffer is a + // multiple of that size. + const SPARE: usize = 0; + // a little trick - if number of records is indeterminate, set the capacity to one + // message. Any send will wake the stream reader then, effectively disabling + // buffering. This mode is clearly inefficient, so avoid using this mode. + let write_size = if total_records.is_indeterminate() { + NonZeroUsize::new(M::Size::USIZE).unwrap() + } else { + // capacity is defined in terms of number of elements, while sender wants bytes + // so perform the conversion here + capacity + .checked_mul( + NonZeroUsize::new(M::Size::USIZE) + .expect("Message size should be greater than 0"), + ) + .expect("capacity should not overflow") + }; + + Arc::new(GatewaySender::new( + channel_id, + OrderingSender::new(write_size, SPARE), + total_records, + )) + } } -impl Stream for GatewaySendStream { +impl Stream for GatewaySendStream { type Item = Vec; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 654fbb11c..2e1f24267 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -73,11 +73,12 @@ mod gateway { use super::{receive, send, AtomicUsize, Debug, Formatter, ObserveState, Observed, Weak}; use crate::{ helpers::{ - gateway::{Gateway, State}, - GatewayConfig, HelperChannelId, Message, ReceivingEnd, Role, RoleAssignment, - SendingEnd, TotalRecords, TransportImpl, + gateway::{Gateway, ShardTransportImpl, State}, + GatewayConfig, HelperChannelId, Message, MpcMessage, MpcReceivingEnd, MpcTransportImpl, + Role, RoleAssignment, SendingEnd, ShardChannelId, ShardReceivingEnd, TotalRecords, }, protocol::QueryId, + sharding::ShardIndex, sync::Arc, }; @@ -105,13 +106,14 @@ mod gateway { query_id: QueryId, config: GatewayConfig, roles: RoleAssignment, - transport: TransportImpl, + mpc_transport: MpcTransportImpl, + shard_transport: ShardTransportImpl, ) -> Self { let version = Arc::new(AtomicUsize::default()); let r = Self::wrap( Arc::downgrade(&version), InstrumentedGateway { - gateway: Gateway::new(query_id, config, roles, transport), + gateway: Gateway::new(query_id, config, roles, mpc_transport, shard_transport), _sn: version, }, ); @@ -147,22 +149,50 @@ mod gateway { } #[must_use] - pub fn get_sender( + pub fn get_mpc_sender( &self, channel_id: &HelperChannelId, total_records: TotalRecords, - ) -> SendingEnd { + ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), - self.inner().gateway.get_sender(channel_id, total_records), + self.inner() + .gateway + .get_mpc_sender(channel_id, total_records), + ) + } + + pub fn get_shard_sender( + &self, + channel_id: &ShardChannelId, + total_records: TotalRecords, + ) -> SendingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner + .gateway + .get_shard_sender(channel_id, total_records), ) } #[must_use] - pub fn get_receiver(&self, channel_id: &HelperChannelId) -> ReceivingEnd { + pub fn get_mpc_receiver( + &self, + channel_id: &HelperChannelId, + ) -> MpcReceivingEnd { Observed::wrap( Weak::clone(self.get_sn()), - self.inner().gateway.get_receiver(channel_id), + self.inner().gateway.get_mpc_receiver(channel_id), + ) + } + + pub fn get_shard_receiver( + &self, + channel_id: &ShardChannelId, + ) -> ShardReceivingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner().gateway.get_shard_receiver(channel_id), ) } @@ -175,17 +205,25 @@ mod gateway { } } - pub struct GatewayWaitingTasks { - senders_state: Option, - receivers_state: Option, + pub struct GatewayWaitingTasks { + mpc_send: Option, + mpc_recv: Option, + shard_send: Option, + shard_recv: Option, } - impl Debug for GatewayWaitingTasks { + impl Debug for GatewayWaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - if let Some(senders_state) = &self.senders_state { + if let Some(senders_state) = &self.mpc_send { write!(f, "\n{{{senders_state:?}\n}}")?; } - if let Some(receivers_state) = &self.receivers_state { + if let Some(receivers_state) = &self.mpc_recv { + write!(f, "\n{{{receivers_state:?}\n}}")?; + } + if let Some(senders_state) = &self.shard_send { + write!(f, "\n{{{senders_state:?}\n}}")?; + } + if let Some(receivers_state) = &self.shard_recv { write!(f, "\n{{{receivers_state:?}\n}}")?; } @@ -194,15 +232,27 @@ mod gateway { } impl ObserveState for Weak { - type State = GatewayWaitingTasks; + type State = GatewayWaitingTasks< + send::WaitingTasks, + receive::WaitingTasks, + send::WaitingTasks, + receive::WaitingTasks, + >; fn get_state(&self) -> Option { self.upgrade().and_then(|state| { - match (state.senders.get_state(), state.receivers.get_state()) { - (None, None) => None, - (senders_state, receivers_state) => Some(Self::State { - senders_state, - receivers_state, + match ( + state.mpc_senders.get_state(), + state.mpc_receivers.get_state(), + state.shard_senders.get_state(), + state.shard_receivers.get_state(), + ) { + (None, None, None, None) => None, + (mpc_send, mpc_recv, shard_send, shard_recv) => Some(Self::State { + mpc_send, + mpc_recv, + shard_send, + shard_recv, }), } }) @@ -214,19 +264,27 @@ mod receive { use std::{ collections::BTreeMap, fmt::{Debug, Formatter}, + pin::Pin, + task::{Context, Poll}, }; + use futures::Stream; + use super::{ObserveState, Observed}; use crate::{ helpers::{ error::Error, - gateway::{receive::GatewayReceivers, ReceivingEnd}, - HelperChannelId, Message, Role, + gateway::{ + receive::{GatewayReceivers, ShardReceiveStream, ShardReceivingEnd, UR}, + MpcReceivingEnd, + }, + ChannelId, Message, MpcMessage, Role, TransportIdentity, }, protocol::RecordId, + sharding::ShardIndex, }; - impl Observed> { + impl Observed> { delegate::delegate! { to { self.advance(); self.inner() } { #[inline] @@ -235,9 +293,18 @@ mod receive { } } - pub struct WaitingTasks(BTreeMap>); + impl Stream for Observed> { + type Item = as Stream>::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.advance(); + Pin::new(&mut self.inner).poll_next(cx) + } + } - impl Debug for WaitingTasks { + pub struct WaitingTasks(BTreeMap, Vec>); + + impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { for (channel, records) in &self.0 { write!( @@ -251,8 +318,8 @@ mod receive { } } - impl ObserveState for GatewayReceivers { - type State = WaitingTasks; + impl ObserveState for GatewayReceivers { + type State = WaitingTasks; fn get_state(&self) -> Option { let mut map = BTreeMap::default(); @@ -266,6 +333,23 @@ mod receive { (!map.is_empty()).then_some(WaitingTasks(map)) } } + + impl ObserveState for GatewayReceivers { + type State = WaitingTasks; + + fn get_state(&self) -> Option { + let mut map = BTreeMap::default(); + for entry in &self.inner { + let channel = entry.key(); + map.insert( + channel.clone(), + vec!["Shard receiver state is not implemented yet".to_string()], + ); + } + + (!map.is_empty()).then_some(WaitingTasks(map)) + } + } } mod send { @@ -280,23 +364,23 @@ mod send { helpers::{ error::Error, gateway::send::{GatewaySender, GatewaySenders}, - HelperChannelId, Message, Role, TotalRecords, + ChannelId, Message, TotalRecords, TransportIdentity, }, protocol::RecordId, }; - impl Observed> { + impl Observed> { delegate::delegate! { to { self.advance(); self.inner() } { #[inline] - pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error>; + pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error>; } } } - pub struct WaitingTasks(BTreeMap)>); + pub struct WaitingTasks(BTreeMap, (TotalRecords, Vec)>); - impl Debug for WaitingTasks { + impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { for (channel, (total, records)) in &self.0 { write!( @@ -310,8 +394,8 @@ mod send { } } - impl ObserveState for GatewaySenders { - type State = WaitingTasks; + impl ObserveState for GatewaySenders { + type State = WaitingTasks; fn get_state(&self) -> Option { let mut state = BTreeMap::new(); @@ -327,7 +411,7 @@ mod send { } } - impl ObserveState for GatewaySender { + impl ObserveState for GatewaySender { type State = Vec; fn get_state(&self) -> Option { diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index 43840ce4a..a014ce0c7 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -3,15 +3,16 @@ use futures::Stream; use crate::{ helpers::{ - transport::routing::RouteId, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, - RouteParams, StepBinding, Transport, TransportImpl, + transport::routing::RouteId, MpcTransportImpl, NoResourceIdentifier, QueryIdBinding, Role, + RoleAssignment, RouteParams, StepBinding, Transport, }, protocol::{step::Gate, QueryId}, + sharding::ShardIndex, }; #[derive(Debug, thiserror::Error)] #[error("Failed to send to {0:?}: {1:?}")] -pub struct SendToRoleError(Role, ::Error); +pub struct SendToRoleError(Role, ::Error); /// Transport adapter that resolves [`Role`] -> [`HelperIdentity`] mapping. As gateways created /// per query, it is not ambiguous. @@ -20,13 +21,19 @@ pub struct SendToRoleError(Role, ::Error); #[derive(Clone)] pub struct RoleResolvingTransport { pub(super) roles: RoleAssignment, - pub(super) inner: TransportImpl, + pub(super) inner: MpcTransportImpl, +} + +/// Set of transports used inside [`super::Gateway`]. +pub(super) struct Transports, S: Transport> { + pub mpc: M, + pub shard: S, } #[async_trait] impl Transport for RoleResolvingTransport { type Identity = Role; - type RecordsStream = ::RecordsStream; + type RecordsStream = ::RecordsStream; type Error = SendToRoleError; fn identity(&self) -> Role { diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index e6416b565..2d7ea2183 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -13,6 +13,7 @@ mod gateway; pub(crate) mod prss_protocol; pub mod stream; mod transport; + use std::ops::{Index, IndexMut}; /// to validate that transport can actually send streams of this type @@ -24,14 +25,17 @@ use serde::{Deserialize, Serialize, Serializer}; #[cfg(feature = "stall-detection")] mod gateway_exports { + use crate::helpers::{ gateway, gateway::{stall_detection::Observed, InstrumentedGateway}, }; pub type Gateway = Observed; - pub type SendingEnd = Observed>; - pub type ReceivingEnd = Observed>; + pub type SendingEnd = Observed>; + + pub type MpcReceivingEnd = Observed>; + pub type ShardReceivingEnd = Observed>; } #[cfg(not(feature = "stall-detection"))] @@ -39,15 +43,18 @@ mod gateway_exports { use crate::helpers::gateway; pub type Gateway = gateway::Gateway; - pub type SendingEnd = gateway::SendingEnd; - pub type ReceivingEnd = gateway::ReceivingEnd; + pub type SendingEnd = gateway::SendingEnd; + pub type MpcReceivingEnd = gateway::MpcReceivingEnd; + pub type ShardReceivingEnd = gateway::ShardReceivingEnd; } pub use gateway::GatewayConfig; // TODO: this type should only be available within infra. Right now several infra modules // are exposed at the root level. That makes it impossible to have a proper hierarchy here. -pub use gateway::{TransportError, TransportImpl}; -pub use gateway_exports::{Gateway, ReceivingEnd, SendingEnd}; +pub use gateway::{ + MpcTransportError, MpcTransportImpl, RoleResolvingTransport, ShardTransportImpl, +}; +pub use gateway_exports::{Gateway, MpcReceivingEnd, SendingEnd, ShardReceivingEnd}; pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; @@ -70,6 +77,7 @@ use crate::{ }, protocol::{step::Gate, RecordId}, secret_sharing::Sendable, + sharding::ShardIndex, }; // TODO work with ArrayLength only @@ -81,7 +89,7 @@ pub const MESSAGE_PAYLOAD_SIZE_BYTES: usize = MessagePayloadArrayLen::USIZE; /// represents a helper's role within an MPC protocol, which may be different per protocol. /// `HelperIdentity` will be established at startup and then never change. Components that want to /// resolve this identifier into something (Uri, encryption keys, etc) must consult configuration -#[derive(Copy, Clone, Eq, PartialEq, Hash, Deserialize)] +#[derive(Copy, Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Deserialize)] #[serde(try_from = "usize")] pub struct HelperIdentity { id: u8, @@ -226,7 +234,6 @@ pub enum Role { #[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] -#[serde(transparent)] pub struct RoleAssignment { helper_roles: [HelperIdentity; 3], } @@ -397,7 +404,7 @@ impl TryFrom<[Role; 3]> for RoleAssignment { /// Combination of helper role and step that uniquely identifies a single channel of communication /// between two helpers. #[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub struct ChannelId { +pub struct ChannelId { /// Entity we are talking to through this channel. It can be a source or a destination. pub peer: I, // TODO: step could be either reference or owned value. references are convenient to use inside @@ -406,6 +413,7 @@ pub struct ChannelId { } pub type HelperChannelId = ChannelId; +pub type ShardChannelId = ChannelId; impl ChannelId { #[must_use] @@ -420,12 +428,21 @@ impl Debug for ChannelId { } } -/// Trait for messages sent between helpers. Everything needs to be serializable and safe to send. +/// Trait for messages that can be communicated over the network. +pub trait Message: Debug + Send + Serializable + 'static {} + +/// Trait for messages that may be sent between MPC helpers. Sending raw field values may be OK, +/// sending secret shares is most definitely not OK. +/// +/// This trait is not implemented for [`SecretShares`] types and there is a doctest inside [`Gateway`] +/// module that ensures compile errors are generated in this case. /// -/// Infrastructure's `Message` trait corresponds to IPA's `Sendable` trait. -pub trait Message: Debug + Send + Serializable + 'static + Sized {} +/// [`SecretShares`]: crate::secret_sharing::replicated::ReplicatedSecretSharing +/// [`Gateway`]: crate::helpers::gateway::Gateway::get_mpc_sender +pub trait MpcMessage: Message {} -impl Message for V {} +impl MpcMessage for V {} +impl Message for V {} impl Serializable for PublicKey { type Size = typenum::U32; @@ -442,7 +459,7 @@ impl Serializable for PublicKey { } } -impl Message for PublicKey {} +impl MpcMessage for PublicKey {} #[derive(Clone, Copy, Debug)] pub enum TotalRecords { diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index 348a36596..001ee9f24 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -37,10 +37,10 @@ pub async fn negotiate( let right_channel = ChannelId::new(gateway.role().peer(Direction::Right), step.clone()); let total_records = TotalRecords::from(1); - let left_sender = gateway.get_sender::(&left_channel, total_records); - let right_sender = gateway.get_sender::(&right_channel, total_records); - let left_receiver = gateway.get_receiver::(&left_channel); - let right_receiver = gateway.get_receiver::(&right_channel); + let left_sender = gateway.get_mpc_sender::(&left_channel, total_records); + let right_sender = gateway.get_mpc_sender::(&right_channel, total_records); + let left_receiver = gateway.get_mpc_receiver::(&left_channel); + let right_receiver = gateway.get_mpc_receiver::(&right_channel); // setup local prss endpoint let ep_setup = prss::Endpoint::prepare(rng); diff --git a/ipa-core/src/helpers/transport/in_memory/sharding.rs b/ipa-core/src/helpers/transport/in_memory/sharding.rs index 23175e375..7457711b4 100644 --- a/ipa-core/src/helpers/transport/in_memory/sharding.rs +++ b/ipa-core/src/helpers/transport/in_memory/sharding.rs @@ -64,6 +64,14 @@ impl InMemoryShardNetwork { Arc::downgrade(&self.shard_network[2][shard_id]), ] } + + pub fn reset(&self) { + for helper in &self.shard_network { + for shard in helper.iter() { + shard.reset(); + } + } + } } #[cfg(all(test, unit_test))] @@ -149,4 +157,29 @@ mod tests { assert!(h3.upgrade().is_none()); }); } + + #[test] + fn reset() { + async fn test_send(network: &InMemoryShardNetwork) { + let (_tx, rx) = mpsc::channel(1); + let src_shard = ShardIndex::FIRST; + let dst_shard = ShardIndex::from(1); + network + .transport(HelperIdentity::ONE, src_shard) + .send( + dst_shard, + (RouteId::Records, QueryId, Gate::default()), + ReceiverStream::new(rx), + ) + .await + .unwrap(); + } + + run(|| async { + let shard_network = InMemoryShardNetwork::with_shards(2); + test_send(&shard_network).await; + shard_network.reset(); + test_send(&shard_network).await; + }); + } } diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 202d9fbad..8071085ae 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -1,4 +1,8 @@ -use std::{borrow::Borrow, fmt::Debug, hash::Hash}; +use std::{ + borrow::{Borrow, Cow}, + fmt::Debug, + hash::Hash, +}; use async_trait::async_trait; use futures::Stream; @@ -25,8 +29,8 @@ pub use receive::{LogErrors, ReceiveRecords}; #[cfg(feature = "web-app")] pub use stream::WrappedAxumBodyStream; pub use stream::{ - BodyStream, BytesStream, LengthDelimitedStream, RecordsStream, StreamCollection, StreamKey, - WrappedBoxBodyStream, + BodyStream, BytesStream, LengthDelimitedStream, RecordsStream, SingleRecordStream, + StreamCollection, StreamKey, WrappedBoxBodyStream, }; use crate::{ @@ -36,14 +40,30 @@ use crate::{ /// An identity of a peer that can be communicated with using [`Transport`]. There are currently two /// types of peers - helpers and shards. -pub trait Identity: Copy + Clone + Debug + PartialEq + Eq + Hash + Send + Sync + 'static {} +pub trait Identity: + Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord + Hash + Send + Sync + 'static +{ + fn as_str<'a>(&self) -> Cow<'a, str>; +} -impl Identity for ShardIndex {} -impl Identity for HelperIdentity {} +impl Identity for ShardIndex { + fn as_str<'a>(&self) -> Cow<'a, str> { + Cow::Owned(self.to_string()) + } +} +impl Identity for HelperIdentity { + fn as_str<'a>(&self) -> Cow<'a, str> { + Cow::Owned(self.id.to_string()) + } +} /// Role is an identifier of helper peer, only valid within a given query. For every query, there /// exists a static mapping from role to helper identity. -impl Identity for Role {} +impl Identity for Role { + fn as_str<'a>(&self) -> Cow<'a, str> { + Cow::Borrowed(Role::as_static_str(self)) + } +} pub trait ResourceIdentifier: Sized {} pub trait QueryIdBinding: Sized diff --git a/ipa-core/src/helpers/transport/receive.rs b/ipa-core/src/helpers/transport/receive.rs index 15fa7d4d9..557a10fc9 100644 --- a/ipa-core/src/helpers/transport/receive.rs +++ b/ipa-core/src/helpers/transport/receive.rs @@ -89,7 +89,7 @@ impl ReceiveRecords /// Converts this into a stream that yields owned byte chunks. /// /// ## Panics - /// If inner stream yields an [`Err`] chunk. + /// If inner stream yields [`Err`] chunk. pub(crate) fn into_bytes_stream(self) -> impl Stream> { self.inner.map(Result::unwrap).map(Into::into) } diff --git a/ipa-core/src/helpers/transport/stream/input.rs b/ipa-core/src/helpers/transport/stream/input.rs index e091cac28..02b312795 100644 --- a/ipa-core/src/helpers/transport/stream/input.rs +++ b/ipa-core/src/helpers/transport/stream/input.rs @@ -115,12 +115,9 @@ impl BufDeque { self.read_bytes(T::Size::USIZE) .map(|bytes| T::deserialize_infallible(GenericArray::from_slice(&bytes))) } - /// Attempts to deserialize a single instance of `T` from the buffer. - /// Returns `None` if there is insufficient data available /// - /// ## Errors - /// Returns a deserialization error if `T` rejects the bytes from this buffer. + /// Returns `None` if there is insufficient data available, and an error if deserialization fails. fn try_read(&mut self) -> Option> { self.read_bytes(T::Size::USIZE) .map(|bytes| T::deserialize(GenericArray::from_slice(&bytes))) @@ -220,6 +217,8 @@ where phantom_data: PhantomData<(T, M)>, } +pub type SingleRecordStream = RecordsStream; + impl RecordsStream where S: BytesStream, diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 053b6033c..17fe29e3a 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -12,7 +12,7 @@ pub use box_body::WrappedBoxBodyStream; use bytes::Bytes; pub use collection::{StreamCollection, StreamKey}; use futures::Stream; -pub use input::{LengthDelimitedStream, RecordsStream}; +pub use input::{LengthDelimitedStream, RecordsStream, SingleRecordStream}; use crate::error::BoxError; diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 5d5c83e03..734512411 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -9,4 +9,4 @@ mod transport; pub use client::{ClientIdentity, MpcHelperClient}; pub use error::Error; pub use server::{MpcHelperServer, TracingSpanMaker}; -pub use transport::HttpTransport; +pub use transport::{HttpShardTransport, HttpTransport}; diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 79a80bea7..786a4a492 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -20,10 +20,12 @@ use crate::{ }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{step::Gate, QueryId}, + sharding::ShardIndex, sync::Arc, }; /// HTTP transport for IPA helper service. +/// TODO: rename to MPC pub struct HttpTransport { identity: HelperIdentity, clients: [MpcHelperClient; 3], @@ -33,6 +35,10 @@ pub struct HttpTransport { handler: Option, } +/// A stub for HTTP transport implementation, suitable for serviing inter-shard traffic +#[derive(Clone, Default)] +pub struct HttpShardTransport; + impl RouteParams for QueryConfig { type Params = String; @@ -224,6 +230,42 @@ impl Transport for Arc { } } +#[async_trait] +impl Transport for HttpShardTransport { + type Identity = ShardIndex; + type RecordsStream = ReceiveRecords; + type Error = (); + + fn identity(&self) -> Self::Identity { + unimplemented!() + } + + async fn send( + &self, + _dest: Self::Identity, + _route: R, + _data: D, + ) -> Result<(), Self::Error> + where + Option: From, + Option: From, + Q: QueryIdBinding, + S: StepBinding, + R: RouteParams, + D: Stream> + Send + 'static, + { + unimplemented!() + } + + fn receive>( + &self, + _from: Self::Identity, + _route: R, + ) -> Self::RecordsStream { + unimplemented!() + } +} + #[cfg(all(test, web_test))] mod tests { use std::{iter::zip, net::TcpListener, task::Poll}; @@ -323,7 +365,7 @@ mod tests { ); server.start_on(Some(socket), ()).await; - setup.connect(transport) + setup.connect(transport, HttpShardTransport) }, ), ) diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index ef03a0dcf..75900aeca 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -10,7 +10,10 @@ use ipa_macros::Step; use super::{UpgradeContext, UpgradeToMalicious}; use crate::{ error::Error, - helpers::{ChannelId, Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{ + ChannelId, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, + ShardReceivingEnd, TotalRecords, + }, protocol::{ basics::{ mul::malicious::Step::RandomnessForValidation, SecureMul, ShareKnownValue, @@ -32,7 +35,7 @@ use crate::{ ReplicatedSecretSharing, }, seq_join::SeqJoin, - sharding::NotSharded, + sharding::{NotSharded, ShardIndex}, sync::Arc, }; @@ -112,13 +115,21 @@ impl<'a> super::Context for Context<'a> { self.inner.prss_rng() } - fn send_channel(&self, role: Role) -> SendingEnd { + fn send_channel(&self, role: Role) -> SendingEnd { self.inner.send_channel(role) } - fn recv_channel(&self, role: Role) -> ReceivingEnd { + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.shard_send_channel(dest_shard) + } + + fn recv_channel(&self, role: Role) -> MpcReceivingEnd { self.inner.recv_channel(role) } + + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { + self.inner.shard_recv_channel(origin) + } } impl<'a> UpgradableContext for Context<'a> { @@ -326,16 +337,29 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { ) } - fn send_channel(&self, role: Role) -> SendingEnd { + fn send_channel(&self, role: Role) -> SendingEnd { + self.inner + .gateway + .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + } + + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.gateway.get_shard_sender( + &ChannelId::new(dest_shard, self.gate.clone()), + self.total_records, + ) + } + + fn recv_channel(&self, role: Role) -> MpcReceivingEnd { self.inner .gateway - .get_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + .get_mpc_receiver(&ChannelId::new(role, self.gate.clone())) } - fn recv_channel(&self, role: Role) -> ReceivingEnd { + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { self.inner .gateway - .get_receiver(&ChannelId::new(role, self.gate.clone())) + .get_shard_receiver(&ChannelId::new(origin, self.gate.clone())) } } diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 3bf6de8c7..1c0ef05bd 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -23,7 +23,10 @@ pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; use crate::{ error::Error, - helpers::{ChannelId, Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{ + ChannelId, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, + ShardReceivingEnd, TotalRecords, + }, protocol::{ basics::ZeroPositions, prss::Endpoint as PrssEndpoint, @@ -87,8 +90,26 @@ pub trait Context: Clone + Send + Sync + SeqJoin { InstrumentedSequentialSharedRandomness, ); - fn send_channel(&self, role: Role) -> SendingEnd; - fn recv_channel(&self, role: Role) -> ReceivingEnd; + /// Open a communication channel to an MPC peer. This channel can be requested multiple times + /// and this method is safe to use in multi-threaded environments. + fn send_channel(&self, role: Role) -> SendingEnd; + + /// Open a communication channel to another shard within the same MPC helper. Similarly to + /// [`Self::send_channel`], it can be requested more than once for the same channel and from + /// multiple threads, but it should not be required. See [`Self::shard_recv_channel`]. + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd; + + /// Requests data to be received from another MPC helper. Receive requests [`MpcReceivingEnd::receive`] + /// can be issued from multiple threads. + fn recv_channel(&self, role: Role) -> MpcReceivingEnd; + + /// Request a stream to be received from a peer shard within the same MPC helper. This method + /// can be called only once per communication channel. + /// + /// ## Panics + /// If called more than once for the same origin and on context instance, narrowed to the same + /// [`Self::gate`]. + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd; } pub trait UpgradableContext: Context { @@ -252,16 +273,29 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { ) } - fn send_channel(&self, role: Role) -> SendingEnd { + fn send_channel(&self, role: Role) -> SendingEnd { + self.inner + .gateway + .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + } + + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.gateway.get_shard_sender( + &ChannelId::new(dest_shard, self.gate.clone()), + self.total_records, + ) + } + + fn recv_channel(&self, role: Role) -> MpcReceivingEnd { self.inner .gateway - .get_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + .get_mpc_receiver(&ChannelId::new(role, self.gate.clone())) } - fn recv_channel(&self, role: Role) -> ReceivingEnd { + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { self.inner .gateway - .get_receiver(&ChannelId::new(role, self.gate.clone())) + .get_shard_receiver(&ChannelId::new(origin, self.gate.clone())) } } diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index c742c02e5..b08812968 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -11,7 +11,10 @@ use ipa_macros::Step; use super::{Context as SuperContext, UpgradeContext, UpgradeToMalicious}; use crate::{ error::Error, - helpers::{Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{ + Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, + TotalRecords, + }, protocol::{ basics::{ShareKnownValue, ZeroPositions}, context::{ @@ -117,13 +120,21 @@ impl<'a, B: ShardBinding> super::Context for Context<'a, B> { self.inner.prss_rng() } - fn send_channel(&self, role: Role) -> SendingEnd { + fn send_channel(&self, role: Role) -> SendingEnd { self.inner.send_channel(role) } - fn recv_channel(&self, role: Role) -> ReceivingEnd { + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.shard_send_channel(dest_shard) + } + + fn recv_channel(&self, role: Role) -> MpcReceivingEnd { self.inner.recv_channel(role) } + + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { + self.inner.shard_recv_channel(origin) + } } impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { @@ -201,13 +212,21 @@ impl<'a, B: ShardBinding, F: ExtendableField> super::Context for Upgraded<'a, B, self.inner.prss_rng() } - fn send_channel(&self, role: Role) -> SendingEnd { + fn send_channel(&self, role: Role) -> SendingEnd { self.inner.send_channel(role) } - fn recv_channel(&self, role: Role) -> ReceivingEnd { + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.shard_send_channel(dest_shard) + } + + fn recv_channel(&self, role: Role) -> MpcReceivingEnd { self.inner.recv_channel(role) } + + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { + self.inner.shard_recv_channel(origin) + } } impl<'a, B: ShardBinding, F: ExtendableField> SeqJoin for Upgraded<'a, B, F> { diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 7f0b92283..ec752f7ac 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -6,7 +6,7 @@ use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng use crate::{ error::Error, - helpers::{Direction, ReceivingEnd, Role}, + helpers::{Direction, MpcReceivingEnd, Role}, protocol::{context::Context, RecordId}, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, @@ -335,7 +335,7 @@ where S: SharedValue, { let role = ctx.role().peer(direction); - let receive_channel: ReceivingEnd = ctx + let receive_channel: MpcReceivingEnd = ctx .narrow(step) .set_total_records(batch_size) .recv_channel(role); diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a003e95ac..00723fcee 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -9,7 +9,8 @@ use crate::{ error::Error as ProtocolError, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, - Gateway, GatewayConfig, Role, RoleAssignment, Transport, TransportError, TransportImpl, + Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, + ShardTransportImpl, Transport, }, hpke::{KeyPair, KeyRegistry}, protocol::QueryId, @@ -57,7 +58,7 @@ pub enum NewQueryError { #[error(transparent)] State(#[from] StateError), #[error(transparent)] - Transport(#[from] TransportError), + MpcTransport(#[from] MpcTransportError), } #[derive(thiserror::Error, Debug)] @@ -132,7 +133,7 @@ impl Processor { #[allow(clippy::missing_panics_doc)] pub async fn new_query( &self, - transport: TransportImpl, + transport: MpcTransportImpl, req: QueryConfig, ) -> Result { let query_id = QueryId; @@ -158,7 +159,7 @@ impl Processor { transport.send(right, prepare_request.clone(), stream::empty()), ) .await - .map_err(NewQueryError::Transport)?; + .map_err(NewQueryError::MpcTransport)?; handle.set_state(QueryState::AwaitingInputs(query_id, req, roles))?; @@ -176,7 +177,7 @@ impl Processor { /// if query is already running or this helper cannot be a follower in it pub fn prepare( &self, - transport: &TransportImpl, + transport: &MpcTransportImpl, req: PrepareQuery, ) -> Result<(), PrepareQueryError> { let my_role = req.roles.role(transport.identity()); @@ -207,7 +208,8 @@ impl Processor { /// If failed to obtain exclusive access to the query collection. pub fn receive_inputs( &self, - transport: TransportImpl, + mpc_transport: MpcTransportImpl, + shard_transport: ShardTransportImpl, input: QueryInput, ) -> Result<(), QueryInputError> { let mut queries = self.queries.inner.lock().unwrap(); @@ -223,7 +225,8 @@ impl Processor { query_id, GatewayConfig::from(&config), role_assignment, - transport, + mpc_transport, + shard_transport, ); queries.insert( input.query_id, @@ -443,7 +446,7 @@ mod tests { assert!(matches!( p0.new_query(t0, request).await.unwrap_err(), - NewQueryError::Transport(_) + NewQueryError::MpcTransport(_) )); } @@ -465,7 +468,7 @@ mod tests { assert!(matches!( p0.new_query(t0, request).await.unwrap_err(), - NewQueryError::Transport(_) + NewQueryError::MpcTransport(_) )); } diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 96d09fe59..86ab7d00b 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -7,7 +7,7 @@ use crate::{ ff::Serializable, helpers::{ query::{QueryConfig, QueryInput}, - ApiError, InMemoryMpcNetwork, + ApiError, InMemoryMpcNetwork, InMemoryShardNetwork, Transport, }, protocol::QueryId, query::QueryStatus, @@ -49,7 +49,8 @@ where /// [`TestWorld`]: crate::test_fixture::TestWorld pub struct TestApp { drivers: [HelperApp; 3], - network: InMemoryMpcNetwork, + mpc_network: InMemoryMpcNetwork, + shard_network: InMemoryShardNetwork, } fn unzip_tuple_array(input: [(T, U); 3]) -> ([T; 3], [U; 3]) { @@ -61,18 +62,23 @@ impl Default for TestApp { fn default() -> Self { let (setup, handlers) = unzip_tuple_array(array::from_fn(|_| AppSetup::new())); - let network = InMemoryMpcNetwork::new(handlers.map(Some)); - let drivers = network + let mpc_network = InMemoryMpcNetwork::new(handlers.map(Some)); + let shard_network = InMemoryShardNetwork::with_shards(1); + let drivers = mpc_network .transports() .iter() .zip(setup) - .map(|(t, s)| s.connect(Clone::clone(t))) + .map(|(t, s)| s.connect(Clone::clone(t), shard_network.transport(t.identity(), 0))) .collect::>() .try_into() - .map_err(|_| "infallible") + .ok() .unwrap(); - Self { drivers, network } + Self { + drivers, + mpc_network, + shard_network, + } } } @@ -131,7 +137,8 @@ impl TestApp { pub async fn complete_query(&self, query_id: QueryId) -> Result<[Vec; 3], ApiError> { let results = try_join3_array([0, 1, 2].map(|i| self.drivers[i].complete_query(query_id))).await; - self.network.reset(); + self.mpc_network.reset(); + self.shard_network.reset(); results } diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index e54e5eba4..e79bff05d 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -25,7 +25,7 @@ use rand::{distributions::Standard, prelude::Distribution, rngs::mock::StepRng}; use rand_core::{CryptoRng, RngCore}; pub use sharing::{get_bits, into_bits, Reconstruct, ReconstructArr}; #[cfg(feature = "in-memory-infra")] -pub use world::{Runner, TestExecutionStep, TestWorld, TestWorldConfig}; +pub use world::{Runner, TestExecutionStep, TestWorld, TestWorldConfig, WithShards}; use crate::{ ff::{Field, U128Conversions}, diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 3cd252fae..16afd6f76 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -11,7 +11,7 @@ use tracing::{Instrument, Level, Span}; use crate::{ helpers::{ Gateway, GatewayConfig, HelperIdentity, InMemoryMpcNetwork, InMemoryShardNetwork, - InMemoryTransport, Role, RoleAssignment, + InMemoryTransport, Role, RoleAssignment, Transport, }, protocol::{ context::{ @@ -140,6 +140,7 @@ impl WithShards { /// for a single shard. /// /// It uses Round-robin strategy to distribute [`A`] across [`SHARDS`] + #[must_use] pub fn shard(input: Vec) -> [Vec; SHARDS] { let mut r: [_; SHARDS] = from_fn(|_| Vec::new()); for (i, share) in input.into_iter().enumerate() { @@ -539,14 +540,23 @@ impl ShardWorld { let participants = make_participants(&mut StdRng::seed_from_u64(config.seed + shard_seed)); let network = InMemoryMpcNetwork::default(); - let mut gateways = network.transports().map(|t| { - Gateway::new( - QueryId, - config.gateway_config, - config.role_assignment().clone(), - t, - ) - }); + let mut gateways: [_; 3] = network + .transports() + .iter() + .zip(transports.iter()) + .map(|(mpc, shard)| { + Gateway::new( + QueryId, + config.gateway_config, + config.role_assignment().clone(), + Transport::clone_ref(mpc), + Transport::clone_ref(shard), + ) + }) + .collect::>() + .try_into() + .ok() + .unwrap(); // The name for `g` is too complicated and depends on features enabled #[allow(clippy::redundant_closure_for_method_calls)] From b16da76add20d2ec683dccb16b60a79cf4d1f1f6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 13 Apr 2024 17:07:28 -0400 Subject: [PATCH 2/4] Update ipa-core/src/helpers/gateway/mod.rs Co-authored-by: Andy Leiserson --- ipa-core/src/helpers/gateway/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 740d4fefd..2abf70858 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -40,9 +40,8 @@ pub type MpcTransportImpl = TransportImpl; pub type ShardTransportImpl = TransportImpl; #[cfg(feature = "real-world-infra")] -type TransportImpl = crate::sync::Arc; #[cfg(feature = "real-world-infra")] -pub type MpcTransportImpl = TransportImpl; +pub type MpcTransportImpl = crate::sync::Arc; #[cfg(feature = "real-world-infra")] pub type ShardTransportImpl = crate::net::HttpShardTransport; From 99f708b562d84e87f19c9cfdad7fddf9bd1eae30 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 13 Apr 2024 17:07:42 -0400 Subject: [PATCH 3/4] Feedback --- ipa-core/src/helpers/transport/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 8071085ae..f367ed44c 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -43,16 +43,16 @@ use crate::{ pub trait Identity: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord + Hash + Send + Sync + 'static { - fn as_str<'a>(&self) -> Cow<'a, str>; + fn as_str(&self) -> Cow<'static, str>; } impl Identity for ShardIndex { - fn as_str<'a>(&self) -> Cow<'a, str> { + fn as_str(&self) -> Cow<'static, str> { Cow::Owned(self.to_string()) } } impl Identity for HelperIdentity { - fn as_str<'a>(&self) -> Cow<'a, str> { + fn as_str(&self) -> Cow<'static, str> { Cow::Owned(self.id.to_string()) } } @@ -60,7 +60,7 @@ impl Identity for HelperIdentity { /// Role is an identifier of helper peer, only valid within a given query. For every query, there /// exists a static mapping from role to helper identity. impl Identity for Role { - fn as_str<'a>(&self) -> Cow<'a, str> { + fn as_str(&self) -> Cow<'static, str> { Cow::Borrowed(Role::as_static_str(self)) } } From 7e09115d575c475a8335795f43398157c60c92cf Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 15 Apr 2024 22:13:18 -0700 Subject: [PATCH 4/4] Add a comment explaining the Mutex inside the receiver stream --- ipa-core/src/helpers/gateway/receive.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ipa-core/src/helpers/gateway/receive.rs b/ipa-core/src/helpers/gateway/receive.rs index ad2ccbf61..c37efd279 100644 --- a/ipa-core/src/helpers/gateway/receive.rs +++ b/ipa-core/src/helpers/gateway/receive.rs @@ -52,6 +52,10 @@ pub type UR = UnorderedReceiver< /// Stream of records received from a peer shard. #[derive(Clone)] pub struct ShardReceiveStream( + /// Using a mutex here may not be necessary - there is always a single caller that polls it, + /// and there may be an observer from stall detection that wants to know the state of it. + /// There could be a better way to share the state and make sure the owning reference is stored + /// inside the map of receivers. pub(super) Arc::RecordsStream>>, );