Skip to content

Commit

Permalink
fix: update StreamHandler to use new StreamHashMap interface
Browse files Browse the repository at this point in the history
  • Loading branch information
guy-starkware committed Nov 3, 2024
1 parent b5920a1 commit 506ed06
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 60 deletions.
11 changes: 8 additions & 3 deletions crates/papyrus_protobuf/src/consensus.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Display;

use futures::channel::{mpsc, oneshot};
use starknet_api::block::{BlockHash, BlockNumber};
use starknet_api::core::ContractAddress;
Expand Down Expand Up @@ -54,9 +56,12 @@ pub enum StreamMessageBody<T> {
}

#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct StreamMessage<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> {
pub struct StreamMessage<
T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + Clone,
> {
pub message: StreamMessageBody<T>,
pub stream_id: u64,
pub stream_id: StreamId,
pub message_id: u64,
}

Expand Down Expand Up @@ -99,7 +104,7 @@ pub enum ProposalPart {
Fin(ProposalFin),
}

impl<T> std::fmt::Display for StreamMessage<T>
impl<T, StreamId: Into<Vec<u8>> + Clone + Display> std::fmt::Display for StreamMessage<T, StreamId>
where
T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
{
Expand Down
33 changes: 21 additions & 12 deletions crates/papyrus_protobuf/src/converters/consensus.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(test)]
#[path = "consensus_test.rs"]
mod consensus_test;

use std::convert::{TryFrom, TryInto};

use prost::Message;
Expand Down Expand Up @@ -125,8 +126,10 @@ impl From<Vote> for protobuf::Vote {

auto_impl_into_and_try_from_vec_u8!(Vote, protobuf::Vote);

impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>>
TryFrom<protobuf::StreamMessage> for StreamMessage<T>
impl<
T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Clone,
> TryFrom<protobuf::StreamMessage> for StreamMessage<T, StreamId>
{
type Error = ProtobufConversionError;

Expand All @@ -147,16 +150,18 @@ impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>>
StreamMessageBody::Fin
}
},
stream_id: value.stream_id,
stream_id: value.stream_id.into(),
message_id: value.message_id,
})
}
}

impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> From<StreamMessage<T>>
for protobuf::StreamMessage
impl<
T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Clone,
> From<StreamMessage<T, StreamId>> for protobuf::StreamMessage
{
fn from(value: StreamMessage<T>) -> Self {
fn from(value: StreamMessage<T, StreamId>) -> Self {
Self {
message: match value {
StreamMessage {
Expand All @@ -168,7 +173,7 @@ impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> From<
Some(protobuf::stream_message::Message::Fin(protobuf::Fin {}))
}
},
stream_id: value.stream_id,
stream_id: value.stream_id.into(),
message_id: value.message_id,
}
}
Expand All @@ -177,17 +182,21 @@ impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> From<
// Can't use auto_impl_into_and_try_from_vec_u8!(StreamMessage, protobuf::StreamMessage);
// because it doesn't seem to work with generics.
// TODO(guyn): consider expanding the macro to support generics
impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> From<StreamMessage<T>>
for Vec<u8>
impl<
T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Clone,
> From<StreamMessage<T, StreamId>> for Vec<u8>
{
fn from(value: StreamMessage<T>) -> Self {
fn from(value: StreamMessage<T, StreamId>) -> Self {
let protobuf_value = <protobuf::StreamMessage>::from(value);
protobuf_value.encode_to_vec()
}
}

impl<T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> TryFrom<Vec<u8>>
for StreamMessage<T>
impl<
T: Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Clone,
> TryFrom<Vec<u8>> for StreamMessage<T, StreamId>
{
type Error = ProtobufConversionError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Expand Down
5 changes: 3 additions & 2 deletions crates/papyrus_protobuf/src/converters/consensus_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::consensus::{
TransactionBatch,
Vote,
};
use crate::converters::test_instances::StreamId;

// If all the fields of `AllResources` are 0 upon serialization,
// then the deserialized value will be interpreted as the `L1Gas` variant.
Expand Down Expand Up @@ -52,7 +53,7 @@ fn convert_stream_message_to_vec_u8_and_back() {
let mut rng = get_rng();

// Test that we can convert a StreamMessage with a ConsensusMessage message to bytes and back.
let mut stream_message: StreamMessage<ConsensusMessage> =
let mut stream_message: StreamMessage<ConsensusMessage, StreamId> =
StreamMessage::get_test_instance(&mut rng);

if let StreamMessageBody::Content(ConsensusMessage::Proposal(proposal)) =
Expand Down Expand Up @@ -159,7 +160,7 @@ fn convert_proposal_part_to_vec_u8_and_back() {
#[test]
fn stream_message_display() {
let mut rng = get_rng();
let stream_id = 42;
let stream_id = StreamId(42);
let message_id = 127;
let proposal = Proposal::get_test_instance(&mut rng);
let proposal_bytes: Vec<u8> = proposal.clone().into();
Expand Down
29 changes: 27 additions & 2 deletions crates/papyrus_protobuf/src/converters/test_instances.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Display;

use papyrus_test_utils::{auto_impl_get_test_instance, get_number_of_variants, GetTestInstance};
use rand::Rng;
use starknet_api::block::BlockHash;
Expand Down Expand Up @@ -61,15 +63,38 @@ auto_impl_get_test_instance! {

}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamId(pub u64);

impl Into<Vec<u8>> for StreamId {
fn into(self) -> Vec<u8> {
self.0.to_be_bytes().to_vec()
}
}

impl From<Vec<u8>> for StreamId {
fn from(bytes: Vec<u8>) -> Self {
let mut array = [0; 8];
array.copy_from_slice(&bytes);
StreamId(u64::from_be_bytes(array))
}
}

impl Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "StreamId({})", self.0)
}
}

// The auto_impl_get_test_instance macro does not work for StreamMessage because it has
// a generic type. TODO(guyn): try to make the macro work with generic types.
impl GetTestInstance for StreamMessage<ConsensusMessage> {
impl GetTestInstance for StreamMessage<ConsensusMessage, StreamId> {
fn get_test_instance(rng: &mut rand_chacha::ChaCha8Rng) -> Self {
let message = if rng.gen_bool(0.5) {
StreamMessageBody::Content(ConsensusMessage::Proposal(Proposal::get_test_instance(rng)))
} else {
StreamMessageBody::Fin
};
Self { message, stream_id: 12, message_id: 47 }
Self { message, stream_id: StreamId(12), message_id: 47 }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ message StreamMessage {
bytes content = 1;
Fin fin = 2;
}
uint64 stream_id = 3;
bytes stream_id = 3;
uint64 message_id = 4;
}

Expand Down
63 changes: 42 additions & 21 deletions crates/sequencing/papyrus_consensus/src/stream_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::cmp::Ordering;
use std::collections::btree_map::Entry as BTreeEntry;
use std::collections::hash_map::Entry as HashMapEntry;
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Display};
use std::hash::Hash;

use futures::channel::mpsc;
use futures::StreamExt;
Expand All @@ -23,24 +25,31 @@ use tracing::{instrument, warn};
mod stream_handler_test;

type PeerId = OpaquePeerId;
type StreamId = u64;
// type StreamId = u64;
type MessageId = u64;
type StreamKey = (PeerId, StreamId);
// type StreamKey = (PeerId, StreamId);

const CHANNEL_BUFFER_LENGTH: usize = 100;

#[derive(Debug, Clone)]
struct StreamData<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> {
struct StreamData<
T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Eq + Hash + Clone + Unpin + Display + Debug,
> {
next_message_id: MessageId,
// Last message ID. If None, it means we have not yet gotten to it.
fin_message_id: Option<MessageId>,
max_message_id_received: MessageId,
sender: mpsc::Sender<T>,
// A buffer for messages that were received out of order.
message_buffer: BTreeMap<MessageId, StreamMessage<T>>,
message_buffer: BTreeMap<MessageId, StreamMessage<T, StreamId>>,
}

impl<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> StreamData<T> {
impl<
T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Eq + Hash + Clone + Unpin + Display + Debug,
> StreamData<T, StreamId>
{
fn new(sender: mpsc::Sender<T>) -> Self {
StreamData {
next_message_id: 0,
Expand All @@ -57,37 +66,40 @@ impl<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError
/// - Sending outbound messages to the network, wrapped in StreamMessage.
pub struct StreamHandler<
T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Eq + Hash + Clone + Unpin + Display + Debug,
> {
// For each stream ID from the network, send the application a Receiver
// that will receive the messages in order. This allows sending such Receivers.
inbound_channel_sender: mpsc::Sender<mpsc::Receiver<T>>,
// This receives messages from the network.
inbound_receiver: BroadcastTopicServer<StreamMessage<T>>,
inbound_receiver: BroadcastTopicServer<StreamMessage<T, StreamId>>,
// A map from (peer_id, stream_id) to a struct that contains all the information
// about the stream. This includes both the message buffer and some metadata
// (like the latest message ID).
inbound_stream_data: HashMap<StreamKey, StreamData<T>>,
inbound_stream_data: HashMap<(PeerId, StreamId), StreamData<T, StreamId>>,
// Whenever application wants to start a new stream, it must send out a
// (stream_id, Receiver) pair. Each receiver gets messages that should
// be sent out to the network.
outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver<T>)>,
// A map where the abovementioned Receivers are stored.
outbound_stream_receivers: StreamHashMap<StreamId, mpsc::Receiver<T>>,
// A network sender that allows sending StreamMessages to peers.
outbound_sender: BroadcastTopicClient<StreamMessage<T>>,
outbound_sender: BroadcastTopicClient<StreamMessage<T, StreamId>>,
// For each stream, keep track of the message_id of the last message sent.
outbound_stream_number: HashMap<StreamId, MessageId>,
}

impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>>
StreamHandler<T>
impl<
T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>,
StreamId: Into<Vec<u8>> + From<Vec<u8>> + Eq + Hash + Clone + Unpin + Send + Display + Debug,
> StreamHandler<T, StreamId>
{
/// Create a new StreamHandler.
pub fn new(
inbound_channel_sender: mpsc::Sender<mpsc::Receiver<T>>,
inbound_receiver: BroadcastTopicServer<StreamMessage<T>>,
inbound_receiver: BroadcastTopicServer<StreamMessage<T, StreamId>>,
outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver<T>)>,
outbound_sender: BroadcastTopicClient<StreamMessage<T>>,
outbound_sender: BroadcastTopicClient<StreamMessage<T, StreamId>>,
) -> Self {
Self {
inbound_channel_sender,
Expand Down Expand Up @@ -137,7 +149,7 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi
}
}

fn inbound_send(data: &mut StreamData<T>, message: StreamMessage<T>) {
fn inbound_send(data: &mut StreamData<T, StreamId>, message: StreamMessage<T, StreamId>) {
// TODO(guyn): reconsider the "expect" here.
let sender = &mut data.sender;
if let StreamMessageBody::Content(content) = message.message {
Expand All @@ -150,20 +162,22 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi
async fn broadcast(&mut self, stream_id: StreamId, message: T) {
let message = StreamMessage {
message: StreamMessageBody::Content(message),
stream_id,
stream_id: stream_id.clone(),
message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0),
};
// TODO(guyn): reconsider the "expect" here.
self.outbound_sender.broadcast_message(message).await.expect("Send should succeed");
self.outbound_stream_number
.insert(stream_id, self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1);
self.outbound_stream_number.insert(
stream_id.clone(),
self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1,
);
}

// Send a fin message to the network.
async fn broadcast_fin(&mut self, stream_id: StreamId) {
let message = StreamMessage {
message: StreamMessageBody::Fin,
stream_id,
stream_id: stream_id.clone(),
message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0),
};
self.outbound_sender.broadcast_message(message).await.expect("Send should succeed");
Expand All @@ -174,7 +188,10 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi
#[instrument(skip_all, level = "warn")]
fn handle_message(
&mut self,
message: (Result<StreamMessage<T>, ProtobufConversionError>, BroadcastedMessageMetadata),
message: (
Result<StreamMessage<T, StreamId>, ProtobufConversionError>,
BroadcastedMessageMetadata,
),
) {
let (message, metadata) = message;
let message = match message {
Expand All @@ -185,7 +202,7 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi
}
};
let peer_id = metadata.originator_id;
let stream_id = message.stream_id;
let stream_id = message.stream_id.clone();
let key = (peer_id, stream_id);
let message_id = message.message_id;

Expand Down Expand Up @@ -263,7 +280,11 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi
}

// Store an inbound message in the buffer.
fn store(data: &mut StreamData<T>, key: StreamKey, message: StreamMessage<T>) {
fn store(
data: &mut StreamData<T, StreamId>,
key: (PeerId, StreamId),
message: StreamMessage<T, StreamId>,
) {
let message_id = message.message_id;

match data.message_buffer.entry(message_id) {
Expand All @@ -282,7 +303,7 @@ impl<T: Clone + Send + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversi

// Tries to drain as many messages as possible from the buffer (in order),
// DOES NOT guarantee that the buffer will be empty after calling this function.
fn process_buffer(data: &mut StreamData<T>) {
fn process_buffer(data: &mut StreamData<T, StreamId>) {
while let Some(message) = data.message_buffer.remove(&data.next_message_id) {
Self::inbound_send(data, message);
}
Expand Down
Loading

0 comments on commit 506ed06

Please sign in to comment.