Skip to content

Commit

Permalink
chore(mempool_infra): wrap serde operations with generic custom struc…
Browse files Browse the repository at this point in the history
…t towards binary serde

commit-id:325d0501
  • Loading branch information
Itay-Tsabary-Starkware committed Sep 21, 2024
1 parent 171eece commit 5f99898
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 67 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 14 additions & 14 deletions crates/gateway_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ impl GatewayClient for LocalGatewayClientImpl {
}
}

#[async_trait]
impl GatewayClient for RemoteGatewayClientImpl {
#[instrument(skip(self))]
async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
let request = GatewayRequest::AddTransaction(gateway_input);
let response = self.send(request).await?;
match response {
GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
GatewayResponse::AddTransaction(Err(response)) => {
Err(GatewayClientError::GatewayError(response))
}
}
}
}
// #[async_trait]
// impl GatewayClient for RemoteGatewayClientImpl {
// #[instrument(skip(self))]
// async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
// let request = GatewayRequest::AddTransaction(gateway_input);
// let response = self.send(request).await?;
// match response {
// GatewayResponse::AddTransaction(Ok(response)) => Ok(response),
// GatewayResponse::AddTransaction(Err(response)) => {
// Err(GatewayClientError::GatewayError(response))
// }
// }
// }
// }
4 changes: 3 additions & 1 deletion crates/mempool_infra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ hyper = { workspace = true, features = ["client", "http2", "server", "tcp"] }
papyrus_config.workspace = true
rstest.workspace = true
serde = { workspace = true, features = ["derive"] }
starknet-types-core.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tracing.workspace = true
validator.workspace = true


[dev-dependencies]
assert_matches.workspace = true
pretty_assertions.workspace = true
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ use std::marker::PhantomData;
use std::net::IpAddr;
use std::sync::Arc;

use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Client, Request as HyperRequest, Response as HyperResponse, StatusCode, Uri};
use serde::de::DeserializeOwned;
use serde::Serialize;

use super::definitions::{ClientError, ClientResult};
use crate::component_definitions::APPLICATION_OCTET_STREAM;
use crate::component_definitions::{BincodeSerializable, SerdeWrapper, APPLICATION_OCTET_STREAM};

/// The `RemoteComponentClient` struct is a generic client for sending component requests and
/// receiving responses asynchronously through HTTP connection.
Expand All @@ -35,12 +34,12 @@ use crate::component_definitions::APPLICATION_OCTET_STREAM;
/// use crate::starknet_mempool_infra::component_client::RemoteComponentClient;
///
/// // Define your request and response types
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug, Clone)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// content: String,
/// }
Expand Down Expand Up @@ -79,8 +78,8 @@ where

impl<Request, Response> RemoteComponentClient<Request, Response>
where
Request: Serialize,
Response: DeserializeOwned,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Clone,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
pub fn new(ip_address: IpAddr, port: u16, max_retries: usize) -> Self {
let uri = match ip_address {
Expand All @@ -98,23 +97,25 @@ where
// Construct and request, and send it up to 'max_retries' times. Return if received a
// successful response.
for _ in 0..self.max_retries {
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request.clone());
let res = self.try_send(http_request).await;
if res.is_ok() {
return res;
}
}
// Construct and send the request, return the received respone regardless whether it
// Construct and send the request, return the received response regardless whether it
// successful or not.
let http_request = self.construct_http_request(&component_request);
let http_request = self.construct_http_request(component_request);
self.try_send(http_request).await
}

fn construct_http_request(&self, component_request: &Request) -> HyperRequest<Body> {
fn construct_http_request(&self, component_request: Request) -> HyperRequest<Body> {
HyperRequest::post(self.uri.clone())
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(component_request).expect("Request serialization should succeed"),
SerdeWrapper { data: component_request }
.to_bincode()
.expect("Request serialization should succeed"),
))
.expect("Request building should succeed")
}
Expand All @@ -138,12 +139,15 @@ where

async fn get_response_body<Response>(response: HyperResponse<Body>) -> Result<Response, ClientError>
where
Response: DeserializeOwned,
Response: Serialize + DeserializeOwned + std::fmt::Debug,
{
let body_bytes = to_bytes(response.into_body())
.await
.map_err(|e| ClientError::ResponseParsingFailure(Arc::new(e)))?;
deserialize(&body_bytes).map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))

SerdeWrapper::<Response>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
.map(|open| open.data)
}

// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do
Expand Down
35 changes: 35 additions & 0 deletions crates/mempool_infra/src/component_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::net::IpAddr;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use papyrus_config::dumping::{ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::error;
use validator::Validate;

const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 32;
const DEFAULT_RETRIES: usize = 3;

#[cfg(test)]
#[path = "component_definitions_test.rs"]
pub mod component_definitions_test;

#[async_trait]
pub trait ComponentRequestHandler<Request, Response> {
async fn handle_request(&mut self, request: Request) -> Response;
Expand Down Expand Up @@ -114,3 +121,31 @@ impl Default for RemoteComponentCommunicationConfig {
Self { ip: "0.0.0.0".parse().unwrap(), port: 8080, retries: DEFAULT_RETRIES }
}
}

// Generic wrapper struct
#[derive(Serialize, Deserialize, std::fmt::Debug)]
pub struct SerdeWrapper<T> {
pub data: T,
}

// Trait to define our serialization and deserialization behavior
pub trait BincodeSerializable: Sized {
fn to_bincode(&self) -> Result<Vec<u8>, bincode::Error>;
fn from_bincode(bytes: &[u8]) -> Result<Self, bincode::Error>;
}

// Implement the trait for our wrapper
impl<T: Serialize + for<'de> Deserialize<'de>> BincodeSerializable for SerdeWrapper<T>
where
T: std::fmt::Debug,
{
// #[instrument(err, ret)]
fn to_bincode(&self) -> Result<Vec<u8>, bincode::Error> {
serialize(self)
}

// #[instrument(err, ret)]
fn from_bincode(bytes: &[u8]) -> Result<Self, bincode::Error> {
deserialize(bytes)
}
}
46 changes: 46 additions & 0 deletions crates/mempool_infra/src/component_definitions_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use serde::{Deserialize, Serialize};
use starknet_types_core::felt::Felt;

use crate::component_definitions::{BincodeSerializable, SerdeWrapper};
use crate::trace_util::configure_tracing;

#[test]
fn test_serde_native_type() {
let data: u32 = 8;

let encoded =
SerdeWrapper { data }.to_bincode().expect("Server error serialization should succeed");
let decoded = SerdeWrapper::<u32>::from_bincode(&encoded).unwrap();

assert_eq!(data, decoded.data);
}

#[test]
fn test_serde_struct_type() {
#[derive(Serialize, Deserialize, std::fmt::Debug, Clone, std::cmp::PartialEq, Copy)]
struct TestStruct {
a: u32,
b: u32,
}

let data: TestStruct = TestStruct { a: 17, b: 8 };

let encoded =
SerdeWrapper { data }.to_bincode().expect("Server error serialization should succeed");
let decoded = SerdeWrapper::<TestStruct>::from_bincode(&encoded).unwrap();

assert_eq!(data, decoded.data);
}

#[test]
fn test_serde_felt() {
configure_tracing();

let data: Felt = Felt::ONE;

let encoded =
SerdeWrapper { data }.to_bincode().expect("Server error serialization should succeed");
let decoded = SerdeWrapper::<Felt>::from_bincode(&encoded).unwrap();

assert_eq!(data, decoded.data);
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::mpsc::Sender;
use tracing::instrument;

use super::definitions::ComponentServerStarter;
use crate::component_client::send_locally;
use crate::component_client::{send_locally, ClientError};
use crate::component_definitions::{
BincodeSerializable,
ComponentRequestAndResponseSender,
SerdeWrapper,
ServerError,
APPLICATION_OCTET_STREAM,
};
Expand Down Expand Up @@ -63,14 +66,14 @@ use crate::component_definitions::{
/// }
///
/// // Define your request and response types
/// #[derive(Deserialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyRequest {
/// pub content: String,
/// }
///
/// #[derive(Serialize)]
/// #[derive(Serialize, Deserialize, Debug)]
/// struct MyResponse {
/// pub content: String,
/// content: String,
/// }
///
/// // Define your request processing logic
Expand Down Expand Up @@ -101,17 +104,17 @@ use crate::component_definitions::{
/// ```
pub struct RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + 'static,
{
socket: SocketAddr,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Request, Response> RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
Response: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static,
{
pub fn new(
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
Expand All @@ -121,25 +124,34 @@ where
Self { tx, socket: SocketAddr::new(ip_address, port) }
}

async fn handler(
#[instrument(ret, err)]
async fn remote_component_server_handler(
http_request: HyperRequest<Body>,
tx: Sender<ComponentRequestAndResponseSender<Request, Response>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {

let http_response = match SerdeWrapper::<Request>::from_bincode(&body_bytes)
.map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e)))
.map(|open| open.data)
{
Ok(request) => {
let response = send_locally(tx, request).await;
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&response).expect("Response serialization should succeed"),
SerdeWrapper { data: response }
.to_bincode()
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
SerdeWrapper { data: server_error }
.to_bincode()
.expect("Server error serialization should succeed"),
))
}
}
Expand All @@ -152,13 +164,17 @@ where
#[async_trait]
impl<Request, Response> ComponentServerStarter for RemoteComponentServer<Request, Response>
where
Request: DeserializeOwned + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
Request: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
Response: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let tx = self.tx.clone();
async { Ok::<_, hyper::Error>(service_fn(move |req| Self::handler(req, tx.clone()))) }
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::remote_component_server_handler(req, tx.clone())
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
Expand Down
Loading

0 comments on commit 5f99898

Please sign in to comment.