Skip to content

Commit

Permalink
Migrate MessageHandler callback to SwarmCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma233 committed Nov 1, 2023
1 parent b475b55 commit f7f2141
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 229 deletions.
169 changes: 39 additions & 130 deletions core/src/message/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ use std::sync::Arc;
use async_recursion::async_recursion;
use async_trait::async_trait;

use super::CustomMessage;
use super::Message;
use super::MessagePayload;
use crate::dht::vnode::VirtualNode;
use crate::dht::Did;
use crate::dht::PeerRing;
use crate::error::Error;
use crate::error::Result;
use crate::message::ConnectNodeReport;
use crate::message::ConnectNodeSend;
Expand All @@ -37,44 +35,6 @@ pub mod storage;
/// Operator and Handler for Subring
pub mod subring;

/// Trait of message callback.
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
pub trait MessageCallback {
/// Message handler for custom message
async fn custom_message(
&self,
ctx: &MessagePayload,
msg: &CustomMessage,
) -> Vec<MessageHandlerEvent>;
/// Message handler for builtin message
async fn builtin_message(&self, ctx: &MessagePayload) -> Vec<MessageHandlerEvent>;
}

/// Trait of message validator.
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
pub trait MessageValidator {
/// Externality validator
async fn validate(&self, ctx: &MessagePayload) -> Option<String>;
}

/// Boxed Callback, for non-wasm, it should be Sized, Send and Sync.
#[cfg(not(feature = "wasm"))]
pub type CallbackFn = Box<dyn MessageCallback + Send + Sync>;

/// Boxed Callback
#[cfg(feature = "wasm")]
pub type CallbackFn = Box<dyn MessageCallback>;

/// Boxed Validator
#[cfg(not(feature = "wasm"))]
pub type ValidatorFn = Box<dyn MessageValidator + Send + Sync>;

/// Boxed Validator, for non-wasm, it should be Sized, Send and Sync.
#[cfg(feature = "wasm")]
pub type ValidatorFn = Box<dyn MessageValidator>;

type NextHop = Did;

/// MessageHandlerEvent that will be handled by Swarm.
Expand Down Expand Up @@ -124,10 +84,6 @@ pub enum MessageHandlerEvent {
#[derive(Clone)]
pub struct MessageHandler {
dht: Arc<PeerRing>,
/// CallbackFn implement `customMessage` and `builtin_message`.
callback: Arc<Option<CallbackFn>>,
/// A specific validator implement ValidatorFn.
validator: Arc<Option<ValidatorFn>>,
}

/// Generic trait for handle message ,inspired by Actor-Model.
Expand All @@ -140,54 +96,8 @@ pub trait HandleMsg<T> {

impl MessageHandler {
/// Create a new MessageHandler Instance.
pub fn new(
dht: Arc<PeerRing>,
callback: Option<CallbackFn>,
validator: Option<ValidatorFn>,
) -> Self {
Self {
dht,
callback: Arc::new(callback),
validator: Arc::new(validator),
}
}

/// Invoke callback, which will be call after builtin handler.
async fn invoke_callback(
&self,
payload: &MessagePayload,
message: &Message,
) -> Vec<MessageHandlerEvent> {
if let Some(ref cb) = *self.callback {
match message {
Message::CustomMessage(ref msg) => {
if self.dht.did == payload.transaction.destination {
tracing::debug!(
"INVOKE CUSTOM MESSAGE CALLBACK {}",
&payload.transaction.tx_id
);
return cb.custom_message(payload, msg).await;
}
}
_ => return cb.builtin_message(payload).await,
};
} else if let Message::CustomMessage(ref msg) = message {
if self.dht.did == payload.transaction.destination {
tracing::warn!("No callback registered, skip invoke_callback of {:?}", msg);
}
}
vec![]
}

/// Validate message.
async fn validate(&self, payload: &MessagePayload) -> Result<()> {
if let Some(ref v) = *self.validator {
v.validate(payload)
.await
.map(|info| Err(Error::InvalidMessage(info)))
.unwrap_or(Ok(()))?;
};
Ok(())
pub fn new(dht: Arc<PeerRing>) -> Self {
Self { dht }
}

/// Handle builtin message.
Expand All @@ -197,7 +107,6 @@ impl MessageHandler {
&self,
payload: &MessagePayload,
) -> Result<Vec<MessageHandlerEvent>> {
self.validate(payload).await?;
let message: Message = payload.transaction.data()?;

#[cfg(test)]
Expand All @@ -210,7 +119,7 @@ impl MessageHandler {
&message
);

let mut events = match &message {
let events = match &message {
Message::JoinDHT(ref msg) => self.handle(payload, msg).await,
Message::LeaveDHT(ref msg) => self.handle(payload, msg).await,
Message::ConnectNodeSend(ref msg) => self.handle(payload, msg).await,
Expand All @@ -228,10 +137,6 @@ impl MessageHandler {
Message::QueryForTopoInfoReport(ref msg) => self.handle(payload, msg).await,
}?;

tracing::debug!("INVOKE CALLBACK {}", &payload.transaction.tx_id);

events.extend(self.invoke_callback(payload, &message).await);

tracing::debug!("FINISH HANDLE MESSAGE {}", &payload.transaction.tx_id);
Ok(events)
}
Expand All @@ -249,14 +154,13 @@ pub mod tests {
use crate::ecc::SecretKey;
use crate::message::MessageVerificationExt;
use crate::message::PayloadSender;
use crate::swarm::callback::SwarmCallback;
use crate::swarm::Swarm;
use crate::tests::default::prepare_node_with_callback;
use crate::tests::default::prepare_node;
use crate::tests::manually_establish_connection;

#[derive(Clone)]
struct MessageCallbackInstance {
#[allow(clippy::type_complexity)]
handler_messages: Arc<Mutex<Vec<(Did, Vec<u8>)>>>,
struct SwarmCallbackInstance {
handler_messages: Mutex<Vec<(Did, Vec<u8>)>>,
}

#[tokio::test]
Expand All @@ -265,37 +169,42 @@ pub mod tests {
let key2 = SecretKey::random();

#[async_trait]
impl MessageCallback for MessageCallbackInstance {
async fn custom_message(
impl SwarmCallback for SwarmCallbackInstance {
async fn on_payload(
&self,
ctx: &MessagePayload,
msg: &CustomMessage,
) -> Vec<MessageHandlerEvent> {
self.handler_messages
.lock()
.await
.push((ctx.signer(), msg.0.clone()));
println!("{:?}, {:?}, {:?}", ctx, ctx.signer(), msg);
vec![]
}
payload: &MessagePayload,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let msg: Message = payload.transaction.data().map_err(Box::new)?;

match msg {
Message::CustomMessage(ref msg) => {
self.handler_messages
.lock()
.await
.push((payload.transaction.signer(), msg.0.clone()));
println!("{:?}, {:?}, {:?}", payload, payload.signer(), msg);
}
_ => {
println!("{:?}, {:?}", payload, payload.signer());
}
}

async fn builtin_message(&self, ctx: &MessagePayload) -> Vec<MessageHandlerEvent> {
println!("{:?}, {:?}", ctx, ctx.signer());
vec![]
Ok(())
}
}

let msg_callback1 = MessageCallbackInstance {
handler_messages: Arc::new(Mutex::new(vec![])),
};
let msg_callback2 = MessageCallbackInstance {
handler_messages: Arc::new(Mutex::new(vec![])),
};
let cb1: CallbackFn = Box::new(msg_callback1.clone());
let cb2: CallbackFn = Box::new(msg_callback2.clone());
let cb1 = Arc::new(SwarmCallbackInstance {
handler_messages: Mutex::new(vec![]),
});
let cb2 = Arc::new(SwarmCallbackInstance {
handler_messages: Mutex::new(vec![]),
});

let (node1, _path1) = prepare_node(key1).await;
let (node2, _path2) = prepare_node(key2).await;

let (node1, _path1) = prepare_node_with_callback(key1, Some(cb1)).await;
let (node2, _path2) = prepare_node_with_callback(key2, Some(cb2)).await;
node1.set_callback(cb1.clone()).unwrap();
node2.set_callback(cb2.clone()).unwrap();

manually_establish_connection(&node1, &node2).await;

Expand Down Expand Up @@ -346,12 +255,12 @@ pub mod tests {

sleep(Duration::from_secs(5)).await;

assert_eq!(msg_callback1.handler_messages.lock().await.as_slice(), &[
assert_eq!(cb1.handler_messages.lock().await.as_slice(), &[
(node2.did(), "Hello world 2 to 1 - 1".as_bytes().to_vec()),
(node2.did(), "Hello world 2 to 1 - 2".as_bytes().to_vec())
]);

assert_eq!(msg_callback2.handler_messages.lock().await.as_slice(), &[
assert_eq!(cb2.handler_messages.lock().await.as_slice(), &[
(node1.did(), "Hello world 1 to 2 - 1".as_bytes().to_vec()),
(node1.did(), "Hello world 1 to 2 - 2".as_bytes().to_vec()),
(node1.did(), "Hello world 1 to 2 - 3".as_bytes().to_vec())
Expand Down
3 changes: 0 additions & 3 deletions core/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ pub mod handlers;
pub use handlers::storage::ChordStorageInterface;
pub use handlers::storage::ChordStorageInterfaceCacheChecker;
pub use handlers::subring::SubringInterface;
pub use handlers::CallbackFn;
pub use handlers::HandleMsg;
pub use handlers::MessageCallback;
pub use handlers::MessageHandler;
pub use handlers::MessageHandlerEvent;
pub use handlers::ValidatorFn;

mod protocols;
pub use protocols::MessageRelay;
Expand Down
21 changes: 1 addition & 20 deletions core/src/swarm/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use std::sync::RwLock;

use crate::channels::Channel;
use crate::dht::PeerRing;
use crate::message::CallbackFn;
use crate::message::MessageHandler;
use crate::message::ValidatorFn;
use crate::session::SessionSk;
use crate::storage::PersistenceStorage;
use crate::swarm::callback::SharedSwarmCallback;
Expand All @@ -31,8 +29,6 @@ pub struct SwarmBuilder {
session_sk: SessionSk,
session_ttl: Option<usize>,
measure: Option<MeasureImpl>,
message_callback: Option<CallbackFn>,
message_validator: Option<ValidatorFn>,
callback: Option<SharedSwarmCallback>,
}

Expand All @@ -47,8 +43,6 @@ impl SwarmBuilder {
session_sk,
session_ttl: None,
measure: None,
message_callback: None,
message_validator: None,
callback: None,
}
}
Expand Down Expand Up @@ -78,18 +72,6 @@ impl SwarmBuilder {
self
}

/// Bind message callback function for Swarm.
pub fn message_callback(mut self, callback: CallbackFn) -> Self {
self.message_callback = Some(callback);
self
}

/// Bind message vilidator function implementation for Swarm.
pub fn message_validator(mut self, validator: ValidatorFn) -> Self {
self.message_validator = Some(validator);
self
}

/// Bind callback for Swarm.
pub fn callback(mut self, callback: SharedSwarmCallback) -> Self {
self.callback = Some(callback);
Expand All @@ -106,8 +88,7 @@ impl SwarmBuilder {
self.dht_storage,
));

let message_handler =
MessageHandler::new(dht.clone(), self.message_callback, self.message_validator);
let message_handler = MessageHandler::new(dht.clone());

let transport_event_channel = Channel::new();
let transport = Box::new(Transport::new(&self.ice_servers, self.external_address));
Expand Down
19 changes: 2 additions & 17 deletions core/src/tests/default/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::dht::Did;
use crate::dht::PeerRing;
use crate::ecc::SecretKey;
use crate::error::Result;
use crate::message::CallbackFn;
use crate::session::SessionSk;
use crate::storage::PersistenceStorage;
use crate::swarm::Swarm;
Expand All @@ -13,36 +12,22 @@ use crate::swarm::SwarmBuilder;
mod test_message_handler;
mod test_stabilization;

pub async fn prepare_node_with_callback(
key: SecretKey,
message_callback: Option<CallbackFn>,
) -> (Arc<Swarm>, String) {
pub async fn prepare_node(key: SecretKey) -> (Arc<Swarm>, String) {
let stun = "stun://stun.l.google.com:19302";
let path = PersistenceStorage::random_path("./tmp");
let storage = PersistenceStorage::new_with_path(path.as_str())
.await
.unwrap();

let session_sk = SessionSk::new_with_seckey(&key).unwrap();

let mut swarm_builder = SwarmBuilder::new(stun, storage, session_sk);

if let Some(callback) = message_callback {
swarm_builder = swarm_builder.message_callback(callback);
}

let swarm = Arc::new(swarm_builder.build());
let swarm = Arc::new(SwarmBuilder::new(stun, storage, session_sk).build());

println!("key: {:?}", key.to_string());
println!("did: {:?}", swarm.did());

(swarm, path)
}

pub async fn prepare_node(key: SecretKey) -> (Arc<Swarm>, String) {
prepare_node_with_callback(key, None).await
}

pub async fn gen_pure_dht(did: Did) -> Result<PeerRing> {
let db_path = PersistenceStorage::random_path("./tmp");
let db = PersistenceStorage::new_with_path(db_path.as_str()).await?;
Expand Down
Loading

0 comments on commit f7f2141

Please sign in to comment.