From 3c1fdade899c7b4f51fb9f8aa5be70ecc288b09c Mon Sep 17 00:00:00 2001 From: Peter Nose Date: Tue, 24 Sep 2024 11:47:26 +0200 Subject: [PATCH 1/2] keymanager/src/client: Remove remote nodes from the RPC client Removing the list of remote nodes from the client allows the caller to specify which node the call should be delegated to, without needing to update the global list first. This will be useful once multiple calls can be made in parallel. --- keymanager/src/churp/handler.rs | 14 +-- keymanager/src/client/interface.rs | 32 ++++-- keymanager/src/client/mock.rs | 6 ++ keymanager/src/client/remote.rs | 33 +++--- keymanager/src/runtime/secrets.rs | 1 - keymanager/src/secrets/provider.rs | 14 ++- runtime/src/enclave_rpc/client.rs | 117 +++++++++------------ tests/runtimes/simple-keyvalue/src/main.rs | 1 - 8 files changed, 116 insertions(+), 102 deletions(-) diff --git a/keymanager/src/churp/handler.rs b/keymanager/src/churp/handler.rs index b8e0752a493..e5513a97145 100644 --- a/keymanager/src/churp/handler.rs +++ b/keymanager/src/churp/handler.rs @@ -620,11 +620,13 @@ impl Instance { } // Fetch from the remote node. - client.set_nodes(vec![node_id]); - if handoff.needs_verification_matrix()? { // The remote verification matrix needs to be verified. - let vm = block_on(client.churp_verification_matrix(self.churp_id, status.handoff))?; + let vm = block_on(client.churp_verification_matrix( + self.churp_id, + status.handoff, + vec![node_id], + ))?; let checksum = self.checksum_verification_matrix_bytes(&vm, status.handoff); let status_checksum = status.checksum.ok_or(Error::InvalidHandoff)?; // Should never happen. if checksum != status_checksum { @@ -640,6 +642,7 @@ impl Instance { self.churp_id, status.next_handoff, self.node_id, + vec![node_id], ))?; let point = scalar_from_bytes(&point).ok_or(Error::PointDecodingFailed)?; @@ -669,11 +672,11 @@ impl Instance { } // Fetch from the remote node. - client.set_nodes(vec![node_id]); let point = block_on(client.churp_share_distribution_point( self.churp_id, status.next_handoff, self.node_id, + vec![node_id], ))?; let point = scalar_from_bytes(&point).ok_or(Error::PointDecodingFailed)?; @@ -706,11 +709,11 @@ impl Instance { } // Fetch from the remote node. - client.set_nodes(vec![node_id]); let share = block_on(client.churp_bivariate_share( self.churp_id, status.next_handoff, self.node_id, + vec![node_id], ))?; // The remote verification matrix needs to be verified. @@ -1131,7 +1134,6 @@ impl Instance { self.consensus_verifier.clone(), self.identity.clone(), 1, // Not used, doesn't matter. - vec![], ); Ok(client) diff --git a/keymanager/src/client/interface.rs b/keymanager/src/client/interface.rs index 8b287d3790f..33d4cfcb8e4 100644 --- a/keymanager/src/client/interface.rs +++ b/keymanager/src/client/interface.rs @@ -69,17 +69,22 @@ pub trait KeyManagerClient: Send + Sync { async fn replicate_master_secret( &self, generation: u64, + nodes: Vec, ) -> Result; /// Get a copy of the ephemeral secret for replication. - async fn replicate_ephemeral_secret(&self, epoch: EpochTime) - -> Result; + async fn replicate_ephemeral_secret( + &self, + epoch: EpochTime, + nodes: Vec, + ) -> Result; /// Returns the verification matrix for the given handoff. async fn churp_verification_matrix( &self, churp_id: u8, epoch: EpochTime, + nodes: Vec, ) -> Result, KeyManagerError>; /// Returns a switch point for the share reduction phase @@ -89,6 +94,7 @@ pub trait KeyManagerClient: Send + Sync { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError>; /// Returns a switch point for the share distribution phase @@ -98,6 +104,7 @@ pub trait KeyManagerClient: Send + Sync { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError>; /// Returns a bivariate share for the given handoff. @@ -106,6 +113,7 @@ pub trait KeyManagerClient: Send + Sync { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result; /// Returns state key. @@ -165,23 +173,26 @@ impl KeyManagerClient for Arc { async fn replicate_master_secret( &self, generation: u64, + nodes: Vec, ) -> Result { - KeyManagerClient::replicate_master_secret(&**self, generation).await + KeyManagerClient::replicate_master_secret(&**self, generation, nodes).await } async fn replicate_ephemeral_secret( &self, epoch: EpochTime, + nodes: Vec, ) -> Result { - KeyManagerClient::replicate_ephemeral_secret(&**self, epoch).await + KeyManagerClient::replicate_ephemeral_secret(&**self, epoch, nodes).await } async fn churp_verification_matrix( &self, churp_id: u8, epoch: EpochTime, + nodes: Vec, ) -> Result, KeyManagerError> { - KeyManagerClient::churp_verification_matrix(&**self, churp_id, epoch).await + KeyManagerClient::churp_verification_matrix(&**self, churp_id, epoch, nodes).await } async fn churp_share_reduction_point( @@ -189,8 +200,10 @@ impl KeyManagerClient for Arc { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError> { - KeyManagerClient::churp_share_reduction_point(&**self, churp_id, epoch, node_id).await + KeyManagerClient::churp_share_reduction_point(&**self, churp_id, epoch, node_id, nodes) + .await } async fn churp_share_distribution_point( @@ -198,8 +211,10 @@ impl KeyManagerClient for Arc { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError> { - KeyManagerClient::churp_share_distribution_point(&**self, churp_id, epoch, node_id).await + KeyManagerClient::churp_share_distribution_point(&**self, churp_id, epoch, node_id, nodes) + .await } async fn churp_bivariate_share( @@ -207,8 +222,9 @@ impl KeyManagerClient for Arc { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result { - KeyManagerClient::churp_bivariate_share(&**self, churp_id, epoch, node_id).await + KeyManagerClient::churp_bivariate_share(&**self, churp_id, epoch, node_id, nodes).await } async fn churp_state_key( diff --git a/keymanager/src/client/mock.rs b/keymanager/src/client/mock.rs index c8c58974d57..becb8fe7ff1 100644 --- a/keymanager/src/client/mock.rs +++ b/keymanager/src/client/mock.rs @@ -109,6 +109,7 @@ impl KeyManagerClient for MockClient { async fn replicate_master_secret( &self, _generation: u64, + _nodes: Vec, ) -> Result { unimplemented!(); } @@ -116,6 +117,7 @@ impl KeyManagerClient for MockClient { async fn replicate_ephemeral_secret( &self, _epoch: EpochTime, + _nodes: Vec, ) -> Result { unimplemented!(); } @@ -124,6 +126,7 @@ impl KeyManagerClient for MockClient { &self, _churp_id: u8, _epoch: EpochTime, + _nodes: Vec, ) -> Result, KeyManagerError> { unimplemented!(); } @@ -133,6 +136,7 @@ impl KeyManagerClient for MockClient { _churp_id: u8, _epoch: EpochTime, _node_id: PublicKey, + _nodes: Vec, ) -> Result, KeyManagerError> { unimplemented!(); } @@ -142,6 +146,7 @@ impl KeyManagerClient for MockClient { _churp_id: u8, _epoch: EpochTime, _node_id: PublicKey, + _nodes: Vec, ) -> Result, KeyManagerError> { unimplemented!(); } @@ -151,6 +156,7 @@ impl KeyManagerClient for MockClient { _churp_id: u8, _epoch: EpochTime, _node_id: PublicKey, + _nodes: Vec, ) -> Result { unimplemented!(); } diff --git a/keymanager/src/client/remote.rs b/keymanager/src/client/remote.rs index baebfe37f1a..b13704d42c0 100644 --- a/keymanager/src/client/remote.rs +++ b/keymanager/src/client/remote.rs @@ -15,7 +15,7 @@ use rand::{prelude::SliceRandom, rngs::OsRng}; use oasis_core_runtime::{ common::{ - crypto::signature::{self, PublicKey}, + crypto::signature::PublicKey, namespace::{Namespace, NAMESPACE_SIZE}, sgx::{EnclaveIdentity, QuotePolicy}, }, @@ -123,7 +123,6 @@ impl RemoteClient { consensus_verifier: Arc, identity: Arc, keys_cache_sizes: usize, - nodes: Vec, ) -> Self { Self::new( runtime_id, @@ -136,7 +135,6 @@ impl RemoteClient { .remote_runtime_id(km_runtime_id), protocol, KEY_MANAGER_ENDPOINT, - nodes, ), consensus_verifier, keys_cache_sizes, @@ -156,7 +154,6 @@ impl RemoteClient { identity: Arc, keys_cache_sizes: usize, signers: TrustedSigners, - nodes: Vec, ) -> Self { // When using a non-empty policy signer set we set enclaves to an empty set so until we get // a policy we will not accept any enclave identities (as we don't know what they should @@ -187,7 +184,6 @@ impl RemoteClient { consensus_verifier, identity, keys_cache_sizes, - nodes, ) } @@ -225,11 +221,6 @@ impl RemoteClient { self.rpc_client.update_quote_policy(policy); } - /// Set allowed key manager nodes. - pub fn set_nodes(&self, nodes: Vec) { - self.rpc_client.update_nodes(nodes); - } - fn verify_public_key( &self, key: &SignedPublicKey, @@ -268,11 +259,6 @@ impl RemoteClient { } // Fetch key share from the current node. - self.rpc_client - .update_nodes_async(vec![node_id]) - .await - .map_err(|err| KeyManagerError::Other(err.into()))?; - let response = self .rpc_client .secure_call( @@ -284,6 +270,7 @@ impl RemoteClient { key_runtime_id: self.runtime_id, key_id, }, + vec![node_id], ) .await .into_result_with_feedback() @@ -386,6 +373,7 @@ impl KeyManagerClient for RemoteClient { key_pair_id, generation, }, + vec![], ) .await .into_result_with_feedback() @@ -436,6 +424,7 @@ impl KeyManagerClient for RemoteClient { key_pair_id, generation, }, + vec![], ) .await .into_result_with_feedback() @@ -484,6 +473,7 @@ impl KeyManagerClient for RemoteClient { key_pair_id, epoch, }, + vec![], ) .await .into_result_with_feedback() @@ -541,6 +531,7 @@ impl KeyManagerClient for RemoteClient { key_pair_id, epoch, }, + vec![], ) .await .into_result_with_feedback() @@ -560,6 +551,7 @@ impl KeyManagerClient for RemoteClient { async fn replicate_master_secret( &self, generation: u64, + nodes: Vec, ) -> Result { let height = self .consensus_verifier @@ -574,6 +566,7 @@ impl KeyManagerClient for RemoteClient { height: Some(height), generation, }, + nodes, ) .await .into_result_with_feedback() @@ -588,6 +581,7 @@ impl KeyManagerClient for RemoteClient { async fn replicate_ephemeral_secret( &self, epoch: EpochTime, + nodes: Vec, ) -> Result { let height = self .consensus_verifier @@ -602,6 +596,7 @@ impl KeyManagerClient for RemoteClient { height: Some(height), epoch, }, + nodes, ) .await .into_result_with_feedback() @@ -614,6 +609,7 @@ impl KeyManagerClient for RemoteClient { &self, churp_id: u8, epoch: EpochTime, + nodes: Vec, ) -> Result, KeyManagerError> { self.rpc_client .insecure_call( @@ -624,6 +620,7 @@ impl KeyManagerClient for RemoteClient { epoch, node_id: None, }, + nodes, ) .await .into_result_with_feedback() @@ -636,6 +633,7 @@ impl KeyManagerClient for RemoteClient { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError> { self.rpc_client .secure_call( @@ -646,6 +644,7 @@ impl KeyManagerClient for RemoteClient { epoch, node_id: Some(node_id), }, + nodes, ) .await .into_result_with_feedback() @@ -658,6 +657,7 @@ impl KeyManagerClient for RemoteClient { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result, KeyManagerError> { self.rpc_client .secure_call( @@ -668,6 +668,7 @@ impl KeyManagerClient for RemoteClient { epoch, node_id: Some(node_id), }, + nodes, ) .await .into_result_with_feedback() @@ -680,6 +681,7 @@ impl KeyManagerClient for RemoteClient { churp_id: u8, epoch: EpochTime, node_id: PublicKey, + nodes: Vec, ) -> Result { self.rpc_client .secure_call( @@ -690,6 +692,7 @@ impl KeyManagerClient for RemoteClient { epoch, node_id: Some(node_id), }, + nodes, ) .await .into_result_with_feedback() diff --git a/keymanager/src/runtime/secrets.rs b/keymanager/src/runtime/secrets.rs index b5210c1836c..5fe4772f167 100644 --- a/keymanager/src/runtime/secrets.rs +++ b/keymanager/src/runtime/secrets.rs @@ -467,7 +467,6 @@ impl Secrets { self.consensus_verifier.clone(), self.identity.clone(), 1, // Not used, doesn't matter. - vec![], ) } diff --git a/keymanager/src/secrets/provider.rs b/keymanager/src/secrets/provider.rs index 3ba228d45f1..4f6aab797c1 100644 --- a/keymanager/src/secrets/provider.rs +++ b/keymanager/src/secrets/provider.rs @@ -50,8 +50,11 @@ impl SecretProvider for KeyManagerSecretProvider { inner.last_node = idx; counter += 1; - inner.client.set_nodes(vec![inner.nodes[idx]]); - if let Ok(secret) = block_on(inner.client.replicate_master_secret(generation)) { + if let Ok(secret) = block_on( + inner + .client + .replicate_master_secret(generation, vec![inner.nodes[idx]]), + ) { return Some(secret); } } @@ -78,8 +81,11 @@ impl SecretProvider for KeyManagerSecretProvider { inner.last_node = idx; counter += 1; - inner.client.set_nodes(vec![inner.nodes[idx]]); - if let Ok(secret) = block_on(inner.client.replicate_ephemeral_secret(epoch)) { + if let Ok(secret) = block_on( + inner + .client + .replicate_ephemeral_secret(epoch, vec![inner.nodes[idx]]), + ) { return Some(secret); } } diff --git a/runtime/src/enclave_rpc/client.rs b/runtime/src/enclave_rpc/client.rs index c3bdeb0b4eb..e26b6fe5f08 100644 --- a/runtime/src/enclave_rpc/client.rs +++ b/runtime/src/enclave_rpc/client.rs @@ -51,13 +51,13 @@ enum Command { Call( types::Request, types::Kind, + Vec, oneshot::Sender>, ), PeerFeedback(u64, types::PeerFeedback, types::Kind), UpdateEnclaves(Option>), UpdateQuotePolicy(QuotePolicy), UpdateRuntimeID(Option), - UpdateNodes(Vec), #[cfg(test)] Ping(oneshot::Sender<()>), } @@ -87,8 +87,6 @@ impl MultiplexedSession { } struct Controller { - /// Allowed nodes. - nodes: Vec, /// Multiplexed session. session: MultiplexedSession, /// Used transport. @@ -101,7 +99,9 @@ impl Controller { async fn run(mut self) { while let Some(cmd) = self.cmdq.recv().await { match cmd { - Command::Call(request, kind, sender) => self.call(request, kind, sender).await, + Command::Call(request, kind, nodes, sender) => { + self.call(request, kind, nodes, sender).await + } Command::PeerFeedback(pfid, peer_feedback, kind) => { self.transport.set_peer_feedback(pfid, Some(peer_feedback)); @@ -141,9 +141,6 @@ impl Controller { mem::take(&mut self.session.builder).remote_runtime_id(id); self.reset().await; } - Command::UpdateNodes(nodes) => { - self.nodes = nodes; - } #[cfg(test)] Command::Ping(sender) => { let _ = sender.send(()); @@ -159,6 +156,7 @@ impl Controller { &mut self, request: types::Request, kind: types::Kind, + nodes: Vec, sender: oneshot::Sender>, ) { let result = async { @@ -166,14 +164,14 @@ impl Controller { types::Kind::NoiseSession => { // Attempt to establish a connection. This will not do anything in case the // session has already been established. - self.connect().await?; + self.connect(nodes).await?; // Perform the call. self.secure_call_raw(request).await } types::Kind::InsecureQuery => { // Perform the call. - self.insecure_call_raw(request).await + self.insecure_call_raw(request, nodes).await } _ => Err(RpcClientError::UnsupportedRpcKind), } @@ -196,10 +194,10 @@ impl Controller { let _ = sender.send(result.map(|rsp| (pfid, rsp))); } - async fn connect(&mut self) -> Result<(), RpcClientError> { + async fn connect(&mut self, nodes: Vec) -> Result<(), RpcClientError> { // No need to create a new session if we are connected to one of the nodes. if self.session.inner.is_connected() - && (self.nodes.is_empty() || self.session.inner.is_connected_to(&self.nodes)) + && (nodes.is_empty() || self.session.inner.is_connected_to(&nodes)) { return Ok(()); } @@ -217,7 +215,7 @@ impl Controller { let (data, node) = self .transport - .write_noise_session(session_id, buffer, String::new(), self.nodes.clone()) + .write_noise_session(session_id, buffer, String::new(), nodes) .await .map_err(|_| RpcClientError::Transport)?; @@ -286,10 +284,11 @@ impl Controller { async fn insecure_call_raw( &mut self, request: types::Request, + nodes: Vec, ) -> Result { let (data, _) = self .transport - .write_insecure_query(cbor::to_vec(request), self.nodes.clone()) + .write_insecure_query(cbor::to_vec(request), nodes) .await .map_err(|_| RpcClientError::Transport)?; @@ -405,17 +404,12 @@ pub struct RpcClient { } impl RpcClient { - fn new( - transport: Box, - builder: Builder, - nodes: Vec, - ) -> Self { + fn new(transport: Box, builder: Builder) -> Self { // Create the command channel. let (tx, rx) = mpsc::channel(CMDQ_BACKLOG); // Create the controller task and start it. let controller = Controller { - nodes, session: MultiplexedSession::new(builder), transport, cmdq: rx, @@ -426,38 +420,47 @@ impl RpcClient { } /// Construct an unconnected RPC client with runtime-internal transport. - pub fn new_runtime( - builder: Builder, - protocol: Arc, - endpoint: &str, - nodes: Vec, - ) -> Self { - Self::new( - Box::new(RuntimeTransport::new(protocol, endpoint)), - builder, - nodes, - ) + pub fn new_runtime(builder: Builder, protocol: Arc, endpoint: &str) -> Self { + Self::new(Box::new(RuntimeTransport::new(protocol, endpoint)), builder) } /// Call a remote method using an encrypted and authenticated Noise session. - pub async fn secure_call(&self, method: &'static str, args: C) -> Response + pub async fn secure_call( + &self, + method: &'static str, + args: C, + nodes: Vec, + ) -> Response where C: cbor::Encode, O: cbor::Decode + Send + 'static, { - self.call(method, args, types::Kind::NoiseSession).await + self.call(method, args, types::Kind::NoiseSession, nodes) + .await } /// Call a remote method over an insecure channel where messages are sent in plain text. - pub async fn insecure_call(&self, method: &'static str, args: C) -> Response + pub async fn insecure_call( + &self, + method: &'static str, + args: C, + nodes: Vec, + ) -> Response where C: cbor::Encode, O: cbor::Decode + Send + 'static, { - self.call(method, args, types::Kind::InsecureQuery).await + self.call(method, args, types::Kind::InsecureQuery, nodes) + .await } - async fn call(&self, method: &'static str, args: C, kind: types::Kind) -> Response + async fn call( + &self, + method: &'static str, + args: C, + kind: types::Kind, + nodes: Vec, + ) -> Response where C: cbor::Encode, O: cbor::Decode + Send + 'static, @@ -474,9 +477,10 @@ impl RpcClient { .max_delay(std::time::Duration::from_millis(250)) .take(MAX_TRANSPORT_ERROR_RETRIES); - let result = - tokio_retry::Retry::spawn(retry_strategy, || self.execute_call(request.clone(), kind)) - .await; + let result = tokio_retry::Retry::spawn(retry_strategy, || { + self.execute_call(request.clone(), kind, nodes.clone()) + }) + .await; let (pfid, inner) = match result { Ok((pfid, response)) => match response.body { @@ -500,10 +504,11 @@ impl RpcClient { &self, request: types::Request, kind: types::Kind, + nodes: Vec, ) -> Result<(u64, types::Response), RpcClientError> { let (tx, rx) = oneshot::channel(); self.cmdq - .send(Command::Call(request, kind, tx)) + .send(Command::Call(request, kind, nodes, tx)) .await .map_err(|_| RpcClientError::Dropped)?; @@ -545,28 +550,6 @@ impl RpcClient { .unwrap(); } - /// Update allowed nodes. - /// - /// # Panics - /// - /// This function panics if called within an asynchronous execution context. - pub fn update_nodes(&self, nodes: Vec) { - self.cmdq - .blocking_send(Command::UpdateNodes(nodes)) - .unwrap(); - } - - /// Update allowed nodes. - pub async fn update_nodes_async( - &self, - nodes: Vec, - ) -> Result<(), RpcClientError> { - self.cmdq - .send(Command::UpdateNodes(nodes)) - .await - .map_err(|_| RpcClientError::Dropped) - } - /// Wait for the controller to process all queued messages. #[cfg(test)] async fn flush_cmd_queue(&self) -> Result<(), RpcClientError> { @@ -729,13 +712,13 @@ mod test { let _guard = rt.enter(); // Ensure Tokio runtime is available. let transport = MockTransport::new(); let builder = session::Builder::default(); - let client = RpcClient::new(Box::new(transport.clone()), builder, vec![]); + let client = RpcClient::new(Box::new(transport.clone()), builder); // Basic secure call. let result: u64 = rt .block_on(async { client - .secure_call("test", 42) + .secure_call("test", 42, vec![]) .await .into_result_with_feedback() .await @@ -759,7 +742,7 @@ mod test { let result: u64 = rt .block_on(async { client - .secure_call("test", 43) + .secure_call("test", 43, vec![]) .await .into_result_with_feedback() .await @@ -785,7 +768,7 @@ mod test { let result: u64 = rt .block_on(async { client - .secure_call("test", 44) + .secure_call("test", 44, vec![]) .await .into_result_with_feedback() .await @@ -809,7 +792,7 @@ mod test { let result: u64 = rt .block_on(async { client - .insecure_call("test", 45) + .insecure_call("test", 45, vec![]) .await .into_result_with_feedback() .await @@ -831,7 +814,7 @@ mod test { let result: u64 = rt .block_on(async { client - .insecure_call("test", 46) + .insecure_call("test", 46, vec![]) .await .into_result_with_feedback() .await diff --git a/tests/runtimes/simple-keyvalue/src/main.rs b/tests/runtimes/simple-keyvalue/src/main.rs index 7964b39593a..4c46322ba2b 100644 --- a/tests/runtimes/simple-keyvalue/src/main.rs +++ b/tests/runtimes/simple-keyvalue/src/main.rs @@ -388,7 +388,6 @@ pub fn main_with_version(version: Version) { state.identity.clone(), 1024, trusted_signers(), - vec![], )); let key_manager = km_client.clone(); From fb28b63c4315e5675d9036f508aa3c3a253ade9f Mon Sep 17 00:00:00 2001 From: Peter Nose Date: Tue, 24 Sep 2024 13:03:43 +0200 Subject: [PATCH 2/2] keymanager/src/client: Fetch churp key shares concurrently --- .changelog/5863.internal.md | 1 + keymanager/src/client/remote.rs | 41 ++++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) create mode 100644 .changelog/5863.internal.md diff --git a/.changelog/5863.internal.md b/.changelog/5863.internal.md new file mode 100644 index 00000000000..1717e954478 --- /dev/null +++ b/.changelog/5863.internal.md @@ -0,0 +1 @@ +keymanager/src/client: Fetch churp key shares concurrently diff --git a/keymanager/src/client/remote.rs b/keymanager/src/client/remote.rs index b13704d42c0..a8fbddfc294 100644 --- a/keymanager/src/client/remote.rs +++ b/keymanager/src/client/remote.rs @@ -9,6 +9,7 @@ use std::{ use anyhow::anyhow; use async_trait::async_trait; +use futures::stream::{FuturesUnordered, StreamExt}; use group::GroupEncoding; use lru::LruCache; use rand::{prelude::SliceRandom, rngs::OsRng}; @@ -248,20 +249,23 @@ impl RemoteClient { let mut shares = Vec::with_capacity(min_shares); // Fetch key shares in random order. - // TODO: Optimize by fetching key shares concurrently. let mut committee = status.committee; committee.shuffle(&mut OsRng); - for node_id in committee { - // Stop fetching when enough shares are received. - if shares.len() == min_shares { - break; - } + // Fetch key shares concurrently. + let mut futures = FuturesUnordered::new(); + + loop { + // Continuously add new key share requests until the required + // number of key shares is received, ensuring the future queue + // remains filled even if some requests fail. + while shares.len() + futures.len() < min_shares { + let node_id = match committee.pop() { + Some(node_id) => node_id, + None => return Err(KeyManagerError::InsufficientKeyShares), + }; - // Fetch key share from the current node. - let response = self - .rpc_client - .secure_call( + let future = self.rpc_client.secure_call( METHOD_SGX_POLICY_KEY_SHARE, KeyShareRequest { id: status.id, @@ -271,10 +275,19 @@ impl RemoteClient { key_id, }, vec![node_id], - ) - .await - .into_result_with_feedback() - .await; + ); + + futures.push(future); + } + + // Wait for the next future to finish. + let response = match futures.next().await { + Some(response) => response, + None => break, + }; + + // Send back peer feedback. + let response = response.into_result_with_feedback().await; // Decode the response. let encoded_share: EncodedEncryptedPoint = match response {