Skip to content

Commit

Permalink
chore(mempool_infra): wrap serde operations with generic custom struc…
Browse files Browse the repository at this point in the history
…t towards binary serde

commit-id:325d0501
  • Loading branch information
Itay-Tsabary-Starkware committed Sep 22, 2024
1 parent 192301f commit cc2530b
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 82 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions crates/batcher_types/src/batcher_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ use serde::{Deserialize, Serialize};
use crate::errors::BatcherError;

// TODO(Tsabary/Yael/Dafna): Populate the data structure used to invoke the batcher.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatcherFnOneInput {}

// TODO(Tsabary/Yael/Dafna): Populate the data structure used to invoke the batcher.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatcherFnTwoInput {}

// TODO(Tsabary/Yael/Dafna): Replace with the actual return type of the batcher function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatcherFnOneReturnValue {}

// TODO(Tsabary/Yael/Dafna): Replace with the actual return type of the batcher function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatcherFnTwoReturnValue {}

pub type BatcherResult<T> = Result<T, BatcherError>;
4 changes: 2 additions & 2 deletions crates/batcher_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ pub trait BatcherClient: Send + Sync {
) -> BatcherClientResult<BatcherFnTwoReturnValue>;
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum BatcherRequest {
BatcherFnOne(BatcherFnOneInput),
BatcherFnTwo(BatcherFnTwoInput),
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum BatcherResponse {
BatcherFnOne(BatcherResult<BatcherFnOneReturnValue>),
BatcherFnTwo(BatcherResult<BatcherFnTwoReturnValue>),
Expand Down
5 changes: 3 additions & 2 deletions crates/consensus_manager_types/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
Expand Down Expand Up @@ -47,13 +48,13 @@ pub trait ConsensusManagerClient: Send + Sync {
) -> ConsensusManagerClientResult<ConsensusManagerFnTwoReturnValue>;
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ConsensusManagerRequest {
ConsensusManagerFnOne(ConsensusManagerFnOneInput),
ConsensusManagerFnTwo(ConsensusManagerFnTwoInput),
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ConsensusManagerResponse {
ConsensusManagerFnOne(ConsensusManagerResult<ConsensusManagerFnOneReturnValue>),
ConsensusManagerFnTwo(ConsensusManagerResult<ConsensusManagerFnTwoReturnValue>),
Expand Down
10 changes: 6 additions & 4 deletions crates/consensus_manager_types/src/consensus_manager_types.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
use std::fmt::Debug;

use serde::{Deserialize, Serialize};

use crate::errors::ConsensusManagerError;

// TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnOneInput {}

// TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnTwoInput {}

// TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnOneReturnValue {}

// TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConsensusManagerFnTwoReturnValue {}

pub type ConsensusManagerResult<T> = Result<T, ConsensusManagerError>;
28 changes: 14 additions & 14 deletions crates/gateway_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ impl GatewayClient for LocalGatewayClientImpl {
}
}

#[async_trait]
impl GatewayClient for RemoteGatewayClientImpl {
#[instrument(skip(self))]
async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
let request = GatewayRequest::AddTransaction(gateway_input);
let response = self.send(request).await?;
match response {
GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
GatewayResponse::AddTransaction(Err(response)) => {
Err(GatewayClientError::GatewayError(response))
}
}
}
}
// #[async_trait]
// impl GatewayClient for RemoteGatewayClientImpl {
// #[instrument(skip(self))]
// async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
// let request = GatewayRequest::AddTransaction(gateway_input);
// let response = self.send(request).await?;
// match response {
// GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
// GatewayResponse::AddTransaction(Err(response)) => {
// Err(GatewayClientError::GatewayError(response))
// }
// }
// }
// }
2 changes: 2 additions & 0 deletions crates/mempool_infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ hyper = { workspace = true, features = ["client", "http2", "server", "tcp"] }
papyrus_config.workspace = true
rstest.workspace = true
serde = { workspace = true, features = ["derive"] }
starknet-types-core.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
validator.workspace = true


[dev-dependencies]
assert_matches.workspace = true
pretty_assertions.workspace = true
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::marker::PhantomData;
use std::net::IpAddr;
use std::sync::Arc;

use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Client, Request as HyperRequest, Response as HyperResponse, StatusCode, Uri};
Expand All @@ -11,6 +10,7 @@ use serde::Serialize;

use super::definitions::{ClientError, ClientResult};
use crate::component_definitions::APPLICATION_OCTET_STREAM;
use crate::serde_utils::{BincodeSerializable, SerdeWrapper};

/// The `RemoteComponentClient` struct is a generic client for sending component requests and
/// receiving responses asynchronously through HTTP connection.
Expand All @@ -35,12 +35,12 @@ use crate::component_definitions::APPLICATION_OCTET_STREAM;
/// use crate::starknet_mempool_infra::component_client::RemoteComponentClient;
///
/// // Define your request and response types
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug, Clone)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// content: String,
/// }
Expand Down Expand Up @@ -79,8 +79,8 @@ where

impl<Request, Response> RemoteComponentClient<Request, Response>
where
Request: Serialize,
Response: DeserializeOwned,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Clone,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
pub fn new(ip_address: IpAddr, port: u16, max_retries: usize) -> Self {
let uri = match ip_address {
Expand All @@ -98,23 +98,25 @@ where
// Construct and request, and send it up to 'max_retries' times. Return if received a
// successful response.
for _ in 0..self.max_retries {
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request.clone());
let res = self.try_send(http_request).await;
if res.is_ok() {
return res;
}
}
// Construct and send the request, return the received respone regardless whether it
// Construct and send the request, return the received response regardless whether it
// successful or not.
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request);
self.try_send(http_request).await
}

fn construct_http_request(&self, component_request: &Request) -> HyperRequest<Body> {
fn construct_http_request(&self, component_request: Request) -> HyperRequest<Body> {
HyperRequest::post(self.uri.clone())
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(component_request).expect("Request serialization should succeed"),
SerdeWrapper::new(component_request)
.to_bincode()
.expect("Request serialization should succeed"),
))
.expect("Request building should succeed")
}
Expand All @@ -138,12 +140,14 @@ where

async fn get_response_body<Response>(response: HyperResponse<Body>) -> Result<Response, ClientError>
where
Response: DeserializeOwned,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
let body_bytes = to_bytes(response.into_body())
.await
.map_err(|e| ClientError::ResponseParsingFailure(Arc::new(e)))?;
deserialize(&body_bytes).map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))

SerdeWrapper::<Response>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
}

// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do
Expand Down
6 changes: 4 additions & 2 deletions crates/mempool_infra/src/component_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::net::IpAddr;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use papyrus_config::dumping::{ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::error;
use validator::Validate;

pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 32;
const DEFAULT_RETRIES: usize = 3;

Expand Down Expand Up @@ -45,8 +49,6 @@ where
pub tx: Sender<Response>,
}

pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";

#[derive(Debug, Error, Deserialize, Serialize, Clone)]
pub enum ServerError {
#[error("Could not deserialize client request: {0}")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::mpsc::Sender;
use tracing::instrument;

use super::definitions::ComponentServerStarter;
use crate::component_client::send_locally;
use crate::component_client::{send_locally, ClientError};
use crate::component_definitions::{
ComponentRequestAndResponseSender,
ServerError,
APPLICATION_OCTET_STREAM,
};
use crate::serde_utils::{BincodeSerializable, SerdeWrapper};

/// The `RemoteComponentServer` struct is a generic server that handles requests and responses for a
/// specified component. It receives requests, processes them using the provided component, and
Expand Down Expand Up @@ -63,14 +65,14 @@ use crate::component_definitions::{
/// }
///
/// // Define your request and response types
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// pub content: String,
/// content: String,
/// }
///
/// // Define your request processing logic
Expand Down Expand Up @@ -101,17 +103,17 @@ use crate::component_definitions::{
/// ```
pub struct RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + 'static,
{
socket: SocketAddr,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Request, Response> RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
{
pub fn new(
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
Expand All @@ -121,25 +123,33 @@ where
Self { tx, socket: SocketAddr::new(ip_address, port) }
}

async fn handler(
#[instrument(ret, err)]
async fn remote_component_server_handler(
http_request: HyperRequest<Body>,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {

let http_response = match SerdeWrapper::<Request>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
{
Ok(request) => {
let response = send_locally(tx, request).await;
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&response).expect("Response serialization should succeed"),
SerdeWrapper::new(response)
.to_bincode()
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
SerdeWrapper::new(server_error)
.to_bincode()
.expect("Server error serialization should succeed"),
))
}
}
Expand All @@ -152,13 +162,17 @@ where
#[async_trait]
impl<Request, Response> ComponentServerStarter for RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let tx = self.tx.clone();
async { Ok::<_, hyper::Error>(service_fn(move |req| Self::handler(req, tx.clone()))) }
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::remote_component_server_handler(req, tx.clone())
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
Expand Down
Loading

0 comments on commit cc2530b

Please sign in to comment.