From c4d0defef6919c803538deeaae7b885c293058b3 Mon Sep 17 00:00:00 2001 From: Itay Tsabary Date: Thu, 19 Sep 2024 15:50:47 +0300 Subject: [PATCH] refactor(mempool_infra): change remote server to send messages to local server commit-id:7e0d8402 --- crates/batcher/src/communication.rs | 13 +--- crates/consensus_manager/src/communication.rs | 14 +--- crates/mempool/src/communication.rs | 16 +---- .../src/component_client/definitions.rs | 17 ++++- .../local_component_client.rs | 9 +-- .../remote_component_server.rs | 69 +++++++------------ .../remote_component_client_server_test.rs | 39 +++++++---- 7 files changed, 74 insertions(+), 103 deletions(-) diff --git a/crates/batcher/src/communication.rs b/crates/batcher/src/communication.rs index 308255980f..d1784672f0 100644 --- a/crates/batcher/src/communication.rs +++ b/crates/batcher/src/communication.rs @@ -1,5 +1,3 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_batcher_types::communication::{ BatcherRequest, @@ -7,13 +5,12 @@ use starknet_batcher_types::communication::{ BatcherResponse, }; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; -use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalComponentServer; use tokio::sync::mpsc::Receiver; use crate::batcher::Batcher; pub type LocalBatcherServer = LocalComponentServer; -pub type RemoteBatcherServer = RemoteComponentServer; pub fn create_local_batcher_server( batcher: Batcher, @@ -22,14 +19,6 @@ pub fn create_local_batcher_server( LocalComponentServer::new(batcher, rx_batcher) } -pub fn create_remote_batcher_server( - batcher: Batcher, - ip_address: IpAddr, - port: u16, -) -> RemoteBatcherServer { - RemoteComponentServer::new(batcher, ip_address, port) -} - #[async_trait] impl ComponentRequestHandler for Batcher { async fn handle_request(&mut self, request: BatcherRequest) -> BatcherResponse { diff --git a/crates/consensus_manager/src/communication.rs b/crates/consensus_manager/src/communication.rs index 7dfff057e4..79e31ec9b6 100644 --- a/crates/consensus_manager/src/communication.rs +++ b/crates/consensus_manager/src/communication.rs @@ -1,5 +1,3 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_consensus_manager_types::communication::{ ConsensusManagerRequest, @@ -7,15 +5,13 @@ use starknet_consensus_manager_types::communication::{ ConsensusManagerResponse, }; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; -use starknet_mempool_infra::component_server::{LocalActiveComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalActiveComponentServer; use tokio::sync::mpsc::Receiver; use crate::consensus_manager::ConsensusManager; pub type LocalConsensusManagerServer = LocalActiveComponentServer; -pub type RemoteConsensusManagerServer = - RemoteComponentServer; pub fn create_local_consensus_manager_server( consensus_manager: ConsensusManager, @@ -24,14 +20,6 @@ pub fn create_local_consensus_manager_server( LocalActiveComponentServer::new(consensus_manager, rx_consensus_manager) } -pub fn create_remote_consensus_manager_server( - consensus_manager: ConsensusManager, - ip_address: IpAddr, - port: u16, -) -> RemoteConsensusManagerServer { - RemoteComponentServer::new(consensus_manager, ip_address, port) -} - #[async_trait] impl ComponentRequestHandler for ConsensusManager diff --git a/crates/mempool/src/communication.rs b/crates/mempool/src/communication.rs index 946bfa377f..8afdf21bac 100644 --- a/crates/mempool/src/communication.rs +++ b/crates/mempool/src/communication.rs @@ -1,10 +1,8 @@ -use std::net::IpAddr; - use async_trait::async_trait; use starknet_api::executable_transaction::Transaction; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_runner::ComponentStarter; -use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer}; +use starknet_mempool_infra::component_server::LocalComponentServer; use starknet_mempool_types::communication::{ MempoolRequest, MempoolRequestAndResponseSender, @@ -19,9 +17,6 @@ use crate::mempool::Mempool; pub type MempoolServer = LocalComponentServer; -pub type RemoteMempoolServer = - RemoteComponentServer; - pub fn create_mempool_server( mempool: Mempool, rx_mempool: Receiver, @@ -30,15 +25,6 @@ pub fn create_mempool_server( LocalComponentServer::new(communication_wrapper, rx_mempool) } -pub fn create_remote_mempool_server( - mempool: Mempool, - ip_address: IpAddr, - port: u16, -) -> RemoteMempoolServer { - let communication_wrapper = MempoolCommunicationWrapper::new(mempool); - RemoteComponentServer::new(communication_wrapper, ip_address, port) -} - /// Wraps the mempool to enable inbound async communication from other components. pub struct MempoolCommunicationWrapper { mempool: Mempool, diff --git a/crates/mempool_infra/src/component_client/definitions.rs b/crates/mempool_infra/src/component_client/definitions.rs index d1c4fe53be..61ad1c0194 100644 --- a/crates/mempool_infra/src/component_client/definitions.rs +++ b/crates/mempool_infra/src/component_client/definitions.rs @@ -2,8 +2,9 @@ use std::sync::Arc; use hyper::StatusCode; use thiserror::Error; +use tokio::sync::mpsc::{channel, Sender}; -use crate::component_definitions::ServerError; +use crate::component_definitions::{ComponentRequestAndResponseSender, ServerError}; #[derive(Clone, Debug, Error)] pub enum ClientError { @@ -20,3 +21,17 @@ pub enum ClientError { } pub type ClientResult = Result; + +pub async fn send_locally( + tx: Sender>, + request: Request, +) -> Response +where + Request: Send + Sync, + Response: Send + Sync, +{ + let (res_tx, mut res_rx) = channel::(1); + let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx }; + tx.send(request_and_res_tx).await.expect("Outbound connection should be open."); + res_rx.recv().await.expect("Inbound connection should be open.") +} diff --git a/crates/mempool_infra/src/component_client/local_component_client.rs b/crates/mempool_infra/src/component_client/local_component_client.rs index ce561bc3ac..0c77d9e4e0 100644 --- a/crates/mempool_infra/src/component_client/local_component_client.rs +++ b/crates/mempool_infra/src/component_client/local_component_client.rs @@ -1,5 +1,6 @@ -use tokio::sync::mpsc::{channel, Sender}; +use tokio::sync::mpsc::Sender; +use crate::component_client::send_locally; use crate::component_definitions::ComponentRequestAndResponseSender; /// The `LocalComponentClient` struct is a generic client for sending component requests and @@ -71,11 +72,7 @@ where // TODO(Tsabary, 1/5/2024): Consider implementation for messages without expected responses. pub async fn send(&self, request: Request) -> Response { - let (res_tx, mut res_rx) = channel::(1); - let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx }; - self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open."); - - res_rx.recv().await.expect("Inbound connection should be open.") + send_locally(self.tx.clone(), request).await } } diff --git a/crates/mempool_infra/src/component_server/remote_component_server.rs b/crates/mempool_infra/src/component_server/remote_component_server.rs index 70dc099d6e..88c09f7f57 100644 --- a/crates/mempool_infra/src/component_server/remote_component_server.rs +++ b/crates/mempool_infra/src/component_server/remote_component_server.rs @@ -1,6 +1,4 @@ -use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; use async_trait::async_trait; use bincode::{deserialize, serialize}; @@ -10,11 +8,12 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode}; use serde::de::DeserializeOwned; use serde::Serialize; -use tokio::sync::Mutex; +use tokio::sync::mpsc::Sender; use super::definitions::ComponentServerStarter; +use crate::component_client::send_locally; use crate::component_definitions::{ - ComponentRequestHandler, + ComponentRequestAndResponseSender, ServerError, APPLICATION_OCTET_STREAM, }; @@ -84,17 +83,15 @@ use crate::component_definitions::{ /// /// #[tokio::main] /// async fn main() { -/// // Instantiate the component. -/// let component = MyComponent {}; +/// // Instantiate a channel to communicate with component. +/// let (tx, _rx) = tokio::sync::mpsc::channel(32); /// /// // Set the ip address and port of the server's socket. /// let ip_address = std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); /// let port: u16 = 8080; /// /// // Instantiate the server. -/// let mut server = RemoteComponentServer::::new( -/// component, ip_address, port, -/// ); +/// let mut server = RemoteComponentServer::::new(tx, ip_address, port); /// /// // Start the server in a new task. /// task::spawn(async move { @@ -102,49 +99,41 @@ use crate::component_definitions::{ /// }); /// } /// ``` -pub struct RemoteComponentServer +pub struct RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, - Request: DeserializeOwned + Send + 'static, - Response: Serialize + 'static, + Request: DeserializeOwned + Send + Sync + 'static, + Response: Serialize + Send + Sync + 'static, { socket: SocketAddr, - component: Arc>, - _req: PhantomData, - _res: PhantomData, + tx: Sender>, } -impl RemoteComponentServer +impl RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, - Request: DeserializeOwned + Send + 'static, - Response: Serialize + 'static, + Request: DeserializeOwned + Send + Sync + 'static, + Response: Serialize + Send + Sync + 'static, { - pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self { - Self { - component: Arc::new(Mutex::new(component)), - socket: SocketAddr::new(ip_address, port), - _req: PhantomData, - _res: PhantomData, - } + pub fn new( + tx: Sender>, + ip_address: IpAddr, + port: u16, + ) -> Self { + Self { tx, socket: SocketAddr::new(ip_address, port) } } async fn handler( http_request: HyperRequest, - component: Arc>, + tx: Sender>, ) -> Result, hyper::Error> { let body_bytes = to_bytes(http_request.into_body()).await?; let http_response = match deserialize(&body_bytes) { - Ok(component_request) => { - // Acquire the lock for component computation, release afterwards. - let component_response = - { component.lock().await.handle_request(component_request).await }; + Ok(request) => { + let response = send_locally(tx, request).await; HyperResponse::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( - serialize(&component_response) - .expect("Response serialization should succeed"), + serialize(&response).expect("Response serialization should succeed"), )) } Err(error) => { @@ -161,21 +150,15 @@ where } #[async_trait] -impl ComponentServerStarter - for RemoteComponentServer +impl ComponentServerStarter for RemoteComponentServer where - Component: ComponentRequestHandler + Send + 'static, Request: DeserializeOwned + Send + Sync + 'static, Response: Serialize + Send + Sync + 'static, { async fn start(&mut self) { let make_svc = make_service_fn(|_conn| { - let component = Arc::clone(&self.component); - async { - Ok::<_, hyper::Error>(service_fn(move |req| { - Self::handler(req, Arc::clone(&component)) - })) - } + let tx = self.tx.clone(); + async { Ok::<_, hyper::Error>(service_fn(move |req| Self::handler(req, tx.clone()))) } }); Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap(); diff --git a/crates/mempool_infra/tests/remote_component_client_server_test.rs b/crates/mempool_infra/tests/remote_component_client_server_test.rs index eea574cdf2..c1758079f9 100644 --- a/crates/mempool_infra/tests/remote_component_client_server_test.rs +++ b/crates/mempool_infra/tests/remote_component_client_server_test.rs @@ -24,11 +24,17 @@ use rstest::rstest; use serde::Serialize; use starknet_mempool_infra::component_client::{ClientError, ClientResult, RemoteComponentClient}; use starknet_mempool_infra::component_definitions::{ + ComponentRequestAndResponseSender, ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM, }; -use starknet_mempool_infra::component_server::{ComponentServerStarter, RemoteComponentServer}; +use starknet_mempool_infra::component_server::{ + ComponentServerStarter, + LocalComponentServer, + RemoteComponentServer, +}; +use tokio::sync::mpsc::channel; use tokio::sync::Mutex; use tokio::task; @@ -162,23 +168,30 @@ async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) { let component_a = ComponentA::new(Box::new(b_client)); let component_b = ComponentB::new(setup_value, Box::new(a_client.clone())); - let mut component_a_server = RemoteComponentServer::< - ComponentA, - ComponentARequest, - ComponentAResponse, - >::new(component_a, LOCAL_IP, a_port); - let mut component_b_server = RemoteComponentServer::< - ComponentB, - ComponentBRequest, - ComponentBResponse, - >::new(component_b, LOCAL_IP, b_port); + let (tx_a, rx_a) = + channel::>(32); + let (tx_b, rx_b) = + channel::>(32); + + let mut component_a_local_server = LocalComponentServer::new(component_a, rx_a); + let mut component_b_local_server = LocalComponentServer::new(component_b, rx_b); + + let mut component_a_remote_server = RemoteComponentServer::new(tx_a, LOCAL_IP, a_port); + let mut component_b_remote_server = RemoteComponentServer::new(tx_b, LOCAL_IP, b_port); + + task::spawn(async move { + component_a_local_server.start().await; + }); + task::spawn(async move { + component_b_local_server.start().await; + }); task::spawn(async move { - component_a_server.start().await; + component_a_remote_server.start().await; }); task::spawn(async move { - component_b_server.start().await; + component_b_remote_server.start().await; }); // Todo(uriel): Get rid of this