From 8c2ab8a553c55a995a0c669170890df12362f93f Mon Sep 17 00:00:00 2001 From: Itay Tsabary Date: Thu, 19 Sep 2024 16:59:20 +0300 Subject: [PATCH] chore(mempool_infra): wrap serde operations with generic custom struct towards binary serde commit-id:325d0501 --- Cargo.lock | 1 + crates/batcher_types/src/batcher_types.rs | 8 ++-- crates/batcher_types/src/communication.rs | 4 +- .../src/communication.rs | 5 +- .../src/consensus_manager_types.rs | 10 ++-- crates/gateway_types/src/communication.rs | 28 +++++------ crates/mempool_infra/Cargo.toml | 2 + .../remote_component_client.rs | 28 ++++++----- .../src/component_definitions.rs | 5 +- .../remote_component_server.rs | 46 ++++++++++++------- crates/mempool_infra/tests/common/mod.rs | 18 ++++---- .../local_component_client_server_test.rs | 5 +- .../remote_component_client_server_test.rs | 37 ++++++++------- crates/mempool_types/src/communication.rs | 6 +-- 14 files changed, 117 insertions(+), 86 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 199073a7a6..f70cbdc4d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9886,6 +9886,7 @@ dependencies = [ "pretty_assertions", "rstest", "serde", + "starknet-types-core", "thiserror", "tokio", "tracing", diff --git a/crates/batcher_types/src/batcher_types.rs b/crates/batcher_types/src/batcher_types.rs index 46ed01af10..10b24e01ff 100644 --- a/crates/batcher_types/src/batcher_types.rs +++ b/crates/batcher_types/src/batcher_types.rs @@ -3,19 +3,19 @@ use serde::{Deserialize, Serialize}; use crate::errors::BatcherError; // TODO(Tsabary/Yael/Dafna): Populate the data structure used to invoke the batcher. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BatcherFnOneInput {} // TODO(Tsabary/Yael/Dafna): Populate the data structure used to invoke the batcher. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BatcherFnTwoInput {} // TODO(Tsabary/Yael/Dafna): Replace with the actual return type of the batcher function. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BatcherFnOneReturnValue {} // TODO(Tsabary/Yael/Dafna): Replace with the actual return type of the batcher function. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BatcherFnTwoReturnValue {} pub type BatcherResult = Result; diff --git a/crates/batcher_types/src/communication.rs b/crates/batcher_types/src/communication.rs index b70ba698be..c4411e38e6 100644 --- a/crates/batcher_types/src/communication.rs +++ b/crates/batcher_types/src/communication.rs @@ -45,13 +45,13 @@ pub trait BatcherClient: Send + Sync { ) -> BatcherClientResult; } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum BatcherRequest { BatcherFnOne(BatcherFnOneInput), BatcherFnTwo(BatcherFnTwoInput), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum BatcherResponse { BatcherFnOne(BatcherResult), BatcherFnTwo(BatcherResult), diff --git a/crates/consensus_manager_types/src/communication.rs b/crates/consensus_manager_types/src/communication.rs index e008df2e97..665f017f94 100644 --- a/crates/consensus_manager_types/src/communication.rs +++ b/crates/consensus_manager_types/src/communication.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::sync::Arc; use async_trait::async_trait; @@ -47,13 +48,13 @@ pub trait ConsensusManagerClient: Send + Sync { ) -> ConsensusManagerClientResult; } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum ConsensusManagerRequest { ConsensusManagerFnOne(ConsensusManagerFnOneInput), ConsensusManagerFnTwo(ConsensusManagerFnTwoInput), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum ConsensusManagerResponse { ConsensusManagerFnOne(ConsensusManagerResult), ConsensusManagerFnTwo(ConsensusManagerResult), diff --git a/crates/consensus_manager_types/src/consensus_manager_types.rs b/crates/consensus_manager_types/src/consensus_manager_types.rs index a1eb9b3eb0..26dffac1bc 100644 --- a/crates/consensus_manager_types/src/consensus_manager_types.rs +++ b/crates/consensus_manager_types/src/consensus_manager_types.rs @@ -1,21 +1,23 @@ +use std::fmt::Debug; + use serde::{Deserialize, Serialize}; use crate::errors::ConsensusManagerError; // TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ConsensusManagerFnOneInput {} // TODO(Tsabary/Matan): Populate the data structure used to invoke the consensus manager. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ConsensusManagerFnTwoInput {} // TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ConsensusManagerFnOneReturnValue {} // TODO(Tsabary/Matan): Replace with the actual return type of the consensus manager function. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ConsensusManagerFnTwoReturnValue {} pub type ConsensusManagerResult = Result; diff --git a/crates/gateway_types/src/communication.rs b/crates/gateway_types/src/communication.rs index 8150bbf5bf..0f1e3fa1ff 100644 --- a/crates/gateway_types/src/communication.rs +++ b/crates/gateway_types/src/communication.rs @@ -65,17 +65,17 @@ impl GatewayClient for LocalGatewayClientImpl { } } -#[async_trait] -impl GatewayClient for RemoteGatewayClientImpl { - #[instrument(skip(self))] - async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult { - let request = GatewayRequest::AddTransaction(gateway_input); - let response = self.send(request).await?; - match response { - GatewayResponse::AddTransaction(Ok(response)) => Ok(response), - GatewayResponse::AddTransaction(Err(response)) => { - Err(GatewayClientError::GatewayError(response)) - } - } - } -} +// #[async_trait] +// impl GatewayClient for RemoteGatewayClientImpl { +// #[instrument(skip(self))] +// async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult { +// let request = GatewayRequest::AddTransaction(gateway_input); +// let response = self.send(request).await?; +// match response { +// GatewayResponse::AddTransaction(Ok(response)) => Ok(response), +// GatewayResponse::AddTransaction(Err(response)) => { +// Err(GatewayClientError::GatewayError(response)) +// } +// } +// } +// } diff --git a/crates/mempool_infra/Cargo.toml b/crates/mempool_infra/Cargo.toml index 3bdb5b3bd3..a083e450f3 100644 --- a/crates/mempool_infra/Cargo.toml +++ b/crates/mempool_infra/Cargo.toml @@ -18,12 +18,14 @@ hyper = { workspace = true, features = ["client", "http2", "server", "tcp"] } papyrus_config.workspace = true rstest.workspace = true serde = { workspace = true, features = ["derive"] } +starknet-types-core.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tracing.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter"] } validator.workspace = true + [dev-dependencies] assert_matches.workspace = true pretty_assertions.workspace = true diff --git a/crates/mempool_infra/src/component_client/remote_component_client.rs b/crates/mempool_infra/src/component_client/remote_component_client.rs index c7cc5f8121..3d2a59cb24 100644 --- a/crates/mempool_infra/src/component_client/remote_component_client.rs +++ b/crates/mempool_infra/src/component_client/remote_component_client.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use std::net::IpAddr; use std::sync::Arc; -use bincode::{deserialize, serialize}; use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; use hyper::{Body, Client, Request as HyperRequest, Response as HyperResponse, StatusCode, Uri}; @@ -11,6 +10,7 @@ use serde::Serialize; use super::definitions::{ClientError, ClientResult}; use crate::component_definitions::APPLICATION_OCTET_STREAM; +use crate::serde_utils::{BincodeSerializable, SerdeWrapper}; /// The `RemoteComponentClient` struct is a generic client for sending component requests and /// receiving responses asynchronously through HTTP connection. @@ -35,12 +35,12 @@ use crate::component_definitions::APPLICATION_OCTET_STREAM; /// use crate::starknet_mempool_infra::component_client::RemoteComponentClient; /// /// // Define your request and response types -/// #[derive(Serialize)] +/// #[derive(Serialize, Deserialize, Debug, Clone)] /// struct MyRequest { /// pub content: String, /// } /// -/// #[derive(Deserialize)] +/// #[derive(Serialize, Deserialize, Debug)] /// struct MyResponse { /// content: String, /// } @@ -79,8 +79,8 @@ where impl RemoteComponentClient where - Request: Serialize, - Response: DeserializeOwned, + Request: Serialize + DeserializeOwned + std::fmt::Debug + Clone, + Response: Serialize + DeserializeOwned + std::fmt::Debug, { pub fn new(ip_address: IpAddr, port: u16, max_retries: usize) -> Self { let uri = match ip_address { @@ -98,23 +98,25 @@ where // Construct and request, and send it up to 'max_retries' times. Return if received a // successful response. for _ in 0..self.max_retries { - let http_request = self.construct_http_request(&component_request); + let http_request = self.construct_http_request(component_request.clone()); let res = self.try_send(http_request).await; if res.is_ok() { return res; } } - // Construct and send the request, return the received respone regardless whether it + // Construct and send the request, return the received response regardless whether it // successful or not. - let http_request = self.construct_http_request(&component_request); + let http_request = self.construct_http_request(component_request); self.try_send(http_request).await } - fn construct_http_request(&self, component_request: &Request) -> HyperRequest { + fn construct_http_request(&self, component_request: Request) -> HyperRequest { HyperRequest::post(self.uri.clone()) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( - serialize(component_request).expect("Request serialization should succeed"), + SerdeWrapper::new(component_request) + .to_bincode() + .expect("Request serialization should succeed"), )) .expect("Request building should succeed") } @@ -138,12 +140,14 @@ where async fn get_response_body(response: HyperResponse) -> Result where - Response: DeserializeOwned, + Response: Serialize + DeserializeOwned + std::fmt::Debug, { let body_bytes = to_bytes(response.into_body()) .await .map_err(|e| ClientError::ResponseParsingFailure(Arc::new(e)))?; - deserialize(&body_bytes).map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e))) + + SerdeWrapper::::from_bincode(&body_bytes) + .map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e))) } // Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do diff --git a/crates/mempool_infra/src/component_definitions.rs b/crates/mempool_infra/src/component_definitions.rs index 5cd742a85f..6f0d0ee031 100644 --- a/crates/mempool_infra/src/component_definitions.rs +++ b/crates/mempool_infra/src/component_definitions.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::fmt::Debug; use std::net::IpAddr; use async_trait::async_trait; @@ -7,8 +8,10 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::mpsc::{Receiver, Sender}; +use tracing::error; use validator::Validate; +pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 32; const DEFAULT_RETRIES: usize = 3; @@ -45,8 +48,6 @@ where pub tx: Sender, } -pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; - #[derive(Debug, Error, Deserialize, Serialize, Clone)] pub enum ServerError { #[error("Could not deserialize client request: {0}")] diff --git a/crates/mempool_infra/src/component_server/remote_component_server.rs b/crates/mempool_infra/src/component_server/remote_component_server.rs index 88c09f7f57..2a1ce7a781 100644 --- a/crates/mempool_infra/src/component_server/remote_component_server.rs +++ b/crates/mempool_infra/src/component_server/remote_component_server.rs @@ -1,7 +1,7 @@ use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; use async_trait::async_trait; -use bincode::{deserialize, serialize}; use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; @@ -9,14 +9,16 @@ use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, St use serde::de::DeserializeOwned; use serde::Serialize; use tokio::sync::mpsc::Sender; +use tracing::instrument; use super::definitions::ComponentServerStarter; -use crate::component_client::send_locally; +use crate::component_client::{send_locally, ClientError}; use crate::component_definitions::{ ComponentRequestAndResponseSender, ServerError, APPLICATION_OCTET_STREAM, }; +use crate::serde_utils::{BincodeSerializable, SerdeWrapper}; /// The `RemoteComponentServer` struct is a generic server that handles requests and responses for a /// specified component. It receives requests, processes them using the provided component, and @@ -63,14 +65,14 @@ use crate::component_definitions::{ /// } /// /// // Define your request and response types -/// #[derive(Deserialize)] +/// #[derive(Serialize, Deserialize, Debug)] /// struct MyRequest { /// pub content: String, /// } /// -/// #[derive(Serialize)] +/// #[derive(Serialize, Deserialize, Debug)] /// struct MyResponse { -/// pub content: String, +/// content: String, /// } /// /// // Define your request processing logic @@ -101,8 +103,8 @@ use crate::component_definitions::{ /// ``` pub struct RemoteComponentServer where - Request: DeserializeOwned + Send + Sync + 'static, - Response: Serialize + Send + Sync + 'static, + Request: Serialize + DeserializeOwned + Send + Sync + 'static, + Response: Serialize + DeserializeOwned + Send + Sync + 'static, { socket: SocketAddr, tx: Sender>, @@ -110,8 +112,8 @@ where impl RemoteComponentServer where - Request: DeserializeOwned + Send + Sync + 'static, - Response: Serialize + Send + Sync + 'static, + Request: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static, + Response: Serialize + DeserializeOwned + std::fmt::Debug + Send + Sync + 'static, { pub fn new( tx: Sender>, @@ -121,25 +123,33 @@ where Self { tx, socket: SocketAddr::new(ip_address, port) } } - async fn handler( + #[instrument(ret, err)] + async fn remote_component_server_handler( http_request: HyperRequest, tx: Sender>, ) -> Result, hyper::Error> { let body_bytes = to_bytes(http_request.into_body()).await?; - let http_response = match deserialize(&body_bytes) { + + let http_response = match SerdeWrapper::::from_bincode(&body_bytes) + .map_err(|e| ClientError::ResponseDeserializationFailure(Arc::new(e))) + { Ok(request) => { let response = send_locally(tx, request).await; HyperResponse::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( - serialize(&response).expect("Response serialization should succeed"), + SerdeWrapper::new(response) + .to_bincode() + .expect("Response serialization should succeed"), )) } Err(error) => { let server_error = ServerError::RequestDeserializationFailure(error.to_string()); HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from( - serialize(&server_error).expect("Server error serialization should succeed"), + SerdeWrapper::new(server_error) + .to_bincode() + .expect("Server error serialization should succeed"), )) } } @@ -152,13 +162,17 @@ where #[async_trait] impl ComponentServerStarter for RemoteComponentServer where - Request: DeserializeOwned + Send + Sync + 'static, - Response: Serialize + Send + Sync + 'static, + Request: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static, + Response: Serialize + DeserializeOwned + Send + Sync + std::fmt::Debug + 'static, { async fn start(&mut self) { let make_svc = make_service_fn(|_conn| { let tx = self.tx.clone(); - async { Ok::<_, hyper::Error>(service_fn(move |req| Self::handler(req, tx.clone()))) } + async { + Ok::<_, hyper::Error>(service_fn(move |req| { + Self::remote_component_server_handler(req, tx.clone()) + })) + } }); Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap(); diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index c5177cc39f..a21ed7df20 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -2,30 +2,31 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use starknet_mempool_infra::component_client::ClientResult; use starknet_mempool_infra::component_runner::ComponentStarter; +use starknet_types_core::felt::Felt; -pub(crate) type ValueA = u32; -pub(crate) type ValueB = u8; +pub(crate) type ValueA = Felt; +pub(crate) type ValueB = Felt; pub(crate) type ResultA = ClientResult; pub(crate) type ResultB = ClientResult; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum ComponentARequest { AGetValue, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum ComponentAResponse { AGetValue(ValueA), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum ComponentBRequest { BGetValue, BSetValue(ValueB), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum ComponentBResponse { BGetValue(ValueB), BSetValue, @@ -52,8 +53,7 @@ impl ComponentA { } pub async fn a_get_value(&self) -> ValueA { - let b_value = self.b.b_get_value().await.unwrap(); - b_value.into() + self.b.b_get_value().await.unwrap() } } @@ -92,7 +92,7 @@ pub(crate) async fn test_a_b_functionality( let new_expected_value: ValueA = expected_value + 1; // Check that setting a new value to component B succeeds. - assert!(b_client.b_set_value(new_expected_value.try_into().unwrap()).await.is_ok()); + assert!(b_client.b_set_value(new_expected_value).await.is_ok()); // Check the new value in component B through client A. assert_eq!(a_client.a_get_value().await.unwrap(), new_expected_value); } diff --git a/crates/mempool_infra/tests/local_component_client_server_test.rs b/crates/mempool_infra/tests/local_component_client_server_test.rs index 09ff93f0f7..8a4204cb64 100644 --- a/crates/mempool_infra/tests/local_component_client_server_test.rs +++ b/crates/mempool_infra/tests/local_component_client_server_test.rs @@ -17,6 +17,7 @@ use starknet_mempool_infra::component_definitions::{ ComponentRequestHandler, }; use starknet_mempool_infra::component_server::{ComponentServerStarter, LocalComponentServer}; +use starknet_types_core::felt::Felt; use tokio::sync::mpsc::channel; use tokio::task; @@ -81,8 +82,8 @@ impl ComponentRequestHandler for Componen #[tokio::test] async fn test_setup() { - let setup_value: ValueB = 30; - let expected_value: ValueA = setup_value.into(); + let setup_value: ValueB = Felt::from(30); + let expected_value: ValueA = setup_value; let (tx_a, rx_a) = channel::>(32); diff --git a/crates/mempool_infra/tests/remote_component_client_server_test.rs b/crates/mempool_infra/tests/remote_component_client_server_test.rs index c1758079f9..f926038dc6 100644 --- a/crates/mempool_infra/tests/remote_component_client_server_test.rs +++ b/crates/mempool_infra/tests/remote_component_client_server_test.rs @@ -1,10 +1,10 @@ mod common; +use std::fmt::Debug; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use async_trait::async_trait; -use bincode::{deserialize, serialize}; use common::{ ComponentAClientTrait, ComponentARequest, @@ -21,6 +21,7 @@ use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use rstest::rstest; +use serde::de::DeserializeOwned; use serde::Serialize; use starknet_mempool_infra::component_client::{ClientError, ClientResult, RemoteComponentClient}; use starknet_mempool_infra::component_definitions::{ @@ -34,6 +35,8 @@ use starknet_mempool_infra::component_server::{ LocalComponentServer, RemoteComponentServer, }; +use starknet_mempool_infra::serde_utils::{BincodeSerializable, SerdeWrapper}; +use starknet_types_core::felt::Felt; use tokio::sync::mpsc::channel; use tokio::sync::Mutex; use tokio::task; @@ -59,7 +62,7 @@ const ARBITRARY_DATA: &str = "arbitrary data"; const DESERIALIZE_REQ_ERROR_MESSAGE: &str = "Could not deserialize client request"; // ClientError::ResponseDeserializationFailure error message. const DESERIALIZE_RES_ERROR_MESSAGE: &str = "Could not deserialize server response"; -const VALID_VALUE_A: ValueA = 1; +const VALID_VALUE_A: ValueA = Felt::ONE; #[async_trait] impl ComponentAClientTrait for RemoteComponentClient { @@ -133,16 +136,19 @@ fn assert_error_contains_keywords(error: String, expected_error_contained_keywor async fn create_client_and_faulty_server(port: u16, body: T) -> ComponentAClient where - T: Serialize + Send + Sync + 'static + Clone, + T: Serialize + DeserializeOwned + Debug + Send + Sync + 'static + Clone, { task::spawn(async move { - async fn handler( + async fn handler( _http_request: Request, body: T, - ) -> Result, hyper::Error> { + ) -> Result, hyper::Error> + where + T: Serialize + DeserializeOwned + Debug + Send + Sync + Clone, + { Ok(Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from(serialize(&body).unwrap())) + .body(Body::from(SerdeWrapper::new(body).to_bincode().unwrap())) .unwrap()) } @@ -200,18 +206,18 @@ async fn setup_for_tests(setup_value: ValueB, a_port: u16, b_port: u16) { #[tokio::test] async fn test_proper_setup() { - let setup_value: ValueB = 90; + let setup_value: ValueB = Felt::from(90); setup_for_tests(setup_value, A_PORT_TEST_SETUP, B_PORT_TEST_SETUP).await; let a_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP, MAX_RETRIES); let b_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP, MAX_RETRIES); - test_a_b_functionality(a_client, b_client, setup_value.into()).await; + test_a_b_functionality(a_client, b_client, setup_value).await; } #[tokio::test] async fn test_faulty_client_setup() { // Todo(uriel): Find a better way to pass expected value to the setup // 123 is some arbitrary value, we don't check it anyway. - setup_for_tests(123, A_PORT_FAULTY_CLIENT, B_PORT_FAULTY_CLIENT).await; + setup_for_tests(Felt::from(123), A_PORT_FAULTY_CLIENT, B_PORT_FAULTY_CLIENT).await; struct FaultyAClient; @@ -223,13 +229,12 @@ async fn test_faulty_client_setup() { format!("http://[{}]:{}/", LOCAL_IP, A_PORT_FAULTY_CLIENT).parse().unwrap(); let http_request = Request::post(uri) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) - .body(Body::from(serialize(&component_request).unwrap())) + .body(Body::from(SerdeWrapper::new(component_request).to_bincode().unwrap())) .unwrap(); let http_response = Client::new().request(http_request).await.unwrap(); let status_code = http_response.status(); let body_bytes = to_bytes(http_response.into_body()).await.unwrap(); - let response: ServerError = deserialize(&body_bytes).unwrap(); - + let response = SerdeWrapper::::from_bincode(&body_bytes).unwrap(); Err(ClientError::ResponseError(status_code, response)) } } @@ -256,7 +261,7 @@ async fn test_unconnected_server() { &[StatusCode::BAD_REQUEST.as_str(),DESERIALIZE_REQ_ERROR_MESSAGE, MOCK_SERVER_ERROR], )] #[case::response_deserialization_failure( - create_client_and_faulty_server(FAULTY_SERVER_RES_DESER_PORT,ARBITRARY_DATA).await, + create_client_and_faulty_server(FAULTY_SERVER_RES_DESER_PORT,ARBITRARY_DATA.to_string()).await, &[DESERIALIZE_RES_ERROR_MESSAGE], )] #[tokio::test] @@ -281,12 +286,12 @@ async fn test_retry_request() { let ret = if *should_send_ok { Response::builder() .status(StatusCode::OK) - .body(Body::from(serialize(&body).unwrap())) + .body(Body::from(SerdeWrapper::new(body).to_bincode().unwrap())) .unwrap() } else { Response::builder() .status(StatusCode::IM_A_TEAPOT) - .body(Body::from(serialize(&body).unwrap())) + .body(Body::from(SerdeWrapper::new(body).to_bincode().unwrap())) .unwrap() }; *should_send_ok = !*should_send_ok; @@ -316,6 +321,6 @@ async fn test_retry_request() { // The current server state is 'false', hence the first and only attempt returns an error. let a_client_no_retry = ComponentAClient::new(LOCAL_IP, RETRY_REQ_PORT, 0); - let expected_error_contained_keywords = [DESERIALIZE_RES_ERROR_MESSAGE]; + let expected_error_contained_keywords = [StatusCode::IM_A_TEAPOT.as_str()]; verify_error(a_client_no_retry.clone(), &expected_error_contained_keywords).await; } diff --git a/crates/mempool_types/src/communication.rs b/crates/mempool_types/src/communication.rs index 9f06d884b1..07b23d45d3 100644 --- a/crates/mempool_types/src/communication.rs +++ b/crates/mempool_types/src/communication.rs @@ -26,7 +26,7 @@ pub type MempoolRequestAndResponseSender = ComponentRequestAndResponseSender; pub type SharedMempoolClient = Arc; -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct MempoolWrapperInput { pub mempool_input: MempoolInput, pub message_metadata: Option, @@ -43,13 +43,13 @@ pub trait MempoolClient: Send + Sync { async fn get_txs(&self, n_txs: usize) -> MempoolClientResult>; } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum MempoolRequest { AddTransaction(MempoolWrapperInput), GetTransactions(usize), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum MempoolResponse { AddTransaction(MempoolResult<()>), GetTransactions(MempoolResult>),