Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into lexnv/update-upstre…
Browse files Browse the repository at this point in the history
…am-yamux
  • Loading branch information
lexnv committed Jan 29, 2025
2 parents 5b91511 + b7511c8 commit 103bf91
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 54 deletions.
37 changes: 18 additions & 19 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use crate::{
error::{self, Error, ParseError},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError, Version,
},
Expand Down Expand Up @@ -305,7 +306,7 @@ pub enum HandshakeResult {
/// Handshake state.
#[derive(Debug)]
enum HandshakeState {
/// Wainting to receive any response from remote peer.
/// Waiting to receive any response from remote peer.
WaitingResponse,

/// Waiting to receive the actual application protocol from remote peer.
Expand All @@ -314,7 +315,7 @@ enum HandshakeState {

/// `multistream-select` dialer handshake state.
#[derive(Debug)]
pub struct DialerState {
pub struct WebRtcDialerState {
/// Proposed main protocol.
protocol: ProtocolName,

Expand All @@ -325,16 +326,16 @@ pub struct DialerState {
state: HandshakeState,
}

impl DialerState {
impl WebRtcDialerState {
/// Propose protocol to remote peer.
///
/// Return [`DialerState`] which is used to drive forward the negotiation and an encoded
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
/// `multistream-select` message that contains the protocol proposal for the substream.
pub fn propose(
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
std::iter::once(protocol.clone())
.chain(fallback_names.clone())
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
Expand All @@ -353,7 +354,7 @@ impl DialerState {
))
}

/// Register response to [`DialerState`].
/// Register response to [`WebRtcDialerState`].
pub fn register_response(
&mut self,
payload: Vec<u8>,
Expand Down Expand Up @@ -548,9 +549,7 @@ mod tests {
io.close().await.unwrap();
});

// TODO: Once https://github.com/paritytech/litep2p/pull/62 is merged, this
// should be changed to `is_ok`.
assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_err());
assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_ok());
}

#[tokio::test]
Expand Down Expand Up @@ -757,7 +756,7 @@ mod tests {
#[test]
fn propose() {
let (mut dialer_state, message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
let message = bytes::BytesMut::from(&message[..]).freeze();

let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
Expand All @@ -779,7 +778,7 @@ mod tests {

#[test]
fn propose_with_fallback() {
let (mut dialer_state, message) = DialerState::propose(
let (mut dialer_state, message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down Expand Up @@ -815,7 +814,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -834,7 +833,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -844,7 +843,7 @@ mod tests {

#[test]
fn negotiate_main_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
)]
Expand All @@ -853,7 +852,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand All @@ -862,13 +861,13 @@ mod tests {
match dialer_state.register_response(message.to_vec()) {
Ok(HandshakeResult::Succeeded(negotiated)) =>
assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
_ => panic!("invalid event"),
event => panic!("invalid event {event:?}"),
}
}

#[test]
fn negotiate_fallback_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
)]
Expand All @@ -877,7 +876,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down
36 changes: 19 additions & 17 deletions src/multistream_select/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use crate::{
error::{self, Error},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError,
},
Expand Down Expand Up @@ -324,7 +325,7 @@ where
}
}

/// Result of [`listener_negotiate()`].
/// Result of [`webrtc_listener_negotiate()`].
#[derive(Debug)]
pub enum ListenerSelectResult {
/// Requested protocol is available and substream can be accepted.
Expand All @@ -348,7 +349,7 @@ pub enum ListenerSelectResult {
/// Parse protocols offered by the remote peer and check if any of the offered protocols match
/// locally available protocols. If a match is found, return an encoded multistream-select
/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
pub fn listener_negotiate<'a>(
pub fn webrtc_listener_negotiate<'a>(
supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
payload: Bytes,
) -> crate::Result<ListenerSelectResult> {
Expand Down Expand Up @@ -382,9 +383,9 @@ pub fn listener_negotiate<'a>(
if protocol.as_ref() == supported.as_bytes() {
return Ok(ListenerSelectResult::Accepted {
protocol: supported.clone(),
message: encode_multistream_message(std::iter::once(Message::Protocol(
protocol,
)))?,
message: webrtc_encode_multistream_message(std::iter::once(
Message::Protocol(protocol),
))?,
});
}
}
Expand All @@ -396,7 +397,7 @@ pub fn listener_negotiate<'a>(
);

Ok(ListenerSelectResult::Rejected {
message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?,
})
}

Expand All @@ -405,15 +406,15 @@ mod tests {
use super::*;

#[test]
fn listener_negotiate_works() {
fn webrtc_listener_negotiate_works() {
let mut local_protocols = vec![
ProtocolName::from("/13371338/proto/1"),
ProtocolName::from("/sup/proto/1"),
ProtocolName::from("/13371338/proto/2"),
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
Expand All @@ -423,7 +424,7 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
Ok(ListenerSelectResult::Accepted { protocol, message }) => {
Expand All @@ -441,14 +442,14 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
])))
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
_ => panic!("invalid event"),
}
Expand All @@ -469,7 +470,7 @@ mod tests {
let message = Message::Header(HeaderLine::V1);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand Down Expand Up @@ -498,7 +499,7 @@ mod tests {
]);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand All @@ -518,7 +519,7 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
)]
Expand All @@ -527,12 +528,13 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))
.unwrap()
);
}
Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
Expand Down
5 changes: 3 additions & 2 deletions src/multistream_select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ mod negotiated;
mod protocol;

pub use crate::multistream_select::{
dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult},
dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState},
listener_select::{
listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult,
listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture,
ListenerSelectResult,
},
negotiated::{Negotiated, NegotiatedComplete, NegotiationError},
protocol::{HeaderLine, Message, Protocol, ProtocolError},
Expand Down
15 changes: 11 additions & 4 deletions src/multistream_select/negotiated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,22 @@ where
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
// Ensure all data has been flushed and expected negotiation messages
// have been received.
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
// Ensure all data has been flushed, including optimistic multistream-select messages.
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);

// Continue with the shutdown of the underlying I/O stream.
match self.project().state.project() {
StateProj::Completed { io, .. } => io.poll_close(cx),
StateProj::Expecting { io, .. } => io.poll_close(cx),
StateProj::Expecting { io, .. } => {
let close_poll = io.poll_close(cx);
if let Poll::Ready(Ok(())) = close_poll {
tracing::debug!(
target: LOG_TARGET,
"Stream closed. Confirmation from remote for optimstic protocol negotiation still pending."
);
}
close_poll
}
StateProj::Invalid => panic!("Negotiated: Invalid state"),
}
}
Expand Down
Loading

0 comments on commit 103bf91

Please sign in to comment.