Skip to content

Commit

Permalink
chore(mempool_infra): wrap serde operations with generic custom struc…
Browse files Browse the repository at this point in the history
…t towards binary serde

commit-id:325d0501
  • Loading branch information
Itay-Tsabary-Starkware committed Sep 24, 2024
1 parent 5c7105f commit d4d9698
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 82 deletions.
4 changes: 2 additions & 2 deletions crates/batcher_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub trait BatcherClient: Send + Sync {
async fn decision_reached(&self, input: DecisionReachedInput) -> BatcherClientResult<()>;
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum BatcherRequest {
BuildProposal(BuildProposalInput),
GetProposalContent(GetProposalContentInput),
Expand All @@ -78,7 +78,7 @@ pub enum BatcherRequest {
DecisionReached(DecisionReachedInput),
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum BatcherResponse {
BuildProposal(BatcherResult<()>),
GetProposalContent(BatcherResult<GetProposalContentResponse>),
Expand Down
5 changes: 3 additions & 2 deletions crates/consensus_manager_types/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
Expand Down Expand Up @@ -47,13 +48,13 @@ pub trait ConsensusManagerClient: Send + Sync {
) -> ConsensusManagerClientResult<ConsensusManagerFnTwoReturnValue>;
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ConsensusManagerRequest {
ConsensusManagerFnOne(ConsensusManagerFnOneInput),
ConsensusManagerFnTwo(ConsensusManagerFnTwoInput),
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ConsensusManagerResponse {
ConsensusManagerFnOne(ConsensusManagerResult<ConsensusManagerFnOneReturnValue>),
ConsensusManagerFnTwo(ConsensusManagerResult<ConsensusManagerFnTwoReturnValue>),
Expand Down
10 changes: 6 additions & 4 deletions crates/consensus_manager_types/src/consensus_manager_types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Debug;

use derive_more::Display;
use serde::{Deserialize, Serialize};

Expand All @@ -10,19 +12,19 @@ use crate::errors::ConsensusManagerError;
pub struct ProposalId(pub u64);

// TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnOneInput {}

// TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnTwoInput {}

// TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnOneReturnValue {}

// TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnTwoReturnValue {}

pub type ConsensusManagerResult<T> = Result<T, ConsensusManagerError>;
28 changes: 14 additions & 14 deletions crates/gateway_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ impl GatewayClient for LocalGatewayClientImpl {
}
}

#[async_trait]
impl GatewayClient for RemoteGatewayClientImpl {
#[instrument(skip(self))]
async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
let request = GatewayRequest::AddTransaction(gateway_input);
let response = self.send(request).await?;
match response {
GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
GatewayResponse::AddTransaction(Err(response)) => {
Err(GatewayClientError::GatewayError(response))
}
}
}
}
// #[async_trait]
// impl GatewayClient for RemoteGatewayClientImpl {
// #[instrument(skip(self))]
// async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
// let request = GatewayRequest::AddTransaction(gateway_input);
// let response = self.send(request).await?;
// match response {
// GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
// GatewayResponse::AddTransaction(Err(response)) => {
// Err(GatewayClientError::GatewayError(response))
// }
// }
// }
// }
1 change: 1 addition & 0 deletions crates/mempool_infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
validator.workspace = true


[dev-dependencies]
assert_matches.workspace = true
pretty_assertions.workspace = true
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::marker::PhantomData;
use std::net::IpAddr;
use std::sync::Arc;

use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Client, Request as HyperRequest, Response as HyperResponse, StatusCode, Uri};
Expand All @@ -11,6 +10,7 @@ use serde::Serialize;

use super::definitions::{ClientError, ClientResult};
use crate::component_definitions::APPLICATION_OCTET_STREAM;
use crate::serde_utils::BincodeSerdeWrapper;

/// The `RemoteComponentClient` struct is a generic client for sending component requests and
/// receiving responses asynchronously through HTTP connection.
Expand All @@ -35,12 +35,12 @@ use crate::component_definitions::APPLICATION_OCTET_STREAM;
/// use crate::starknet_mempool_infra::component_client::RemoteComponentClient;
///
/// // Define your request and response types
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug, Clone)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// content: String,
/// }
Expand Down Expand Up @@ -79,8 +79,8 @@ where

impl<Request, Response> RemoteComponentClient<Request, Response>
where
Request: Serialize,
Response: DeserializeOwned,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Clone,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
pub fn new(ip_address: IpAddr, port: u16, max_retries: usize) -> Self {
let uri = match ip_address {
Expand All @@ -98,23 +98,25 @@ where
// Construct and request, and send it up to 'max_retries' times. Return if received a
// successful response.
for _ in 0..self.max_retries {
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request.clone());
let res = self.try_send(http_request).await;
if res.is_ok() {
return res;
}
}
// Construct and send the request, return the received respone regardless whether it
// Construct and send the request, return the received response regardless whether it
// successful or not.
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request);
self.try_send(http_request).await
}

fn construct_http_request(&self, component_request: &Request) -> HyperRequest<Body> {
fn construct_http_request(&self, component_request: Request) -> HyperRequest<Body> {
HyperRequest::post(self.uri.clone())
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(component_request).expect("Request serialization should succeed"),
BincodeSerdeWrapper::new(component_request)
.to_bincode()
.expect("Request serialization should succeed"),
))
.expect("Request building should succeed")
}
Expand All @@ -138,12 +140,14 @@ where

async fn get_response_body<Response>(response: HyperResponse<Body>) -> Result<Response, ClientError>
where
Response: DeserializeOwned,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
let body_bytes = to_bytes(response.into_body())
.await
.map_err(|e| ClientError::ResponseParsingFailure(Arc::new(e)))?;
deserialize(&body_bytes).map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))

BincodeSerdeWrapper::<Response>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
}

// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do
Expand Down
5 changes: 3 additions & 2 deletions crates/mempool_infra/src/component_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::net::IpAddr;

use async_trait::async_trait;
Expand All @@ -7,8 +8,10 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::error;
use validator::Validate;

pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 32;
const DEFAULT_RETRIES: usize = 3;

Expand Down Expand Up @@ -45,8 +48,6 @@ where
pub tx: Sender<Response>,
}

pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";

#[derive(Debug, Error, Deserialize, Serialize, Clone)]
pub enum ServerError {
#[error("Could not deserialize client request: {0}")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
Expand All @@ -10,8 +10,9 @@ use serde::de::DeserializeOwned;
use serde::Serialize;

use super::definitions::ComponentServerStarter;
use crate::component_client::LocalComponentClient;
use crate::component_client::{ClientError, LocalComponentClient};
use crate::component_definitions::{ServerError, APPLICATION_OCTET_STREAM};
use crate::serde_utils::BincodeSerdeWrapper;

/// The `RemoteComponentServer` struct is a generic server that handles requests and responses for a
/// specified component. It receives requests, processes them using the provided component, and
Expand Down Expand Up @@ -59,14 +60,14 @@ use crate::component_definitions::{ServerError, APPLICATION_OCTET_STREAM};
/// }
///
/// // Define your request and response types
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// pub content: String,
/// content: String,
/// }
///
/// // Define your request processing logic
Expand Down Expand Up @@ -99,17 +100,17 @@ use crate::component_definitions::{ServerError, APPLICATION_OCTET_STREAM};
/// ```
pub struct RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + 'static,
{
socket: SocketAddr,
local_client: LocalComponentClient<Request, Response>,
}

impl<Request, Response> RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
{
pub fn new(
local_client: LocalComponentClient<Request, Response>,
Expand All @@ -119,25 +120,32 @@ where
Self { local_client, socket: SocketAddr::new(ip_address, port) }
}

async fn handler(
async fn remote_component_server_handler(
http_request: HyperRequest<Body>,
local_client: LocalComponentClient<Request, Response>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {

let http_response = match BincodeSerdeWrapper::<Request>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
{
Ok(request) => {
let response = local_client.send(request).await;
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&response).expect("Response serialization should succeed"),
BincodeSerdeWrapper::new(response)
.to_bincode()
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
BincodeSerdeWrapper::new(server_error)
.to_bincode()
.expect("Server error serialization should succeed"),
))
}
}
Expand All @@ -150,15 +158,15 @@ where
#[async_trait]
impl<Request, Response> ComponentServerStarter for RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let local_client = self.local_client.clone();
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, local_client.clone())
Self::remote_component_server_handler(req, local_client.clone())
}))
}
});
Expand Down
18 changes: 9 additions & 9 deletions crates/mempool_infra/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,31 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::ClientResult;
use starknet_mempool_infra::component_runner::ComponentStarter;
use starknet_types_core::felt::Felt;

pub(crate) type ValueA = u32;
pub(crate) type ValueB = u8;
pub(crate) type ValueA = Felt;
pub(crate) type ValueB = Felt;

pub(crate) type ResultA = ClientResult<ValueA>;
pub(crate) type ResultB = ClientResult<ValueB>;

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ComponentARequest {
AGetValue,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ComponentAResponse {
AGetValue(ValueA),
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ComponentBRequest {
BGetValue,
BSetValue(ValueB),
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ComponentBResponse {
BGetValue(ValueB),
BSetValue,
Expand All @@ -52,8 +53,7 @@ impl ComponentA {
}

pub async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await.unwrap();
b_value.into()
self.b.b_get_value().await.unwrap()
}
}

Expand Down Expand Up @@ -92,7 +92,7 @@ pub(crate) async fn test_a_b_functionality(

let new_expected_value: ValueA = expected_value + 1;
// Check that setting a new value to component B succeeds.
assert!(b_client.b_set_value(new_expected_value.try_into().unwrap()).await.is_ok());
assert!(b_client.b_set_value(new_expected_value).await.is_ok());
// Check the new value in component B through client A.
assert_eq!(a_client.a_get_value().await.unwrap(), new_expected_value);
}
Loading

0 comments on commit d4d9698

Please sign in to comment.