From 75292420cba0128749488b434082b7944df283a8 Mon Sep 17 00:00:00 2001 From: Guy Nir Date: Sun, 27 Oct 2024 11:35:14 +0200 Subject: [PATCH] feat(consensus): use generic stream id in StreamHandler --- Cargo.lock | 2 + crates/blockifier/cairo_native | 2 +- crates/papyrus_node/src/run.rs | 7 +- crates/papyrus_protobuf/Cargo.toml | 1 + crates/papyrus_protobuf/src/consensus.rs | 48 +++++- .../src/converters/consensus.rs | 33 ++-- .../src/converters/consensus_test.rs | 5 +- .../src/converters/test_instances.rs | 47 +++++- .../src/proto/p2p/proto/consensus.proto | 2 +- .../sequencing/papyrus_consensus/Cargo.toml | 1 + .../papyrus_consensus/src/stream_handler.rs | 154 ++++++++++++------ .../src/stream_handler_test.rs | 116 +++++++++---- .../src/papyrus_consensus_context.rs | 9 +- .../src/papyrus_consensus_context_test.rs | 6 +- .../src/sequencer_consensus_context.rs | 7 +- .../src/sequencer_consensus_context_test.rs | 3 +- .../src/consensus_manager.rs | 4 +- .../src/flow_test_setup.rs | 10 +- .../tests/end_to_end_flow_test.rs | 5 +- 19 files changed, 339 insertions(+), 123 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28a802c924..8c1f7b638e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7653,6 +7653,7 @@ dependencies = [ "papyrus_protobuf", "papyrus_storage", "papyrus_test_utils", + "prost", "serde", "starknet-types-core", "starknet_api", @@ -7906,6 +7907,7 @@ dependencies = [ name = "papyrus_protobuf" version = "0.0.0" dependencies = [ + "bytes", "indexmap 2.7.0", "lazy_static", "papyrus_common", diff --git a/crates/blockifier/cairo_native b/crates/blockifier/cairo_native index 185e94bce3..76e83965d3 160000 --- a/crates/blockifier/cairo_native +++ b/crates/blockifier/cairo_native @@ -1 +1 @@ -Subproject commit 185e94bce3e380db988f98d96b4ddcb3e6c044bc +Subproject commit 76e83965d3bf1252eb6c68200a3accd5fd1ec004 diff --git a/crates/papyrus_node/src/run.rs b/crates/papyrus_node/src/run.rs index d78b1e3ad2..243c779fae 100644 --- a/crates/papyrus_node/src/run.rs +++ b/crates/papyrus_node/src/run.rs @@ -23,7 +23,7 @@ use papyrus_network::{network_manager, NetworkConfig}; use papyrus_p2p_sync::client::{P2PSyncClient, P2PSyncClientChannels}; use papyrus_p2p_sync::server::{P2PSyncServer, P2PSyncServerChannels}; use papyrus_p2p_sync::{Protocol, BUFFER_SIZE}; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; #[cfg(feature = "rpc")] use papyrus_rpc::run_server; use papyrus_storage::storage_metrics::update_storage_metrics; @@ -192,8 +192,9 @@ fn spawn_consensus( let network_channels = network_manager .register_broadcast_topic(Topic::new(config.network_topic.clone()), BUFFER_SIZE)?; - let proposal_network_channels: BroadcastTopicChannels> = - network_manager.register_broadcast_topic(Topic::new(NETWORK_TOPIC), BUFFER_SIZE)?; + let proposal_network_channels: BroadcastTopicChannels< + StreamMessage, + > = network_manager.register_broadcast_topic(Topic::new(NETWORK_TOPIC), BUFFER_SIZE)?; let BroadcastTopicChannels { broadcasted_messages_receiver: inbound_network_receiver, broadcast_topic_client: outbound_network_sender, diff --git a/crates/papyrus_protobuf/Cargo.toml b/crates/papyrus_protobuf/Cargo.toml index 5d638b2756..569d0be399 100644 --- a/crates/papyrus_protobuf/Cargo.toml +++ b/crates/papyrus_protobuf/Cargo.toml @@ -9,6 +9,7 @@ license-file.workspace = true testing = ["papyrus_test_utils", "rand", "rand_chacha"] [dependencies] +bytes.workspace = true indexmap.workspace = true lazy_static.workspace = true primitive-types.workspace = true diff --git a/crates/papyrus_protobuf/src/consensus.rs b/crates/papyrus_protobuf/src/consensus.rs index f9c7d51f49..a4c6c54a77 100644 --- a/crates/papyrus_protobuf/src/consensus.rs +++ b/crates/papyrus_protobuf/src/consensus.rs @@ -1,3 +1,7 @@ +use std::fmt::Display; + +use bytes::{Buf, BufMut}; +use prost::DecodeError; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; @@ -26,10 +30,13 @@ pub enum StreamMessageBody { Fin, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub struct StreamMessage> + TryFrom, Error = ProtobufConversionError>> { +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct StreamMessage< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + Clone, +> { pub message: StreamMessageBody, - pub stream_id: u64, + pub stream_id: StreamId, pub message_id: u64, } @@ -108,7 +115,7 @@ impl From for ProposalPart { } } -impl std::fmt::Display for StreamMessage +impl> + Clone + Display> std::fmt::Display for StreamMessage where T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, { @@ -131,3 +138,36 @@ where } } } + +/// HeighAndRound is a tuple struct used as the StreamId for consensus and context. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HeightAndRound(pub u64, pub u32); + +impl TryFrom> for HeightAndRound { + type Error = ProtobufConversionError; + + fn try_from(value: Vec) -> Result { + if value.len() != 12 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + } + let mut bytes = value.as_slice(); + let height = bytes.get_u64(); + let round = bytes.get_u32(); + Ok(HeightAndRound(height, round)) + } +} + +impl From for Vec { + fn from(value: HeightAndRound) -> Vec { + let mut bytes = Vec::with_capacity(12); + bytes.put_u64(value.0); + bytes.put_u32(value.1); + bytes + } +} + +impl std::fmt::Display for HeightAndRound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(height: {}, round: {})", self.0, self.1) + } +} diff --git a/crates/papyrus_protobuf/src/converters/consensus.rs b/crates/papyrus_protobuf/src/converters/consensus.rs index 1dad9045b8..1974d70b3a 100644 --- a/crates/papyrus_protobuf/src/converters/consensus.rs +++ b/crates/papyrus_protobuf/src/converters/consensus.rs @@ -1,6 +1,7 @@ #[cfg(test)] #[path = "consensus_test.rs"] mod consensus_test; + use std::convert::{TryFrom, TryInto}; use prost::Message; @@ -79,8 +80,10 @@ impl From for protobuf::Vote { auto_impl_into_and_try_from_vec_u8!(Vote, protobuf::Vote); -impl> + TryFrom, Error = ProtobufConversionError>> - TryFrom for StreamMessage +impl TryFrom for StreamMessage +where + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, { type Error = ProtobufConversionError; @@ -101,16 +104,18 @@ impl> + TryFrom, Error = ProtobufConversionError>> StreamMessageBody::Fin } }, - stream_id: value.stream_id, + stream_id: value.stream_id.try_into()?, message_id: value.message_id, }) } } -impl> + TryFrom, Error = ProtobufConversionError>> From> - for protobuf::StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> From> for protobuf::StreamMessage { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { Self { message: match value { StreamMessage { @@ -122,7 +127,7 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< Some(protobuf::stream_message::Message::Fin(protobuf::Fin {})) } }, - stream_id: value.stream_id, + stream_id: value.stream_id.into(), message_id: value.message_id, } } @@ -131,17 +136,21 @@ impl> + TryFrom, Error = ProtobufConversionError>> From< // Can't use auto_impl_into_and_try_from_vec_u8!(StreamMessage, protobuf::StreamMessage); // because it doesn't seem to work with generics. // TODO(guyn): consider expanding the macro to support generics -impl> + TryFrom, Error = ProtobufConversionError>> From> - for Vec +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> From> for Vec { - fn from(value: StreamMessage) -> Self { + fn from(value: StreamMessage) -> Self { let protobuf_value = ::from(value); protobuf_value.encode_to_vec() } } -impl> + TryFrom, Error = ProtobufConversionError>> TryFrom> - for StreamMessage +impl< + T: Into> + TryFrom, Error = ProtobufConversionError>, + StreamId: Into> + TryFrom, Error = ProtobufConversionError> + Clone, +> TryFrom> for StreamMessage { type Error = ProtobufConversionError; fn try_from(value: Vec) -> Result { diff --git a/crates/papyrus_protobuf/src/converters/consensus_test.rs b/crates/papyrus_protobuf/src/converters/consensus_test.rs index d2c465706d..52fea9b216 100644 --- a/crates/papyrus_protobuf/src/converters/consensus_test.rs +++ b/crates/papyrus_protobuf/src/converters/consensus_test.rs @@ -20,6 +20,7 @@ use crate::consensus::{ TransactionBatch, Vote, }; +use crate::converters::test_instances::TestStreamId; // If all the fields of `AllResources` are 0 upon serialization, // then the deserialized value will be interpreted as the `L1Gas` variant. @@ -50,7 +51,7 @@ fn convert_stream_message_to_vec_u8_and_back() { let mut rng = get_rng(); // Test that we can convert a StreamMessage with a ProposalPart message to bytes and back. - let mut stream_message: StreamMessage = + let mut stream_message: StreamMessage = StreamMessage::get_test_instance(&mut rng); if let StreamMessageBody::Content(ProposalPart::Transactions(proposal)) = @@ -128,7 +129,7 @@ fn convert_proposal_part_to_vec_u8_and_back() { #[test] fn stream_message_display() { let mut rng = get_rng(); - let stream_id = 42; + let stream_id = TestStreamId(42); let message_id = 127; let proposal = ProposalPart::get_test_instance(&mut rng); let proposal_bytes: Vec = proposal.clone().into(); diff --git a/crates/papyrus_protobuf/src/converters/test_instances.rs b/crates/papyrus_protobuf/src/converters/test_instances.rs index be55534a75..38507c17c0 100644 --- a/crates/papyrus_protobuf/src/converters/test_instances.rs +++ b/crates/papyrus_protobuf/src/converters/test_instances.rs @@ -1,9 +1,13 @@ +use std::fmt::Display; + use papyrus_test_utils::{auto_impl_get_test_instance, get_number_of_variants, GetTestInstance}; +use prost::DecodeError; use rand::Rng; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; +use super::ProtobufConversionError; use crate::consensus::{ ProposalFin, ProposalInit, @@ -47,9 +51,48 @@ auto_impl_get_test_instance! { } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TestStreamId(pub u64); + +impl From for Vec { + fn from(value: TestStreamId) -> Self { + value.0.to_be_bytes().to_vec() + } +} + +impl TryFrom> for TestStreamId { + type Error = ProtobufConversionError; + fn try_from(bytes: Vec) -> Result { + if bytes.len() != 8 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + }; + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + Ok(TestStreamId(u64::from_be_bytes(array))) + } +} + +impl PartialOrd for TestStreamId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TestStreamId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl Display for TestStreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestStreamId({})", self.0) + } +} + // The auto_impl_get_test_instance macro does not work for StreamMessage because it has // a generic type. TODO(guyn): try to make the macro work with generic types. -impl GetTestInstance for StreamMessage { +impl GetTestInstance for StreamMessage { fn get_test_instance(rng: &mut rand_chacha::ChaCha8Rng) -> Self { let message = if rng.gen_bool(0.5) { StreamMessageBody::Content(ProposalPart::Transactions(TransactionBatch { @@ -58,6 +101,6 @@ impl GetTestInstance for StreamMessage { } else { StreamMessageBody::Fin }; - Self { message, stream_id: 12, message_id: 47 } + Self { message, stream_id: TestStreamId(12), message_id: 47 } } } diff --git a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto index 155a995755..765798744c 100644 --- a/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto +++ b/crates/papyrus_protobuf/src/proto/p2p/proto/consensus.proto @@ -24,7 +24,7 @@ message StreamMessage { bytes content = 1; Fin fin = 2; } - uint64 stream_id = 3; + bytes stream_id = 3; uint64 message_id = 4; } diff --git a/crates/sequencing/papyrus_consensus/Cargo.toml b/crates/sequencing/papyrus_consensus/Cargo.toml index fb372e2dd4..d31172b915 100644 --- a/crates/sequencing/papyrus_consensus/Cargo.toml +++ b/crates/sequencing/papyrus_consensus/Cargo.toml @@ -20,6 +20,7 @@ papyrus_config.workspace = true papyrus_network.workspace = true papyrus_network_types.workspace = true papyrus_protobuf.workspace = true +prost.workspace = true serde = { workspace = true, features = ["derive"] } starknet-types-core.workspace = true starknet_api.workspace = true diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index 30283b91bf..982d303492 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -3,6 +3,8 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; +use std::fmt::{Debug, Display}; +use std::hash::Hash; use futures::channel::mpsc; use futures::StreamExt; @@ -22,32 +24,65 @@ use tracing::{instrument, warn}; mod stream_handler_test; type PeerId = OpaquePeerId; -type StreamId = u64; type MessageId = u64; -type StreamKey = (PeerId, StreamId); const CHANNEL_BUFFER_LENGTH: usize = 100; +/// A combination of trait bounds needed for the content of the stream. +pub trait StreamContentTrait: + Clone + Into> + TryFrom, Error = ProtobufConversionError> + Send +{ +} +impl StreamContentTrait for StreamContent where + StreamContent: Clone + Into> + TryFrom, Error = ProtobufConversionError> + Send +{ +} +/// A combination of trait bounds needed for the stream ID. +pub trait StreamIdTrait: + Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Display + + Debug + + Send +{ +} +impl StreamIdTrait for StreamId where + StreamId: Into> + + TryFrom, Error = ProtobufConversionError> + + Eq + + Hash + + Clone + + Unpin + + Display + + Debug + + Send +{ +} + // Use this struct for each inbound stream. // Drop the struct when: // (1) receiver on the other end is dropped, // (2) fin message is received and all messages are sent. #[derive(Debug)] -struct StreamData< - T: Clone + Into> + TryFrom, Error = ProtobufConversionError> + 'static, -> { +struct StreamData { next_message_id: MessageId, // Last message ID. If None, it means we have not yet gotten to it. fin_message_id: Option, max_message_id_received: MessageId, // Keep the receiver until it is time to send it to the application. - receiver: Option>, - sender: mpsc::Sender, + receiver: Option>, + sender: mpsc::Sender, // A buffer for messages that were received out of order. - message_buffer: HashMap>, + message_buffer: HashMap>, } -impl> + TryFrom, Error = ProtobufConversionError>> StreamData { +impl + StreamData +{ fn new() -> Self { let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_LENGTH); StreamData { @@ -64,39 +99,37 @@ impl> + TryFrom, Error = ProtobufConversionError /// A StreamHandler is responsible for: /// - Buffering inbound messages and reporting them to the application in order. /// - Sending outbound messages to the network, wrapped in StreamMessage. -pub struct StreamHandler< - T: Clone + Into> + TryFrom, Error = ProtobufConversionError> + 'static, -> { +pub struct StreamHandler { // For each stream ID from the network, send the application a Receiver // that will receive the messages in order. This allows sending such Receivers. - inbound_channel_sender: mpsc::Sender>, + inbound_channel_sender: mpsc::Sender>, // This receives messages from the network. - inbound_receiver: BroadcastTopicServer>, + inbound_receiver: BroadcastTopicServer>, // A map from (peer_id, stream_id) to a struct that contains all the information // about the stream. This includes both the message buffer and some metadata // (like the latest message ID). - inbound_stream_data: HashMap>, + inbound_stream_data: HashMap<(PeerId, StreamId), StreamData>, // Whenever application wants to start a new stream, it must send out a // (stream_id, Receiver) pair. Each receiver gets messages that should // be sent out to the network. - outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, // A map where the abovementioned Receivers are stored. - outbound_stream_receivers: StreamHashMap>, + outbound_stream_receivers: StreamHashMap>, // A network sender that allows sending StreamMessages to peers. - outbound_sender: BroadcastTopicClient>, + outbound_sender: BroadcastTopicClient>, // For each stream, keep track of the message_id of the last message sent. outbound_stream_number: HashMap, } -impl> + TryFrom, Error = ProtobufConversionError>> - StreamHandler +impl + StreamHandler { /// Create a new StreamHandler. pub fn new( - inbound_channel_sender: mpsc::Sender>, - inbound_receiver: BroadcastTopicServer>, - outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, - outbound_sender: BroadcastTopicClient>, + inbound_channel_sender: mpsc::Sender>, + inbound_receiver: BroadcastTopicServer>, + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + outbound_sender: BroadcastTopicClient>, ) -> Self { Self { inbound_channel_sender, @@ -113,30 +146,30 @@ impl> + TryFrom, Error = ProtobufConversi /// Gets network input/output channels and returns application input/output channels. #[allow(clippy::type_complexity)] pub fn get_channels( - inbound_network_receiver: BroadcastTopicServer>, - outbound_network_sender: BroadcastTopicClient>, + inbound_network_receiver: BroadcastTopicServer>, + outbound_network_sender: BroadcastTopicClient>, ) -> ( - mpsc::Sender<(StreamId, mpsc::Receiver)>, - mpsc::Receiver>, + mpsc::Sender<(StreamId, mpsc::Receiver)>, + mpsc::Receiver>, tokio::task::JoinHandle<()>, - ) { + ) + where + StreamContent: 'static, + StreamId: 'static, + { // The inbound messages come into StreamHandler via inbound_network_receiver. // The application gets the messages from inbound_internal_receiver // (the StreamHandler keeps the inbound_internal_sender to pass the messages). - let (inbound_internal_sender, inbound_internal_receiver): ( - mpsc::Sender>, - mpsc::Receiver>, - ) = mpsc::channel(CHANNEL_BUFFER_LENGTH); + let (inbound_internal_sender, inbound_internal_receiver) = + mpsc::channel(CHANNEL_BUFFER_LENGTH); // The outbound messages that an application would like to send are: // 1. Sent into outbound_internal_sender as tuples of (StreamId, Receiver) // 2. Ingested by StreamHandler by its outbound_internal_receiver. // 3. Broadcast by the StreamHandler using its outbound_network_sender. - let (outbound_internal_sender, outbound_internal_receiver): ( - mpsc::Sender<(StreamId, mpsc::Receiver)>, - mpsc::Receiver<(StreamId, mpsc::Receiver)>, - ) = mpsc::channel(CHANNEL_BUFFER_LENGTH); + let (outbound_internal_sender, outbound_internal_receiver) = + mpsc::channel(CHANNEL_BUFFER_LENGTH); - let mut stream_handler = StreamHandler::::new( + let mut stream_handler = StreamHandler::::new( inbound_internal_sender, // Sender>, inbound_network_receiver, // BroadcastTopicServer>, outbound_internal_receiver, // Receiver<(StreamId, Receiver)>, @@ -186,8 +219,11 @@ impl> + TryFrom, Error = ProtobufConversi } } - // Returns true if the receiver for this stream is dropped. - fn inbound_send(&mut self, data: &mut StreamData, message: StreamMessage) -> bool { + fn inbound_send( + &mut self, + data: &mut StreamData, + message: StreamMessage, + ) -> bool { // TODO(guyn): reconsider the "expect" here. let sender = &mut data.sender; if let StreamMessageBody::Content(content) = message.message { @@ -229,23 +265,28 @@ impl> + TryFrom, Error = ProtobufConversi } // Send the message to the network. - async fn broadcast(&mut self, stream_id: StreamId, message: T) { + async fn broadcast(&mut self, stream_id: StreamId, message: StreamContent) { + // TODO(guyn): add a random nonce to the outbound stream ID, + // such that even if the client sends the same stream ID, + // (e.g., after a crash) this will be treated as a new stream. let message = StreamMessage { message: StreamMessageBody::Content(message), - stream_id, + stream_id: stream_id.clone(), message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), }; // TODO(guyn): reconsider the "expect" here. self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); - self.outbound_stream_number - .insert(stream_id, self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1); + self.outbound_stream_number.insert( + stream_id.clone(), + self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1, + ); } // Send a fin message to the network. async fn broadcast_fin(&mut self, stream_id: StreamId) { let message = StreamMessage { message: StreamMessageBody::Fin, - stream_id, + stream_id: stream_id.clone(), message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), }; self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); @@ -256,7 +297,10 @@ impl> + TryFrom, Error = ProtobufConversi #[instrument(skip_all, level = "warn")] fn handle_message( &mut self, - message: (Result, ProtobufConversionError>, BroadcastedMessageMetadata), + message: ( + Result, ProtobufConversionError>, + BroadcastedMessageMetadata, + ), ) { let (message, metadata) = message; let message = match message { @@ -268,7 +312,7 @@ impl> + TryFrom, Error = ProtobufConversi }; let peer_id = metadata.originator_id.clone(); - let stream_id = message.stream_id; + let stream_id = message.stream_id.clone(); let key = (peer_id, stream_id); let data = match self.inbound_stream_data.entry(key.clone()) { @@ -289,12 +333,12 @@ impl> + TryFrom, Error = ProtobufConversi /// should be dropped. fn handle_message_inner( &mut self, - message: StreamMessage, + message: StreamMessage, metadata: BroadcastedMessageMetadata, - mut data: StreamData, - ) -> Option> { + mut data: StreamData, + ) -> Option> { let peer_id = metadata.originator_id; - let stream_id = message.stream_id; + let stream_id = message.stream_id.clone(); let key = (peer_id, stream_id); let message_id = message.message_id; @@ -367,7 +411,11 @@ impl> + TryFrom, Error = ProtobufConversi } // Store an inbound message in the buffer. - fn store(data: &mut StreamData, key: StreamKey, message: StreamMessage) { + fn store( + data: &mut StreamData, + key: (PeerId, StreamId), + message: StreamMessage, + ) { let message_id = message.message_id; match data.message_buffer.entry(message_id) { @@ -387,7 +435,7 @@ impl> + TryFrom, Error = ProtobufConversi // Tries to drain as many messages as possible from the buffer (in order), // DOES NOT guarantee that the buffer will be empty after calling this function. // Returns true if the receiver for this stream is dropped. - fn process_buffer(&mut self, data: &mut StreamData) -> bool { + fn process_buffer(&mut self, data: &mut StreamData) -> bool { while let Some(message) = data.message_buffer.remove(&data.next_message_id) { if self.inbound_send(data, message) { return true; diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 2f00140462..89b4f5c157 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::time::Duration; use futures::channel::mpsc; @@ -11,13 +12,54 @@ use papyrus_network::network_manager::test_utils::{ use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_network_types::network_types::BroadcastedMessageMetadata; use papyrus_protobuf::consensus::{StreamMessage, StreamMessageBody}; +use papyrus_protobuf::converters::ProtobufConversionError; use papyrus_test_utils::{get_rng, GetTestInstance}; +use prost::DecodeError; -use super::{MessageId, StreamHandler, StreamId}; +use super::{MessageId, StreamHandler}; const TIMEOUT: Duration = Duration::from_millis(100); const CHANNEL_SIZE: usize = 100; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct TestStreamId(u64); + +impl From for Vec { + fn from(value: TestStreamId) -> Self { + value.0.to_be_bytes().to_vec() + } +} + +impl TryFrom> for TestStreamId { + type Error = ProtobufConversionError; + fn try_from(bytes: Vec) -> Result { + if bytes.len() != 8 { + return Err(ProtobufConversionError::DecodeError(DecodeError::new("Invalid length"))); + } + let mut array = [0; 8]; + array.copy_from_slice(&bytes); + Ok(TestStreamId(u64::from_be_bytes(array))) + } +} + +impl PartialOrd for TestStreamId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TestStreamId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl Display for TestStreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestStreamId({})", self.0) + } +} + #[cfg(test)] mod tests { use papyrus_protobuf::consensus::{ProposalInit, ProposalPart}; @@ -25,10 +67,10 @@ mod tests { use super::*; fn make_test_message( - stream_id: StreamId, + stream_id: TestStreamId, message_id: MessageId, fin: bool, - ) -> StreamMessage { + ) -> StreamMessage { let content = match fin { true => StreamMessageBody::Fin, false => StreamMessageBody::Content(ProposalPart::Init(ProposalInit::default())), @@ -49,21 +91,24 @@ mod tests { } async fn send( - sender: &mut MockBroadcastedMessagesSender>, + sender: &mut MockBroadcastedMessagesSender>, metadata: &BroadcastedMessageMetadata, - msg: StreamMessage, + msg: StreamMessage, ) { sender.send((msg, metadata.clone())).await.unwrap(); } #[allow(clippy::type_complexity)] fn setup_test() -> ( - StreamHandler, - MockBroadcastedMessagesSender>, + StreamHandler, + MockBroadcastedMessagesSender>, mpsc::Receiver>, BroadcastedMessageMetadata, - mpsc::Sender<(StreamId, mpsc::Receiver)>, - futures::stream::Map>, fn(Vec) -> StreamMessage>, + mpsc::Sender<(TestStreamId, mpsc::Receiver)>, + futures::stream::Map< + mpsc::Receiver>, + fn(Vec) -> StreamMessage, + >, ) { // The outbound_sender is the network connector for broadcasting messages. // The network_broadcast_receiver is used to catch those messages in the test. @@ -80,7 +125,7 @@ mod tests { // The receiver goes into StreamHandler, sender is used by the test (as mock Consensus). // Note that each new channel comes in a tuple with (stream_id, receiver). let (outbound_channel_sender, outbound_channel_receiver) = - mpsc::channel::<(StreamId, mpsc::Receiver)>(CHANNEL_SIZE); + mpsc::channel::<(TestStreamId, mpsc::Receiver)>(CHANNEL_SIZE); // The network_sender_to_inbound is the sender of the mock network, that is used by the // test to send messages into the StreamHandler (from the mock network). @@ -126,7 +171,7 @@ mod tests { let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = setup_test(); - let stream_id = 127; + let stream_id = TestStreamId(127); for i in 0..10 { let message = make_test_message(stream_id, i, i == 9); send(&mut network_sender, &metadata, message).await; @@ -158,7 +203,7 @@ mod tests { _, ) = setup_test(); let peer_id = inbound_metadata.originator_id.clone(); - let stream_id = 127; + let stream_id = TestStreamId(127); for i in 0..5 { let message = make_test_message(stream_id, 5 - i, i == 0); @@ -233,9 +278,9 @@ mod tests { ) = setup_test(); let peer_id = inbound_metadata.originator_id.clone(); - let stream_id1 = 127; // Send all messages in order (except the first one). - let stream_id2 = 10; // Send in reverse order (except the first one). - let stream_id3 = 1; // Send in two batches, without the first one, don't send fin. + let stream_id1 = TestStreamId(127); // Send all messages in order (except the first one). + let stream_id2 = TestStreamId(10); // Send in reverse order (except the first one). + let stream_id3 = TestStreamId(1); // Send in two batches, without the first one, don't send fin. for i in 1..10 { let message = make_test_message(stream_id1, i, i == 9); @@ -264,8 +309,14 @@ mod tests { }); let mut stream_handler = join_handle.await.expect("Task should succeed"); - let values = [(peer_id.clone(), 1), (peer_id.clone(), 10), (peer_id.clone(), 127)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [ + (peer_id.clone(), TestStreamId(1)), + (peer_id.clone(), TestStreamId(10)), + (peer_id.clone(), TestStreamId(127)), + ]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // We have all message from 1 to 9 buffered. assert!(do_vecs_match_unordered( @@ -320,8 +371,10 @@ mod tests { let mut stream_handler = join_handle.await.expect("Task should succeed"); // stream_id1 should be gone - let values = [(peer_id.clone(), 1), (peer_id.clone(), 10)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1)), (peer_id.clone(), TestStreamId(10))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // Send the last message on stream_id2: send(&mut network_sender, &inbound_metadata, make_test_message(stream_id2, 0, false)).await; @@ -344,8 +397,10 @@ mod tests { let mut stream_handler = join_handle.await.expect("Task should succeed"); // Stream_id2 should also be gone. - let values = [(peer_id.clone(), 1)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // Send the last message on stream_id3: send(&mut network_sender, &inbound_metadata, make_test_message(stream_id3, 0, false)).await; @@ -366,8 +421,10 @@ mod tests { } // Stream_id3 should still be there, because we didn't send a fin. - let values = [(peer_id.clone(), 1)]; - assert!(stream_handler.inbound_stream_data.keys().all(|item| values.contains(item))); + let values = [(peer_id.clone(), TestStreamId(1))]; + assert!( + stream_handler.inbound_stream_data.keys().to_owned().all(|item| values.contains(item)) + ); // But the buffer should be empty, as we've successfully drained it all. assert!( @@ -380,7 +437,7 @@ mod tests { let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = setup_test(); - let stream_id = 127; + let stream_id = TestStreamId(127); // Send two messages, no Fin. for i in 0..2 { let message = make_test_message(stream_id, i, false); @@ -442,8 +499,8 @@ mod tests { mut broadcasted_messages_receiver, ) = setup_test(); - let stream_id1: StreamId = 42; - let stream_id2: StreamId = 127; + let stream_id1 = TestStreamId(42); + let stream_id2 = TestStreamId(127); // Start a new stream by sending the (stream_id, receiver). let (mut sender1, receiver1) = mpsc::channel(CHANNEL_SIZE); @@ -470,7 +527,7 @@ mod tests { // Check that internally, stream_handler holds this receiver. assert_eq!( - stream_handler.outbound_stream_receivers.keys().collect::>(), + stream_handler.outbound_stream_receivers.keys().collect::>(), vec![&stream_id1] ); // Check that the number of messages sent on this stream is 1. @@ -520,7 +577,8 @@ mod tests { assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message3)); assert_eq!(broadcasted_message.stream_id, stream_id2); assert_eq!(broadcasted_message.message_id, 0); - let mut vec1 = stream_handler.outbound_stream_receivers.keys().collect::>(); + let mut vec1 = + stream_handler.outbound_stream_receivers.keys().collect::>(); vec1.sort(); let mut vec2 = vec![&stream_id1, &stream_id2]; vec2.sort(); @@ -544,7 +602,7 @@ mod tests { // Check that the information about this stream is gone. assert_eq!( - stream_handler.outbound_stream_receivers.keys().collect::>(), + stream_handler.outbound_stream_receivers.keys().collect::>(), vec![&stream_id2] ); } diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs index 3bff3f3758..2e0c127b3a 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context.rs @@ -24,6 +24,7 @@ use papyrus_consensus::types::{ }; use papyrus_network::network_manager::{BroadcastTopicClient, BroadcastTopicClientTrait}; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -47,7 +48,7 @@ const CHANNEL_SIZE: usize = 100; pub struct PapyrusConsensusContext { storage_reader: StorageReader, network_broadcast_client: BroadcastTopicClient, - network_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + network_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, validators: Vec, sync_broadcast_sender: Option>, // Proposal building/validating returns immediately, leaving the actual processing to a spawned @@ -60,7 +61,7 @@ impl PapyrusConsensusContext { pub fn new( storage_reader: StorageReader, network_broadcast_client: BroadcastTopicClient, - network_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + network_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, num_validators: u64, sync_broadcast_sender: Option>, ) -> Self { @@ -117,7 +118,7 @@ impl ConsensusContext for PapyrusConsensusContext { .block_hash; let (mut proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); - let stream_id = height.0; + let stream_id = HeightAndRound(proposal_init.height.0, proposal_init.round); network_proposal_sender .send((stream_id, proposal_receiver)) .await @@ -248,7 +249,7 @@ impl ConsensusContext for PapyrusConsensusContext { .unwrap_or_else(|| panic!("No proposal found for height {height} and id {id}")) .clone(); - let stream_id = height.0; + let stream_id = HeightAndRound(height.0, init.round); let (mut proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); self.network_proposal_sender .send((stream_id, proposal_receiver)) diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs index a22d92b0ed..0bdb8ccc20 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/papyrus_consensus_context_test.rs @@ -11,6 +11,7 @@ use papyrus_network::network_manager::test_utils::{ }; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -128,8 +129,9 @@ fn test_setup() .unwrap(); let network_channels = mock_register_broadcast_topic().unwrap(); - let network_proposal_channels: TestSubscriberChannels> = - mock_register_broadcast_topic().unwrap(); + let network_proposal_channels: TestSubscriberChannels< + StreamMessage, + > = mock_register_broadcast_topic().unwrap(); let BroadcastTopicChannels { broadcasted_messages_receiver: inbound_network_receiver, broadcast_topic_client: outbound_network_sender, diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs index 2e07a3b524..eb495ede2e 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs @@ -21,6 +21,7 @@ use papyrus_consensus::types::{ }; use papyrus_network::network_manager::{BroadcastTopicClient, BroadcastTopicClientTrait}; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -130,7 +131,7 @@ pub struct SequencerConsensusContext { // Stores proposals for future rounds until the round is reached. queued_proposals: BTreeMap)>, - outbound_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + outbound_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, // Used to broadcast votes to other consensus nodes. vote_broadcast_client: BroadcastTopicClient, // Used to convert Transaction to ExecutableTransaction. @@ -145,7 +146,7 @@ impl SequencerConsensusContext { pub fn new( state_sync_client: SharedStateSyncClient, batcher: Arc, - outbound_proposal_sender: mpsc::Sender<(u64, mpsc::Receiver)>, + outbound_proposal_sender: mpsc::Sender<(HeightAndRound, mpsc::Receiver)>, vote_broadcast_client: BroadcastTopicClient, num_validators: u64, chain_id: ChainId, @@ -208,7 +209,7 @@ impl ConsensusContext for SequencerConsensusContext { self.proposal_id += 1; assert!(timeout > BUILD_PROPOSAL_MARGIN); let (proposal_sender, proposal_receiver) = mpsc::channel(CHANNEL_SIZE); - let stream_id = proposal_init.height.0; + let stream_id = HeightAndRound(proposal_init.height.0, proposal_init.round); self.outbound_proposal_sender .send((stream_id, proposal_receiver)) .await diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs index 36b75d4be2..84b7101721 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context_test.rs @@ -16,6 +16,7 @@ use papyrus_network::network_manager::test_utils::{ }; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -71,7 +72,7 @@ fn generate_invoke_tx(nonce: u8) -> Transaction { // Structs which aren't utilized but should not be dropped. struct NetworkDependencies { _vote_network: BroadcastNetworkMock, - _new_proposal_network: BroadcastNetworkMock>, + _new_proposal_network: BroadcastNetworkMock>, } fn setup( diff --git a/crates/starknet_consensus_manager/src/consensus_manager.rs b/crates/starknet_consensus_manager/src/consensus_manager.rs index 6951b1bdc4..2c21e6e230 100644 --- a/crates/starknet_consensus_manager/src/consensus_manager.rs +++ b/crates/starknet_consensus_manager/src/consensus_manager.rs @@ -7,7 +7,7 @@ use papyrus_consensus_orchestrator::cende::CendeAmbassador; use papyrus_consensus_orchestrator::sequencer_consensus_context::SequencerConsensusContext; use papyrus_network::gossipsub_impl::Topic; use papyrus_network::network_manager::{BroadcastTopicChannels, NetworkManager}; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage, Vote}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage, Vote}; use starknet_api::block::BlockNumber; use starknet_batcher_types::communication::SharedBatcherClient; use starknet_infra_utils::type_name::short_type_name; @@ -44,7 +44,7 @@ impl ConsensusManager { NetworkManager::new(self.config.consensus_config.network_config.clone(), None); let proposals_broadcast_channels = network_manager - .register_broadcast_topic::>( + .register_broadcast_topic::>( Topic::new(CONSENSUS_PROPOSALS_TOPIC), BROADCAST_BUFFER_SIZE, ) diff --git a/crates/starknet_integration_tests/src/flow_test_setup.rs b/crates/starknet_integration_tests/src/flow_test_setup.rs index 9dc1053f58..152b78d238 100644 --- a/crates/starknet_integration_tests/src/flow_test_setup.rs +++ b/crates/starknet_integration_tests/src/flow_test_setup.rs @@ -7,7 +7,7 @@ use mempool_test_utils::starknet_api_test_utils::{ }; use papyrus_network::network_manager::test_utils::create_network_configs_connected_to_broadcast_channels; use papyrus_network::network_manager::BroadcastTopicChannels; -use papyrus_protobuf::consensus::{ProposalPart, StreamMessage}; +use papyrus_protobuf::consensus::{HeightAndRound, ProposalPart, StreamMessage}; use starknet_api::rpc_transaction::RpcTransaction; use starknet_api::transaction::TransactionHash; use starknet_consensus_manager::config::ConsensusManagerConfig; @@ -44,7 +44,8 @@ pub struct FlowTestSetup { pub sequencer_1: FlowSequencerSetup, // Channels for consensus proposals, used for asserting the right transactions are proposed. - pub consensus_proposals_channels: BroadcastTopicChannels>, + pub consensus_proposals_channels: + BroadcastTopicChannels>, } impl FlowTestSetup { @@ -175,7 +176,10 @@ impl FlowSequencerSetup { pub fn create_consensus_manager_configs_and_channels( ports: Vec, -) -> (Vec, BroadcastTopicChannels>) { +) -> ( + Vec, + BroadcastTopicChannels>, +) { let (network_configs, broadcast_channels) = create_network_configs_connected_to_broadcast_channels( papyrus_network::gossipsub_impl::Topic::new( diff --git a/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs b/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs index b583e1b5d4..9b10a33c97 100644 --- a/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs +++ b/crates/starknet_integration_tests/tests/end_to_end_flow_test.rs @@ -5,6 +5,7 @@ use mempool_test_utils::starknet_api_test_utils::MultiAccountTransactionGenerato use papyrus_consensus::types::ValidatorId; use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ + HeightAndRound, ProposalFin, ProposalInit, ProposalPart, @@ -149,7 +150,9 @@ async fn wait_for_sequencer_node(sequencer: &FlowSequencerSetup) { } async fn listen_to_broadcasted_messages( - consensus_proposals_channels: &mut BroadcastTopicChannels>, + consensus_proposals_channels: &mut BroadcastTopicChannels< + StreamMessage, + >, expected_batched_tx_hashes: &[TransactionHash], expected_height: BlockNumber, expected_content_id: Felt,