diff --git a/mm2src/mm2_main/src/lp_healthcheck.rs b/mm2src/mm2_main/src/lp_healthcheck.rs index 6a8cf2824a..437a438389 100644 --- a/mm2src/mm2_main/src/lp_healthcheck.rs +++ b/mm2src/mm2_main/src/lp_healthcheck.rs @@ -20,6 +20,7 @@ pub(crate) const PEER_HEALTHCHECK_PREFIX: TopicPrefix = "hcheck"; #[derive(Debug, Deserialize, Serialize)] #[cfg_attr(any(test, target_arch = "wasm32"), derive(PartialEq))] pub(crate) struct HealthcheckMessage { + #[serde(deserialize_with = "deserialize_bytes")] signature: Vec, data: HealthcheckData, } @@ -32,10 +33,9 @@ impl HealthcheckMessage { expires_in_seconds: i64, ) -> Result { let p2p_ctx = P2PContext::fetch_from_mm_arc(ctx); - let sender_peer = p2p_ctx.peer_id().to_string(); + let sender_peer = p2p_ctx.peer_id(); let keypair = p2p_ctx.keypair(); let sender_public_key = keypair.public().encode_protobuf(); - let target_peer = target_peer.to_string(); let data = HealthcheckData { sender_peer, @@ -60,7 +60,7 @@ impl HealthcheckMessage { return false; } - if self.data.target_peer != my_peer_id.to_string() { + if self.data.target_peer != my_peer_id { log::debug!( "`target_peer` doesn't match with our peer address. Our address: '{}', healthcheck `target_peer`: '{}'.", my_peer_id, @@ -75,7 +75,7 @@ impl HealthcheckMessage { return false }; - if self.data.sender_peer != public_key.to_peer_id().to_string() { + if self.data.sender_peer != public_key.to_peer_id() { log::debug!("`sender_peer` and `sender_public_key` doesn't belong each other."); return false; @@ -105,15 +105,18 @@ impl HealthcheckMessage { pub(crate) fn should_reply(&self) -> bool { !self.data.is_a_reply } #[inline] - pub(crate) fn sender_peer(&self) -> &str { &self.data.sender_peer } + pub(crate) fn sender_peer(&self) -> PeerId { self.data.sender_peer } } #[derive(Debug, Deserialize, Serialize)] #[cfg_attr(any(test, target_arch = "wasm32"), derive(PartialEq))] struct HealthcheckData { - sender_peer: String, + #[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")] + sender_peer: PeerId, + #[serde(deserialize_with = "deserialize_bytes")] sender_public_key: Vec, - target_peer: String, + #[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")] + target_peer: PeerId, expires_at: i64, is_a_reply: bool, } @@ -130,13 +133,98 @@ pub fn peer_healthcheck_topic(peer_id: &PeerId) -> String { #[derive(Deserialize)] pub struct RequestPayload { - peer_id: String, + #[serde(deserialize_with = "deserialize_peer_id")] + peer_id: PeerId, +} + +fn deserialize_peer_id<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + struct PeerIdVisitor; + + impl<'de> serde::de::Visitor<'de> for PeerIdVisitor { + type Value = PeerId; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string representation of PeerId") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + if value.len() > 100 { + return Err(serde::de::Error::invalid_length( + value.len(), + &"peer id cannot exceed 100 characters.", + )); + } + + PeerId::from_str(value).map_err(serde::de::Error::custom) + } + + fn visit_string(self, value: String) -> Result + where + E: serde::de::Error, + { + self.visit_str(&value) + } + } + + deserializer.deserialize_str(PeerIdVisitor) +} + +fn serialize_peer_id(peer_id: &PeerId, s: S) -> Result +where + S: serde::Serializer, +{ + s.serialize_str(&peer_id.to_string()) +} + +fn deserialize_bytes<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + struct ByteVisitor; + + impl<'de> serde::de::Visitor<'de> for ByteVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a non-empty byte array up to 512 bytes") + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let mut buffer = vec![]; + while let Some(byte) = seq.next_element()? { + if buffer.len() >= 512 { + return Err(serde::de::Error::invalid_length( + buffer.len(), + &"longest possible length allowed for this field is 512 bytes (with RSA algorithm).", + )); + } + + buffer.push(byte); + } + + if buffer.is_empty() { + return Err(serde::de::Error::custom("Can't be empty.")); + } + + Ok(buffer) + } + } + + deserializer.deserialize_seq(ByteVisitor) } #[derive(Debug, Display, Serialize, SerializeErrorType)] #[serde(tag = "error_type", content = "error_data")] pub enum HealthcheckRpcError { - InvalidPeerAddress { reason: String }, MessageGenerationFailed { reason: String }, MessageEncodingFailed { reason: String }, } @@ -144,7 +232,6 @@ pub enum HealthcheckRpcError { impl HttpStatusCode for HealthcheckRpcError { fn status_code(&self) -> common::StatusCode { match self { - HealthcheckRpcError::InvalidPeerAddress { .. } => StatusCode::BAD_REQUEST, HealthcheckRpcError::MessageGenerationFailed { .. } | HealthcheckRpcError::MessageEncodingFailed { .. } => { StatusCode::INTERNAL_SERVER_ERROR }, @@ -163,11 +250,9 @@ pub async fn peer_connection_healthcheck_rpc( let address_record_exp = ADDRESS_RECORD_EXPIRATION.get_or_init(|| Duration::from_secs(ctx.healthcheck.config.timeout_secs)); - let target_peer_id = PeerId::from_str(&req.peer_id) - .map_err(|e| HealthcheckRpcError::InvalidPeerAddress { reason: e.to_string() })?; + let target_peer_id = req.peer_id; let p2p_ctx = P2PContext::fetch_from_mm_arc(&ctx); - if target_peer_id == p2p_ctx.peer_id() { // That's us, so return true. return Ok(true); @@ -248,15 +333,15 @@ mod tests { assert!(!message.is_received_message_valid(target_peer)); let mut message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap(); - message.data.sender_peer += "0"; + message.data.sender_peer = message.data.target_peer; assert!(!message.is_received_message_valid(target_peer)); let mut message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap(); - message.data.target_peer += "0"; + message.data.target_peer = message.data.sender_peer; assert!(!message.is_received_message_valid(target_peer)); let message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap(); - assert!(!message.is_received_message_valid(PeerId::from_str(&message.data.sender_peer).unwrap())); + assert!(!message.is_received_message_valid(message.data.sender_peer)); }); cross_test!(test_expired_message, { diff --git a/mm2src/mm2_main/src/lp_network.rs b/mm2src/mm2_main/src/lp_network.rs index d7e53f8c74..1e73bc6b6e 100644 --- a/mm2src/mm2_main/src/lp_network.rs +++ b/mm2src/mm2_main/src/lp_network.rs @@ -37,7 +37,6 @@ use mm2_metrics::{mm_label, mm_timing}; use mm2_net::p2p::P2PContext; use serde::de; use std::net::ToSocketAddrs; -use std::str::FromStr; use crate::lp_healthcheck::{peer_healthcheck_topic, HealthcheckMessage}; use crate::{lp_healthcheck, lp_ordermatch, lp_stats, lp_swap}; @@ -242,7 +241,7 @@ async fn process_p2p_message( bruteforce_shield.clear_expired_entries(); if bruteforce_shield .insert( - sender_peer.clone(), + sender_peer.to_string(), (), Duration::from_millis(ctx.healthcheck.config.blocking_ms_for_per_address), ) @@ -261,14 +260,11 @@ async fn process_p2p_message( if data.should_reply() { // Reply the message so they know we are healthy. - let target_peer_id = try_or_return!( - PeerId::from_str(&sender_peer), - format!("'{sender_peer}' is not a valid address") - ); - let topic = peer_healthcheck_topic(&target_peer_id); + + let topic = peer_healthcheck_topic(&sender_peer); let msg = try_or_return!( - HealthcheckMessage::generate_message(&ctx, target_peer_id, true, 10), + HealthcheckMessage::generate_message(&ctx, sender_peer, true, 10), "Couldn't generate the healthcheck message, this is very unusual!" ); @@ -281,7 +277,7 @@ async fn process_p2p_message( } else { // The requested peer is healthy; signal the response channel. let mut response_handler = ctx.healthcheck.response_handler.lock().await; - if let Some(tx) = response_handler.remove(&sender_peer) { + if let Some(tx) = response_handler.remove(&sender_peer.to_string()) { if tx.send(()).is_err() { log::error!("Result channel isn't present for peer '{sender_peer}'."); };