diff --git a/crates/scion/Cargo.toml b/crates/scion/Cargo.toml index 7064cba..d4b5d32 100644 --- a/crates/scion/Cargo.toml +++ b/crates/scion/Cargo.toml @@ -5,8 +5,6 @@ edition = "2021" publish = false [dependencies] -serde = { version = "1.0.188", features = ["derive"] } -thiserror = "1.0.48" bytes = "1.4.0" serde = { version = "1.0.188", features = ["derive"] } thiserror = "1.0.48" diff --git a/crates/scion/src/address/host.rs b/crates/scion/src/address/host.rs index a63ac41..f59853d 100644 --- a/crates/scion/src/address/host.rs +++ b/crates/scion/src/address/host.rs @@ -45,6 +45,12 @@ impl HostType { } } +impl From for u8 { + fn from(value: HostType) -> Self { + value as u8 + } +} + /// Trait to be implemented by address types that are supported by SCION /// as valid AS-host addresses. pub trait HostAddress { diff --git a/crates/scion/src/address/service.rs b/crates/scion/src/address/service.rs index b038c84..b918040 100644 --- a/crates/scion/src/address/service.rs +++ b/crates/scion/src/address/service.rs @@ -30,7 +30,7 @@ impl ServiceAddress { #[allow(unused)] /// Special none service address value. - const NONE: Self = Self(0xffff); + pub(crate) const NONE: Self = Self(0xffff); /// Flag bit indicating whether the address includes multicast const MULTICAST_FLAG: u16 = 0x8000; @@ -103,6 +103,12 @@ impl FromStr for ServiceAddress { } } +impl From for u16 { + fn from(value: ServiceAddress) -> Self { + value.0 + } +} + impl Display for ServiceAddress { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.anycast() { diff --git a/crates/scion/src/lib.rs b/crates/scion/src/lib.rs index b3d6efc..b1b7012 100644 --- a/crates/scion/src/lib.rs +++ b/crates/scion/src/lib.rs @@ -1,3 +1,6 @@ pub mod address; pub mod daemon; pub mod reliable; + +#[cfg(test)] +pub(crate) mod test_utils; diff --git a/crates/scion/src/reliable.rs b/crates/scion/src/reliable.rs index a186593..f3db362 100644 --- a/crates/scion/src/reliable.rs +++ b/crates/scion/src/reliable.rs @@ -1,13 +1,12 @@ mod error; -mod host_address; -mod message; -mod registration; -mod relay_protocol; - pub use error::ReliableRelayError; + +mod relay_protocol; pub use relay_protocol::ReliableRelayProtocol; -const IPV4_OCTETS: usize = 4; -const IPV6_OCTETS: usize = 16; -const LAYER4_PORT_OCTETS: usize = 2; +mod common_header; +mod parser; +mod registration; +mod wire_utils; + const ADDRESS_TYPE_OCTETS: usize = 1; diff --git a/crates/scion/src/reliable/common_header.rs b/crates/scion/src/reliable/common_header.rs new file mode 100644 index 0000000..a5305d4 --- /dev/null +++ b/crates/scion/src/reliable/common_header.rs @@ -0,0 +1,368 @@ +use std::net::{IpAddr, SocketAddr}; + +use bytes::{Buf, BufMut}; +use thiserror::Error; + +use super::{ + wire_utils::{encoded_address_and_port_length, IPV6_OCTETS, LAYER4_PORT_OCTETS}, + ADDRESS_TYPE_OCTETS, +}; +use crate::address::{HostAddress, HostType}; + +#[derive(Error, Debug, Eq, PartialEq)] +pub enum DecodeError { + #[error("received an invalid cookie when decoding header: {0:x}")] + InvalidCookie(u64), + // TODO: Document that type SVC is not supported for this + #[error("invalid or unsupported reliable relay address type: {0}")] + InvalidAddressType(u8), +} + +pub(super) enum DecodedHeader { + Partial(PartialHeader), + Full(CommonHeader), +} + +impl DecodedHeader { + pub fn is_fully_decoded(&self) -> bool { + matches!(self, DecodedHeader::Full(..)) + } +} + +#[derive(Copy, Clone)] +pub(super) struct PartialHeader { + pub host_type: HostType, + pub payload_length: u32, +} + +impl PartialHeader { + /// # Panics + /// + /// Panics if there is not at least CommonHeader::MIN_LENGTH bytes available + /// in the buffer. + fn decode(buffer: &mut impl Buf) -> Result { + assert!( + buffer.remaining() >= CommonHeader::MIN_LENGTH, + "insufficient data" + ); + + let cookie = buffer.get_u64(); + if cookie != CommonHeader::COOKIE { + return Err(DecodeError::InvalidCookie(cookie)); + } + + let host_type = buffer.get_u8(); + let host_type = match HostType::from_byte(host_type) { + None | Some(HostType::Svc) => return Err(DecodeError::InvalidAddressType(host_type)), + Some(address_type) => address_type, + }; + + Ok(Self { + host_type, + payload_length: buffer.get_u32(), + }) + } + + pub fn required_bytes(&self) -> usize { + encoded_address_and_port_length(self.host_type) + } + + /// Attempt to finish decoding of the common header. + /// + /// # Panics + /// + /// Panics if there is not at least self.required_bytes() available in the buffer. + pub fn finish_decoding(self, buffer: &mut impl Buf) -> CommonHeader { + assert!( + buffer.remaining() >= self.required_bytes(), + "insufficient data" + ); + + let PartialHeader { + host_type, + payload_length, + } = self; + + let destination = match host_type { + HostType::None => None, + HostType::Ipv4 | HostType::Ipv6 => { + let ip_address = if host_type == HostType::Ipv4 { + IpAddr::V4(buffer.get_u32().into()) + } else { + IpAddr::V6(buffer.get_u128().into()) + }; + Some(SocketAddr::new(ip_address, buffer.get_u16())) + } + HostType::Svc => unreachable!(), + }; + + CommonHeader { + destination, + payload_length, + } + } +} + +#[derive(Default, Debug, Copy, Clone)] +pub(super) struct CommonHeader { + pub destination: Option, + pub payload_length: u32, +} + +impl CommonHeader { + pub const MIN_LENGTH: usize = + Self::COOKIE_LENGTH + ADDRESS_TYPE_OCTETS + Self::PAYLOAD_SIZE_LENGTH; + pub const MAX_LENGTH: usize = Self::MIN_LENGTH + IPV6_OCTETS + LAYER4_PORT_OCTETS; + + const COOKIE: u64 = 0xde00ad01be02ef03; + const COOKIE_LENGTH: usize = 8; + const PAYLOAD_SIZE_LENGTH: usize = 4; + + pub fn new() -> Self { + Self::default() + } + + #[inline] + pub fn payload_size(&self) -> usize { + self.payload_length + .try_into() + .expect("at least 32-bit architecture") + } + + pub fn encoded_length(&self) -> usize { + Self::COOKIE_LENGTH + + ADDRESS_TYPE_OCTETS + + Self::PAYLOAD_SIZE_LENGTH + + encoded_address_and_port_length(self.destination.host_address_type()) + } + + /// Serialize a common header to the provided buffer. + /// + /// The resulting header is suitable for being written to the network + /// ahead of the payload. + /// + /// # Panic + /// + /// Panics if the the buffer is too small to contain the header. + pub fn encode_to(&self, buffer: &mut impl BufMut) { + let initial_remaining = buffer.remaining_mut(); + + buffer.put_u64(Self::COOKIE); + buffer.put_u8(self.destination.host_address_type().into()); + buffer.put_u32(self.payload_length); + + if let Some(destination) = self.destination.as_ref() { + match destination.ip() { + IpAddr::V4(ipv4) => buffer.put(ipv4.octets().as_slice()), + IpAddr::V6(ipv6) => buffer.put(ipv6.octets().as_slice()), + } + + buffer.put_u16(destination.port()); + } + + let bytes_consumed = initial_remaining - buffer.remaining_mut(); + assert_eq!(bytes_consumed, self.encoded_length()); + } + + /// # Panics + /// + /// Panics if there is not at least CommonHeader::MIN_LENGTH bytes available + /// in the buffer. + pub fn partial_decode(buffer: &mut impl Buf) -> Result { + assert!( + buffer.remaining() >= CommonHeader::MIN_LENGTH, + "insufficient data" + ); + + let partial_header = PartialHeader::decode(buffer)?; + + if buffer.remaining() >= partial_header.required_bytes() { + Ok(DecodedHeader::Full(partial_header.finish_decoding(buffer))) + } else { + Ok(DecodedHeader::Partial(partial_header)) + } + } + + /// # Panics + /// + /// Panics if there is insufficient data in the buffer to decode the entire header. + /// To avoid the panic, ensure there is [`Self::MAX_LENGTH`] bytes available, or + /// use [`Self::partial_decode()`] instead. + pub fn decode(buffer: &mut impl Buf) -> Result { + PartialHeader::decode(buffer).map(|header| header.finish_decoding(buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + mod encode { + use std::str::FromStr; + + use bytes::BytesMut; + + use super::*; + + macro_rules! test_successful_encode { + ($name:ident, $optional_address:expr, $payload_length:expr, $expected_bytes:expr) => { + #[test] + fn $name() { + let address = $optional_address + .map(|addr_str| SocketAddr::from_str(addr_str).expect("valid address")); + + let mut buffer = BytesMut::new(); + CommonHeader { + destination: address, + payload_length: $payload_length, + } + .encode_to(&mut buffer); + + assert_eq!(buffer.as_ref(), $expected_bytes); + } + }; + } + + test_successful_encode!( + no_address_or_data, + None, + 0, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 0, 0, 0, 0, 0] + ); + + test_successful_encode!( + ipv4_no_data, + Some("10.2.3.4:80"), + 0, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0, 80] + ); + + test_successful_encode!( + ipv6_no_data, + Some("[2001:db8::1]:80"), + 0, + [ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 2, 0, 0, 0, 0, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 80 + ] + ); + + test_successful_encode!( + ipv4_big_port_no_data, + Some("10.2.3.4:65534"), + 0, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0xff, 0xfe] + ); + + test_successful_encode!( + ipv4_good_payload, + Some("127.0.0.1:22"), + 4, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 4, 127, 0, 0, 1, 0, 22] + ); + + test_successful_encode!( + max_payload_length, + Some("127.0.0.2:88"), + u32::MAX, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0xff, 0xff, 0xff, 0xff, 127, 0, 0, 2, 0, 88] + ); + } + + mod decode { + use bytes::Bytes; + + use super::*; + + macro_rules! test_decode_error { + ($name:ident, $buffer:expr, $expected_error:expr) => { + #[test] + fn $name() { + let mut buffer = Bytes::copy_from_slice($buffer.as_slice()); + assert_eq!( + CommonHeader::decode(&mut buffer).expect_err("expected invalid data"), + $expected_error + ); + } + }; + } + + test_decode_error!( + invalid_cookie, + [0xaa_u8, 0xbb, 0xaa, 0xbb, 0xaa, 0xbb, 0xaa, 0xbb, 0, 0, 0, 0, 0], + DecodeError::InvalidCookie(0xaabbaabbaabbaabb) + ); + + test_decode_error!( + invalid_address_type, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 3, 0, 0, 0, 0], + DecodeError::InvalidAddressType(3) + ); + + macro_rules! test_decode_panic { + ($name:ident, $buffer:expr) => { + #[test] + #[should_panic(expected = "insufficient data")] + fn $name() { + let mut buffer = Bytes::copy_from_slice($buffer.as_slice()); + let _ = CommonHeader::decode(&mut buffer); + } + }; + } + + test_decode_panic!(incomplete_header, [0xaa]); + + test_decode_panic!( + incomplete_header_address, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3] + ); + + test_decode_panic!( + incomplete_header_port, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0] + ); + + macro_rules! test_successful_decode { + ($name:ident, $buffer:expr, $expected_header:expr) => { + #[test] + fn $name() { + let mut buffer = Bytes::copy_from_slice($buffer.as_slice()); + let header = CommonHeader::decode(&mut buffer).unwrap(); + + assert_eq!(header.destination, $expected_header.destination); + assert_eq!(header.payload_length, $expected_header.payload_length); + } + }; + } + + test_successful_decode!( + valid_no_address, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 0, 0, 0, 0, 1, 42], + CommonHeader { + destination: None, + payload_length: 1 + } + ); + + test_successful_decode!( + valid_with_ipv4, + [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 1, 10, 2, 3, 4, 0, 80, 42], + CommonHeader { + destination: Some("10.2.3.4:80".parse().unwrap()), + payload_length: 1, + } + ); + + test_successful_decode!( + valid_with_ipv6, + [ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 2, 0, 0, 0, 1, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 80, 42 + ], + CommonHeader { + destination: Some("[2001:db8::1]:80".parse().unwrap()), + payload_length: 1, + } + ); + } +} diff --git a/crates/scion/src/reliable/error.rs b/crates/scion/src/reliable/error.rs index 96fb6b1..cf18690 100644 --- a/crates/scion/src/reliable/error.rs +++ b/crates/scion/src/reliable/error.rs @@ -6,15 +6,6 @@ pub enum ReliableRelayError { DestinationUnspecified, #[error("provided destination port mmust be specified")] DestinationPortUnspecified, - #[error("payload size too large ({0}), should be at most {}", u32::MAX)] - PayloadTooLarge(usize), - #[error("received an invalid cookie when decoding header: {0:x}")] - InvalidCookie(u64), - // TODO: Document that type SVC is not supported for this - #[error("invalid or unsupported reliable relay address type: {0}")] - InvalidAddressType(u8), - #[error("incomplete header, wanted {wanted} more bytes, found {found} bytes")] - IncompleteHeader { wanted: usize, found: usize }, - #[error("invalid payload length, wanted {wanted} more bytes, found {found} bytes")] - InvalidPayloadLength { wanted: usize, found: usize }, + #[error("port mismatch, requested port {requested}, received port {assigned}")] + PortMismatch { requested: u16, assigned: u16 }, } diff --git a/crates/scion/src/reliable/host_address.rs b/crates/scion/src/reliable/host_address.rs deleted file mode 100644 index 8b34864..0000000 --- a/crates/scion/src/reliable/host_address.rs +++ /dev/null @@ -1,53 +0,0 @@ -#![allow(dead_code)] -use std::net::{IpAddr, SocketAddr}; - -#[repr(u8)] -#[derive(Eq, PartialEq, Clone, Copy)] -pub enum HostAddressType { - None = 0, - IPv4, - IPv6, - Svc, -} - -impl HostAddressType { - pub fn from_byte(byte: u8) -> Option { - match byte { - 0 => Some(HostAddressType::None), - 1 => Some(HostAddressType::IPv4), - 2 => Some(HostAddressType::IPv6), - 3 => Some(HostAddressType::Svc), - _ => None, - } - } -} - -pub trait HostAddress { - fn host_address_type(&self) -> HostAddressType; -} - -impl HostAddress for SocketAddr { - fn host_address_type(&self) -> HostAddressType { - self.ip().host_address_type() - } -} - -impl HostAddress for IpAddr { - fn host_address_type(&self) -> HostAddressType { - match self { - IpAddr::V4(_) => HostAddressType::IPv4, - IpAddr::V6(_) => HostAddressType::IPv6, - } - } -} - -impl HostAddress for Option -where - T: HostAddress, -{ - fn host_address_type(&self) -> HostAddressType { - self.as_ref() - .map(HostAddress::host_address_type) - .unwrap_or(HostAddressType::None) - } -} diff --git a/crates/scion/src/reliable/message.rs b/crates/scion/src/reliable/message.rs deleted file mode 100644 index 18f6cc9..0000000 --- a/crates/scion/src/reliable/message.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::net::{IpAddr, SocketAddr}; - -use bytes::{buf::Chain, Buf, BufMut, Bytes}; - -use super::{ - error::ReliableRelayError, - host_address::HostAddress, - ADDRESS_TYPE_OCTETS, - IPV6_OCTETS, -}; -use crate::reliable::{host_address::HostAddressType, IPV4_OCTETS, LAYER4_PORT_OCTETS}; - -const COOKIE_LENGTH: usize = 8; -const PAYLOAD_SIZE_LENGTH: usize = 4; -const COOKIE: u64 = 0xde00ad01be02ef03; - -#[derive(Debug)] -pub(super) struct RelayMessage { - destination: Option, - payload: Bytes, -} - -impl RelayMessage { - pub fn new( - payload: Bytes, - destination: Option, - ) -> Result { - if let Some(address) = destination.as_ref() { - if address.ip().is_unspecified() { - return Err(ReliableRelayError::DestinationUnspecified); - } - if address.port() == 0 { - return Err(ReliableRelayError::DestinationPortUnspecified); - } - } - - if u32::try_from(payload.len()).is_err() { - return Err(ReliableRelayError::PayloadTooLarge(payload.len())); - } - - Ok(Self { - destination, - payload, - }) - } - - /// Serialize the header of the relay message to the provided buffer, - /// and return the number of bytes written. - /// - /// The resulting header is suitable for being written to the network - /// ahead of the payload. - /// - /// # Panic - /// - /// Panics if the the buffer is too small to contain the header. - pub fn serialize_header_to(&self, buffer: &mut B) -> usize { - buffer.put_u64(COOKIE); - buffer.put_u8(self.destination.host_address_type() as u8); - buffer.put_u32( - self.payload - .len() - .try_into() - .expect("payload length should be at most 2^32-1"), - ); - if self.destination.is_some() { - let destination = self.destination.as_ref().unwrap(); - - match destination.ip() { - IpAddr::V4(ipv4) => buffer.put(ipv4.octets().as_slice()), - IpAddr::V6(ipv6) => buffer.put(ipv6.octets().as_slice()), - } - - buffer.put_u16(destination.port()); - } - - self.serialized_header_length() - } - - /// Allocates a buffer and serializes the relay header to it. - /// - /// The resulting header is suitable for being written to the network - /// ahead of the payload. - pub fn serialize_header(&self) -> Bytes { - let mut buffer = Vec::with_capacity(self.serialized_header_length()); - self.serialize_header_to(&mut buffer); - buffer.into() - } - - /// Consume and serialize the message. - // FIXME(jsmith): Use of Chain payloads of max length may panic on 32-bit systems. - // If the length of the payload is usize::MAX <= u32::MAX, then later calls to Chain::remaining - // panics. See https://docs.rs/bytes/latest/src/bytes/buf/chain.rs.html#137 - pub fn serialize(self) -> Chain { - let header = self.serialize_header(); - header.chain(self.payload) - } - - /// Returns the total length of the relay message header. - pub fn serialized_header_length(&self) -> usize { - COOKIE_LENGTH - + ADDRESS_TYPE_OCTETS - + PAYLOAD_SIZE_LENGTH - + serialized_address_length(self.destination.host_address_type()) - } - - pub fn deserialize(mut buffer: Bytes) -> Result { - use ReliableRelayError::*; - - check_remaining( - &buffer, - COOKIE_LENGTH + ADDRESS_TYPE_OCTETS + PAYLOAD_SIZE_LENGTH, - )?; - - let cookie = buffer.get_u64(); - if cookie != COOKIE { - return Err(InvalidCookie(cookie)); - } - - let address_type = buffer.get_u8(); - let address_type = match HostAddressType::from_byte(address_type) { - None | Some(HostAddressType::Svc) => return Err(InvalidAddressType(address_type)), - Some(address_type) => address_type, - }; - - let payload_length = usize::try_from(buffer.get_u32()).expect("at least u32 usizes"); - - check_remaining(&buffer, serialized_address_length(address_type))?; - let destination = match address_type { - HostAddressType::None => None, - HostAddressType::IPv4 => { - let mut ipv4_octects = [0u8; 4]; - buffer.copy_to_slice(&mut ipv4_octects); - Some(SocketAddr::new(ipv4_octects.into(), buffer.get_u16())) - } - HostAddressType::IPv6 => { - let mut ipv6_octects = [0u8; 16]; - buffer.copy_to_slice(&mut ipv6_octects); - Some(SocketAddr::new(ipv6_octects.into(), buffer.get_u16())) - } - HostAddressType::Svc => unreachable!(), - }; - - if buffer.remaining() != payload_length { - return Err(InvalidPayloadLength { - wanted: payload_length, - found: buffer.remaining(), - }); - } - - Self::new(buffer, destination) - } -} - -fn serialized_address_length(address_type: HostAddressType) -> usize { - match address_type { - HostAddressType::None => 0, - HostAddressType::IPv4 => IPV4_OCTETS + LAYER4_PORT_OCTETS, - HostAddressType::IPv6 => IPV6_OCTETS + LAYER4_PORT_OCTETS, - HostAddressType::Svc => panic!("does not accept SVC host address type"), - } -} - -fn check_remaining(buffer: &impl Buf, required_bytes: usize) -> Result<(), ReliableRelayError> { - if buffer.remaining() < required_bytes { - Err(ReliableRelayError::IncompleteHeader { - wanted: required_bytes, - found: buffer.remaining(), - }) - } else { - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - mod serialize { - use std::str::FromStr; - - use super::*; - - macro_rules! test_successful_serialize { - ($name:ident, $optional_address:expr, $optional_data:expr, $expected_bytes:expr) => { - #[test] - fn $name() { - let address = $optional_address - .map(|addr_str| SocketAddr::from_str(addr_str).expect("valid address")); - - let optional_data: Option<&[u8]> = $optional_data; - let payload: Bytes = - optional_data.map_or(Bytes::new(), |arr| arr.iter().cloned().collect()); - - let message = RelayMessage::new(payload, address).expect("valid message"); - let message_bytes: Vec = message.serialize().into_iter().collect(); - - assert_eq!(message_bytes, $expected_bytes); - } - }; - } - - test_successful_serialize!( - no_address_or_data, - None, - None, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 0, 0, 0, 0, 0] - ); - - test_successful_serialize!( - ipv4_no_data, - Some("10.2.3.4:80"), - None, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0, 80] - ); - - test_successful_serialize!( - ipv6_no_data, - Some("[2001:db8::1]:80"), - None, - [ - 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 2, 0, 0, 0, 0, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 80 - ] - ); - - test_successful_serialize!( - ipv4_big_port_no_data, - Some("10.2.3.4:65534"), - None, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0xff, 0xfe] - ); - - test_successful_serialize!( - ipv4_good_payload, - Some("127.0.0.1:22"), - Some(&[10, 5, 6, 7]), - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 4, 127, 0, 0, 1, 0, 22, 10, 5, 6, 7] - ); - - #[test] - fn max_data() { - let payload = Bytes::from(vec![42; u32::MAX as usize]); - let expected_header: Vec = vec![ - 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0xff, 0xff, 0xff, 0xff, 127, 0, 0, 2, 0, 88, - ]; - - let message = RelayMessage::new(payload, Some("127.0.0.2:88".parse().unwrap())) - .expect("valid message"); - let serialized_message = message.serialize(); - - assert_eq!( - serialized_message.remaining(), - expected_header.len() + (u32::MAX as usize) - ); - - let header: Vec = serialized_message - .into_iter() - .take(expected_header.len()) - .collect(); - - assert_eq!(header, expected_header); - } - - macro_rules! test_invalid_message { - ($name:ident, $optional_address:expr, $optional_data:expr, $expected_error:expr) => { - #[test] - fn $name() { - let address = $optional_address - .map(|addr_str| SocketAddr::from_str(addr_str).expect("valid address")); - - let optional_data: Option<&[u8]> = $optional_data; - let payload: Bytes = - optional_data.map_or(Bytes::new(), |arr| arr.iter().cloned().collect()); - - let error = - RelayMessage::new(payload, address).expect_err("expected invalid message"); - - assert_eq!(error, $expected_error); - } - }; - } - - test_invalid_message!( - unspecified_ipv4, - Some("0.0.0.0:80"), - None, - ReliableRelayError::DestinationUnspecified - ); - - test_invalid_message!( - unspecified_ipv6, - Some("[::0]:443"), - None, - ReliableRelayError::DestinationUnspecified - ); - - test_invalid_message!( - unspecified_port, - Some("10.0.0.1:0"), - None, - ReliableRelayError::DestinationPortUnspecified - ); - - #[test] - fn payload_too_large() { - let payload_length = 1 + (u32::MAX as usize); - let payload = Bytes::from(vec![42; payload_length]); - - let error = RelayMessage::new(payload, Some("127.0.0.2:88".parse().unwrap())) - .expect_err("expected invalid message"); - - assert_eq!(error, ReliableRelayError::PayloadTooLarge(payload_length)); - } - } - - mod deserialize { - use super::*; - use crate::reliable::ReliableRelayError; - - macro_rules! test_deserialize_error { - ($name:ident, $buffer:expr, $expected_error:expr) => { - #[test] - fn $name() { - let buffer = Bytes::copy_from_slice($buffer.as_slice()); - assert_eq!( - RelayMessage::deserialize(buffer).expect_err("expected invalid data"), - $expected_error - ); - } - }; - } - - test_deserialize_error!( - invalid_cookie, - [0xaa_u8, 0xbb, 0xaa, 0xbb, 0xaa, 0xbb, 0xaa, 0xbb, 0, 0, 0, 0, 0], - ReliableRelayError::InvalidCookie(0xaabbaabbaabbaabb) - ); - - test_deserialize_error!( - invalid_address_type, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 3, 0, 0, 0, 0], - ReliableRelayError::InvalidAddressType(3) - ); - - test_deserialize_error!( - incomplete_header, - [0xaa], - ReliableRelayError::IncompleteHeader { - wanted: 13, - found: 1 - } - ); - - test_deserialize_error!( - incomplete_header_address, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3], - ReliableRelayError::IncompleteHeader { - wanted: 6, - found: 3 - } - ); - - test_deserialize_error!( - incomplete_header_port, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0], - ReliableRelayError::IncompleteHeader { - wanted: 6, - found: 5 - } - ); - - test_deserialize_error!( - payload_too_long, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 0, 10, 2, 3, 4, 0, 80, 42], - ReliableRelayError::InvalidPayloadLength { - wanted: 0, - found: 1 - } - ); - - test_deserialize_error!( - payload_too_short, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 2, 10, 2, 3, 4, 0, 80, 42], - ReliableRelayError::InvalidPayloadLength { - wanted: 2, - found: 1 - } - ); - - macro_rules! test_successful_deserialize { - ($name:ident, $buffer:expr, $expected_message:expr) => { - #[test] - fn $name() { - let buffer = Bytes::copy_from_slice($buffer.as_slice()); - let message = RelayMessage::deserialize(buffer).unwrap(); - let expected_message: RelayMessage = $expected_message; - - assert_eq!(message.destination, expected_message.destination); - assert_eq!(message.payload, expected_message.payload); - } - }; - } - - test_successful_deserialize!( - valid_no_address, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 0, 0, 0, 0, 1, 42], - RelayMessage { - destination: None, - payload: Bytes::from_static(&[42]) - } - ); - - test_successful_deserialize!( - valid_with_ipv4, - [0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 1, 10, 2, 3, 4, 0, 80, 42], - RelayMessage { - destination: Some("10.2.3.4:80".parse().unwrap()), - payload: Bytes::from_static(&[42]) - } - ); - - test_successful_deserialize!( - valid_with_ipv6, - [ - 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 2, 0, 0, 0, 1, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 80, 42 - ], - RelayMessage { - destination: Some("[2001:db8::1]:80".parse().unwrap()), - payload: Bytes::from_static(&[42]) - } - ); - } -} diff --git a/crates/scion/src/reliable/parser.rs b/crates/scion/src/reliable/parser.rs new file mode 100644 index 0000000..c2aefc6 --- /dev/null +++ b/crates/scion/src/reliable/parser.rs @@ -0,0 +1,133 @@ +use std::{collections::VecDeque, ops::Deref}; + +use bytes::{Buf, Bytes}; + +use super::common_header::{CommonHeader, DecodeError, DecodedHeader}; + +pub(super) struct StreamParser { + // INV: byte objects are always non-empty + byte_queue: VecDeque, + bytes_remaining: usize, + next_header: Option, +} + +impl StreamParser { + pub fn new() -> Self { + Self { + byte_queue: VecDeque::new(), + bytes_remaining: 0, + next_header: None, + } + } + + pub fn append_data(&mut self, data: Bytes) { + if !data.is_empty() { + self.bytes_remaining += data.len(); + self.byte_queue.push_back(data); + } + } + + pub fn next_packet(&mut self) -> Result)>, DecodeError> { + match &self.next_header { + None if self.remaining() >= CommonHeader::MIN_LENGTH => { + self.next_header = Some(CommonHeader::partial_decode(self)?); + + // Recursively try to get the payload if we parsed a full common header + if self.next_header.as_ref().unwrap().is_fully_decoded() { + self.next_packet() + } else { + Ok(None) + } + } + Some(DecodedHeader::Partial(header)) if self.remaining() >= header.required_bytes() => { + self.next_header = Some(DecodedHeader::Full(header.finish_decoding(self))); + self.next_packet() + } + Some(DecodedHeader::Full(header)) if self.remaining() >= header.payload_size() => { + let header = *header; + let payload = self.get_payload(header.payload_size()); + + self.next_header = None; + + Ok(Some((header, payload))) + } + _ => Ok(None), + } + } + + fn get_payload(&mut self, payload_size: usize) -> Vec { + let mut result = vec![]; + + let mut payload_bytes_needed = payload_size; + + while payload_bytes_needed > 0 { + let mut data = self.byte_queue.pop_front().expect("there must be data"); + + if data.len() > payload_bytes_needed { + self.byte_queue + .push_front(data.split_off(payload_bytes_needed)); + } + + assert!(data.len() <= payload_bytes_needed); + + payload_bytes_needed -= data.len(); + result.push(data); + } + + result + } +} + +impl Buf for StreamParser { + fn remaining(&self) -> usize { + self.bytes_remaining + } + + fn chunk(&self) -> &[u8] { + self.byte_queue.front().map_or(&[], |data| data.deref()) + } + + fn advance(&mut self, cnt: usize) { + if cnt == 0 { + return; + } + if cnt > self.bytes_remaining { + panic!( + "cnt > self.remaining() ({} > {})", + cnt, self.bytes_remaining + ); + } + + let mut advance_by = cnt; + while advance_by > 0 { + let mut data = self.byte_queue.pop_front().expect("there must be data"); + + if data.len() > advance_by { + self.byte_queue.push_front(data.split_off(advance_by)); + } + assert!(data.len() <= advance_by); + + advance_by -= data.len(); + } + + self.bytes_remaining -= cnt; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn has_available_multiple() { + let mut parser = StreamParser::new(); + + parser.append_data(Bytes::from_static(&[0, 1, 2])); + parser.append_data(Bytes::from_static(&[4, 5, 6])); + + let mut buffer = [0u8; 6]; + parser.copy_to_slice(&mut buffer); + + assert_eq!(buffer, [0, 1, 2, 4, 5, 6]); + } +} diff --git a/crates/scion/src/reliable/registration.rs b/crates/scion/src/reliable/registration.rs index ea48de8..77b8ada 100644 --- a/crates/scion/src/reliable/registration.rs +++ b/crates/scion/src/reliable/registration.rs @@ -1,118 +1,165 @@ use std::net::{IpAddr, SocketAddr}; -use bytes::BufMut; +use bytes::{Buf, BufMut}; -pub enum RegistrationEncodeError {} - -use super::host_address::HostAddress; +use super::wire_utils::LAYER4_PORT_OCTETS; use crate::{ - address::IA, + address::{HostAddress, IsdAsn, ServiceAddress}, reliable::{ - host_address::HostAddressType, + common_header::CommonHeader, + wire_utils::{encoded_address_and_port_length, encoded_address_length}, ADDRESS_TYPE_OCTETS, - IPV4_OCTETS, - IPV6_OCTETS, - LAYER4_PORT_OCTETS, }, }; -#[derive(Copy, Clone)] -struct ServiceAddress(u16); - -impl ServiceAddress { - pub const DAEMON: ServiceAddress = ServiceAddress(0x01); - pub const CONTROL: ServiceAddress = ServiceAddress(0x02); - pub const WILDCARD: ServiceAddress = ServiceAddress(0x10); - - // This value is private to discourage its use in non-networking code - const NONE: ServiceAddress = ServiceAddress(0xffff); - - const MULTICAST_FLAG: u16 = 0x8000; +/// A SCION port registration request to the dispatcher. +pub(super) struct RegistrationRequest { + pub isd_asn: IsdAsn, + pub public_address: SocketAddr, + pub bind_address: Option, + pub associated_service: Option, } -impl From for u16 { - fn from(value: ServiceAddress) -> Self { - value.0 +impl RegistrationRequest { + /// Return a new registration request for the specified IsdAsn and public address. + pub fn new(isd_asn: IsdAsn, public_address: SocketAddr) -> Self { + Self { + isd_asn, + public_address, + bind_address: None, + associated_service: None, + } } -} -pub struct RegistrationRequest { - isd_asn: IA, - public_address: SocketAddr, - bind_address: Option, - associated_service: Option, -} + /// Add the provided bind address to the request. + pub fn with_bind_address(mut self, address: SocketAddr) -> Self { + self.bind_address = Some(address); + self + } -impl RegistrationRequest { - pub fn encode_to(&self, buffer: &mut impl BufMut) -> usize { - let initial_length = buffer.remaining_mut(); + /// Add the provided associated service address to the request. + pub fn with_associated_service(mut self, address: ServiceAddress) -> Self { + self.associated_service = Some(address); + self + } - RegistrationEncoder(buffer).encode(self); + pub fn encoded_length(&self) -> usize { + CommonHeader::new().encoded_length() + self.encoded_request_length() + } - initial_length - buffer.remaining_mut() + /// Encode a registration request to the provided buffer. + /// + /// The encoded format includes + /// + /// 13-bytes: Common header with address type NONE + /// 1-byte: Command (bit mask with 0x04=Bind address, 0x02=SCMP enable, 0x01 always set) + /// 1-byte: L4 Proto (IANA number) + /// 8-bytes: ISD-AS + /// 2-bytes: L4 port + /// 1-byte: Address type + /// var-byte: Address + /// + /// along with an optional + /// + /// 2-bytes: L4 bind port \ + /// 1-byte: Address type ) (optional bind address) + /// var-byte: Bind Address / + /// 2-bytes: SVC (optional SVC type) + /// + /// # Panics + /// + /// Panics if there is not enough space in the buffer to encode the request. + pub fn encode_to(&self, buffer: &mut impl BufMut) { + self.encode_common_header(buffer); + self.encode_request(buffer); } -} -type EncodeResult = Result<(), RegistrationEncodeError>; + fn encode_request(&self, buffer: &mut impl BufMut) { + let initial_remaining = buffer.remaining_mut(); -struct RegistrationEncoder<'a, T>(&'a mut T); + const UDP_PROTOCOL_NUMBER: u8 = 17; -impl<'a, T> RegistrationEncoder<'a, T> -where - T: BufMut, -{ - fn new(buffer: &'a mut T) -> Self { - Self(buffer) - } + self.encode_command_flag(buffer); + + buffer.put_u8(UDP_PROTOCOL_NUMBER); + buffer.put_u64(self.isd_asn.as_u64()); - fn encode(&mut self, message: &RegistrationRequest) { - self.encode_command_field(message.bind_address.as_ref()); - self.encode_layer4_protocol(); - self.encode_isd_asn(message.isd_asn); - self.encode_address(&message.public_address); + encode_address(buffer, &self.public_address); - if let Some(bind_address) = message.bind_address.as_ref() { - self.encode_address(bind_address) + if let Some(bind_address) = self.bind_address.as_ref() { + encode_address(buffer, bind_address) + } + if let Some(service_address) = self.associated_service.as_ref() { + buffer.put_u16(u16::from(*service_address)) } - self.encode_associated_service(message.associated_service); + let written = initial_remaining - buffer.remaining_mut(); + assert_eq!(written, self.encoded_request_length()); } - fn encode_command_field(&mut self, bind_address: Option<&SocketAddr>) { - const FLAG_BASE: u8 = 0b001; - const FLAG_SCMP: u8 = 0b010; - const FLAG_BIND: u8 = 0b100; - - let mut command_field = FLAG_BASE | FLAG_SCMP; - if bind_address.is_some() { - command_field |= FLAG_BIND; + #[inline] + fn encode_common_header(&self, buffer: &mut impl BufMut) { + CommonHeader { + destination: None, + payload_length: u32::try_from(self.encoded_request_length()) + .expect("requests are short"), } - - self.0.put_u8(command_field); + .encode_to(buffer); } - fn encode_layer4_protocol(&mut self) { - const UDP_PROTOCOL_NUMBER: u8 = 18; - self.0.put_u8(UDP_PROTOCOL_NUMBER); + #[inline] + fn encoded_request_length(&self) -> usize { + const BASE_LENGTH: usize = 13; + + BASE_LENGTH + + encoded_address_length(self.public_address.host_address_type()) + + if self.bind_address.is_some() { + ADDRESS_TYPE_OCTETS + + encoded_address_and_port_length(self.bind_address.host_address_type()) + } else { + 0 + } + + encoded_address_and_port_length(self.associated_service.host_address_type()) } - fn encode_isd_asn(&mut self, isd_asn: IA) { - self.0.put_u64(isd_asn.as_u64()); + #[inline] + fn encode_command_flag(&self, buffer: &mut impl BufMut) { + const FLAG_BASE: u8 = 0b001; + const FLAG_SCMP: u8 = 0b010; + const FLAG_BIND: u8 = 0b100; + + buffer.put_u8(FLAG_BASE | FLAG_SCMP | self.bind_address.map_or(0u8, |_| FLAG_BIND)); } +} - fn encode_address(&mut self, address: &SocketAddr) { - self.0.put_u16(address.port()); - self.0.put_u8(address.host_address_type() as u8); +pub(super) struct RegistrationResponse { + pub assigned_port: u16, +} - match address.ip() { - IpAddr::V4(ipv4) => self.0.put(ipv4.octets().as_slice()), - IpAddr::V6(ipv6) => self.0.put(ipv6.octets().as_slice()), +impl RegistrationResponse { + pub const ENCODED_LENGTH: usize = LAYER4_PORT_OCTETS; + + /// Decode a registration response from the provided buffer. + /// + /// Returns None if the buffer contains less than 2 bytes. + pub fn decode(buffer: &mut impl Buf) -> Option { + if buffer.remaining() >= Self::ENCODED_LENGTH { + Some(Self { + assigned_port: buffer.get_u16(), + }) + } else { + None } } +} - fn encode_associated_service(&mut self, address: Option) { - self.0 - .put_u16(address.unwrap_or(ServiceAddress::NONE).into()); +fn encode_address(buffer: &mut impl BufMut, address: &SocketAddr) { + buffer.put_u16(address.port()); + buffer.put_u8(address.host_address_type().into()); + + match address.ip() { + IpAddr::V4(ipv4) => buffer.put(ipv4.octets().as_slice()), + IpAddr::V6(ipv6) => buffer.put(ipv6.octets().as_slice()), } } @@ -120,97 +167,69 @@ where mod tests { use super::*; - mod encode_request { - use std::str::FromStr; - + mod encode { use super::*; - #[test] - fn public_ipv4_only() { - let mut backing_array = [0u8; 50]; - let mut buffer = backing_array.as_mut_slice(); + const BUFFER_LENGTH: usize = 50; - let request = RegistrationRequest { - isd_asn: IA::from_str("1-ff00:0:1").unwrap(), - public_address: "10.2.3.4:80".parse().unwrap(), - bind_address: None, - associated_service: None, + macro_rules! test_successful { + ($name:ident, $request:expr, $expected:expr) => { + #[test] + fn $name() { + let mut backing_array = [0u8; BUFFER_LENGTH]; + let mut buffer = backing_array.as_mut_slice(); + + $request.encode_request(&mut buffer); + + let bytes_written = BUFFER_LENGTH - buffer.remaining_mut(); + assert_eq!(backing_array[..bytes_written], $expected); + } }; - request.encode_to(&mut buffer); } - } - // { - // Name: "public IPv4 address only", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{IP: net.IP{10, 2, 3, 4}, Port: 80}, - // SVCAddress: addr.SvcNone, - // }, - // ExpectedData: []byte{0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 1, - // 10, 2, 3, 4}, - // }, - - // { - // Name: "nil public address", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // SVCAddress: addr.SvcNone, - // }, - // ExpectedData: []byte{}, - // ExpectedError: ErrNoAddress, - // }, - // { - // Name: "nil public address IP", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{Port: 80}, - // SVCAddress: addr.SvcNone, - // }, - // ExpectedData: []byte{}, - // ExpectedError: ErrNoAddress, - // }, - // { - // Name: "public IPv6 address only", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80}, - // SVCAddress: addr.SvcNone, - // }, - // ExpectedData: []byte{0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, - // 0, 80, 2, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, - // }, - // { - // Name: "public address with bind", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{IP: net.IP{10, 2, 3, 4}, Port: 80}, - // BindAddress: &net.UDPAddr{IP: net.IP{10, 5, 6, 7}, Port: 81}, - // SVCAddress: addr.SvcNone, - // }, - // ExpectedData: []byte{0x07, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, - // 0, 80, 1, 10, 2, 3, 4, 0, 81, 1, 10, 5, 6, 7}, - // }, - // { - // Name: "public IPv4 address with SVC", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{IP: net.IP{10, 2, 3, 4}, Port: 80}, - // SVCAddress: addr.SvcCS, - // }, - // ExpectedData: []byte{0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, - // 80, 1, 10, 2, 3, 4, 0x00, 0x02}, - // }, - // { - // Name: "public address with bind and SVC", - // Registration: &Registration{ - // IA: xtest.MustParseIA("1-ff00:0:1"), - // PublicAddress: &net.UDPAddr{IP: net.IP{10, 2, 3, 4}, Port: 80}, - // BindAddress: &net.UDPAddr{IP: net.IP{10, 5, 6, 7}, Port: 81}, - // SVCAddress: addr.SvcCS, - // }, - // ExpectedData: []byte{0x07, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, - // 0, 80, 1, 10, 2, 3, 4, - // 0, 81, 1, 10, 5, 6, 7, 0, 2}, - // }, + use crate::test_utils::parse; + + test_successful!( + public_ipv4_only, + RegistrationRequest::new(parse!("1-ff00:0:1"), parse!("10.2.3.4:80")), + [0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 1, 10, 2, 3, 4] + ); + + test_successful!( + public_ipv6_only, + RegistrationRequest::new(parse!("1-ff00:0:1"), parse!("[2001:db8::1]:80")), + [ + 0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 2, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1 + ] + ); + + test_successful!( + public_with_bind, + RegistrationRequest::new(parse!("1-ff00:0:1"), parse!("10.2.3.4:80")) + .with_bind_address(parse!("10.5.6.7:81")), + [ + 0x07, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 1, 10, 2, 3, 4, 0, 81, 1, 10, 5, 6, + 7 + ] + ); + + test_successful!( + public_ipv4_with_service, + RegistrationRequest::new(parse!("1-ff00:0:1"), parse!("10.2.3.4:80")) + .with_associated_service(ServiceAddress::CONTROL), + [0x03, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 1, 10, 2, 3, 4, 0x00, 0x02] + ); + + test_successful!( + with_bind_and_service, + RegistrationRequest::new(parse!("1-ff00:0:1"), parse!("10.2.3.4:80")) + .with_bind_address(parse!("10.5.6.7:81")) + .with_associated_service(ServiceAddress::CONTROL), + [ + 0x07, 17, 0, 1, 0xff, 0, 0, 0, 0, 0x01, 0, 80, 1, 10, 2, 3, 4, 0, 81, 1, 10, 5, 6, + 7, 0, 2 + ] + ); + } } diff --git a/crates/scion/src/reliable/relay_protocol.rs b/crates/scion/src/reliable/relay_protocol.rs index 2d977dd..6eb212d 100644 --- a/crates/scion/src/reliable/relay_protocol.rs +++ b/crates/scion/src/reliable/relay_protocol.rs @@ -1,43 +1,383 @@ -use std::net::SocketAddr; +use std::{collections::VecDeque, net::SocketAddr}; -use bytes::{buf::Chain, Buf, Bytes}; +use bytes::{BufMut, Bytes, BytesMut}; -use super::{error::ReliableRelayError, message::RelayMessage}; +use super::{ + common_header::{CommonHeader, DecodeError}, + error::ReliableRelayError, + parser::StreamParser, + registration::{RegistrationRequest, RegistrationResponse}, +}; +use crate::address::{IsdAsn, ServiceAddress}; -pub struct Transmit(Chain); +enum State { + Initial, + RegistrationRequested { + request: RegistrationRequest, + is_sent: bool, + }, + Registered { + transmit_queue: VecDeque<(CommonHeader, Bytes)>, + port: u16, + }, + Terminated, +} + +pub struct ReliableRelayProtocol { + state: State, + parser: StreamParser, +} + +impl ReliableRelayProtocol { + const MAX_TRANSMIT_BUFFER_SIZE: usize = 1_048_576; // 1 MiB -impl Buf for Transmit { - fn remaining(&self) -> usize { - self.0.remaining() + pub fn new() -> Self { + Self { + state: State::Initial, + parser: StreamParser::new(), + } } - fn chunk(&self) -> &[u8] { - self.0.chunk() + pub fn register(&mut self, isd_asn: IsdAsn, public_address: SocketAddr) { + self.register_with_dispatcher(RegistrationRequest::new(isd_asn, public_address)) } - fn advance(&mut self, cnt: usize) { - self.0.advance(cnt) + pub fn register_service( + &mut self, + isd_asn: IsdAsn, + public_address: SocketAddr, + associated_service: ServiceAddress, + ) { + self.register_with_dispatcher( + RegistrationRequest::new(isd_asn, public_address) + .with_associated_service(associated_service), + ) + } + + fn register_with_dispatcher(&mut self, request: RegistrationRequest) { + match self.state { + State::Initial => { + self.state = State::RegistrationRequested { + request, + is_sent: false, + } + } + State::RegistrationRequested { .. } => panic!("registration already requested"), + State::Registered { .. } => panic!("already registered with the dispatcher"), + State::Terminated => panic!("protocol has already terminated"), + } } -} -#[derive(Default)] -pub struct ReliableRelayProtocol {} + pub fn poll_transmit(&mut self) -> Option> { + match &mut self.state { + State::RegistrationRequested { + request, + is_sent: is_sent @ false, + } => { + let mut buffer = BytesMut::with_capacity(request.encoded_length()); + request.encode_to(&mut buffer); -impl ReliableRelayProtocol { - pub fn new() -> Self { - ReliableRelayProtocol {} + *is_sent = true; + Some(vec![buffer.freeze()]) + } + + State::Registered { transmit_queue, .. } => { + if transmit_queue.is_empty() { + None + } else { + let buffer_length = std::cmp::min( + CommonHeader::MAX_LENGTH.saturating_mul(transmit_queue.len()), + Self::MAX_TRANSMIT_BUFFER_SIZE, + ); + let mut buffer = BytesMut::with_capacity(buffer_length); + let mut output_bytes = Vec::new(); + + while let Some((header, bytes)) = transmit_queue.pop_front() { + let header_buffer = buffer.split_to(header.encoded_length()); + + output_bytes.push(header_buffer.freeze()); + output_bytes.push(bytes); + + if buffer.remaining_mut() < CommonHeader::MAX_LENGTH { + break; + } + } + + Some(output_bytes) + } + } + + State::Initial | State::Terminated | State::RegistrationRequested { .. } => None, + } } - /// Send an encoded SCION packet over IP to a remote SCION host. - /// - /// This may be, for example, another SCION-enabled host within the same IP network or - /// a SCION border router. - pub fn relay( + pub fn send( + &mut self, scion_packet_data: Bytes, - destination: Option, - ) -> Result { - Ok(Transmit( - RelayMessage::new(scion_packet_data, destination)?.serialize(), - )) + destination: SocketAddr, + ) -> Result<(), SendError> { + match &mut self.state { + State::Initial | State::RegistrationRequested { .. } => Err(SendError::NotRegistered), + State::Terminated => Err(SendError::ProtocolTerminated), + State::Registered { transmit_queue, .. } => { + if destination.ip().is_unspecified() { + Err(SendError::DestinationUnspecified) + } else if destination.port() == 0 { + Err(SendError::DestinationPortUnspecified) + } else { + transmit_queue.push_back(( + CommonHeader { + destination: Some(destination), + payload_length: u32::try_from(scion_packet_data.len()) + .or(Err(SendError::PayloadTooLarge(scion_packet_data.len())))?, + }, + scion_packet_data, + )); + Ok(()) + } + } + } + } + + pub fn handle_incoming(&mut self, data: Bytes) -> Result<(), ReliableRelayError> { + match self.state { + State::Initial | State::RegistrationRequested { is_sent: false, .. } => { + panic!("not yet registered, cannot handle incoming data") + } + State::RegistrationRequested { is_sent: true, .. } => { + self.parser.append_data(data); + self.maybe_complete_registration() + } + State::Registered { .. } => { + self.parser.append_data(data); + Ok(()) + } + State::Terminated => panic!("protocol already terminated"), + } + } + + fn maybe_complete_registration(&mut self) -> Result<(), ReliableRelayError> { + let State::RegistrationRequested { + is_sent: true, + request, + } = &self.state else { + panic!("must only be called while awaiting registration response"); + }; + + if let Some(response) = RegistrationResponse::decode(&mut self.parser) { + let requested_port = request.public_address.port(); + + if requested_port != response.assigned_port && requested_port != 0 { + self.state = State::Terminated; + + Err(ReliableRelayError::PortMismatch { + requested: requested_port, + assigned: response.assigned_port, + }) + } else { + self.state = State::Registered { + transmit_queue: VecDeque::new(), + port: response.assigned_port, + }; + + Ok(()) + } + } else { + Ok(()) // Need more data, do nothing. + } + } + + pub fn port(&self) -> Option { + match self.state { + State::Registered { port, .. } => Some(port), + _ => None, + } + } + + pub fn receive(&mut self) -> Result<(SocketAddr, Vec), ReceiveError> { + match self.state { + State::Initial | State::RegistrationRequested { .. } => { + Err(ReceiveError::NotRegistered) + } + State::Terminated => Err(ReceiveError::ProtocolTerminated), + State::Registered { .. } => match self.parser.next_packet()? { + Some((header, bytes)) => Ok(( + // TODO(jsmith): Determine in which cases we do not receive an address + // TODO(jsmith): Handle when we do not receive an address with an error + header + .destination + .expect("there is always a desintation address"), + bytes, + )), + None => Err(ReceiveError::Blocked), + }, + } + } +} + +#[derive(thiserror::Error, Debug, Eq, PartialEq)] +pub enum ReceiveError { + #[error("currently no packets available, try again later")] + Blocked, + #[error("not yet registered")] + NotRegistered, + #[error("protocol already terminated, receive is no longer possible")] + ProtocolTerminated, + #[error("An error occured in the protocol: {0}")] + ProtocolError(#[from] DecodeError), +} + +#[derive(thiserror::Error, Debug, Eq, PartialEq)] +pub enum SendError { + #[error("not yet registered")] + NotRegistered, + #[error("protocol already terminated, receive is no longer possible")] + ProtocolTerminated, + #[error("provided destination address must be specified, not 0.0.0.0 or ::0")] + DestinationUnspecified, + #[error("provided destination port mmust be specified")] + DestinationPortUnspecified, + #[error("payload size too large ({0}), should be at most {}", u32::MAX)] + PayloadTooLarge(usize), +} + +impl Default for ReliableRelayProtocol { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + mod registration { + use super::*; + use crate::test_utils::parse; + + fn send_registration() -> ReliableRelayProtocol { + let mut relay = ReliableRelayProtocol::new(); + + relay.register(parse!("1-ff00:0:1"), parse!("10.2.3.4:80")); + + let bytes_to_send = relay.poll_transmit().expect("must have bytes to output"); + assert_eq!(bytes_to_send.len(), 1, "expected only registration message"); + assert_eq!( + bytes_to_send.first().unwrap(), + [ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 0, 0, 0, 0, 17, 0x03, 17, 0, 1, 0xff, 0, 0, + 0, 0, 0x01, 0, 80, 1, 10, 2, 3, 4 + ] + .as_slice() + ); + + relay + } + + #[test] + fn success() { + let mut relay = send_registration(); + relay + .handle_incoming(Bytes::from_static(&[0, 80])) + .expect("no error for valid response"); + assert_eq!(relay.port(), Some(80)); + } + + #[test] + fn port_mismatch() { + let mut relay = send_registration(); + let error = relay + .handle_incoming(Bytes::from_static(&[0, 81])) + .expect_err("expected port mismatch error"); + + assert_eq!( + error, + ReliableRelayError::PortMismatch { + requested: 80, + assigned: 81 + } + ); + } + + #[test] + fn incremental_data() { + let mut relay = send_registration(); + + relay + .handle_incoming(Bytes::from_static(&[0])) + .expect("no error for partial response"); + assert_eq!(relay.port(), None); + + relay + .handle_incoming(Bytes::from_static(&[80])) + .expect("no error for valid total response"); + assert_eq!(relay.port(), Some(80)); + } + } + + mod incoming_data { + + use super::*; + use crate::test_utils::parse; + + #[test] + fn full_packet() { + let mut relay = ReliableRelayProtocol { + state: State::Registered { + transmit_queue: VecDeque::new(), + port: 80, + }, + parser: StreamParser::new(), + }; + + let Err(ReceiveError::Blocked) = relay.receive() else { + panic!("expected to be blocked"); + }; + + relay + .handle_incoming(Bytes::from_static(&[ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 5, 10, 2, 3, 4, 0, 80, b'H', + b'E', b'L', b'L', b'O', + ])) + .expect("should not err"); + + let (address, data_bytes) = relay.receive().expect("data to be available"); + assert_eq!(address, parse!("10.2.3.4:80")); + assert_eq!(data_bytes.len(), 1); + assert_eq!(data_bytes[0], b"HELLO".as_slice()); + } + + #[test] + fn partial_packet() { + let mut relay = ReliableRelayProtocol { + state: State::Registered { + transmit_queue: VecDeque::new(), + port: 80, + }, + parser: StreamParser::new(), + }; + + let parts = [ + vec![0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0], + vec![0, 0, 5, 10, 2, 3, 4], + vec![0, 80, b'H'], + vec![b'E', b'L', b'L', b'O'], + ]; + + for data in parts.into_iter() { + let Err(ReceiveError::Blocked) = relay.receive() else { + panic!("expected to be blocked"); + }; + relay + .handle_incoming(Bytes::from(data)) + .expect("should not err"); + } + + let (address, data_bytes) = relay.receive().expect("data to be available"); + assert_eq!(address, parse!("10.2.3.4:80")); + assert_eq!(data_bytes.len(), 2); + assert_eq!(data_bytes[0], b"H".as_slice()); + assert_eq!(data_bytes[1], b"ELLO".as_slice()); + } } } diff --git a/crates/scion/src/reliable/wire_utils.rs b/crates/scion/src/reliable/wire_utils.rs new file mode 100644 index 0000000..a7c7f68 --- /dev/null +++ b/crates/scion/src/reliable/wire_utils.rs @@ -0,0 +1,25 @@ +use crate::address::HostType; + +pub(super) const IPV4_OCTETS: usize = 4; +pub(super) const IPV6_OCTETS: usize = 16; +pub(super) const LAYER4_PORT_OCTETS: usize = 2; + +pub(super) fn encoded_address_length(host_type: HostType) -> usize { + match host_type { + HostType::Svc => 2, + HostType::Ipv4 => IPV4_OCTETS, + HostType::Ipv6 => IPV6_OCTETS, + HostType::None => 0, + } +} + +pub(super) fn encoded_port_length(host_type: HostType) -> usize { + match host_type { + HostType::None | HostType::Svc => 0, + HostType::Ipv4 | HostType::Ipv6 => LAYER4_PORT_OCTETS, + } +} + +pub(super) fn encoded_address_and_port_length(host_type: HostType) -> usize { + encoded_address_length(host_type) + encoded_port_length(host_type) +} diff --git a/crates/scion/src/test_utils.rs b/crates/scion/src/test_utils.rs new file mode 100644 index 0000000..da049b2 --- /dev/null +++ b/crates/scion/src/test_utils.rs @@ -0,0 +1,7 @@ +macro_rules! parse { + ($string:literal) => { + $string.parse().unwrap() + }; +} + +pub(crate) use parse;