From 24e0581ab6860347fbf78d39f1ee37e3c008fc2f Mon Sep 17 00:00:00 2001 From: Guy Nir Date: Sun, 27 Oct 2024 11:35:14 +0200 Subject: [PATCH] fix: update StreamHandler to use new StreamHashMap interface --- crates/papyrus_network/src/lib.rs | 2 +- crates/papyrus_protobuf/src/consensus.rs | 11 +- .../src/converters/consensus.rs | 33 +- .../src/converters/consensus_test.rs | 5 +- .../src/converters/test_instances.rs | 29 +- .../src/proto/p2p/proto/consensus.proto | 2 +- .../papyrus_consensus/src/stream_handler.rs | 155 ++++++-- .../src/stream_handler_test.rs | 368 ++++++++++++++---- 8 files changed, 484 insertions(+), 121 deletions(-) diff --git a/crates/papyrus_network/src/lib.rs b/crates/papyrus_network/src/lib.rs index 5c0b33a7899..ae72a9612ae 100644 --- a/crates/papyrus_network/src/lib.rs +++ b/crates/papyrus_network/src/lib.rs @@ -13,7 +13,7 @@ mod peer_manager; mod sqmr; #[cfg(test)] mod test_utils; -mod utils; +pub mod utils; use std::collections::BTreeMap; use std::time::Duration; diff --git a/crates/papyrus_protobuf/src/consensus.rs b/crates/papyrus_protobuf/src/consensus.rs index 7dbd661ab9d..bce2a6e8755 100644 --- a/crates/papyrus_protobuf/src/consensus.rs +++ b/crates/papyrus_protobuf/src/consensus.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use futures::channel::{mpsc, oneshot}; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; @@ -54,9 +56,12 @@ pub enum StreamMessageBody { } #[derive(Debug, Clone, Hash, Eq, PartialEq)] -pub struct StreamMessage> + TryFrom, Error = ProtobufConversionError>> { +pub struct StreamMessage< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + Clone, +> { pub message: StreamMessageBody, - pub stream_id: u64, + pub stream_id: StreamId, pub message_id: u64, } @@ -99,7 +104,7 @@ pub enum ProposalPart { Fin(ProposalFin), } -impl std::fmt::Display for StreamMessage +impl> + Clone + Display> std::fmt::Display for StreamMessage where T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, { diff --git a/crates/papyrus_protobuf/src/converters/consensus.rs b/crates/papyrus_protobuf/src/converters/consensus.rs index 32463367c5f..5eb37f120e6 100644 --- a/crates/papyrus_protobuf/src/converters/consensus.rs +++ b/crates/papyrus_protobuf/src/converters/consensus.rs @@ -1,6 +1,7 @@ #[cfg(test)] #[path = "consensus_test.rs"] mod consensus_test; + use std::convert::{TryFrom, TryInto}; use prost::Message; @@ -125,8 +126,10 @@ impl From for protobuf::Vote { auto_impl_into_and_try_from_vec_u8!(Vote, protobuf::Vote); -impl> + TryFrom, Error = ProtobufConversionError>> - TryFrom for StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Clone, +> TryFrom for StreamMessage { type Error = ProtobufConversionError; @@ -147,16 +150,18 @@ impl> + TryFrom, Error = ProtobufConversionError>> StreamMessageBody::Fin } }, - stream_id: value.stream_id, + stream_id: value.stream_id.into(), message_id: value.message_id, }) } } -impl> + TryFrom, Error = ProtobufConversionError>> From> - for protobuf::StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Clone, +> From> for protobuf::StreamMessage { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { Self { message: match value { StreamMessage { @@ -168,7 +173,7 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< Some(protobuf::stream_message::Message::Fin(protobuf::Fin {})) } }, - stream_id: value.stream_id, + stream_id: value.stream_id.into(), message_id: value.message_id, } } @@ -177,17 +182,21 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< // Can't use auto_impl_into_and_try_from_vec_u8!(StreamMessage, protobuf::StreamMessage); // because it doesn't seem to work with generics. // TODO(guyn): consider expanding the macro to support generics -impl> + TryFrom, Error = ProtobufConversionError>> From> - for Vec +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Clone, +> From> for Vec { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { let protobuf_value = ::from(value); protobuf_value.encode_to_vec() } } -impl> + TryFrom, Error = ProtobufConversionError>> TryFrom> - for StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Clone, +> TryFrom> for StreamMessage { type Error = ProtobufConversionError; fn try_from(value: Vec) -> Result { diff --git a/crates/papyrus_protobuf/src/converters/consensus_test.rs b/crates/papyrus_protobuf/src/converters/consensus_test.rs index af6c070c2f6..d018a7df6b0 100644 --- a/crates/papyrus_protobuf/src/converters/consensus_test.rs +++ b/crates/papyrus_protobuf/src/converters/consensus_test.rs @@ -22,6 +22,7 @@ use crate::consensus::{ TransactionBatch, Vote, }; +use crate::converters::test_instances::StreamId; // If all the fields of `AllResources` are 0 upon serialization, // then the deserialized value will be interpreted as the `L1Gas` variant. @@ -52,7 +53,7 @@ fn convert_stream_message_to_vec_u8_and_back() { let mut rng = get_rng(); // Test that we can convert a StreamMessage with a ConsensusMessage message to bytes and back. - let mut stream_message: StreamMessage = + let mut stream_message: StreamMessage = StreamMessage::get_test_instance(&mut rng); if let StreamMessageBody::Content(ConsensusMessage::Proposal(proposal)) = @@ -159,7 +160,7 @@ fn convert_proposal_part_to_vec_u8_and_back() { #[test] fn stream_message_display() { let mut rng = get_rng(); - let stream_id = 42; + let stream_id = StreamId(42); let message_id = 127; let proposal = Proposal::get_test_instance(&mut rng); let proposal_bytes: Vec = proposal.clone().into(); diff --git a/crates/papyrus_protobuf/src/converters/test_instances.rs b/crates/papyrus_protobuf/src/converters/test_instances.rs index 57a787320d0..81821c6be2f 100644 --- a/crates/papyrus_protobuf/src/converters/test_instances.rs +++ b/crates/papyrus_protobuf/src/converters/test_instances.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use papyrus_test_utils::{auto_impl_get_test_instance, get_number_of_variants, GetTestInstance}; use rand::Rng; use starknet_api::block::BlockHash; @@ -61,15 +63,38 @@ auto_impl_get_test_instance! { } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StreamId(pub u64); + +impl Into> for StreamId { + fn into(self) -> Vec { + self.0.to_be_bytes().to_vec() + } +} + +impl From> for StreamId { + fn from(bytes: Vec) -> Self { + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + StreamId(u64::from_be_bytes(array)) + } +} + +impl Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamId({})", self.0) + } +} + // The auto_impl_get_test_instance macro does not work for StreamMessage because it has // a generic type. TODO(guyn): try to make the macro work with generic types. -impl GetTestInstance for StreamMessage { +impl GetTestInstance for StreamMessage { fn get_test_instance(rng: &mut rand_chacha::ChaCha8Rng) -> Self { let message = if rng.gen_bool(0.5) { StreamMessageBody::Content(ConsensusMessage::Proposal(Proposal::get_test_instance(rng))) } else { StreamMessageBody::Fin }; - Self { message, stream_id: 12, message_id: 47 } + Self { message, stream_id: StreamId(12), message_id: 47 } } } diff --git a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto index 8c31067f97e..e198fc8094d 100644 --- a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto +++ b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto @@ -41,7 +41,7 @@ message StreamMessage { bytes content = 1; Fin fin = 2; } - uint64 stream_id = 3; + bytes stream_id = 3; uint64 message_id = 4; } diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index f2cee108dc8..776c68c51d8 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -1,12 +1,20 @@ //! Stream handler, see StreamManager struct. + use std::cmp::Ordering; use std::collections::btree_map::Entry as BTreeEntry; use std::collections::hash_map::Entry as HashMapEntry; use std::collections::{BTreeMap, HashMap}; +use std::fmt::{Debug, Display}; +use std::hash::Hash; use futures::channel::mpsc; use futures::StreamExt; -use papyrus_network::network_manager::BroadcastTopicServer; +use papyrus_network::network_manager::{ + BroadcastTopicClient, + BroadcastTopicClientTrait, + BroadcastTopicServer, +}; +use papyrus_network::utils::StreamHashMap; use papyrus_network_types::network_types::{BroadcastedMessageMetadata, OpaquePeerId}; use papyrus_protobuf::consensus::{StreamMessage, StreamMessageBody}; use papyrus_protobuf::converters::ProtobufConversionError; @@ -17,25 +25,31 @@ use tracing::{instrument, warn}; mod stream_handler_test; type PeerId = OpaquePeerId; +// type StreamId = u64; type MessageId = u64; -type StreamKey = (PeerId, u64); +// type StreamKey = (PeerId, StreamId); const CHANNEL_BUFFER_LENGTH: usize = 100; #[derive(Debug, Clone)] -struct StreamData> + TryFrom, Error = ProtobufConversionError>> { +struct StreamData< + T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Eq + Hash + Clone + Unpin + Display + Debug, +> { next_message_id: MessageId, - // The message_id of the message that is marked as "fin" (the last message), - // if None, it means we have not yet gotten to it. + // Last message ID. If None, it means we have not yet gotten to it. fin_message_id: Option, max_message_id_received: MessageId, - // The sender that corresponds to the receiver that was sent out for this stream. sender: mpsc::Sender, // A buffer for messages that were received out of order. - message_buffer: BTreeMap>, + message_buffer: BTreeMap>, } -impl> + TryFrom, Error = ProtobufConversionError>> StreamData { +impl< + T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Eq + Hash + Clone + Unpin + Display + Debug, +> StreamData +{ fn new(sender: mpsc::Sender) -> Self { StreamData { next_message_id: 0, @@ -47,41 +61,88 @@ impl> + TryFrom, Error = ProtobufConversionError } } -/// A StreamHandler is responsible for buffering and sending messages in order. +/// A StreamHandler is responsible for: +/// - Buffering inbound messages and reporting them to the application in order. +/// - Sending outbound messages to the network, wrapped in StreamMessage. pub struct StreamHandler< T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Eq + Hash + Clone + Unpin + Display + Debug, > { - // An end of a channel used to send out receivers, one for each stream. + // For each stream ID from the network, send the application a Receiver + // that will receive the messages in order. This allows sending such Receivers. inbound_channel_sender: mpsc::Sender>, - // An end of a channel used to receive messages. - inbound_receiver: BroadcastTopicServer>, - // A map from stream_id to a struct that contains all the information about the stream. - // This includes both the message buffer and some metadata (like the latest message_id). - inbound_stream_data: HashMap>, - // TODO(guyn): perhaps make input_stream_data and output_stream_data? + // This receives messages from the network. + inbound_receiver: BroadcastTopicServer>, + // A map from (peer_id, stream_id) to a struct that contains all the information + // about the stream. This includes both the message buffer and some metadata + + // (like the latest message ID). + inbound_stream_data: HashMap<(PeerId, StreamId), StreamData>, + // Whenever application wants to start a new stream, it must send out a + // (stream_id, Receiver) pair. Each receiver gets messages that should + // be sent out to the network. + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + // A map where the abovementioned Receivers are stored. + outbound_stream_receivers: StreamHashMap>, + // A network sender that allows sending StreamMessages to peers. + outbound_sender: BroadcastTopicClient>, + // For each stream, keep track of the message_id of the last message sent. + outbound_stream_number: HashMap, } -impl> + TryFrom, Error = ProtobufConversionError>> - StreamHandler +impl< + T: Clone + Send + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + From> + Eq + Hash + Clone + Unpin + Send + Display + Debug, +> StreamHandler { /// Create a new StreamHandler. pub fn new( inbound_channel_sender: mpsc::Sender>, - inbound_receiver: BroadcastTopicServer>, + inbound_receiver: BroadcastTopicServer>, + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + outbound_sender: BroadcastTopicClient>, ) -> Self { - StreamHandler { + Self { inbound_channel_sender, inbound_receiver, inbound_stream_data: HashMap::new(), + outbound_channel_receiver, + outbound_sender, + outbound_stream_receivers: StreamHashMap::new(HashMap::new()), + outbound_stream_number: HashMap::new(), } } - /// Listen for messages on the receiver channel, buffering them if necessary. - /// Guarantees that messages are sent in order. + /// Listen for messages coming from the network and from the application. + /// - Outbound messages are wrapped as StreamMessage and sent to the network directly. + /// - Inbound messages are stripped of StreamMessage and buffered until they can be sent in the + /// correct order to the application. + #[instrument(skip_all)] pub async fn run(&mut self) { loop { - // TODO(guyn): this select is here to allow us to add the outbound flow. tokio::select!( + // Go over the channel receiver to see if there is a new channel. + Some((stream_id, receiver)) = self.outbound_channel_receiver.next() => { + self.outbound_stream_receivers.insert(stream_id, receiver); + } + // Go over all existing outbound receivers to see if there are any messages. + output = self.outbound_stream_receivers.next() => { + match output { + Some((key, Some(message))) => { + self.broadcast(key, message).await; + } + Some((key, None)) => { + self.broadcast_fin(key).await; + } + None => { + warn!( + "StreamHashMap should not be closed! \ + Usually only the individual channels are closed. " + ) + } + } + } + // Check if there is an inbound message from the network. Some(message) = self.inbound_receiver.next() => { self.handle_message(message); } @@ -89,7 +150,7 @@ impl> + TryFrom, Error = ProtobufConversionError } } - fn inbound_send(data: &mut StreamData, message: StreamMessage) { + fn inbound_send(data: &mut StreamData, message: StreamMessage) { // TODO(guyn): reconsider the "expect" here. let sender = &mut data.sender; if let StreamMessageBody::Content(content) = message.message { @@ -98,10 +159,40 @@ impl> + TryFrom, Error = ProtobufConversionError } } + // Send the message to the network. + async fn broadcast(&mut self, stream_id: StreamId, message: T) { + let message = StreamMessage { + message: StreamMessageBody::Content(message), + stream_id: stream_id.clone(), + message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), + }; + // TODO(guyn): reconsider the "expect" here. + self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); + self.outbound_stream_number.insert( + stream_id.clone(), + self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1, + ); + } + + // Send a fin message to the network. + async fn broadcast_fin(&mut self, stream_id: StreamId) { + let message = StreamMessage { + message: StreamMessageBody::Fin, + stream_id: stream_id.clone(), + message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), + }; + self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); + self.outbound_stream_number.remove(&stream_id); + } + + // Handle a message that was received from the network. #[instrument(skip_all, level = "warn")] fn handle_message( &mut self, - message: (Result, ProtobufConversionError>, BroadcastedMessageMetadata), + message: ( + Result, ProtobufConversionError>, + BroadcastedMessageMetadata, + ), ) { let (message, metadata) = message; let message = match message { @@ -112,7 +203,7 @@ impl> + TryFrom, Error = ProtobufConversionError } }; let peer_id = metadata.originator_id; - let stream_id = message.stream_id; + let stream_id = message.stream_id.clone(); let key = (peer_id, stream_id); let message_id = message.message_id; @@ -134,7 +225,7 @@ impl> + TryFrom, Error = ProtobufConversionError data.max_message_id_received = message_id; } - // Check for Fin type message + // Check for Fin type message. match message.message { StreamMessageBody::Content(_) => {} StreamMessageBody::Fin => { @@ -167,7 +258,6 @@ impl> + TryFrom, Error = ProtobufConversionError match message_id.cmp(&data.next_message_id) { Ordering::Equal => { Self::inbound_send(data, message); - Self::process_buffer(data); if data.message_buffer.is_empty() && data.fin_message_id.is_some() { @@ -190,7 +280,12 @@ impl> + TryFrom, Error = ProtobufConversionError } } - fn store(data: &mut StreamData, key: StreamKey, message: StreamMessage) { + // Store an inbound message in the buffer. + fn store( + data: &mut StreamData, + key: (PeerId, StreamId), + message: StreamMessage, + ) { let message_id = message.message_id; match data.message_buffer.entry(message_id) { @@ -209,7 +304,7 @@ impl> + TryFrom, Error = ProtobufConversionError // Tries to drain as many messages as possible from the buffer (in order), // DOES NOT guarantee that the buffer will be empty after calling this function. - fn process_buffer(data: &mut StreamData) { + fn process_buffer(data: &mut StreamData) { while let Some(message) = data.message_buffer.remove(&data.next_message_id) { Self::inbound_send(data, message); } diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 962ead1e9eb..1d27a081ed8 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::time::Duration; use futures::channel::mpsc; @@ -9,28 +10,64 @@ use papyrus_network::network_manager::test_utils::{ TestSubscriberChannels, }; use papyrus_network::network_manager::BroadcastTopicChannels; +use papyrus_network_types::network_types::BroadcastedMessageMetadata; use papyrus_protobuf::consensus::{ConsensusMessage, Proposal, StreamMessage, StreamMessageBody}; use papyrus_test_utils::{get_rng, GetTestInstance}; -use super::StreamHandler; +use super::{MessageId, StreamHandler}; -#[cfg(test)] -mod tests { +const TIMEOUT: Duration = Duration::from_millis(100); +const CHANNEL_SIZE: usize = 100; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct StreamId(u64); + +impl Into> for StreamId { + fn into(self) -> Vec { + self.0.to_be_bytes().to_vec() + } +} - use papyrus_network_types::network_types::BroadcastedMessageMetadata; +impl From> for StreamId { + fn from(bytes: Vec) -> Self { + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + StreamId(u64::from_be_bytes(array)) + } +} + +impl PartialOrd for StreamId { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} +impl Ord for StreamId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamId({})", self.0) + } +} + +#[cfg(test)] +mod tests { use super::*; fn make_test_message( stream_id: u64, - message_id: u64, + message_id: MessageId, fin: bool, - ) -> StreamMessage { + ) -> StreamMessage { let content = match fin { true => StreamMessageBody::Fin, false => StreamMessageBody::Content(ConsensusMessage::Proposal(Proposal::default())), }; - StreamMessage { message: content, stream_id, message_id } + StreamMessage { message: content, stream_id: StreamId(stream_id), message_id } } // Check if two vectors are the same: @@ -40,39 +77,85 @@ mod tests { } async fn send( - sender: &mut MockBroadcastedMessagesSender>, + sender: &mut MockBroadcastedMessagesSender>, metadata: &BroadcastedMessageMetadata, - msg: StreamMessage, + msg: StreamMessage, ) { sender.send((msg, metadata.clone())).await.unwrap(); } #[allow(clippy::type_complexity)] fn setup_test() -> ( - StreamHandler, - MockBroadcastedMessagesSender>, + StreamHandler, + MockBroadcastedMessagesSender>, mpsc::Receiver>, BroadcastedMessageMetadata, + mpsc::Sender<(StreamId, mpsc::Receiver)>, + futures::stream::Map< + mpsc::Receiver>, + fn(Vec) -> StreamMessage, + >, ) { + // The outbound_sender is the network connector for broadcasting messages. + // The network_broadcast_receiver is used to catch those messages in the test. + let TestSubscriberChannels { mock_network: mock_broadcast_network, subscriber_channels } = + mock_register_broadcast_topic().unwrap(); + let BroadcastTopicChannels { + broadcasted_messages_receiver: _, + broadcast_topic_client: outbound_sender, + } = subscriber_channels; + + let network_broadcast_receiver = mock_broadcast_network.messages_to_broadcast_receiver; + + // This is used to feed receivers of messages to StreamHandler for broadcasting. + // The receiver goes into StreamHandler, sender is used by the test (as mock Consensus). + // Note that each new channel comes in a tuple with (stream_id, receiver). + let (outbound_channel_sender, outbound_channel_receiver) = + mpsc::channel::<(StreamId, mpsc::Receiver)>(CHANNEL_SIZE); + + // The network_sender_to_inbound is the sender of the mock network, that is used by the + // test to send messages into the StreamHandler (from the mock network). let TestSubscriberChannels { mock_network, subscriber_channels } = mock_register_broadcast_topic().unwrap(); - let network_sender = mock_network.broadcasted_messages_sender; - let BroadcastTopicChannels { broadcasted_messages_receiver, broadcast_topic_client: _ } = - subscriber_channels; + let network_sender_to_inbound = mock_network.broadcasted_messages_sender; + + // The inbound_receiver is given to StreamHandler to inbound to mock network messages. + let BroadcastTopicChannels { + broadcasted_messages_receiver: inbound_receiver, + broadcast_topic_client: _, + } = subscriber_channels; + + // The inbound_channel_sender is given to StreamHandler so it can output new channels for + // each stream. The inbound_channel_receiver is given to the "mock consensus" that + // gets new channels and inbounds to them. + let (inbound_channel_sender, inbound_channel_receiver) = + mpsc::channel::>(CHANNEL_SIZE); // TODO(guyn): We should also give the broadcast_topic_client to the StreamHandler - let (tx_output, rx_output) = mpsc::channel::>(100); - let handler = StreamHandler::new(tx_output, broadcasted_messages_receiver); + // This will allow reporting to the network things like bad peers. + let handler = StreamHandler::new( + inbound_channel_sender, + inbound_receiver, + outbound_channel_receiver, + outbound_sender, + ); - let broadcasted_message_metadata = - BroadcastedMessageMetadata::get_test_instance(&mut get_rng()); + let inbound_metadata = BroadcastedMessageMetadata::get_test_instance(&mut get_rng()); - (handler, network_sender, rx_output, broadcasted_message_metadata) + ( + handler, + network_sender_to_inbound, + inbound_channel_receiver, + inbound_metadata, + outbound_channel_sender, + network_broadcast_receiver, + ) } #[tokio::test] - async fn stream_handler_in_order() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); + async fn inbound_in_order() { + let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = + setup_test(); let stream_id = 127; for i in 0..10 { @@ -81,12 +164,12 @@ mod tests { } let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; }); join_handle.await.expect("Task should succeed"); - let mut receiver = rx_output.next().await.unwrap(); + let mut receiver = inbound_channel_receiver.next().await.unwrap(); for _ in 0..9 { // message number 9 is Fin, so it will not be sent! let _ = receiver.next().await.unwrap(); @@ -96,33 +179,44 @@ mod tests { } #[tokio::test] - async fn stream_handler_in_reverse() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); - let peer_id = metadata.originator_id.clone(); + async fn inbound_in_reverse() { + let ( + mut stream_handler, + mut network_sender, + mut inbound_channel_receiver, + inbound_metadata, + _, + _, + ) = setup_test(); + let peer_id = inbound_metadata.originator_id.clone(); let stream_id = 127; for i in 0..5 { let message = make_test_message(stream_id, 5 - i, i == 0); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); let mut stream_handler = join_handle.await.expect("Task should succeed"); // Get the receiver for the stream. - let mut receiver = rx_output.next().await.unwrap(); + let mut receiver = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver.try_next().is_err()); assert_eq!(stream_handler.inbound_stream_data.len(), 1); assert_eq!( - stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id)].message_buffer.len(), + stream_handler.inbound_stream_data[&(peer_id.clone(), StreamId(stream_id))] + .message_buffer + .len(), 5 ); let range: Vec = (1..6).collect(); - let keys: Vec = stream_handler.inbound_stream_data[&(peer_id, stream_id)] + let keys: Vec = stream_handler.inbound_stream_data[&(peer_id, StreamId(stream_id))] .clone() .message_buffer .into_keys() @@ -130,9 +224,11 @@ mod tests { assert!(do_vecs_match(&keys, &range)); // Now send the last message: - send(&mut network_sender, &metadata, make_test_message(stream_id, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); @@ -148,9 +244,16 @@ mod tests { } #[tokio::test] - async fn stream_handler_multiple_streams() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); - let peer_id = metadata.originator_id.clone(); + async fn inbound_multiple_streams() { + let ( + mut stream_handler, + mut network_sender, + mut inbound_channel_receiver, + inbound_metadata, + _, + _, + ) = setup_test(); + let peer_id = inbound_metadata.originator_id.clone(); let stream_id1 = 127; // Send all messages in order (except the first one). let stream_id2 = 10; // Send in reverse order (except the first one). @@ -158,30 +261,36 @@ mod tests { for i in 1..10 { let message = make_test_message(stream_id1, i, i == 9); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } for i in 0..5 { let message = make_test_message(stream_id2, 5 - i, i == 0); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } for i in 5..10 { let message = make_test_message(stream_id3, i, false); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + for i in 1..5 { let message = make_test_message(stream_id3, i, false); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); let mut stream_handler = join_handle.await.expect("Task should succeed"); - let values = vec![(peer_id.clone(), 1), (peer_id.clone(), 10), (peer_id.clone(), 127)]; + let values = [ + (peer_id.clone(), StreamId(1)), + (peer_id.clone(), StreamId(10)), + (peer_id.clone(), StreamId(127)), + ]; assert!( stream_handler .inbound_stream_data @@ -192,7 +301,7 @@ mod tests { // We have all message from 1 to 9 buffered. assert!(do_vecs_match( - &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id1)] + &stream_handler.inbound_stream_data[&(peer_id.clone(), StreamId(stream_id1))] .message_buffer .clone() .into_keys() @@ -202,7 +311,7 @@ mod tests { // We have all message from 1 to 5 buffered. assert!(do_vecs_match( - &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id2)] + &stream_handler.inbound_stream_data[&(peer_id.clone(), StreamId(stream_id2))] .message_buffer .clone() .into_keys() @@ -212,7 +321,7 @@ mod tests { // We have all message from 1 to 5 buffered. assert!(do_vecs_match( - &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id3)] + &stream_handler.inbound_stream_data[&(peer_id.clone(), StreamId(stream_id3))] .message_buffer .clone() .into_keys() @@ -221,43 +330,41 @@ mod tests { )); // Get the receiver for the first stream. - let mut receiver1 = rx_output.next().await.unwrap(); + let mut receiver1 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver1.try_next().is_err()); // Get the receiver for the second stream. - let mut receiver2 = rx_output.next().await.unwrap(); + let mut receiver2 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver2.try_next().is_err()); // Get the receiver for the third stream. - let mut receiver3 = rx_output.next().await.unwrap(); + let mut receiver3 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver3.try_next().is_err()); // Send the last message on stream_id1: - send(&mut network_sender, &metadata, make_test_message(stream_id1, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id1, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); - let mut stream_handler = join_handle.await.expect("Task should succeed"); - // Should be able to read all the messages for stream_id1. for _ in 0..9 { // message number 9 is Fin, so it will not be sent! let _ = receiver1.next().await.unwrap(); } - - // Check that the receiver was closed: - assert!(matches!(receiver1.try_next(), Ok(None))); + let mut stream_handler = join_handle.await.expect("Task should succeed"); // stream_id1 should be gone - let values = [(peer_id.clone(), 1), (peer_id.clone(), 10)]; + let values = [(peer_id.clone(), StreamId(1)), (peer_id.clone(), StreamId(10))]; assert!( stream_handler .inbound_stream_data @@ -267,25 +374,24 @@ mod tests { ); // Send the last message on stream_id2: - send(&mut network_sender, &metadata, make_test_message(stream_id2, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id2, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); - let mut stream_handler = join_handle.await.expect("Task should succeed"); - // Should be able to read all the messages for stream_id2. for _ in 0..5 { // message number 5 is Fin, so it will not be sent! let _ = receiver2.next().await.unwrap(); } - // Check that the receiver was closed: - assert!(matches!(receiver2.try_next(), Ok(None))); + let mut stream_handler = join_handle.await.expect("Task should succeed"); // Stream_id2 should also be gone. - let values = [(peer_id.clone(), 1)]; + let values = [(peer_id.clone(), StreamId(1))]; assert!( stream_handler .inbound_stream_data @@ -295,10 +401,11 @@ mod tests { ); // Send the last message on stream_id3: - send(&mut network_sender, &metadata, make_test_message(stream_id3, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id3, 0, false)).await; + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); @@ -308,11 +415,8 @@ mod tests { let _ = receiver3.next().await.unwrap(); } - // In this case the receiver is not closed, because we didn't send a fin. - assert!(receiver3.try_next().is_err()); - // Stream_id3 should still be there, because we didn't send a fin. - let values = [(peer_id.clone(), 1)]; + let values = [(peer_id.clone(), StreamId(1))]; assert!( stream_handler .inbound_stream_data @@ -323,7 +427,131 @@ mod tests { // But the buffer should be empty, as we've successfully drained it all. assert!( - stream_handler.inbound_stream_data[&(peer_id, stream_id3)].message_buffer.is_empty() + stream_handler.inbound_stream_data[&(peer_id, StreamId(stream_id3))] + .message_buffer + .is_empty() + ); + } + + // This test does two things: + // 1. Opens two outbound channels and checks that messages get correctly sent on both. + // 2. Closes the first channel and checks that Fin is sent and that the relevant structures + // inside the stream handler are cleaned up. + #[tokio::test] + async fn outbound_multiple_streams() { + let ( + mut stream_handler, + _, + _, + _, + mut broadcast_channel_sender, + mut broadcasted_messages_receiver, + ) = setup_test(); + + let stream_id1 = StreamId(42); + let stream_id2 = StreamId(127); + + // Start a new stream by sending the (stream_id, receiver). + let (mut sender1, receiver1) = mpsc::channel(CHANNEL_SIZE); + broadcast_channel_sender.send((stream_id1, receiver1)).await.unwrap(); + + // Send a message on the stream. + let message1 = ConsensusMessage::Proposal(Proposal::default()); + sender1.send(message1.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message1)); + assert_eq!(broadcasted_message.stream_id, stream_id1); + assert_eq!(broadcasted_message.message_id, 0); + + // Check that internally, stream_handler holds this receiver. + assert_eq!( + stream_handler.outbound_stream_receivers.keys().collect::>(), + vec![&stream_id1] + ); + // Check that the number of messages sent on this stream is 1. + assert_eq!(stream_handler.outbound_stream_number[&stream_id1], 1); + + // Send another message on the same stream. + let message2 = ConsensusMessage::Proposal(Proposal::default()); + sender1.send(message2.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message2)); + assert_eq!(broadcasted_message.stream_id, stream_id1); + assert_eq!(broadcasted_message.message_id, 1); + assert_eq!(stream_handler.outbound_stream_number[&stream_id1], 2); + + // Start a new stream by sending the (stream_id, receiver). + let (mut sender2, receiver2) = mpsc::channel(CHANNEL_SIZE); + broadcast_channel_sender.send((stream_id2, receiver2)).await.unwrap(); + + // Send a message on the stream. + let message3 = ConsensusMessage::Proposal(Proposal::default()); + sender2.send(message3.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message3)); + assert_eq!(broadcasted_message.stream_id, stream_id2); + assert_eq!(broadcasted_message.message_id, 0); + let mut vec1 = stream_handler.outbound_stream_receivers.keys().collect::>(); + vec1.sort(); + let mut vec2 = vec![&stream_id1, &stream_id2]; + vec2.sort(); + do_vecs_match(&vec1, &vec2); + assert_eq!(stream_handler.outbound_stream_number[&stream_id2], 1); + + // Close the first channel. + sender1.close_channel(); + + // Run the loop for a short duration to process that the channel was closed. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Check that we got a fin message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + assert_eq!(broadcasted_message.message, StreamMessageBody::Fin); + + let stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that the information about this stream is gone. + assert_eq!( + stream_handler.outbound_stream_receivers.keys().collect::>(), + vec![&stream_id2] ); } }