Skip to content

Commit

Permalink
refactor(mempool_infra): change remote server to send messages to loc…
Browse files Browse the repository at this point in the history
…al server

commit-id:7e0d8402
  • Loading branch information
Itay-Tsabary-Starkware committed Sep 22, 2024
1 parent 96501ce commit c4d0def
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 103 deletions.
13 changes: 1 addition & 12 deletions crates/batcher/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_batcher_types::communication::{
BatcherRequest,
BatcherRequestAndResponseSender,
BatcherResponse,
};
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalComponentServer;
use tokio::sync::mpsc::Receiver;

use crate::batcher::Batcher;

pub type LocalBatcherServer = LocalComponentServer<Batcher, BatcherRequest, BatcherResponse>;
pub type RemoteBatcherServer = RemoteComponentServer<Batcher, BatcherRequest, BatcherResponse>;

pub fn create_local_batcher_server(
batcher: Batcher,
Expand All @@ -22,14 +19,6 @@ pub fn create_local_batcher_server(
LocalComponentServer::new(batcher, rx_batcher)
}

pub fn create_remote_batcher_server(
batcher: Batcher,
ip_address: IpAddr,
port: u16,
) -> RemoteBatcherServer {
RemoteComponentServer::new(batcher, ip_address, port)
}

#[async_trait]
impl ComponentRequestHandler<BatcherRequest, BatcherResponse> for Batcher {
async fn handle_request(&mut self, request: BatcherRequest) -> BatcherResponse {
Expand Down
14 changes: 1 addition & 13 deletions crates/consensus_manager/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_consensus_manager_types::communication::{
ConsensusManagerRequest,
ConsensusManagerRequestAndResponseSender,
ConsensusManagerResponse,
};
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::{LocalActiveComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalActiveComponentServer;
use tokio::sync::mpsc::Receiver;

use crate::consensus_manager::ConsensusManager;

pub type LocalConsensusManagerServer =
LocalActiveComponentServer<ConsensusManager, ConsensusManagerRequest, ConsensusManagerResponse>;
pub type RemoteConsensusManagerServer =
RemoteComponentServer<ConsensusManager, ConsensusManagerRequest, ConsensusManagerResponse>;

pub fn create_local_consensus_manager_server(
consensus_manager: ConsensusManager,
Expand All @@ -24,14 +20,6 @@ pub fn create_local_consensus_manager_server(
LocalActiveComponentServer::new(consensus_manager, rx_consensus_manager)
}

pub fn create_remote_consensus_manager_server(
consensus_manager: ConsensusManager,
ip_address: IpAddr,
port: u16,
) -> RemoteConsensusManagerServer {
RemoteComponentServer::new(consensus_manager, ip_address, port)
}

#[async_trait]
impl ComponentRequestHandler<ConsensusManagerRequest, ConsensusManagerResponse>
for ConsensusManager
Expand Down
16 changes: 1 addition & 15 deletions crates/mempool/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::net::IpAddr;

use async_trait::async_trait;
use starknet_api::executable_transaction::Transaction;
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_runner::ComponentStarter;
use starknet_mempool_infra::component_server::{LocalComponentServer, RemoteComponentServer};
use starknet_mempool_infra::component_server::LocalComponentServer;
use starknet_mempool_types::communication::{
MempoolRequest,
MempoolRequestAndResponseSender,
Expand All @@ -19,9 +17,6 @@ use crate::mempool::Mempool;
pub type MempoolServer =
LocalComponentServer<MempoolCommunicationWrapper, MempoolRequest, MempoolResponse>;

pub type RemoteMempoolServer =
RemoteComponentServer<MempoolCommunicationWrapper, MempoolRequest, MempoolResponse>;

pub fn create_mempool_server(
mempool: Mempool,
rx_mempool: Receiver<MempoolRequestAndResponseSender>,
Expand All @@ -30,15 +25,6 @@ pub fn create_mempool_server(
LocalComponentServer::new(communication_wrapper, rx_mempool)
}

pub fn create_remote_mempool_server(
mempool: Mempool,
ip_address: IpAddr,
port: u16,
) -> RemoteMempoolServer {
let communication_wrapper = MempoolCommunicationWrapper::new(mempool);
RemoteComponentServer::new(communication_wrapper, ip_address, port)
}

/// Wraps the mempool to enable inbound async communication from other components.
pub struct MempoolCommunicationWrapper {
mempool: Mempool,
Expand Down
17 changes: 16 additions & 1 deletion crates/mempool_infra/src/component_client/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::sync::Arc;

use hyper::StatusCode;
use thiserror::Error;
use tokio::sync::mpsc::{channel, Sender};

use crate::component_definitions::ServerError;
use crate::component_definitions::{ComponentRequestAndResponseSender, ServerError};

#[derive(Clone, Debug, Error)]
pub enum ClientError {
Expand All @@ -20,3 +21,17 @@ pub enum ClientError {
}

pub type ClientResult<T> = Result<T, ClientError>;

pub async fn send_locally<Request, Response>(
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
request: Request,
) -> Response
where
Request: Send + Sync,
Response: Send + Sync,
{
let (res_tx, mut res_rx) = channel::<Response>(1);
let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx };
tx.send(request_and_res_tx).await.expect("Outbound connection should be open.");
res_rx.recv().await.expect("Inbound connection should be open.")
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::mpsc::Sender;

use crate::component_client::send_locally;
use crate::component_definitions::ComponentRequestAndResponseSender;

/// The `LocalComponentClient` struct is a generic client for sending component requests and
Expand Down Expand Up @@ -71,11 +72,7 @@ where
// TODO(Tsabary, 1/5/2024): Consider implementation for messages without expected responses.

pub async fn send(&self, request: Request) -> Response {
let (res_tx, mut res_rx) = channel::<Response>(1);
let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx };
self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open.");

res_rx.recv().await.expect("Inbound connection should be open.")
send_locally(self.tx.clone(), request).await
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
Expand All @@ -10,11 +8,12 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::Mutex;
use tokio::sync::mpsc::Sender;

use super::definitions::ComponentServerStarter;
use crate::component_client::send_locally;
use crate::component_definitions::{
ComponentRequestHandler,
ComponentRequestAndResponseSender,
ServerError,
APPLICATION_OCTET_STREAM,
};
Expand Down Expand Up @@ -84,67 +83,57 @@ use crate::component_definitions::{
///
/// #[tokio::main]
/// async fn main() {
/// // Instantiate the component.
/// let component = MyComponent {};
/// // Instantiate a channel to communicate with component.
/// let (tx, _rx) = tokio::sync::mpsc::channel(32);
///
/// // Set the ip address and port of the server's socket.
/// let ip_address = std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
/// let port: u16 = 8080;
///
/// // Instantiate the server.
/// let mut server = RemoteComponentServer::<MyComponent, MyRequest, MyResponse>::new(
/// component, ip_address, port,
/// );
/// let mut server = RemoteComponentServer::<MyRequest, MyResponse>::new(tx, ip_address, port);
///
/// // Start the server in a new task.
/// task::spawn(async move {
/// server.start().await;
/// });
/// }
/// ```
pub struct RemoteComponentServer<Component, Request, Response>
pub struct RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + 'static,
Response: Serialize + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
socket: SocketAddr,
component: Arc<Mutex<Component>>,
_req: PhantomData<Request>,
_res: PhantomData<Response>,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Component, Request, Response> RemoteComponentServer<Component, Request, Response>
impl<Request, Response> RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + 'static,
Response: Serialize + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
_req: PhantomData,
_res: PhantomData,
}
pub fn new(
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
ip_address: IpAddr,
port: u16,
) -> Self {
Self { tx, socket: SocketAddr::new(ip_address, port) }
}

async fn handler(
http_request: HyperRequest<Body>,
component: Arc<Mutex<Component>>,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {
Ok(component_request) => {
// Acquire the lock for component computation, release afterwards.
let component_response =
{ component.lock().await.handle_request(component_request).await };
Ok(request) => {
let response = send_locally(tx, request).await;
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response)
.expect("Response serialization should succeed"),
serialize(&response).expect("Response serialization should succeed"),
))
}
Err(error) => {
Expand All @@ -161,21 +150,15 @@ where
}

#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for RemoteComponentServer<Component, Request, Response>
impl<Request, Response> ComponentServerStarter for RemoteComponentServer<Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
let tx = self.tx.clone();
async { Ok::<_, hyper::Error>(service_fn(move |req| Self::handler(req, tx.clone()))) }
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
Expand Down
39 changes: 26 additions & 13 deletions crates/mempool_infra/tests/remote_component_client_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ use rstest::rstest;
use serde::Serialize;
use starknet_mempool_infra::component_client::{ClientError, ClientResult, RemoteComponentClient};
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender,
ComponentRequestHandler,
ServerError,
APPLICATION_OCTET_STREAM,
};
use starknet_mempool_infra::component_server::{ComponentServerStarter, RemoteComponentServer};
use starknet_mempool_infra::component_server::{
ComponentServerStarter,
LocalComponentServer,
RemoteComponentServer,
};
use tokio::sync::mpsc::channel;
use tokio::sync::Mutex;
use tokio::task;

Expand Down Expand Up @@ -162,23 +168,30 @@ async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) {
let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client.clone()));

let mut component_a_server = RemoteComponentServer::<
ComponentA,
ComponentARequest,
ComponentAResponse,
>::new(component_a, LOCAL_IP, a_port);
let mut component_b_server = RemoteComponentServer::<
ComponentB,
ComponentBRequest,
ComponentBResponse,
>::new(component_b, LOCAL_IP, b_port);
let (tx_a, rx_a) =
channel::<ComponentRequestAndResponseSender<ComponentARequest, ComponentAResponse>>(32);
let (tx_b, rx_b) =
channel::<ComponentRequestAndResponseSender<ComponentBRequest, ComponentBResponse>>(32);

let mut component_a_local_server = LocalComponentServer::new(component_a, rx_a);
let mut component_b_local_server = LocalComponentServer::new(component_b, rx_b);

let mut component_a_remote_server = RemoteComponentServer::new(tx_a, LOCAL_IP, a_port);
let mut component_b_remote_server = RemoteComponentServer::new(tx_b, LOCAL_IP, b_port);

task::spawn(async move {
component_a_local_server.start().await;
});
task::spawn(async move {
component_b_local_server.start().await;
});

task::spawn(async move {
component_a_server.start().await;
component_a_remote_server.start().await;
});

task::spawn(async move {
component_b_server.start().await;
component_b_remote_server.start().await;
});

// Todo(uriel): Get rid of this
Expand Down

0 comments on commit c4d0def

Please sign in to comment.