diff --git a/neqo-bin/src/server/old_https.rs b/neqo-bin/src/server/http09.rs similarity index 95% rename from neqo-bin/src/server/old_https.rs rename to neqo-bin/src/server/http09.rs index 05520e1d3d..64b1e1be19 100644 --- a/neqo-bin/src/server/old_https.rs +++ b/neqo-bin/src/server/http09.rs @@ -17,21 +17,21 @@ use neqo_transport::{ }; use regex::Regex; -use super::{qns_read_response, Args, HttpServer}; +use super::{qns_read_response, Args}; #[derive(Default)] -struct Http09StreamState { +struct HttpStreamState { writable: bool, data_to_send: Option<(Vec, usize)>, } -pub struct Http09Server { +pub struct HttpServer { server: Server, - write_state: HashMap, + write_state: HashMap, read_state: HashMap>, } -impl Http09Server { +impl HttpServer { pub fn new( now: Instant, certs: &[impl AsRef], @@ -92,7 +92,7 @@ impl Http09Server { } else { self.write_state.insert( stream_id, - Http09StreamState { + HttpStreamState { writable: false, data_to_send: Some((resp, 0)), }, @@ -194,7 +194,7 @@ impl Http09Server { } } -impl HttpServer for Http09Server { +impl super::HttpServer for HttpServer { fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output { self.server.process(dgram, now) } @@ -210,7 +210,7 @@ impl HttpServer for Http09Server { match event { ConnectionEvent::NewStream { stream_id } => { self.write_state - .insert(stream_id, Http09StreamState::default()); + .insert(stream_id, HttpStreamState::default()); } ConnectionEvent::RecvStreamReadable { stream_id } => { self.stream_readable(stream_id, &mut acr, args); @@ -258,7 +258,7 @@ impl HttpServer for Http09Server { } } -impl Display for Http09Server { +impl Display for HttpServer { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { write!(f, "Http 0.9 server ") } diff --git a/neqo-bin/src/server/http3.rs b/neqo-bin/src/server/http3.rs new file mode 100644 index 0000000000..40a733ffb5 --- /dev/null +++ b/neqo-bin/src/server/http3.rs @@ -0,0 +1,249 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + borrow::Cow, + cell::RefCell, + cmp::min, + collections::HashMap, + fmt::{self, Display}, + path::PathBuf, + rc::Rc, + time::Instant, +}; + +use neqo_common::{qdebug, qerror, qwarn, Datagram, Header}; +use neqo_crypto::{generate_ech_keys, random, AntiReplay, Cipher}; +use neqo_http3::{ + Http3OrWebTransportStream, Http3Parameters, Http3Server, Http3ServerEvent, StreamId, +}; +use neqo_transport::{server::ValidateAddress, ConnectionIdGenerator}; + +use super::{qns_read_response, Args}; + +pub struct HttpServer { + server: Http3Server, + /// Progress writing to each stream. + remaining_data: HashMap, + posts: HashMap, +} + +impl HttpServer { + const MESSAGE: &'static [u8] = &[0; 4096]; + + pub fn new( + args: &Args, + anti_replay: AntiReplay, + cid_mgr: Rc>, + ) -> Self { + let server = Http3Server::new( + args.now(), + &[args.key.clone()], + &[args.shared.alpn.clone()], + anti_replay, + cid_mgr, + Http3Parameters::default() + .connection_parameters(args.shared.quic_parameters.get(&args.shared.alpn)) + .max_table_size_encoder(args.shared.max_table_size_encoder) + .max_table_size_decoder(args.shared.max_table_size_decoder) + .max_blocked_streams(args.shared.max_blocked_streams), + None, + ) + .expect("We cannot make a server!"); + Self { + server, + remaining_data: HashMap::new(), + posts: HashMap::new(), + } + } +} + +impl Display for HttpServer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.server.fmt(f) + } +} + +impl super::HttpServer for HttpServer { + fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> neqo_http3::Output { + self.server.process(dgram, now) + } + + fn process_events(&mut self, args: &Args, _now: Instant) { + while let Some(event) = self.server.next_event() { + match event { + Http3ServerEvent::Headers { + mut stream, + headers, + fin, + } => { + qdebug!("Headers (request={stream} fin={fin}): {headers:?}"); + + if headers + .iter() + .any(|h| h.name() == ":method" && h.value() == "POST") + { + self.posts.insert(stream, 0); + continue; + } + + let Some(path) = headers.iter().find(|&h| h.name() == ":path") else { + stream + .cancel_fetch(neqo_http3::Error::HttpRequestIncomplete.code()) + .unwrap(); + continue; + }; + + let mut response = if args.shared.qns_test.is_some() { + match qns_read_response(path.value()) { + Ok(data) => ResponseData::from(data), + Err(e) => { + qerror!("Failed to read {}: {e}", path.value()); + stream + .send_headers(&[Header::new(":status", "404")]) + .unwrap(); + stream.stream_close_send().unwrap(); + continue; + } + } + } else if let Ok(count) = + path.value().trim_matches(|p| p == '/').parse::() + { + ResponseData::repeat(Self::MESSAGE, count) + } else { + ResponseData::from(Self::MESSAGE) + }; + + stream + .send_headers(&[ + Header::new(":status", "200"), + Header::new("content-length", response.remaining.to_string()), + ]) + .unwrap(); + response.send(&mut stream); + if response.done() { + stream.stream_close_send().unwrap(); + } else { + self.remaining_data.insert(stream.stream_id(), response); + } + } + Http3ServerEvent::DataWritable { mut stream } => { + if self.posts.get_mut(&stream).is_none() { + if let Some(remaining) = self.remaining_data.get_mut(&stream.stream_id()) { + remaining.send(&mut stream); + if remaining.done() { + self.remaining_data.remove(&stream.stream_id()); + stream.stream_close_send().unwrap(); + } + } + } + } + + Http3ServerEvent::Data { + mut stream, + data, + fin, + } => { + if let Some(received) = self.posts.get_mut(&stream) { + *received += data.len(); + } + if fin { + if let Some(received) = self.posts.remove(&stream) { + let msg = received.to_string().as_bytes().to_vec(); + stream + .send_headers(&[Header::new(":status", "200")]) + .unwrap(); + stream.send_data(&msg).unwrap(); + stream.stream_close_send().unwrap(); + } + } + } + _ => {} + } + } + } + + fn set_qlog_dir(&mut self, dir: Option) { + self.server.set_qlog_dir(dir); + } + + fn validate_address(&mut self, v: ValidateAddress) { + self.server.set_validation(v); + } + + fn set_ciphers(&mut self, ciphers: &[Cipher]) { + self.server.set_ciphers(ciphers); + } + + fn enable_ech(&mut self) -> &[u8] { + let (sk, pk) = generate_ech_keys().expect("should create ECH keys"); + self.server + .enable_ech(random::<1>()[0], "public.example", &sk, &pk) + .unwrap(); + self.server.ech_config() + } + + fn has_events(&self) -> bool { + self.server.has_events() + } +} + +struct ResponseData { + data: Cow<'static, [u8]>, + offset: usize, + remaining: usize, +} + +impl From<&[u8]> for ResponseData { + fn from(data: &[u8]) -> Self { + Self::from(data.to_vec()) + } +} + +impl From> for ResponseData { + fn from(data: Vec) -> Self { + let remaining = data.len(); + Self { + data: Cow::Owned(data), + offset: 0, + remaining, + } + } +} + +impl ResponseData { + fn repeat(buf: &'static [u8], total: usize) -> Self { + Self { + data: Cow::Borrowed(buf), + offset: 0, + remaining: total, + } + } + + fn send(&mut self, stream: &mut Http3OrWebTransportStream) { + while self.remaining > 0 { + let end = min(self.data.len(), self.offset + self.remaining); + let slice = &self.data[self.offset..end]; + match stream.send_data(slice) { + Ok(0) => { + return; + } + Ok(sent) => { + self.remaining -= sent; + self.offset = (self.offset + sent) % self.data.len(); + } + Err(e) => { + qwarn!("Error writing to stream {}: {:?}", stream, e); + return; + } + } + } + } + + fn done(&self) -> bool { + self.remaining == 0 + } +} diff --git a/neqo-bin/src/server/mod.rs b/neqo-bin/src/server/mod.rs index df385119c2..bc874e413d 100644 --- a/neqo-bin/src/server/mod.rs +++ b/neqo-bin/src/server/mod.rs @@ -5,10 +5,7 @@ // except according to those terms. use std::{ - borrow::Cow, cell::RefCell, - cmp::min, - collections::HashMap, fmt::{self, Display}, fs, io, net::{SocketAddr, ToSocketAddrs}, @@ -24,25 +21,20 @@ use futures::{ future::{select, select_all, Either}, FutureExt, }; -use neqo_common::{hex, qdebug, qerror, qinfo, qwarn, Datagram, Header}; +use neqo_common::{hex, qdebug, qerror, qinfo, qwarn, Datagram}; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, - generate_ech_keys, init_db, random, AntiReplay, Cipher, + init_db, AntiReplay, Cipher, }; -use neqo_http3::{ - Http3OrWebTransportStream, Http3Parameters, Http3Server, Http3ServerEvent, StreamId, -}; -use neqo_transport::{ - server::ValidateAddress, ConnectionIdGenerator, Output, RandomConnectionIdGenerator, Version, -}; -use old_https::Http09Server; +use neqo_transport::{server::ValidateAddress, Output, RandomConnectionIdGenerator, Version}; use tokio::time::Sleep; use crate::{udp, SharedArgs}; const ANTI_REPLAY_WINDOW: Duration = Duration::from_secs(10); -mod old_https; +mod http09; +mod http3; #[derive(Debug)] pub enum Error { @@ -200,230 +192,6 @@ trait HttpServer: Display { fn enable_ech(&mut self) -> &[u8]; } -struct ResponseData { - data: Cow<'static, [u8]>, - offset: usize, - remaining: usize, -} - -impl From<&[u8]> for ResponseData { - fn from(data: &[u8]) -> Self { - Self::from(data.to_vec()) - } -} - -impl From> for ResponseData { - fn from(data: Vec) -> Self { - let remaining = data.len(); - Self { - data: Cow::Owned(data), - offset: 0, - remaining, - } - } -} - -impl ResponseData { - fn repeat(buf: &'static [u8], total: usize) -> Self { - Self { - data: Cow::Borrowed(buf), - offset: 0, - remaining: total, - } - } - - fn send(&mut self, stream: &mut Http3OrWebTransportStream) { - while self.remaining > 0 { - let end = min(self.data.len(), self.offset + self.remaining); - let slice = &self.data[self.offset..end]; - match stream.send_data(slice) { - Ok(0) => { - return; - } - Ok(sent) => { - self.remaining -= sent; - self.offset = (self.offset + sent) % self.data.len(); - } - Err(e) => { - qwarn!("Error writing to stream {}: {:?}", stream, e); - return; - } - } - } - } - - fn done(&self) -> bool { - self.remaining == 0 - } -} - -struct SimpleServer { - server: Http3Server, - /// Progress writing to each stream. - remaining_data: HashMap, - posts: HashMap, -} - -impl SimpleServer { - const MESSAGE: &'static [u8] = &[0; 4096]; - - pub fn new( - args: &Args, - anti_replay: AntiReplay, - cid_mgr: Rc>, - ) -> Self { - let server = Http3Server::new( - args.now(), - &[args.key.clone()], - &[args.shared.alpn.clone()], - anti_replay, - cid_mgr, - Http3Parameters::default() - .connection_parameters(args.shared.quic_parameters.get(&args.shared.alpn)) - .max_table_size_encoder(args.shared.max_table_size_encoder) - .max_table_size_decoder(args.shared.max_table_size_decoder) - .max_blocked_streams(args.shared.max_blocked_streams), - None, - ) - .expect("We cannot make a server!"); - Self { - server, - remaining_data: HashMap::new(), - posts: HashMap::new(), - } - } -} - -impl Display for SimpleServer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.server.fmt(f) - } -} - -impl HttpServer for SimpleServer { - fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output { - self.server.process(dgram, now) - } - - fn process_events(&mut self, args: &Args, _now: Instant) { - while let Some(event) = self.server.next_event() { - match event { - Http3ServerEvent::Headers { - mut stream, - headers, - fin, - } => { - qdebug!("Headers (request={stream} fin={fin}): {headers:?}"); - - if headers - .iter() - .any(|h| h.name() == ":method" && h.value() == "POST") - { - self.posts.insert(stream, 0); - continue; - } - - let Some(path) = headers.iter().find(|&h| h.name() == ":path") else { - stream - .cancel_fetch(neqo_http3::Error::HttpRequestIncomplete.code()) - .unwrap(); - continue; - }; - - let mut response = if args.shared.qns_test.is_some() { - match qns_read_response(path.value()) { - Ok(data) => ResponseData::from(data), - Err(e) => { - qerror!("Failed to read {}: {e}", path.value()); - stream - .send_headers(&[Header::new(":status", "404")]) - .unwrap(); - stream.stream_close_send().unwrap(); - continue; - } - } - } else if let Ok(count) = - path.value().trim_matches(|p| p == '/').parse::() - { - ResponseData::repeat(Self::MESSAGE, count) - } else { - ResponseData::from(Self::MESSAGE) - }; - - stream - .send_headers(&[ - Header::new(":status", "200"), - Header::new("content-length", response.remaining.to_string()), - ]) - .unwrap(); - response.send(&mut stream); - if response.done() { - stream.stream_close_send().unwrap(); - } else { - self.remaining_data.insert(stream.stream_id(), response); - } - } - Http3ServerEvent::DataWritable { mut stream } => { - if self.posts.get_mut(&stream).is_none() { - if let Some(remaining) = self.remaining_data.get_mut(&stream.stream_id()) { - remaining.send(&mut stream); - if remaining.done() { - self.remaining_data.remove(&stream.stream_id()); - stream.stream_close_send().unwrap(); - } - } - } - } - - Http3ServerEvent::Data { - mut stream, - data, - fin, - } => { - if let Some(received) = self.posts.get_mut(&stream) { - *received += data.len(); - } - if fin { - if let Some(received) = self.posts.remove(&stream) { - let msg = received.to_string().as_bytes().to_vec(); - stream - .send_headers(&[Header::new(":status", "200")]) - .unwrap(); - stream.send_data(&msg).unwrap(); - stream.stream_close_send().unwrap(); - } - } - } - _ => {} - } - } - } - - fn set_qlog_dir(&mut self, dir: Option) { - self.server.set_qlog_dir(dir); - } - - fn validate_address(&mut self, v: ValidateAddress) { - self.server.set_validation(v); - } - - fn set_ciphers(&mut self, ciphers: &[Cipher]) { - self.server.set_ciphers(ciphers); - } - - fn enable_ech(&mut self) -> &[u8] { - let (sk, pk) = generate_ech_keys().expect("should create ECH keys"); - self.server - .enable_ech(random::<1>()[0], "public.example", &sk, &pk) - .unwrap(); - self.server.ech_config() - } - - fn has_events(&self) -> bool { - self.server.has_events() - } -} - struct ServersRunner { args: Args, server: Box, @@ -466,7 +234,7 @@ impl ServersRunner { let mut svr: Box = if args.shared.use_old_http { Box::new( - Http09Server::new( + http09::HttpServer::new( args.now(), &[args.key.clone()], &[args.shared.alpn.clone()], @@ -477,7 +245,7 @@ impl ServersRunner { .expect("We cannot make a server!"), ) } else { - Box::new(SimpleServer::new(args, anti_replay, cid_mgr)) + Box::new(http3::HttpServer::new(args, anti_replay, cid_mgr)) }; svr.set_ciphers(&args.get_ciphers()); svr.set_qlog_dir(args.shared.qlog_dir.clone());