Skip to content

Commit

Permalink
feat: add stream recv buffer trait and impls (#2505)
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft authored Mar 4, 2025
1 parent 4d60027 commit 70e8de8
Show file tree
Hide file tree
Showing 13 changed files with 627 additions and 378 deletions.
12 changes: 11 additions & 1 deletion dc/s2n-quic-dc/src/stream/client/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
// SPDX-License-Identifier: Apache-2.0

use crate::{
event,
event, msg,
path::secret,
stream::{
application::Stream,
endpoint,
environment::tokio::{self as env, Environment},
recv,
socket::Protocol,
},
};
Expand All @@ -33,6 +34,7 @@ where
env,
peer,
env::UdpUnbound(acceptor_addr.into()),
recv_buffer(),
subscriber,
None,
)?;
Expand Down Expand Up @@ -87,6 +89,7 @@ where
peer_addr,
local_port,
},
recv_buffer(),
subscriber,
None,
)?;
Expand Down Expand Up @@ -126,6 +129,7 @@ where
peer_addr,
local_port,
},
recv_buffer(),
subscriber,
None,
)?;
Expand Down Expand Up @@ -153,3 +157,9 @@ where
.await
.map(|_| ())
}

#[inline]
fn recv_buffer() -> recv::shared::RecvBuffer {
// TODO replace this with a parameter once everything is in place
recv::buffer::Local::new(msg::recv::Message::new(9000), None)
}
17 changes: 7 additions & 10 deletions dc/s2n-quic-dc/src/stream/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::{
event::{self, api::Subscriber as _, IntoEvent as _},
msg, packet,
packet,
path::secret::{self, map, Map},
random::Random,
stream::{
Expand Down Expand Up @@ -37,6 +37,7 @@ pub fn open_stream<Env, P>(
env: &Env,
entry: map::Peer,
peer: P,
recv_buffer: recv::shared::RecvBuffer,
subscriber: Env::Subscriber,
parameter_override: Option<&dyn Fn(dc::ApplicationParams) -> dc::ApplicationParams>,
) -> Result<application::Builder<Env::Subscriber>>
Expand Down Expand Up @@ -76,8 +77,7 @@ where
crypto,
entry.map(),
parameters,
None,
None,
recv_buffer,
endpoint::Type::Client,
subscriber,
subscriber_ctx,
Expand All @@ -90,8 +90,7 @@ pub fn accept_stream<Env, P>(
env: &Env,
mut peer: P,
packet: &server::InitialPacket,
handshake: Option<server::handshake::Receiver>,
buffer: Option<&mut msg::recv::Message>,
recv_buffer: recv::shared::RecvBuffer,
map: &Map,
subscriber: Env::Subscriber,
subscriber_ctx: <Env::Subscriber as event::Subscriber>::ConnectionContext,
Expand Down Expand Up @@ -134,8 +133,7 @@ where
crypto,
map,
parameters,
handshake,
buffer,
recv_buffer,
endpoint::Type::Server,
subscriber,
subscriber_ctx,
Expand Down Expand Up @@ -164,8 +162,7 @@ fn build_stream<Env, P>(
crypto: secret::map::Bidirectional,
map: &Map,
parameters: dc::ApplicationParams,
handshake: Option<server::handshake::Receiver>,
recv_buffer: Option<&mut msg::recv::Message>,
recv_buffer: recv::shared::RecvBuffer,
endpoint_type: endpoint::Type,
subscriber: Env::Subscriber,
subscriber_ctx: <Env::Subscriber as event::Subscriber>::ConnectionContext,
Expand All @@ -179,7 +176,7 @@ where
let sockets = peer.setup(env)?;

// construct shared reader state
let reader = recv::shared::State::new(stream_id, &parameters, handshake, features, recv_buffer);
let reader = recv::shared::State::new(stream_id, &parameters, features, recv_buffer);

let writer = {
let worker = sockets
Expand Down
1 change: 1 addition & 0 deletions dc/s2n-quic-dc/src/stream/recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

mod ack;
pub mod application;
pub(crate) mod buffer;
mod error;
mod packet;
mod probes;
Expand Down
28 changes: 13 additions & 15 deletions dc/s2n-quic-dc/src/stream/recv/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ where

let shared = &self.shared;
let sockets = &self.sockets;
let transport_features = sockets.read_application().features();
let transport_features = sockets.features();

let mut reader = shared.receiver.application_guard(
self.ack_mode,
Expand Down Expand Up @@ -263,28 +263,25 @@ where
_ => {}
}

let before_len = reader.recv_buffer.payload_len();

let recv = reader.poll_fill_recv_buffer(
cx,
self.sockets.read_application(),
&self.shared.clock,
&self.shared.subscriber,
);

match Self::handle_socket_result(cx, &mut reader.receiver, &mut self.timer, recv) {
Poll::Ready(res) => res?,
// if we've written at least one byte then return that amount
Poll::Pending if out_buf.written_len() > 0 => break,
Poll::Pending => return Poll::Pending,
}
let recv_len =
match Self::handle_socket_result(cx, &mut reader.receiver, &mut self.timer, recv) {
Poll::Ready(res) => res?,
// if we've written at least one byte then return that amount
Poll::Pending if out_buf.written_len() > 0 => break,
Poll::Pending => return Poll::Pending,
};

// clear the forced receive after performing it once
force_recv = false;

let after_len = reader.recv_buffer.payload_len();

if before_len == after_len {
if recv_len == 0 {
if transport_features.is_stream() {
// if we got a 0-length read then the stream was closed - notify the receiver
reader.receiver.on_transport_close();
Expand All @@ -303,8 +300,8 @@ where
cx: &mut Context,
receiver: &mut recv::state::State,
timer: &mut Option<Timer>,
res: Poll<io::Result<()>>,
) -> Poll<io::Result<()>> {
res: Poll<io::Result<usize>>,
) -> Poll<io::Result<usize>> {
if let Poll::Ready(res) = res {
return res.into();
}
Expand All @@ -320,7 +317,8 @@ where
ready!(timer.poll_ready(cx));

// if the timer expired then keep going, even if the recv buffer is empty
Ok(()).into()
// we return `1` to make the caller think that something was written to the buffer
Ok(1).into()
} else {
timer.cancel();
Poll::Pending
Expand Down
2 changes: 1 addition & 1 deletion dc/s2n-quic-dc/src/stream/recv/application/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ where

let remote_addr = shared.read_remote_addr();
// we only need a timer for unreliable transports
let is_reliable = sockets.read_application().features().is_reliable();
let is_reliable = sockets.features().is_reliable();
let timer = if is_reliable {
None
} else {
Expand Down
85 changes: 85 additions & 0 deletions dc/s2n-quic-dc/src/stream/recv/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
event,
stream::{recv, socket::Socket, TransportFeatures},
};
use core::task::{Context, Poll};
use std::io;

mod dispatch;
mod local;

pub use dispatch::Dispatch;
pub use local::Local;

pub trait Buffer {
fn is_empty(&self) -> bool;

fn poll_fill<S, Pub>(
&mut self,
cx: &mut Context,
socket: &S,
publisher: &mut Pub,
) -> Poll<io::Result<usize>>
where
S: ?Sized + Socket,
Pub: event::ConnectionPublisher;

fn process<R>(
&mut self,
features: TransportFeatures,
router: &mut R,
) -> Result<(), recv::Error>
where
R: Dispatch;
}

#[allow(dead_code)] // TODO remove this once we start using the channel buffer
pub enum Either<A, B> {
A(A),
B(B),
}

impl<A, B> Buffer for Either<A, B>
where
A: Buffer,
B: Buffer,
{
#[inline]
fn is_empty(&self) -> bool {
match self {
Self::A(a) => a.is_empty(),
Self::B(b) => b.is_empty(),
}
}

#[inline]
fn poll_fill<S, Pub>(
&mut self,
cx: &mut Context,
socket: &S,
publisher: &mut Pub,
) -> Poll<io::Result<usize>>
where
S: ?Sized + Socket,
Pub: event::ConnectionPublisher,
{
match self {
Self::A(a) => a.poll_fill(cx, socket, publisher),
Self::B(b) => b.poll_fill(cx, socket, publisher),
}
}

#[inline]
fn process<R>(&mut self, features: TransportFeatures, router: &mut R) -> Result<(), recv::Error>
where
R: Dispatch,
{
match self {
Self::A(a) => a.process(features, router),
Self::B(b) => b.process(features, router),
}
}
}
55 changes: 55 additions & 0 deletions dc/s2n-quic-dc/src/stream/recv/buffer/dispatch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{packet, stream::recv};
use s2n_codec::DecoderBufferMut;
use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress};

pub trait Dispatch {
#[inline(always)]
fn tag_len(&self) -> usize {
16
}

fn on_packet(
&mut self,
remote_addr: &SocketAddress,
ecn: ExplicitCongestionNotification,
packet: packet::Packet,
) -> Result<(), recv::Error>;

#[inline]
fn on_datagram_segment(
&mut self,
remote_addr: &SocketAddress,
ecn: ExplicitCongestionNotification,
segment: &mut [u8],
) -> Result<(), recv::Error> {
let tag_len = self.tag_len();
let segment_len = segment.len();
let mut decoder = DecoderBufferMut::new(segment);

while !decoder.is_empty() {
let packet = match decoder.decode_parameterized(tag_len) {
Ok((packet, remaining)) => {
decoder = remaining;
packet
}
Err(decoder_error) => {
// the packet was likely corrupted so log it and move on to the
// next segment
tracing::warn!(
%decoder_error,
segment_len
);

break;
}
};

self.on_packet(remote_addr, ecn, packet)?;
}

Ok(())
}
}
Loading

0 comments on commit 70e8de8

Please sign in to comment.