-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add async way to read early data from TLSAcceptor #73
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,65 @@ where | |
} | ||
} | ||
|
||
#[cfg(feature = "early-data")] | ||
impl<IO> TlsStream<IO> | ||
where | ||
IO: AsyncRead + AsyncWrite + Unpin, | ||
{ | ||
pub fn poll_read_early_data( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<io::Result<()>> { | ||
use std::io::Read; | ||
|
||
let this = self.get_mut(); | ||
let mut stream = | ||
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); | ||
|
||
match &this.state { | ||
TlsState::Stream | TlsState::WriteShutdown => { | ||
{ | ||
let mut stream = stream.as_mut_pin(); | ||
|
||
while !stream.eof && stream.session.wants_read() { | ||
match stream.read_io(cx) { | ||
Poll::Ready(Ok(0)) => { | ||
break; | ||
} | ||
Poll::Ready(Ok(_)) => (), | ||
Poll::Pending => { | ||
break; | ||
} | ||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), | ||
} | ||
} | ||
} | ||
Comment on lines
+120
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks it's duplicated from somewhere else. Can we abstract over it instead? If not, why not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some of the branches update different variables, I don't really know how to extract it |
||
|
||
if let Some(mut early_data) = stream.session.early_data() { | ||
match early_data.read(buf.initialize_unfilled()) { | ||
Ok(n) => { | ||
if n > 0 { | ||
buf.advance(n); | ||
return Poll::Ready(Ok(())); | ||
} | ||
} | ||
Err(err) => return Poll::Ready(Err(err)), | ||
} | ||
} | ||
|
||
if stream.session.is_handshaking() { | ||
return Poll::Pending; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we return pending here? or should I handshake? I'm worried about a hang due to missing wake. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there would be a missing wake? since handshake is done once the client sends it's finished, and this code depends on read. I'm not fully sure though. but now that I think about it, the test code doesn't cover this. since I do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we got |
||
} | ||
|
||
Poll::Ready(Ok(())) | ||
} | ||
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), | ||
s => unreachable!("server TLS can not hit this state: {:?}", s), | ||
} | ||
} | ||
} | ||
|
||
impl<IO> AsyncWrite for TlsStream<IO> | ||
where | ||
IO: AsyncRead + AsyncWrite + Unpin, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
#![cfg(feature = "early-data")] | ||
|
||
use std::io::{self, BufReader, Cursor, Read, Write}; | ||
use std::net::{SocketAddr, TcpListener}; | ||
use std::io::{self, BufReader, Cursor}; | ||
use std::net::SocketAddr; | ||
use std::pin::Pin; | ||
use std::sync::Arc; | ||
use std::task::{Context, Poll}; | ||
use std::thread; | ||
|
||
use futures_util::{future::Future, ready}; | ||
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream}; | ||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; | ||
use tokio::net::TcpStream; | ||
use tokio_rustls::{client::TlsStream, TlsConnector}; | ||
use pin_project_lite::pin_project; | ||
use rustls::{self, ClientConfig, RootCertStore, ServerConfig}; | ||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; | ||
use tokio::net::{TcpListener, TcpStream}; | ||
use tokio_rustls::{client, server, TlsAcceptor, TlsConnector}; | ||
|
||
struct Read1<T>(T); | ||
|
||
|
@@ -33,12 +33,32 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> { | |
} | ||
} | ||
|
||
pin_project! { | ||
struct TlsStreamEarlyWrapper<IO> { | ||
#[pin] | ||
inner: server::TlsStream<IO> | ||
} | ||
} | ||
|
||
impl<IO> AsyncRead for TlsStreamEarlyWrapper<IO> | ||
where | ||
IO: AsyncRead + AsyncWrite + Unpin, | ||
{ | ||
fn poll_read( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<io::Result<()>> { | ||
return self.project().inner.poll_read_early_data(cx, buf); | ||
tahmid-23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
async fn send( | ||
config: Arc<ClientConfig>, | ||
addr: SocketAddr, | ||
data: &[u8], | ||
vectored: bool, | ||
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> { | ||
) -> io::Result<(client::TlsStream<TcpStream>, Vec<u8>)> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change seems unrelated? If so, prefer to avoid it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave a more qualified name because I use server::TlsStream as well. should I still revert it? |
||
let connector = TlsConnector::from(config).early_data(true); | ||
let stream = TcpStream::connect(&addr).await?; | ||
let domain = pki_types::ServerName::try_from("foobar.com").unwrap(); | ||
|
@@ -75,38 +95,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { | |
.unwrap(); | ||
server.max_early_data_size = 8192; | ||
let server = Arc::new(server); | ||
let acceptor = Arc::new(TlsAcceptor::from(server)); | ||
|
||
let listener = TcpListener::bind("127.0.0.1:0")?; | ||
let listener = TcpListener::bind("127.0.0.1:0").await?; | ||
let server_port = listener.local_addr().unwrap().port(); | ||
thread::spawn(move || loop { | ||
let (mut sock, _addr) = listener.accept().unwrap(); | ||
tokio::spawn(async move { | ||
loop { | ||
let (mut sock, _addr) = listener.accept().await.unwrap(); | ||
|
||
let acceptor = acceptor.clone(); | ||
tokio::spawn(async move { | ||
let stream = acceptor.accept(&mut sock).await.unwrap(); | ||
|
||
let server = Arc::clone(&server); | ||
thread::spawn(move || { | ||
let mut conn = ServerConnection::new(server).unwrap(); | ||
conn.complete_io(&mut sock).unwrap(); | ||
let mut buf = Vec::new(); | ||
let mut stream_wrapper = TlsStreamEarlyWrapper { inner: stream }; | ||
stream_wrapper.read_to_end(&mut buf).await.unwrap(); | ||
let mut stream = stream_wrapper.inner; | ||
stream.write_all(b"EARLY:").await.unwrap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a test that explicitly tests for being able to read actual data off of the early data stream? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow. This is already reading the early data? |
||
stream.write_all(&buf).await.unwrap(); | ||
|
||
if let Some(mut early_data) = conn.early_data() { | ||
let mut buf = Vec::new(); | ||
early_data.read_to_end(&mut buf).unwrap(); | ||
let mut stream = Stream::new(&mut conn, &mut sock); | ||
stream.write_all(b"EARLY:").unwrap(); | ||
stream.write_all(&buf).unwrap(); | ||
} | ||
|
||
let mut stream = Stream::new(&mut conn, &mut sock); | ||
stream.write_all(b"LATE:").unwrap(); | ||
loop { | ||
let mut buf = [0; 1024]; | ||
let n = stream.read(&mut buf).unwrap(); | ||
if n == 0 { | ||
conn.send_close_notify(); | ||
conn.complete_io(&mut sock).unwrap(); | ||
break; | ||
} | ||
stream.write_all(&buf[..n]).unwrap(); | ||
} | ||
}); | ||
stream.read_to_end(&mut buf).await.unwrap(); | ||
stream.write_all(b"LATE:").await.unwrap(); | ||
stream.write_all(&buf).await.unwrap(); | ||
|
||
stream.shutdown().await.unwrap(); | ||
}); | ||
} | ||
}); | ||
|
||
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); | ||
|
@@ -125,7 +140,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { | |
|
||
let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?; | ||
assert!(!io.get_ref().1.is_early_data_accepted()); | ||
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf)); | ||
assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf)); | ||
|
||
let (io, buf) = send(config, addr, b"world!", vectored).await?; | ||
assert!(io.get_ref().1.is_early_data_accepted()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: would be nice to restructure this code more like this: