From 69d1bac2267b4deb20e0abbc4b9e7bc2b4c9978e Mon Sep 17 00:00:00 2001 From: Guy Nir Date: Sun, 27 Oct 2024 11:35:14 +0200 Subject: [PATCH] fix: update StreamHandler to use new StreamHashMap interface --- Cargo.lock | 6 +- crates/papyrus_node/src/run.rs | 7 +- crates/papyrus_protobuf/Cargo.toml | 1 + crates/papyrus_protobuf/src/consensus.rs | 48 +++++++- .../src/converters/consensus.rs | 33 +++-- .../src/converters/consensus_test.rs | 5 +- .../src/converters/test_instances.rs | 35 +++++- .../src/proto/p2p/proto/consensus.proto | 2 +- .../sequencing/papyrus_consensus/Cargo.toml | 1 + .../papyrus_consensus/src/stream_handler.rs | 116 +++++++++++++----- .../src/stream_handler_test.rs | 116 +++++++++++++----- .../src/papyrus_consensus_context.rs | 9 +- .../src/papyrus_consensus_context_test.rs | 6 +- .../src/sequencer_consensus_context.rs | 7 +- .../src/sequencer_consensus_context_test.rs | 3 +- .../src/consensus_manager.rs | 4 +- .../src/flow_test_setup.rs | 5 +- .../src/integration_test_setup.rs | 5 +- .../starknet_integration_tests/src/utils.rs | 7 +- .../tests/end_to_end_flow_test.rs | 5 +- 20 files changed, 315 insertions(+), 106 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a9e185d4dd..fba33c59b22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1646,9 +1646,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" dependencies = [ "serde", ] @@ -7382,6 +7382,7 @@ dependencies = [ "papyrus_protobuf", "papyrus_storage", "papyrus_test_utils", + "prost", "serde", "starknet-types-core", "starknet_api", @@ -7630,6 +7631,7 @@ dependencies = [ name = "papyrus_protobuf" version = "0.0.0" dependencies = [ + "bytes", "indexmap 2.6.0", "lazy_static", "papyrus_common", diff --git a/crates/papyrus_node/src/run.rs b/crates/papyrus_node/src/run.rs index 1bfdeae7563..bb4fe17edac 100644 --- a/crates/papyrus_node/src/run.rs +++ b/crates/papyrus_node/src/run.rs @@ -23,7 +23,7 @@ use papyrus_network::{network_manager, NetworkConfig}; use papyrus_p2p_sync::client::{P2PSyncClient, P2PSyncClientChannels}; use papyrus_p2p_sync::server::{P2PSyncServer, P2PSyncServerChannels}; use papyrus_p2p_sync::{Protocol, BUFFER_SIZE}; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; #[cfg(feature = "rpc")] use papyrus_rpc::run_server; use papyrus_storage::storage_metrics::update_storage_metrics; @@ -192,8 +192,9 @@ fn spawn_consensus( let network_channels = network_manager .register_broadcast_topic(Topic::new(config.network_topic.clone()), BUFFER_SIZE)?; - let proposal_network_channels: BroadcastTopicChannels> = - network_manager.register_broadcast_topic(Topic::new(NETWORK_TOPIC), BUFFER_SIZE)?; + let proposal_network_channels: BroadcastTopicChannels< + StreamMessage, + > = network_manager.register_broadcast_topic(Topic::new(NETWORK_TOPIC), BUFFER_SIZE)?; let BroadcastTopicChannels { broadcasted_messages_receiver: inbound_network_receiver, broadcast_topic_client: outbound_network_sender, diff --git a/crates/papyrus_protobuf/Cargo.toml b/crates/papyrus_protobuf/Cargo.toml index 5d638b27568..569d0be3997 100644 --- a/crates/papyrus_protobuf/Cargo.toml +++ b/crates/papyrus_protobuf/Cargo.toml @@ -9,6 +9,7 @@ license-file.workspace = true testing = ["papyrus_test_utils", "rand", "rand_chacha"] [dependencies] +bytes.workspace = true indexmap.workspace = true lazy_static.workspace = true primitive-types.workspace = true diff --git a/crates/papyrus_protobuf/src/consensus.rs b/crates/papyrus_protobuf/src/consensus.rs index f9c7d51f49a..a4c6c54a77d 100644 --- a/crates/papyrus_protobuf/src/consensus.rs +++ b/crates/papyrus_protobuf/src/consensus.rs @@ -1,3 +1,7 @@ +use std::fmt::Display; + +use bytes::{Buf, BufMut}; +use prost::DecodeError; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; @@ -26,10 +30,13 @@ pub enum StreamMessageBody { Fin, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub struct StreamMessage> + TryFrom, Error = ProtobufConversionError>> { +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct StreamMessage< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + Clone, +> { pub message: StreamMessageBody, - pub stream_id: u64, + pub stream_id: StreamId, pub message_id: u64, } @@ -108,7 +115,7 @@ impl From for ProposalPart { } } -impl std::fmt::Display for StreamMessage +impl> + Clone + Display> std::fmt::Display for StreamMessage where T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, { @@ -131,3 +138,36 @@ where } } } + +/// HeighAndRound is a tuple struct used as the StreamId for consensus and context. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HeightAndRound(pub u64, pub u32); + +impl TryFrom> for HeightAndRound { + type Error = ProtobufConversionError; + + fn try_from(value: Vec) -> Result { + if value.len() != 12 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + } + let mut bytes = value.as_slice(); + let height = bytes.get_u64(); + let round = bytes.get_u32(); + Ok(HeightAndRound(height, round)) + } +} + +impl From for Vec { + fn from(value: HeightAndRound) -> Vec { + let mut bytes = Vec::with_capacity(12); + bytes.put_u64(value.0); + bytes.put_u32(value.1); + bytes + } +} + +impl std::fmt::Display for HeightAndRound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(height: {}, round: {})", self.0, self.1) + } +} diff --git a/crates/papyrus_protobuf/src/converters/consensus.rs b/crates/papyrus_protobuf/src/converters/consensus.rs index 1dad9045b85..ab08374690b 100644 --- a/crates/papyrus_protobuf/src/converters/consensus.rs +++ b/crates/papyrus_protobuf/src/converters/consensus.rs @@ -1,6 +1,7 @@ #[cfg(test)] #[path = "consensus_test.rs"] mod consensus_test; + use std::convert::{TryFrom, TryInto}; use prost::Message; @@ -79,8 +80,10 @@ impl From for protobuf::Vote { auto_impl_into_and_try_from_vec_u8!(Vote, protobuf::Vote); -impl> + TryFrom, Error = ProtobufConversionError>> - TryFrom for StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> TryFrom for StreamMessage { type Error = ProtobufConversionError; @@ -101,16 +104,18 @@ impl> + TryFrom, Error = ProtobufConversionError>> StreamMessageBody::Fin } }, - stream_id: value.stream_id, + stream_id: value.stream_id.try_into()?, message_id: value.message_id, }) } } -impl> + TryFrom, Error = ProtobufConversionError>> From> - for protobuf::StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> From> for protobuf::StreamMessage { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { Self { message: match value { StreamMessage { @@ -122,7 +127,7 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< Some(protobuf::stream_message::Message::Fin(protobuf::Fin {})) } }, - stream_id: value.stream_id, + stream_id: value.stream_id.into(), message_id: value.message_id, } } @@ -131,17 +136,21 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< // Can't use auto_impl_into_and_try_from_vec_u8!(StreamMessage, protobuf::StreamMessage); // because it doesn't seem to work with generics. // TODO(guyn): consider expanding the macro to support generics -impl> + TryFrom, Error = ProtobufConversionError>> From> - for Vec +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> From> for Vec { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { let protobuf_value = ::from(value); protobuf_value.encode_to_vec() } } -impl> + TryFrom, Error = ProtobufConversionError>> TryFrom> - for StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> TryFrom> for StreamMessage { type Error = ProtobufConversionError; fn try_from(value: Vec) -> Result { diff --git a/crates/papyrus_protobuf/src/converters/consensus_test.rs b/crates/papyrus_protobuf/src/converters/consensus_test.rs index d2c465706da..52fea9b2169 100644 --- a/crates/papyrus_protobuf/src/converters/consensus_test.rs +++ b/crates/papyrus_protobuf/src/converters/consensus_test.rs @@ -20,6 +20,7 @@ use crate::consensus::{ TransactionBatch, Vote, }; +use crate::converters::test_instances::TestStreamId; // If all the fields of `AllResources` are 0 upon serialization, // then the deserialized value will be interpreted as the `L1Gas` variant. @@ -50,7 +51,7 @@ fn convert_stream_message_to_vec_u8_and_back() { let mut rng = get_rng(); // Test that we can convert a StreamMessage with a ProposalPart message to bytes and back. - let mut stream_message: StreamMessage = + let mut stream_message: StreamMessage = StreamMessage::get_test_instance(&mut rng); if let StreamMessageBody::Content(ProposalPart::Transactions(proposal)) = @@ -128,7 +129,7 @@ fn convert_proposal_part_to_vec_u8_and_back() { #[test] fn stream_message_display() { let mut rng = get_rng(); - let stream_id = 42; + let stream_id = TestStreamId(42); let message_id = 127; let proposal = ProposalPart::get_test_instance(&mut rng); let proposal_bytes: Vec = proposal.clone().into(); diff --git a/crates/papyrus_protobuf/src/converters/test_instances.rs b/crates/papyrus_protobuf/src/converters/test_instances.rs index be55534a758..3ba24ea4416 100644 --- a/crates/papyrus_protobuf/src/converters/test_instances.rs +++ b/crates/papyrus_protobuf/src/converters/test_instances.rs @@ -1,9 +1,13 @@ +use std::fmt::Display; + use papyrus_test_utils::{auto_impl_get_test_instance, get_number_of_variants, GetTestInstance}; +use prost::DecodeError; use rand::Rng; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; +use super::ProtobufConversionError; use crate::consensus::{ ProposalFin, ProposalInit, @@ -47,9 +51,36 @@ auto_impl_get_test_instance! { } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TestStreamId(pub u64); + +impl From for Vec { + fn from(value: TestStreamId) -> Self { + value.0.to_be_bytes().to_vec() + } +} + +impl TryFrom> for TestStreamId { + type Error = ProtobufConversionError; + fn try_from(bytes: Vec) -> Result { + if bytes.len() != 8 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + }; + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + Ok(TestStreamId(u64::from_be_bytes(array))) + } +} + +impl Display for TestStreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestStreamId({})", self.0) + } +} + // The auto_impl_get_test_instance macro does not work for StreamMessage because it has // a generic type. TODO(guyn): try to make the macro work with generic types. -impl GetTestInstance for StreamMessage { +impl GetTestInstance for StreamMessage { fn get_test_instance(rng: &mut rand_chacha::ChaCha8Rng) -> Self { let message = if rng.gen_bool(0.5) { StreamMessageBody::Content(ProposalPart::Transactions(TransactionBatch { @@ -58,6 +89,6 @@ impl GetTestInstance for StreamMessage { } else { StreamMessageBody::Fin }; - Self { message, stream_id: 12, message_id: 47 } + Self { message, stream_id: TestStreamId(12), message_id: 47 } } } diff --git a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto index 155a9957553..765798744c3 100644 --- a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto +++ b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto @@ -24,7 +24,7 @@ message StreamMessage { bytes content = 1; Fin fin = 2; } - uint64 stream_id = 3; + bytes stream_id = 3; uint64 message_id = 4; } diff --git a/crates/sequencing/papyrus_consensus/Cargo.toml b/crates/sequencing/papyrus_consensus/Cargo.toml index fb372e2dd44..d31172b915c 100644 --- a/crates/sequencing/papyrus_consensus/Cargo.toml +++ b/crates/sequencing/papyrus_consensus/Cargo.toml @@ -20,6 +20,7 @@ papyrus_config.workspace = true papyrus_network.workspace = true papyrus_network_types.workspace = true papyrus_protobuf.workspace = true +prost.workspace = true serde = { workspace = true, features = ["derive"] } starknet-types-core.workspace = true starknet_api.workspace = true diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index b794ce87dd7..317fd0c40c7 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -3,6 +3,8 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; +use std::fmt::{Debug, Display}; +use std::hash::Hash; use futures::channel::mpsc; use futures::StreamExt; @@ -22,9 +24,7 @@ use tracing::{instrument, warn}; mod stream_handler_test; type PeerId = OpaquePeerId; -type StreamId = u64; type MessageId = u64; -type StreamKey = (PeerId, StreamId); const CHANNEL_BUFFER_LENGTH: usize = 100; @@ -34,7 +34,15 @@ const CHANNEL_BUFFER_LENGTH: usize = 100; // (2) fin message is received and all messages are sent. #[derive(Debug)] struct StreamData< - T: Clone + Into> + TryFrom, Error = ProtobufConversionError> + 'static, + T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Display + + Debug, > { next_message_id: MessageId, // Last message ID. If None, it means we have not yet gotten to it. @@ -44,10 +52,21 @@ struct StreamData< receiver: Option>, sender: mpsc::Sender, // A buffer for messages that were received out of order. - message_buffer: HashMap>, + message_buffer: HashMap>, } -impl> + TryFrom, Error = ProtobufConversionError>> StreamData { +impl< + T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Display + + Debug, +> StreamData +{ fn new() -> Self { let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_LENGTH); StreamData { @@ -65,17 +84,25 @@ impl> + TryFrom, Error = ProtobufConversionError /// - Buffering inbound messages and reporting them to the application in order. /// - Sending outbound messages to the network, wrapped in StreamMessage. pub struct StreamHandler< - T: Clone + Into> + TryFrom, Error = ProtobufConversionError> + 'static, + T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Display + + Debug, > { // For each stream ID from the network, send the application a Receiver // that will receive the messages in order. This allows sending such Receivers. inbound_channel_sender: mpsc::Sender>, // This receives messages from the network. - inbound_receiver: BroadcastTopicServer>, + inbound_receiver: BroadcastTopicServer>, // A map from (peer_id, stream_id) to a struct that contains all the information // about the stream. This includes both the message buffer and some metadata // (like the latest message ID). - inbound_stream_data: HashMap>, + inbound_stream_data: HashMap<(PeerId, StreamId), StreamData>, // Whenever application wants to start a new stream, it must send out a // (stream_id, Receiver) pair. Each receiver gets messages that should // be sent out to the network. @@ -83,20 +110,30 @@ pub struct StreamHandler< // A map where the abovementioned Receivers are stored. outbound_stream_receivers: StreamHashMap>, // A network sender that allows sending StreamMessages to peers. - outbound_sender: BroadcastTopicClient>, + outbound_sender: BroadcastTopicClient>, // For each stream, keep track of the message_id of the last message sent. outbound_stream_number: HashMap, } -impl> + TryFrom, Error = ProtobufConversionError>> - StreamHandler +impl< + T: Clone + Send + Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Send + + Display + + Debug, +> StreamHandler { /// Create a new StreamHandler. pub fn new( inbound_channel_sender: mpsc::Sender>, - inbound_receiver: BroadcastTopicServer>, + inbound_receiver: BroadcastTopicServer>, outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, - outbound_sender: BroadcastTopicClient>, + outbound_sender: BroadcastTopicClient>, ) -> Self { Self { inbound_channel_sender, @@ -113,13 +150,17 @@ impl> + TryFrom, Error = ProtobufConversi /// Gets network input/output channels and returns application input/output channels. #[allow(clippy::type_complexity)] pub fn get_channels( - inbound_network_receiver: BroadcastTopicServer>, - outbound_network_sender: BroadcastTopicClient>, + inbound_network_receiver: BroadcastTopicServer>, + outbound_network_sender: BroadcastTopicClient>, ) -> ( mpsc::Sender<(StreamId, mpsc::Receiver)>, mpsc::Receiver>, tokio::task::JoinHandle<()>, - ) { + ) + where + T: 'static, + StreamId: 'static, + { // The inbound messages come into StreamHandler via inbound_network_receiver. // The application gets the messages from inbound_internal_receiver // (the StreamHandler keeps the inbound_internal_sender to pass the messages). @@ -136,7 +177,7 @@ impl> + TryFrom, Error = ProtobufConversi mpsc::Receiver<(StreamId, mpsc::Receiver)>, ) = mpsc::channel(CHANNEL_BUFFER_LENGTH); - let mut stream_handler = StreamHandler::::new( + let mut stream_handler = StreamHandler::::new( inbound_internal_sender, // Sender>, inbound_network_receiver, // BroadcastTopicServer>, outbound_internal_receiver, // Receiver<(StreamId, Receiver)>, @@ -186,8 +227,11 @@ impl> + TryFrom, Error = ProtobufConversi } } - // Returns true if the receiver for this stream is dropped. - fn inbound_send(&mut self, data: &mut StreamData, message: StreamMessage) -> bool { + fn inbound_send( + &mut self, + data: &mut StreamData, + message: StreamMessage, + ) -> bool { // TODO(guyn): reconsider the "expect" here. let sender = &mut data.sender; if let StreamMessageBody::Content(content) = message.message { @@ -231,20 +275,22 @@ impl> + TryFrom, Error = ProtobufConversi async fn broadcast(&mut self, stream_id: StreamId, message: T) { let message = StreamMessage { message: StreamMessageBody::Content(message), - stream_id, + stream_id: stream_id.clone(), message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), }; // TODO(guyn): reconsider the "expect" here. self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); - self.outbound_stream_number - .insert(stream_id, self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1); + self.outbound_stream_number.insert( + stream_id.clone(), + self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1, + ); } // Send a fin message to the network. async fn broadcast_fin(&mut self, stream_id: StreamId) { let message = StreamMessage { message: StreamMessageBody::Fin, - stream_id, + stream_id: stream_id.clone(), message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), }; self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); @@ -255,7 +301,10 @@ impl> + TryFrom, Error = ProtobufConversi #[instrument(skip_all, level = "warn")] fn handle_message( &mut self, - message: (Result, ProtobufConversionError>, BroadcastedMessageMetadata), + message: ( + Result, ProtobufConversionError>, + BroadcastedMessageMetadata, + ), ) { let (message, metadata) = message; let message = match message { @@ -267,7 +316,7 @@ impl> + TryFrom, Error = ProtobufConversi }; let peer_id = metadata.originator_id.clone(); - let stream_id = message.stream_id; + let stream_id = message.stream_id.clone(); let key = (peer_id, stream_id); let data = match self.inbound_stream_data.entry(key.clone()) { @@ -288,12 +337,12 @@ impl> + TryFrom, Error = ProtobufConversi /// should be dropped. fn handle_message_inner( &mut self, - message: StreamMessage, + message: StreamMessage, metadata: BroadcastedMessageMetadata, - mut data: StreamData, - ) -> Option> { + mut data: StreamData, + ) -> Option> { let peer_id = metadata.originator_id; - let stream_id = message.stream_id; + let stream_id = message.stream_id.clone(); let key = (peer_id, stream_id); let message_id = message.message_id; @@ -366,7 +415,11 @@ impl> + TryFrom, Error = ProtobufConversi } // Store an inbound message in the buffer. - fn store(data: &mut StreamData, key: StreamKey, message: StreamMessage) { + fn store( + data: &mut StreamData, + key: (PeerId, StreamId), + message: StreamMessage, + ) { let message_id = message.message_id; match data.message_buffer.entry(message_id) { @@ -385,8 +438,7 @@ impl> + TryFrom, Error = ProtobufConversi // Tries to drain as many messages as possible from the buffer (in order), // DOES NOT guarantee that the buffer will be empty after calling this function. - // Returns true if the receiver for this stream is dropped. - fn process_buffer(&mut self, data: &mut StreamData) -> bool { + fn process_buffer(&mut self, data: &mut StreamData) -> bool { while let Some(message) = data.message_buffer.remove(&data.next_message_id) { if self.inbound_send(data, message) { return true; diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 2f00140462c..544cca07e3f 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::time::Duration; use futures::channel::mpsc; @@ -11,13 +12,54 @@ use papyrus_network::network_manager::test_utils::{ use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_network_types::network_types::BroadcastedMessageMetadata; use papyrus_protobuf::consensus::{StreamMessage, StreamMessageBody}; +use papyrus_protobuf::converters::ProtobufConversionError; use papyrus_test_utils::{get_rng, GetTestInstance}; +use prost::DecodeError; -use super::{MessageId, StreamHandler, StreamId}; +use super::{MessageId, StreamHandler}; const TIMEOUT: Duration = Duration::from_millis(100); const CHANNEL_SIZE: usize = 100; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct TestStreamId(u64); + +impl From for Vec { + fn from(value: TestStreamId) -> Self { + value.0.to_be_bytes().to_vec() + } +} + +impl TryFrom> for TestStreamId { + type Error = ProtobufConversionError; + fn try_from(bytes: Vec) -> Result { + if bytes.len() != 8 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + } + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + Ok(TestStreamId(u64::from_be_bytes(array))) + } +} + +impl PartialOrd for TestStreamId { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Ord for TestStreamId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl Display for TestStreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestStreamId({})", self.0) + } +} + #[cfg(test)] mod tests { use papyrus_protobuf::consensus::{ProposalInit, ProposalPart}; @@ -25,10 +67,10 @@ mod tests { use super::*; fn make_test_message( - stream_id: StreamId, + stream_id: TestStreamId, message_id: MessageId, fin: bool, - ) -> StreamMessage { + ) -> StreamMessage { let content = match fin { true => StreamMessageBody::Fin, false => StreamMessageBody::Content(ProposalPart::Init(ProposalInit::default())), @@ -49,21 +91,24 @@ mod tests { } async fn send( - sender: &mut MockBroadcastedMessagesSender>, + sender: &mut MockBroadcastedMessagesSender>, metadata: &BroadcastedMessageMetadata, - msg: StreamMessage, + msg: StreamMessage, ) { sender.send((msg, metadata.clone())).await.unwrap(); } #[allow(clippy::type_complexity)] fn setup_test() -> ( - StreamHandler, - MockBroadcastedMessagesSender>, + StreamHandler, + MockBroadcastedMessagesSender>, mpsc::Receiver>, BroadcastedMessageMetadata, - mpsc::Sender<(StreamId, mpsc::Receiver)>, - futures::stream::Map>, fn(Vec) -> StreamMessage>, + mpsc::Sender<(TestStreamId, mpsc::Receiver)>, + futures::stream::Map< + mpsc::Receiver>, + fn(Vec) -> StreamMessage, + >, ) { // The outbound_sender is the network connector for broadcasting messages. // The network_broadcast_receiver is used to catch those messages in the test. @@ -80,7 +125,7 @@ mod tests { // The receiver goes into StreamHandler, sender is used by the test (as mock Consensus). // Note that each new channel comes in a tuple with (stream_id, receiver). let (outbound_channel_sender, outbound_channel_receiver) = - mpsc::channel::<(StreamId, mpsc::Receiver)>(CHANNEL_SIZE); + mpsc::channel::<(TestStreamId, mpsc::Receiver)>(CHANNEL_SIZE); // The network_sender_to_inbound is the sender of the mock network, that is used by the // test to send messages into the StreamHandler (from the mock network). @@ -126,7 +171,7 @@ mod tests { let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = setup_test(); - let stream_id = 127; + let stream_id = TestStreamId(127); for i in 0..10 { let message = make_test_message(stream_id, i, i == 9); send(&mut network_sender, &metadata, message).await; @@ -158,7 +203,7 @@ mod tests { _, ) = setup_test(); let peer_id = inbound_metadata.originator_id.clone(); - let stream_id = 127; + let stream_id = TestStreamId(127); for i in 0..5 { let message = make_test_message(stream_id, 5 - i, i == 0); @@ -233,9 +278,9 @@ mod tests { ) = setup_test(); let peer_id = inbound_metadata.originator_id.clone(); - let stream_id1 = 127; // Send all messages in order (except the first one). - let stream_id2 = 10; // Send in reverse order (except the first one). - let stream_id3 = 1; // Send in two batches, without the first one, don't send fin. + let stream_id1 = TestStreamId(127); // Send all messages in order (except the first one). + let stream_id2 = TestStreamId(10); // Send in reverse order (except the first one). + let stream_id3 = TestStreamId(1); // Send in two batches, without the first one, don't send fin. for i in 1..10 { let message = make_test_message(stream_id1, i, i == 9); @@ -264,8 +309,14 @@ mod tests { }); let mut stream_handler = join_handle.await.expect("Task should succeed"); - let values = [(peer_id.clone(), 1), (peer_id.clone(), 10), (peer_id.clone(), 127)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [ + (peer_id.clone(), TestStreamId(1)), + (peer_id.clone(), TestStreamId(10)), + (peer_id.clone(), TestStreamId(127)), + ]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // We have all message from 1 to 9 buffered. assert!(do_vecs_match_unordered( @@ -320,8 +371,10 @@ mod tests { let mut stream_handler = join_handle.await.expect("Task should succeed"); // stream_id1 should be gone - let values = [(peer_id.clone(), 1), (peer_id.clone(), 10)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1)), (peer_id.clone(), TestStreamId(10))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // Send the last message on stream_id2: send(&mut network_sender, &inbound_metadata, make_test_message(stream_id2, 0, false)).await; @@ -344,8 +397,10 @@ mod tests { let mut stream_handler = join_handle.await.expect("Task should succeed"); // Stream_id2 should also be gone. - let values = [(peer_id.clone(), 1)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // Send the last message on stream_id3: send(&mut network_sender, &inbound_metadata, make_test_message(stream_id3, 0, false)).await; @@ -366,8 +421,10 @@ mod tests { } // Stream_id3 should still be there, because we didn't send a fin. - let values = [(peer_id.clone(), 1)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // But the buffer should be empty, as we've successfully drained it all. assert!( @@ -380,7 +437,7 @@ mod tests { let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = setup_test(); - let stream_id = 127; + let stream_id = TestStreamId(127); // Send two messages, no Fin. for i in 0..2 { let message = make_test_message(stream_id, i, false); @@ -442,8 +499,8 @@ mod tests { mut broadcasted_messages_receiver, ) = setup_test(); - let stream_id1: StreamId = 42; - let stream_id2: StreamId = 127; + let stream_id1 = TestStreamId(42); + let stream_id2 = TestStreamId(127); // Start a new stream by sending the (stream_id, receiver). let (mut sender1, receiver1) = mpsc::channel(CHANNEL_SIZE); @@ -470,7 +527,7 @@ mod tests { // Check that internally, stream_handler holds this receiver. assert_eq!( - stream_handler.outbound_stream_receivers.keys().collect::>(), + stream_handler.outbound_stream_receivers.keys().collect::>(), vec![&stream_id1] ); // Check that the number of messages sent on this stream is 1. @@ -520,7 +577,8 @@ mod tests { assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message3)); assert_eq!(broadcasted_message.stream_id, stream_id2); assert_eq!(broadcasted_message.message_id, 0); - let mut vec1 = stream_handler.outbound_stream_receivers.keys().collect::>(); + let mut vec1 = + stream_handler.outbound_stream_receivers.keys().collect::>(); vec1.sort(); let mut vec2 = vec![&stream_id1, &stream_id2]; vec2.sort(); @@ -544,7 +602,7 @@ mod tests { // Check that the information about this stream is gone. assert_eq!( - stream_handler.outbound_stream_receivers.keys().collect::>(), + stream_handler.outbound_stream_receivers.keys().collect::>(), vec![&stream_id2] ); } diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs index b7cab998750..199438c24ad 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs @@ -24,6 +24,7 @@ use papyrus_consensus::types::{ }; use papyrus_network::network_manager::{BroadcastTopicClient, BroadcastTopicClientTrait}; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -47,7 +48,7 @@ const CHANNEL_SIZE: usize = 100; pub struct PapyrusConsensusContext { storage_reader: StorageReader, network_broadcast_client: BroadcastTopicClient, - network_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + network_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, validators: Vec, sync_broadcast_sender: Option>, // Proposal building/validating returns immediately, leaving the actual processing to a spawned @@ -60,7 +61,7 @@ impl PapyrusConsensusContext { pub fn new( storage_reader: StorageReader, network_broadcast_client: BroadcastTopicClient, - network_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + network_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, num_validators: u64, sync_broadcast_sender: Option>, ) -> Self { @@ -117,7 +118,7 @@ impl ConsensusContext for PapyrusConsensusContext { .block_hash; let (mut proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); - let stream_id = height.0; + let stream_id = HeightAndRound(proposal_init.height.0, proposal_init.round); network_proposal_sender .send((stream_id, proposal_receiver)) .await @@ -248,7 +249,7 @@ impl ConsensusContext for PapyrusConsensusContext { .unwrap_or_else(|| panic!("No proposal found for height {height} and id {id}")) .clone(); - let stream_id = height.0; + let stream_id = HeightAndRound(height.0, init.round); let (mut proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); self.network_proposal_sender .send((stream_id, proposal_receiver)) diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs index a22d92b0edc..0bdb8ccc20d 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs @@ -11,6 +11,7 @@ use papyrus_network::network_manager::test_utils::{ }; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -128,8 +129,9 @@ fn test_setup() .unwrap(); let network_channels = mock_register_broadcast_topic().unwrap(); - let network_proposal_channels: TestSubscriberChannels> = - mock_register_broadcast_topic().unwrap(); + let network_proposal_channels: TestSubscriberChannels< + StreamMessage, + > = mock_register_broadcast_topic().unwrap(); let BroadcastTopicChannels { broadcasted_messages_receiver: inbound_network_receiver, broadcast_topic_client: outbound_network_sender, diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs index f9e07c10fdc..c75eec7ec52 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs @@ -21,6 +21,7 @@ use papyrus_consensus::types::{ }; use papyrus_network::network_manager::{BroadcastTopicClient, BroadcastTopicClientTrait}; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -116,7 +117,7 @@ pub struct SequencerConsensusContext { // Stores proposals for future rounds until the round is reached. queued_proposals: BTreeMap)>, - outbound_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + outbound_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, // Used to broadcast votes to other consensus nodes. vote_broadcast_client: BroadcastTopicClient, // Used to convert Transaction to ExecutableTransaction. @@ -127,7 +128,7 @@ pub struct SequencerConsensusContext { impl SequencerConsensusContext { pub fn new( batcher: Arc, - outbound_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + outbound_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, vote_broadcast_client: BroadcastTopicClient, num_validators: u64, chain_id: ChainId, @@ -209,7 +210,7 @@ impl ConsensusContext for SequencerConsensusContext { .expect("Failed to initiate proposal build"); debug!("Broadcasting proposal init: {proposal_init:?}"); let (mut proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); - let stream_id = proposal_init.height.0; + let stream_id = HeightAndRound(proposal_init.height.0, proposal_init.round); self.outbound_proposal_sender .send((stream_id, proposal_receiver)) .await diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs index b72fbe3900b..6603bdabda8 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs @@ -14,6 +14,7 @@ use papyrus_network::network_manager::test_utils::{ }; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -68,7 +69,7 @@ fn generate_invoke_tx(nonce: u8) -> Transaction { // Structs which aren't utilized but should not be dropped. struct NetworkDependencies { _vote_network: BroadcastNetworkMock, - _new_proposal_network: BroadcastNetworkMock>, + _new_proposal_network: BroadcastNetworkMock>, } fn setup( diff --git a/crates/starknet_consensus_manager/src/consensus_manager.rs b/crates/starknet_consensus_manager/src/consensus_manager.rs index b5fcf4bdec6..34520eb46f9 100644 --- a/crates/starknet_consensus_manager/src/consensus_manager.rs +++ b/crates/starknet_consensus_manager/src/consensus_manager.rs @@ -8,7 +8,7 @@ use papyrus_consensus_orchestrator::cende::CendeAmbassador; use papyrus_consensus_orchestrator::sequencer_consensus_context::SequencerConsensusContext; use papyrus_network::gossipsub_impl::Topic; use papyrus_network::network_manager::{BroadcastTopicChannels, NetworkManager}; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage, Vote}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage, Vote}; use starknet_api::block::BlockNumber; use starknet_batcher_types::communication::SharedBatcherClient; use starknet_sequencer_infra::component_definitions::ComponentStarter; @@ -44,7 +44,7 @@ impl ConsensusManager { NetworkManager::new(self.config.consensus_config.network_config.clone(), None); let proposals_broadcast_channels = network_manager - .register_broadcast_topic::>( + .register_broadcast_topic::>( Topic::new(CONSENSUS_PROPOSALS_TOPIC), BROADCAST_BUFFER_SIZE, ) diff --git a/crates/starknet_integration_tests/src/flow_test_setup.rs b/crates/starknet_integration_tests/src/flow_test_setup.rs index e3564b551a0..19f3dc7f571 100644 --- a/crates/starknet_integration_tests/src/flow_test_setup.rs +++ b/crates/starknet_integration_tests/src/flow_test_setup.rs @@ -3,7 +3,7 @@ use std::net::SocketAddr; use blockifier::context::ChainInfo; use mempool_test_utils::starknet_api_test_utils::{Contract, MultiAccountTransactionGenerator}; use papyrus_network::network_manager::BroadcastTopicChannels; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; use starknet_api::rpc_transaction::RpcTransaction; use starknet_api::transaction::TransactionHash; use starknet_consensus_manager::config::ConsensusManagerConfig; @@ -37,7 +37,8 @@ pub struct FlowTestSetup { pub sequencer_1: FlowSequencerSetup, // Channels for consensus proposals, used for asserting the right transactions are proposed. - pub consensus_proposals_channels: BroadcastTopicChannels>, + pub consensus_proposals_channels: + BroadcastTopicChannels>, } impl FlowTestSetup { diff --git a/crates/starknet_integration_tests/src/integration_test_setup.rs b/crates/starknet_integration_tests/src/integration_test_setup.rs index b0378fc111a..06264be46ac 100644 --- a/crates/starknet_integration_tests/src/integration_test_setup.rs +++ b/crates/starknet_integration_tests/src/integration_test_setup.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use blockifier::context::ChainInfo; use mempool_test_utils::starknet_api_test_utils::{Contract, MultiAccountTransactionGenerator}; use papyrus_network::network_manager::BroadcastTopicChannels; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; use papyrus_storage::{StorageConfig, StorageReader}; use starknet_api::rpc_transaction::RpcTransaction; use starknet_api::transaction::TransactionHash; @@ -36,7 +36,8 @@ pub struct IntegrationTestSetup { // TODO: To validate test results instead of reading storage - delete this and use monitoring // or use this. // Channels for consensus proposals, used for validating test results. - pub consensus_proposals_channels: BroadcastTopicChannels>, + pub consensus_proposals_channels: + BroadcastTopicChannels>, } impl IntegrationTestSetup { diff --git a/crates/starknet_integration_tests/src/utils.rs b/crates/starknet_integration_tests/src/utils.rs index f05a6793e3b..2b337e1f975 100644 --- a/crates/starknet_integration_tests/src/utils.rs +++ b/crates/starknet_integration_tests/src/utils.rs @@ -13,7 +13,7 @@ use papyrus_network::network_manager::test_utils::{ create_network_configs_connected_to_broadcast_channels, }; use papyrus_network::network_manager::BroadcastTopicChannels; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; use papyrus_storage::StorageConfig; use starknet_api::block::BlockNumber; use starknet_api::core::ChainId; @@ -96,7 +96,10 @@ pub async fn create_config( pub fn create_consensus_manager_configs_and_channels( n_managers: usize, available_ports: &mut AvailablePorts, -) -> (Vec, BroadcastTopicChannels>) { +) -> ( + Vec, + BroadcastTopicChannels>, +) { let (network_configs, broadcast_channels) = create_network_configs_connected_to_broadcast_channels( n_managers, diff --git a/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs b/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs index 6b350485c34..96476636a96 100644 --- a/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs +++ b/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs @@ -5,6 +5,7 @@ use mempool_test_utils::starknet_api_test_utils::MultiAccountTransactionGenerato use papyrus_consensus::types::ValidatorId; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -108,7 +109,9 @@ async fn wait_for_sequencer_node(sequencer: &FlowSequencerSetup) { } async fn listen_to_broadcasted_messages( - consensus_proposals_channels: &mut BroadcastTopicChannels>, + consensus_proposals_channels: &mut BroadcastTopicChannels< + StreamMessage, + >, expected_batched_tx_hashes: &[TransactionHash], expected_height: BlockNumber, expected_content_id: Felt,