diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 0bd9c259..cf9c00f2 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -25,7 +25,8 @@ use crate::{ error::{self, Error, ParseError}, multistream_select::{ protocol::{ - encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, + webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, + ProtocolError, }, Negotiated, NegotiationError, Version, }, @@ -305,7 +306,7 @@ pub enum HandshakeResult { /// Handshake state. #[derive(Debug)] enum HandshakeState { - /// Wainting to receive any response from remote peer. + /// Waiting to receive any response from remote peer. WaitingResponse, /// Waiting to receive the actual application protocol from remote peer. @@ -314,7 +315,7 @@ enum HandshakeState { /// `multistream-select` dialer handshake state. #[derive(Debug)] -pub struct DialerState { +pub struct WebRtcDialerState { /// Proposed main protocol. protocol: ProtocolName, @@ -325,16 +326,16 @@ pub struct DialerState { state: HandshakeState, } -impl DialerState { +impl WebRtcDialerState { /// Propose protocol to remote peer. /// - /// Return [`DialerState`] which is used to drive forward the negotiation and an encoded + /// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded /// `multistream-select` message that contains the protocol proposal for the substream. pub fn propose( protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - let message = encode_multistream_message( + let message = webrtc_encode_multistream_message( std::iter::once(protocol.clone()) .chain(fallback_names.clone()) .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) @@ -353,7 +354,7 @@ impl DialerState { )) } - /// Register response to [`DialerState`]. + /// Register response to [`WebRtcDialerState`]. pub fn register_response( &mut self, payload: Vec, @@ -548,9 +549,7 @@ mod tests { io.close().await.unwrap(); }); - // TODO: Once https://github.com/paritytech/litep2p/pull/62 is merged, this - // should be changed to `is_ok`. - assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_err()); + assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_ok()); } #[tokio::test] @@ -757,7 +756,7 @@ mod tests { #[test] fn propose() { let (mut dialer_state, message) = - DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); let message = bytes::BytesMut::from(&message[..]).freeze(); let Message::Protocols(protocols) = Message::decode(message).unwrap() else { @@ -779,7 +778,7 @@ mod tests { #[test] fn propose_with_fallback() { - let (mut dialer_state, message) = DialerState::propose( + let (mut dialer_state, message) = WebRtcDialerState::propose( ProtocolName::from("/13371338/proto/1"), vec![ProtocolName::from("/sup/proto/1")], ) @@ -815,7 +814,7 @@ mod tests { let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); let (mut dialer_state, _message) = - DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); match dialer_state.register_response(bytes.freeze().to_vec()) { Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} @@ -834,7 +833,7 @@ mod tests { let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); let (mut dialer_state, _message) = - DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); match dialer_state.register_response(bytes.freeze().to_vec()) { Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} @@ -844,7 +843,7 @@ mod tests { #[test] fn negotiate_main_protocol() { - let message = encode_multistream_message( + let message = webrtc_encode_multistream_message( vec![Message::Protocol( Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), )] @@ -853,7 +852,7 @@ mod tests { .unwrap() .freeze(); - let (mut dialer_state, _message) = DialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose( ProtocolName::from("/13371338/proto/1"), vec![ProtocolName::from("/sup/proto/1")], ) @@ -862,13 +861,13 @@ mod tests { match dialer_state.register_response(message.to_vec()) { Ok(HandshakeResult::Succeeded(negotiated)) => assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")), - _ => panic!("invalid event"), + event => panic!("invalid event {event:?}"), } } #[test] fn negotiate_fallback_protocol() { - let message = encode_multistream_message( + let message = webrtc_encode_multistream_message( vec![Message::Protocol( Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), )] @@ -877,7 +876,7 @@ mod tests { .unwrap() .freeze(); - let (mut dialer_state, _message) = DialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose( ProtocolName::from("/13371338/proto/1"), vec![ProtocolName::from("/sup/proto/1")], ) diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index ae7d4d4b..75f4bf9a 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -26,7 +26,8 @@ use crate::{ error::{self, Error}, multistream_select::{ protocol::{ - encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, + webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, + ProtocolError, }, Negotiated, NegotiationError, }, @@ -324,7 +325,7 @@ where } } -/// Result of [`listener_negotiate()`]. +/// Result of [`webrtc_listener_negotiate()`]. #[derive(Debug)] pub enum ListenerSelectResult { /// Requested protocol is available and substream can be accepted. @@ -348,7 +349,7 @@ pub enum ListenerSelectResult { /// Parse protocols offered by the remote peer and check if any of the offered protocols match /// locally available protocols. If a match is found, return an encoded multistream-select /// response and the negotiated protocol. If parsing fails or no match is found, return an error. -pub fn listener_negotiate<'a>( +pub fn webrtc_listener_negotiate<'a>( supported_protocols: &'a mut impl Iterator, payload: Bytes, ) -> crate::Result { @@ -382,9 +383,9 @@ pub fn listener_negotiate<'a>( if protocol.as_ref() == supported.as_bytes() { return Ok(ListenerSelectResult::Accepted { protocol: supported.clone(), - message: encode_multistream_message(std::iter::once(Message::Protocol( - protocol, - )))?, + message: webrtc_encode_multistream_message(std::iter::once( + Message::Protocol(protocol), + ))?, }); } } @@ -396,7 +397,7 @@ pub fn listener_negotiate<'a>( ); Ok(ListenerSelectResult::Rejected { - message: encode_multistream_message(std::iter::once(Message::NotAvailable))?, + message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?, }) } @@ -405,7 +406,7 @@ mod tests { use super::*; #[test] - fn listener_negotiate_works() { + fn webrtc_listener_negotiate_works() { let mut local_protocols = vec![ ProtocolName::from("/13371338/proto/1"), ProtocolName::from("/sup/proto/1"), @@ -413,7 +414,7 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = encode_multistream_message( + let message = webrtc_encode_multistream_message( vec![ Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), @@ -423,7 +424,7 @@ mod tests { .unwrap() .freeze(); - match listener_negotiate(&mut local_protocols.iter(), message) { + match webrtc_listener_negotiate(&mut local_protocols.iter(), message) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"), Ok(ListenerSelectResult::Accepted { protocol, message }) => { @@ -441,14 +442,14 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![ + let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![ Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), ]))) .unwrap() .freeze(); - match listener_negotiate(&mut local_protocols.iter(), message) { + match webrtc_listener_negotiate(&mut local_protocols.iter(), message) { Err(error) => assert!(std::matches!(error, Error::InvalidData)), _ => panic!("invalid event"), } @@ -469,7 +470,7 @@ mod tests { let message = Message::Header(HeaderLine::V1); let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); - match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { + match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::MultistreamSelectError( @@ -498,7 +499,7 @@ mod tests { ]); let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); - match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { + match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::MultistreamSelectError( @@ -518,7 +519,7 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = encode_multistream_message( + let message = webrtc_encode_multistream_message( vec![Message::Protocol( Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), )] @@ -527,12 +528,13 @@ mod tests { .unwrap() .freeze(); - match listener_negotiate(&mut local_protocols.iter(), message) { + match webrtc_listener_negotiate(&mut local_protocols.iter(), message) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, - encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap() + webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable)) + .unwrap() ); } Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index 86abe026..b28093b3 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -76,9 +76,10 @@ mod negotiated; mod protocol; pub use crate::multistream_select::{ - dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult}, + dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState}, listener_select::{ - listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult, + listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture, + ListenerSelectResult, }, negotiated::{Negotiated, NegotiatedComplete, NegotiationError}, protocol::{HeaderLine, Message, Protocol, ProtocolError}, diff --git a/src/multistream_select/negotiated.rs b/src/multistream_select/negotiated.rs index 2a08f29c..ee701e74 100644 --- a/src/multistream_select/negotiated.rs +++ b/src/multistream_select/negotiated.rs @@ -323,15 +323,22 @@ where } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Ensure all data has been flushed and expected negotiation messages - // have been received. - ready!(self.as_mut().poll(cx).map_err(Into::::into)?); + // Ensure all data has been flushed, including optimistic multistream-select messages. ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); // Continue with the shutdown of the underlying I/O stream. match self.project().state.project() { StateProj::Completed { io, .. } => io.poll_close(cx), - StateProj::Expecting { io, .. } => io.poll_close(cx), + StateProj::Expecting { io, .. } => { + let close_poll = io.poll_close(cx); + if let Poll::Ready(Ok(())) = close_poll { + tracing::debug!( + target: LOG_TARGET, + "Stream closed. Confirmation from remote for optimstic protocol negotiation still pending." + ); + } + close_poll + } StateProj::Invalid => panic!("Negotiated: Invalid state"), } } diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index 694745a6..b24b31da 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -201,8 +201,7 @@ impl Message { let mut remaining: &[u8] = &msg; loop { // A well-formed message must be terminated with a newline. - // TODO: don't do this - if remaining == [b'\n'] || remaining.is_empty() { + if remaining == [b'\n'] { break; } else if protocols.len() == MAX_PROTOCOLS { return Err(ProtocolError::TooManyProtocols); @@ -228,7 +227,12 @@ impl Message { } /// Create `multistream-select` message from an iterator of `Message`s. -pub fn encode_multistream_message( +/// +/// # Note +/// +/// This is implementation is not compliant with the multistream-select protocol spec. +/// The only purpose of this was to get the `multistream-select` protocol working with smoldot. +pub fn webrtc_encode_multistream_message( messages: impl IntoIterator, ) -> crate::Result { // encode `/multistream-select/1.0.0` header @@ -245,6 +249,9 @@ pub fn encode_multistream_message( header.append(&mut proto_bytes); } + // For the `Message::Protocols` to be interpreted correctly, it must be followed by a newline. + header.push(b'\n'); + Ok(BytesMut::from(&header[..])) } @@ -468,3 +475,71 @@ impl From for ProtocolError { Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode_main_messages() { + // Decode main messages. + let bytes = Bytes::from_static(MSG_MULTISTREAM_1_0); + assert_eq!( + Message::decode(bytes).unwrap(), + Message::Header(HeaderLine::V1) + ); + + let bytes = Bytes::from_static(MSG_PROTOCOL_NA); + assert_eq!(Message::decode(bytes).unwrap(), Message::NotAvailable); + + let bytes = Bytes::from_static(MSG_LS); + assert_eq!(Message::decode(bytes).unwrap(), Message::ListProtocols); + } + + #[test] + fn test_decode_empty_message() { + // Empty message should decode to an IoError, not Header::Protocols. + let bytes = Bytes::from_static(b""); + match Message::decode(bytes).unwrap_err() { + ProtocolError::IoError(io) => assert_eq!(io.kind(), io::ErrorKind::InvalidData), + err => panic!("Unexpected error: {:?}", err), + }; + } + + #[test] + fn test_decode_protocols() { + // Single protocol. + let bytes = Bytes::from_static(b"/protocol-v1\n"); + assert_eq!( + Message::decode(bytes).unwrap(), + Message::Protocol(Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap()) + ); + + // Multiple protocols. + let expected = Message::Protocols(vec![ + Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(), + Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(), + ]); + let mut encoded = BytesMut::new(); + expected.encode(&mut encoded).unwrap(); + + // `\r` is the length of the protocol names. + let bytes = Bytes::from_static(b"\r/protocol-v1\n\r/protocol-v2\n\n"); + assert_eq!(encoded, bytes); + + assert_eq!( + Message::decode(bytes).unwrap(), + Message::Protocols(vec![ + Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(), + Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(), + ]) + ); + + // Check invalid length. + let bytes = Bytes::from_static(b"\r/v1\n\n"); + assert_eq!( + Message::decode(bytes).unwrap_err(), + ProtocolError::InvalidMessage + ); + } +} diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index 52ceb048..82f4b60b 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -20,7 +20,9 @@ use crate::{ error::{Error, ParseError, SubstreamError}, - multistream_select::{listener_negotiate, DialerState, HandshakeResult, ListenerSelectResult}, + multistream_select::{ + webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, WebRtcDialerState, + }, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, substream::Substream, transport::{ @@ -147,7 +149,7 @@ enum ChannelState { context: ChannelContext, /// `multistream-select` dialer state. - dialer_state: DialerState, + dialer_state: WebRtcDialerState, }, /// Channel is open. @@ -260,7 +262,7 @@ impl WebRtcConnection { let fallback_names = std::mem::take(&mut context.fallback_names); let (dialer_state, message) = - DialerState::propose(context.protocol.clone(), fallback_names)?; + WebRtcDialerState::propose(context.protocol.clone(), fallback_names)?; let message = WebRtcMessage::encode(message); self.rtc @@ -317,11 +319,13 @@ impl WebRtcConnection { ); let payload = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; - let (response, negotiated) = - match listener_negotiate(&mut self.protocol_set.protocols().iter(), payload.into())? { - ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), - ListenerSelectResult::Rejected { message } => (message, None), - }; + let (response, negotiated) = match webrtc_listener_negotiate( + &mut self.protocol_set.protocols().iter(), + payload.into(), + )? { + ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), + ListenerSelectResult::Rejected { message } => (message, None), + }; self.rtc .channel(channel_id) @@ -371,7 +375,7 @@ impl WebRtcConnection { &mut self, channel_id: ChannelId, data: Vec, - mut dialer_state: DialerState, + mut dialer_state: WebRtcDialerState, context: ChannelContext, ) -> Result, SubstreamError> { tracing::trace!(