-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adding server for active components
commit-id:fa8f56cc
- Loading branch information
1 parent
93de0bd
commit f943d3d
Showing
3 changed files
with
257 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
187 changes: 187 additions & 0 deletions
187
crates/mempool_infra/tests/active_component_server_client_test.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
use std::future::pending; | ||
use std::sync::Arc; | ||
|
||
use async_trait::async_trait; | ||
use serde::{Deserialize, Serialize}; | ||
use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult}; | ||
use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient; | ||
use starknet_mempool_infra::component_definitions::{ | ||
ComponentRequestAndResponseSender, ComponentRequestHandler, | ||
}; | ||
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter}; | ||
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter; | ||
use starknet_mempool_infra::component_server::empty_component_server::EmptyServer; | ||
use starknet_mempool_infra::component_server::local_component_server::LocalActiveComponentServer; | ||
use tokio::sync::mpsc::{channel, Sender}; | ||
use tokio::sync::{Mutex, Barrier}; | ||
use tokio::task; | ||
|
||
#[derive(Debug, Clone)] | ||
struct ComponentC { | ||
counter: Arc<Mutex<usize>>, | ||
max_iterations: usize, | ||
barrier: Arc<Barrier>, | ||
} | ||
|
||
impl ComponentC { | ||
pub fn new(init_counter_value: usize, max_iterations: usize, barrier: Arc<Barrier>) -> Self { | ||
Self { | ||
counter: Arc::new(Mutex::new(init_counter_value)), | ||
max_iterations, | ||
barrier, | ||
} | ||
} | ||
|
||
pub async fn c_get_counter(&self) -> usize { | ||
*self.counter.lock().await | ||
} | ||
|
||
pub async fn c_increment_counter(&self) { | ||
*self.counter.lock().await += 1; | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ComponentStarter for ComponentC { | ||
async fn start(&mut self) -> Result<(), ComponentStartError> { | ||
for _ in 0..self.max_iterations { | ||
self.c_increment_counter().await; | ||
} | ||
let val = self.c_get_counter().await; | ||
assert!(val >= self.max_iterations); | ||
self.barrier.wait().await; | ||
|
||
// Mimicking real start function that should not return. | ||
let () = pending().await; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[derive(Serialize, Deserialize, Debug)] | ||
pub enum ComponentCRequest { | ||
CIncCounter, | ||
CGetCounter, | ||
} | ||
|
||
#[derive(Serialize, Deserialize, Debug)] | ||
pub enum ComponentCResponse { | ||
CIncCounter, | ||
CGetCounter(usize), | ||
} | ||
|
||
#[async_trait] | ||
trait ComponentCClientTrait: Send + Sync { | ||
async fn c_inc_counter(&self) -> ClientResult<()>; | ||
async fn c_get_counter(&self) -> ClientResult<usize>; | ||
} | ||
|
||
struct ComponentD { | ||
c: Box<dyn ComponentCClientTrait>, | ||
max_iterations: usize, | ||
barrier: Arc<Barrier>, | ||
} | ||
|
||
impl ComponentD { | ||
pub fn new(c: Box<dyn ComponentCClientTrait>, max_iterations: usize, barrier: Arc<Barrier>) -> Self { | ||
Self { c, max_iterations, barrier } | ||
} | ||
|
||
pub async fn d_increment_counter(&self) { | ||
self.c.c_inc_counter().await.unwrap() | ||
} | ||
|
||
pub async fn d_get_counter(&self) -> usize { | ||
self.c.c_get_counter().await.unwrap() | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ComponentStarter for ComponentD { | ||
async fn start(&mut self) -> Result<(), ComponentStartError> { | ||
for _ in 0..self.max_iterations { | ||
self.d_increment_counter().await; | ||
} | ||
let val = self.d_get_counter().await; | ||
assert!(val >= self.max_iterations); | ||
self.barrier.wait().await; | ||
|
||
// Mimicking real start function that should not return. | ||
let () = pending().await; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ComponentCClientTrait for LocalComponentClient<ComponentCRequest, ComponentCResponse> { | ||
async fn c_inc_counter(&self) -> ClientResult<()> { | ||
let res = self.send(ComponentCRequest::CIncCounter).await; | ||
match res { | ||
ComponentCResponse::CIncCounter => Ok(()), | ||
_ => Err(ClientError::UnexpectedResponse), | ||
} | ||
} | ||
|
||
async fn c_get_counter(&self) -> ClientResult<usize> { | ||
let res = self.send(ComponentCRequest::CGetCounter).await; | ||
match res { | ||
ComponentCResponse::CGetCounter(counter) => Ok(counter), | ||
_ => Err(ClientError::UnexpectedResponse), | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ComponentRequestHandler<ComponentCRequest, ComponentCResponse> for ComponentC { | ||
async fn handle_request(&mut self, request: ComponentCRequest) -> ComponentCResponse { | ||
match request { | ||
ComponentCRequest::CGetCounter => { | ||
ComponentCResponse::CGetCounter(self.c_get_counter().await) | ||
} | ||
ComponentCRequest::CIncCounter => { | ||
self.c_increment_counter().await; | ||
ComponentCResponse::CIncCounter | ||
} | ||
} | ||
} | ||
} | ||
|
||
async fn wait_and_verify_response( | ||
tx_c: Sender<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>, | ||
expected_counter_value: usize, | ||
barrier: Arc<Barrier>, | ||
) { | ||
let c_client = LocalComponentClient::new(tx_c); | ||
|
||
barrier.wait().await; | ||
assert_eq!(c_client.c_get_counter().await.unwrap(), expected_counter_value); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_setup_c_d() { | ||
let init_counter_value: usize = 0; | ||
let max_iterations: usize = 1024; | ||
let expected_counter_value = max_iterations * 2; | ||
|
||
let (tx_c, rx_c) = | ||
channel::<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>(32); | ||
|
||
let c_client = LocalComponentClient::new(tx_c.clone()); | ||
|
||
let barrier = Arc::new(Barrier::new(3)); | ||
let component_c = ComponentC::new(init_counter_value, max_iterations, barrier.clone()); | ||
let component_d = ComponentD::new(Box::new(c_client), max_iterations, barrier.clone()); | ||
|
||
let mut component_c_server = LocalActiveComponentServer::new(component_c, rx_c); | ||
let mut component_d_server = EmptyServer::new(component_d); | ||
|
||
task::spawn(async move { | ||
component_c_server.start().await; | ||
}); | ||
|
||
task::spawn(async move { | ||
component_d_server.start().await; | ||
}); | ||
|
||
// Wait for the components to finish incrementing of the ComponentC::counter and verify it. | ||
wait_and_verify_response(tx_c.clone(), expected_counter_value, barrier).await; | ||
} |