Skip to content

Commit

Permalink
make safer ser and deser functions for bytes and peer addresses
Browse files Browse the repository at this point in the history
Signed-off-by: onur-ozkan <[email protected]>
  • Loading branch information
onur-ozkan committed Sep 18, 2024
1 parent 7fccb15 commit 1777559
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 25 deletions.
117 changes: 101 additions & 16 deletions mm2src/mm2_main/src/lp_healthcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(crate) const PEER_HEALTHCHECK_PREFIX: TopicPrefix = "hcheck";
#[derive(Debug, Deserialize, Serialize)]
#[cfg_attr(any(test, target_arch = "wasm32"), derive(PartialEq))]
pub(crate) struct HealthcheckMessage {
#[serde(deserialize_with = "deserialize_bytes")]
signature: Vec<u8>,
data: HealthcheckData,
}
Expand All @@ -32,10 +33,9 @@ impl HealthcheckMessage {
expires_in_seconds: i64,
) -> Result<Self, String> {
let p2p_ctx = P2PContext::fetch_from_mm_arc(ctx);
let sender_peer = p2p_ctx.peer_id().to_string();
let sender_peer = p2p_ctx.peer_id();
let keypair = p2p_ctx.keypair();
let sender_public_key = keypair.public().encode_protobuf();
let target_peer = target_peer.to_string();

let data = HealthcheckData {
sender_peer,
Expand All @@ -60,7 +60,7 @@ impl HealthcheckMessage {
return false;
}

if self.data.target_peer != my_peer_id.to_string() {
if self.data.target_peer != my_peer_id {
log::debug!(
"`target_peer` doesn't match with our peer address. Our address: '{}', healthcheck `target_peer`: '{}'.",
my_peer_id,
Expand All @@ -75,7 +75,7 @@ impl HealthcheckMessage {
return false
};

if self.data.sender_peer != public_key.to_peer_id().to_string() {
if self.data.sender_peer != public_key.to_peer_id() {
log::debug!("`sender_peer` and `sender_public_key` doesn't belong each other.");

return false;
Expand Down Expand Up @@ -105,15 +105,18 @@ impl HealthcheckMessage {
pub(crate) fn should_reply(&self) -> bool { !self.data.is_a_reply }

#[inline]
pub(crate) fn sender_peer(&self) -> &str { &self.data.sender_peer }
pub(crate) fn sender_peer(&self) -> PeerId { self.data.sender_peer }
}

#[derive(Debug, Deserialize, Serialize)]
#[cfg_attr(any(test, target_arch = "wasm32"), derive(PartialEq))]
struct HealthcheckData {
sender_peer: String,
#[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")]
sender_peer: PeerId,
#[serde(deserialize_with = "deserialize_bytes")]
sender_public_key: Vec<u8>,
target_peer: String,
#[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")]
target_peer: PeerId,
expires_at: i64,
is_a_reply: bool,
}
Expand All @@ -130,21 +133,105 @@ pub fn peer_healthcheck_topic(peer_id: &PeerId) -> String {

#[derive(Deserialize)]
pub struct RequestPayload {
peer_id: String,
#[serde(deserialize_with = "deserialize_peer_id")]
peer_id: PeerId,
}

fn deserialize_peer_id<'de, D>(deserializer: D) -> Result<PeerId, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PeerIdVisitor;

impl<'de> serde::de::Visitor<'de> for PeerIdVisitor {
type Value = PeerId;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representation of peer id.")
}

fn visit_str<E>(self, value: &str) -> Result<PeerId, E>
where
E: serde::de::Error,
{
if value.len() > 100 {
return Err(serde::de::Error::invalid_length(
value.len(),
&"peer id cannot exceed 100 characters.",
));
}

PeerId::from_str(value).map_err(serde::de::Error::custom)
}

fn visit_string<E>(self, value: String) -> Result<PeerId, E>
where
E: serde::de::Error,
{
self.visit_str(&value)
}
}

deserializer.deserialize_str(PeerIdVisitor)
}

fn serialize_peer_id<S>(peer_id: &PeerId, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
s.serialize_str(&peer_id.to_string())
}

fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ByteVisitor;

impl<'de> serde::de::Visitor<'de> for ByteVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a non-empty byte array up to 512 bytes")
}

fn visit_seq<A>(self, mut seq: A) -> Result<Vec<u8>, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut buffer = vec![];
while let Some(byte) = seq.next_element()? {
if buffer.len() >= 512 {
return Err(serde::de::Error::invalid_length(
buffer.len(),
&"longest possible length allowed for this field is 512 bytes (with RSA algorithm).",
));
}

buffer.push(byte);
}

if buffer.is_empty() {
return Err(serde::de::Error::custom("Can't be empty."));
}

Ok(buffer)
}
}

deserializer.deserialize_seq(ByteVisitor)
}

#[derive(Debug, Display, Serialize, SerializeErrorType)]
#[serde(tag = "error_type", content = "error_data")]
pub enum HealthcheckRpcError {
InvalidPeerAddress { reason: String },
MessageGenerationFailed { reason: String },
MessageEncodingFailed { reason: String },
}

impl HttpStatusCode for HealthcheckRpcError {
fn status_code(&self) -> common::StatusCode {
match self {
HealthcheckRpcError::InvalidPeerAddress { .. } => StatusCode::BAD_REQUEST,
HealthcheckRpcError::MessageGenerationFailed { .. } | HealthcheckRpcError::MessageEncodingFailed { .. } => {
StatusCode::INTERNAL_SERVER_ERROR
},
Expand All @@ -163,11 +250,9 @@ pub async fn peer_connection_healthcheck_rpc(
let address_record_exp =
ADDRESS_RECORD_EXPIRATION.get_or_init(|| Duration::from_secs(ctx.healthcheck.config.timeout_secs));

let target_peer_id = PeerId::from_str(&req.peer_id)
.map_err(|e| HealthcheckRpcError::InvalidPeerAddress { reason: e.to_string() })?;
let target_peer_id = req.peer_id;

let p2p_ctx = P2PContext::fetch_from_mm_arc(&ctx);

if target_peer_id == p2p_ctx.peer_id() {
// That's us, so return true.
return Ok(true);
Expand Down Expand Up @@ -248,15 +333,15 @@ mod tests {
assert!(!message.is_received_message_valid(target_peer));

let mut message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap();
message.data.sender_peer += "0";
message.data.sender_peer = message.data.target_peer;
assert!(!message.is_received_message_valid(target_peer));

let mut message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap();
message.data.target_peer += "0";
message.data.target_peer = message.data.sender_peer;
assert!(!message.is_received_message_valid(target_peer));

let message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap();
assert!(!message.is_received_message_valid(PeerId::from_str(&message.data.sender_peer).unwrap()));
assert!(!message.is_received_message_valid(message.data.sender_peer));
});

cross_test!(test_expired_message, {
Expand Down
14 changes: 5 additions & 9 deletions mm2src/mm2_main/src/lp_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use mm2_metrics::{mm_label, mm_timing};
use mm2_net::p2p::P2PContext;
use serde::de;
use std::net::ToSocketAddrs;
use std::str::FromStr;

use crate::lp_healthcheck::{peer_healthcheck_topic, HealthcheckMessage};
use crate::{lp_healthcheck, lp_ordermatch, lp_stats, lp_swap};
Expand Down Expand Up @@ -242,7 +241,7 @@ async fn process_p2p_message(
bruteforce_shield.clear_expired_entries();
if bruteforce_shield
.insert(
sender_peer.clone(),
sender_peer.to_string(),
(),
Duration::from_millis(ctx.healthcheck.config.blocking_ms_for_per_address),
)
Expand All @@ -261,14 +260,11 @@ async fn process_p2p_message(

if data.should_reply() {
// Reply the message so they know we are healthy.
let target_peer_id = try_or_return!(
PeerId::from_str(&sender_peer),
format!("'{sender_peer}' is not a valid address")
);
let topic = peer_healthcheck_topic(&target_peer_id);

let topic = peer_healthcheck_topic(&sender_peer);

let msg = try_or_return!(
HealthcheckMessage::generate_message(&ctx, target_peer_id, true, 10),
HealthcheckMessage::generate_message(&ctx, sender_peer, true, 10),
"Couldn't generate the healthcheck message, this is very unusual!"
);

Expand All @@ -281,7 +277,7 @@ async fn process_p2p_message(
} else {
// The requested peer is healthy; signal the response channel.
let mut response_handler = ctx.healthcheck.response_handler.lock().await;
if let Some(tx) = response_handler.remove(&sender_peer) {
if let Some(tx) = response_handler.remove(&sender_peer.to_string()) {
if tx.send(()).is_err() {
log::error!("Result channel isn't present for peer '{sender_peer}'.");
};
Expand Down

0 comments on commit 1777559

Please sign in to comment.