Skip to content

Commit

Permalink
feat: adding server for active components
Browse files Browse the repository at this point in the history
commit-id:fa8f56cc
  • Loading branch information
lev-starkware committed Jul 29, 2024
1 parent 93de0bd commit f943d3d
Showing 3 changed files with 257 additions and 7 deletions.
20 changes: 20 additions & 0 deletions crates/mempool_infra/src/component_server/definitions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::{error, info};

use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

#[async_trait]
@@ -20,3 +22,21 @@ where
info!("ComponentServer::start() completed.");
true
}

pub async fn request_response_loop<Request, Response, Component>(
rx: &mut Receiver<ComponentRequestAndResponseSender<Request, Response>>,
component: &mut Component,
) where
Component: ComponentRequestHandler<Request, Response> + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
while let Some(request_and_res_tx) = rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;

let res = component.handle_request(request).await;

tx.send(res).await.expect("Response connection should be open.");
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::error;

use super::definitions::{start_component, ComponentServerStarter};
use super::definitions::{request_response_loop, start_component, ComponentServerStarter};
use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

@@ -137,14 +138,56 @@ where
{
async fn start(&mut self) {
if start_component(&mut self.component).await {
while let Some(request_and_res_tx) = self.rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;
request_response_loop(&mut self.rx, &mut self.component).await;
}
}
}

pub struct LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Component, Request, Response> LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
pub fn new(
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
) -> Self {
Self { component, rx }
}
}

let res = self.component.handle_request(request).await;
#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
async fn start(&mut self) {
let mut component = self.component.clone();
let component_future = async move { component.start().await };
let request_response_future = request_response_loop(&mut self.rx, &mut self.component);

tx.send(res).await.expect("Response connection should be open.");
tokio::select! {
_res = component_future => {
error!("Component stopped.");
}
}
_res = request_response_future => {
error!("Server stopped.");
}
};
error!("Server ended with unexpected Ok.");
}
}
187 changes: 187 additions & 0 deletions crates/mempool_infra/tests/active_component_server_client_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use std::future::pending;
use std::sync::Arc;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult};
use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient;
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler,
};
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use starknet_mempool_infra::component_server::empty_component_server::EmptyServer;
use starknet_mempool_infra::component_server::local_component_server::LocalActiveComponentServer;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::{Mutex, Barrier};
use tokio::task;

#[derive(Debug, Clone)]
struct ComponentC {
counter: Arc<Mutex<usize>>,
max_iterations: usize,
barrier: Arc<Barrier>,
}

impl ComponentC {
pub fn new(init_counter_value: usize, max_iterations: usize, barrier: Arc<Barrier>) -> Self {
Self {
counter: Arc::new(Mutex::new(init_counter_value)),
max_iterations,
barrier,
}
}

pub async fn c_get_counter(&self) -> usize {
*self.counter.lock().await
}

pub async fn c_increment_counter(&self) {
*self.counter.lock().await += 1;
}
}

#[async_trait]
impl ComponentStarter for ComponentC {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.c_increment_counter().await;
}
let val = self.c_get_counter().await;
assert!(val >= self.max_iterations);
self.barrier.wait().await;

// Mimicking real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCRequest {
CIncCounter,
CGetCounter,
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCResponse {
CIncCounter,
CGetCounter(usize),
}

#[async_trait]
trait ComponentCClientTrait: Send + Sync {
async fn c_inc_counter(&self) -> ClientResult<()>;
async fn c_get_counter(&self) -> ClientResult<usize>;
}

struct ComponentD {
c: Box<dyn ComponentCClientTrait>,
max_iterations: usize,
barrier: Arc<Barrier>,
}

impl ComponentD {
pub fn new(c: Box<dyn ComponentCClientTrait>, max_iterations: usize, barrier: Arc<Barrier>) -> Self {
Self { c, max_iterations, barrier }
}

pub async fn d_increment_counter(&self) {
self.c.c_inc_counter().await.unwrap()
}

pub async fn d_get_counter(&self) -> usize {
self.c.c_get_counter().await.unwrap()
}
}

#[async_trait]
impl ComponentStarter for ComponentD {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.d_increment_counter().await;
}
let val = self.d_get_counter().await;
assert!(val >= self.max_iterations);
self.barrier.wait().await;

// Mimicking real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[async_trait]
impl ComponentCClientTrait for LocalComponentClient<ComponentCRequest, ComponentCResponse> {
async fn c_inc_counter(&self) -> ClientResult<()> {
let res = self.send(ComponentCRequest::CIncCounter).await;
match res {
ComponentCResponse::CIncCounter => Ok(()),
_ => Err(ClientError::UnexpectedResponse),
}
}

async fn c_get_counter(&self) -> ClientResult<usize> {
let res = self.send(ComponentCRequest::CGetCounter).await;
match res {
ComponentCResponse::CGetCounter(counter) => Ok(counter),
_ => Err(ClientError::UnexpectedResponse),
}
}
}

#[async_trait]
impl ComponentRequestHandler<ComponentCRequest, ComponentCResponse> for ComponentC {
async fn handle_request(&mut self, request: ComponentCRequest) -> ComponentCResponse {
match request {
ComponentCRequest::CGetCounter => {
ComponentCResponse::CGetCounter(self.c_get_counter().await)
}
ComponentCRequest::CIncCounter => {
self.c_increment_counter().await;
ComponentCResponse::CIncCounter
}
}
}
}

async fn wait_and_verify_response(
tx_c: Sender<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>,
expected_counter_value: usize,
barrier: Arc<Barrier>,
) {
let c_client = LocalComponentClient::new(tx_c);

barrier.wait().await;
assert_eq!(c_client.c_get_counter().await.unwrap(), expected_counter_value);
}

#[tokio::test]
async fn test_setup_c_d() {
let init_counter_value: usize = 0;
let max_iterations: usize = 1024;
let expected_counter_value = max_iterations * 2;

let (tx_c, rx_c) =
channel::<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>(32);

let c_client = LocalComponentClient::new(tx_c.clone());

let barrier = Arc::new(Barrier::new(3));
let component_c = ComponentC::new(init_counter_value, max_iterations, barrier.clone());
let component_d = ComponentD::new(Box::new(c_client), max_iterations, barrier.clone());

let mut component_c_server = LocalActiveComponentServer::new(component_c, rx_c);
let mut component_d_server = EmptyServer::new(component_d);

task::spawn(async move {
component_c_server.start().await;
});

task::spawn(async move {
component_d_server.start().await;
});

// Wait for the components to finish incrementing of the ComponentC::counter and verify it.
wait_and_verify_response(tx_c.clone(), expected_counter_value, barrier).await;
}

0 comments on commit f943d3d

Please sign in to comment.