From 05a736e30418586cac44c9ce4327463d53dce3dd Mon Sep 17 00:00:00 2001 From: neonphog Date: Wed, 17 Apr 2024 18:07:21 -0600 Subject: [PATCH 01/33] checkpoint --- Cargo.lock | 1 + rust/sbd-client/src/lib.rs | 15 +- rust/sbd-server/Cargo.toml | 2 + rust/sbd-server/src/bin/sbd-serverd.rs | 8 +- rust/sbd-server/src/config.rs | 19 -- rust/sbd-server/src/cslot.rs | 276 ++++++++++++++++++ rust/sbd-server/src/ip_deny.rs | 24 ++ rust/sbd-server/src/ip_rate.rs | 147 ++++++++++ rust/sbd-server/src/lib.rs | 341 +++++++++++++---------- rust/sbd-server/src/ws/ws_tungstenite.rs | 2 +- rust/sbd-server/tests/suite.rs | 7 +- 11 files changed, 665 insertions(+), 177 deletions(-) create mode 100644 rust/sbd-server/src/cslot.rs create mode 100644 rust/sbd-server/src/ip_deny.rs create mode 100644 rust/sbd-server/src/ip_rate.rs diff --git a/Cargo.lock b/Cargo.lock index 1b03031..fe9c30f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1018,6 +1018,7 @@ dependencies = [ "rcgen", "sbd-client", "serde_json", + "slab", "tempfile", "tokio", "tokio-tungstenite", diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 68708e4..70fae21 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -4,6 +4,7 @@ use std::io::{Error, Result}; use std::sync::Arc; +/// defined by the sbd spec const MAX_MSG_SIZE: usize = 16000; #[cfg(feature = "raw_client")] @@ -169,27 +170,21 @@ impl SbdClient { .await?; let handshake = recv.recv().await?; - if handshake.len() != 4 + 4 + 4 + 32 { + if handshake.len() != 4 + 4 + 32 { return Err(Error::other("invalid handshake")); } - let limit_msg = i32::from_be_bytes([ + let limit_rate = i32::from_be_bytes([ handshake[4], handshake[5], handshake[6], handshake[7], ]); - let limit_rate = i32::from_be_bytes([ - handshake[8], - handshake[9], - handshake[10], - handshake[11], - ]); - println!("msg: {limit_msg}, rate: {limit_rate}"); + println!("rate: {limit_rate}"); let mut nonce = [0; 32]; - nonce.copy_from_slice(&handshake[12..]); + nonce.copy_from_slice(&handshake[8..]); let sig = crypto.sign(&nonce); diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index 5609578..a612c75 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -10,6 +10,7 @@ bytes = "1.6.0" clap = { version = "4.5.4", features = [ "color", "derive", "wrap_help" ] } ed25519-dalek = { version = "2.1.1", default-features = false } rand = "0.8.5" +slab = "0.4.9" tokio = { version = "1.37.0", features = [ "full" ] } # feature tungstenite @@ -28,6 +29,7 @@ rcgen = "0.13.1" sbd-client = { version = "0.0.1-alpha", path = "../sbd-client" } serde_json = "1.0.116" tempfile = "3.10.1" +tokio = { version = "1.37.0", features = [ "test-util" ] } [features] default = [ "tungstenite" ] diff --git a/rust/sbd-server/src/bin/sbd-serverd.rs b/rust/sbd-server/src/bin/sbd-serverd.rs index 3f7d82a..2c79e2a 100644 --- a/rust/sbd-server/src/bin/sbd-serverd.rs +++ b/rust/sbd-server/src/bin/sbd-serverd.rs @@ -4,7 +4,11 @@ use std::sync::Arc; #[tokio::main(flavor = "multi_thread")] async fn main() { let config = ::parse(); - println!("#sbd-serverd# {config:?}"); - let _server = SbdServer::new(Arc::new(config)).await.unwrap(); + println!("#sbd-serverd#note# {config:?}"); + let server = SbdServer::new(Arc::new(config)).await.unwrap(); + for addr in server.bind_addrs() { + println!("#sbd-serverd#listening# {addr:?}"); + } + println!("#sbd-serverd#ready#"); std::future::pending::<()>().await; } diff --git a/rust/sbd-server/src/config.rs b/rust/sbd-server/src/config.rs index a74ad77..2f60eaa 100644 --- a/rust/sbd-server/src/config.rs +++ b/rust/sbd-server/src/config.rs @@ -1,8 +1,6 @@ const DEF_IP_DENY_DIR: &str = "."; const DEF_IP_DENY_S: i32 = 600; -const DEF_LIMIT_FROM_IP: i32 = 4; const DEF_LIMIT_CLIENTS: i32 = 32768; -const DEF_LIMIT_MESSAGE_BYTES: i32 = 16000; const DEF_LIMIT_IP_BYTE_NANOS: i32 = 8000; const DEF_LIMIT_IP_BYTE_BURST: i32 = 32768; @@ -81,35 +79,20 @@ pub struct Config { #[arg(long)] pub bind_prometheus: Option, - /// Limit connection count from a single IP. - #[arg(long, default_value_t = DEF_LIMIT_FROM_IP)] - pub limit_from_ip: i32, - /// Limit client connections. #[arg(long, default_value_t = DEF_LIMIT_CLIENTS)] pub limit_clients: i32, - /// Limit the size of individual messages in bytes. - /// The default is 384 bytes short of 16KiB to account for overhead. - #[arg(long, default_value_t = DEF_LIMIT_MESSAGE_BYTES)] - pub limit_message_bytes: i32, - /// How often in nanoseconds 1 byte is allowed to be sent from an IP. /// The default value of 8000 results in ~1 mbps being allowed. /// If the default of 32768 connections were all sending this amount /// at the same time, the server would need a ~33.6 gbps connection. - /// Note, this limit is sent to clients as the limit for an individual - /// connection. The limit on the server will be multiplied by the value - /// of `limit_from_ip`. #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_NANOS)] pub limit_ip_byte_nanos: i32, /// Allow IPs to burst by this byte count. /// If the max message size is 16K, this value must be at least 16K. /// The default value provides 2 * 16K for an additional buffer. - /// Note, this limit is not sent to clients but is the limit for an - /// individual connection. The limit on the server will be multiplied - /// by the value of `limit_from_ip`. #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_BURST)] pub limit_ip_byte_burst: i32, } @@ -129,9 +112,7 @@ impl Default for Config { back_allow_ip: Vec::new(), back_open: Vec::new(), bind_prometheus: None, - limit_from_ip: DEF_LIMIT_FROM_IP, limit_clients: DEF_LIMIT_CLIENTS, - limit_message_bytes: DEF_LIMIT_MESSAGE_BYTES, limit_ip_byte_nanos: DEF_LIMIT_IP_BYTE_NANOS, limit_ip_byte_burst: DEF_LIMIT_IP_BYTE_BURST, } diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs new file mode 100644 index 0000000..ddfe49e --- /dev/null +++ b/rust/sbd-server/src/cslot.rs @@ -0,0 +1,276 @@ +//! Attempt to pre-allocate as much as possible, including our tokio tasks. +//! Ideally this would include a frame buffer that we could fill on ws +//! recv and use ase a reference for ws send, but alas, fastwebsockets +//! doesn't seem up to the task. tungstenite will willy-nilly allocate +//! buffers for us, but at least we should only be dealing with one at a +//! time per connection. + +use super::*; +use std::sync::{Arc, Mutex, Weak}; +use std::collections::HashMap; + +static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); + +enum TaskMsg { + NewWs { + uniq: u64, + index: usize, + ws: Arc>, + ip: Arc, + pk: PubKey, + }, + Close, +} + +struct SlotEntry { + send: tokio::sync::mpsc::UnboundedSender, +} + +struct SlabEntry { + uniq: u64, + weak_ws: Weak>, +} + +struct CSlotInner { + max_count: usize, + slots: Vec, + slab: slab::Slab, + pk_to_index: HashMap, + ip_to_index: HashMap, Vec>, + task_list: Vec>, +} + +impl Drop for CSlotInner { + fn drop(&mut self) { + for task in self.task_list.iter() { + task.abort(); + } + } +} + +struct WeakCSlot(Weak>); + +impl WeakCSlot { + pub fn upgrade(&self) -> Option { + self.0.upgrade().map(CSlot) + } +} + +pub struct CSlot(Arc>); + +impl CSlot { + pub fn new( + count: usize, + ip_deny: Arc, + ip_rate: Arc, + ) -> Self { + Self(Arc::new_cyclic(|this| { + let mut slots = Vec::with_capacity(count); + let mut task_list = Vec::with_capacity(count); + for _ in 0..count { + let (send, recv) = tokio::sync::mpsc::unbounded_channel(); + slots.push(SlotEntry { + send, + }); + tokio::task::spawn(top_task( + ip_deny.clone(), + ip_rate.clone(), + WeakCSlot(this.clone()), + recv, + )); + } + Mutex::new(CSlotInner { + max_count: count, + slots, + slab: slab::Slab::with_capacity(count), + pk_to_index: HashMap::with_capacity(count), + ip_to_index: HashMap::with_capacity(count), + task_list, + }) + })) + } + + pub fn remove(&self, uniq: u64, index: usize) { + let mut lock = self.0.lock().unwrap(); + + match lock.slab.get(index) { + None => return, + Some(s) => { + if s.uniq != uniq { + return; + } + } + } + + let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close); + lock.slab.remove(index); + lock.pk_to_index.retain(|_, i| *i != index); + lock.ip_to_index.retain(|_, v| { + v.retain(|i| *i != index); + !v.is_empty() + }); + } + + pub fn insert( + &self, + ip: Arc, + pk: PubKey, + ws: Arc> + ) -> Result { + let mut lock = self.0.lock().unwrap(); + + if lock.slab.len() >= lock.max_count { + return Err(Error::other("too many connections")); + } + + let weak_ws = Arc::downgrade(&ws); + + let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let index = lock.slab.insert(SlabEntry { + uniq, + weak_ws, + }); + + lock.pk_to_index.insert(pk.clone(), index); + + // TODO - should we block more than Vec::with_capacity(count) + // connections from the same IP so we avoid allocating + // here? Or set this to the max connection count value? + + lock + .ip_to_index + .entry(ip.clone()) + .or_insert_with(|| Vec::with_capacity(1024)) + .push(index); + + // TODO - send rate updates to all clients on this ip + + let send = lock.slots.get(index).unwrap().send.clone(); + if send.send(TaskMsg::NewWs { uniq, index, ws, ip, pk }).is_err() { + return Err(Error::other("closed")); + } + + Ok(index) + } + + pub async fn send( + &self, + pk: &PubKey, + payload: Payload<'_>, + ) -> Result<()> { + let (uniq, index, ws) = { + // XXX - DO NOT AWAIT IN THIS BLOCK + let lock = self.0.lock().unwrap(); + + let index = match lock.pk_to_index.get(&pk) { + None => return Err(Error::other("no such peer")), + Some(index) => *index, + }; + + let slab = lock.slab.get(index).unwrap(); + let uniq = slab.uniq; + let ws = match slab.weak_ws.upgrade() { + None => return Err(Error::other("no such peer")), + Some(ws) => ws, + }; + + (uniq, index, ws) + }; + + match ws.send(payload).await { + Err(err) => { + self.remove(uniq, index); + Err(err) + } + Ok(_) => Ok(()), + } + } +} + +async fn top_task( + ip_deny: Arc, + ip_rate: Arc, + weak: WeakCSlot, + mut recv: tokio::sync::mpsc::UnboundedReceiver, +) { + while let Some(task_msg) = recv.recv().await { + match task_msg { + TaskMsg::NewWs { uniq, index, ws, ip, pk } => { + tokio::select! { + task_msg = recv.recv() => { + match task_msg { + None => break, + Some(TaskMsg::Close) => (), + _ => unreachable!(), + } + }, + _ = ws_task( + &ip_deny, + &ip_rate, + &weak, + index, + ws, + ip, + pk, + ) => (), + } + if let Some(cslot) = weak.upgrade() { + cslot.remove(uniq, index); + } + } + _ => unreachable!(), + } + } +} + +async fn ws_task( + ip_deny: &ip_deny::IpDeny, + ip_rate: &ip_rate::IpRate, + weak: &WeakCSlot, + index: usize, + ws: Arc>, + ip: Arc, + pk: PubKey, +) { + while let Ok(mut payload) = ws.recv().await { + if !ip_rate.is_ok(*ip, payload.len()) { + ip_deny.block(*ip).await.unwrap(); + break; + } + + if payload.len() < 32 { + break; + } + + const KEEPALIVE: &[u8; 32] = &[0; 32]; + + let dest = { + let payload = payload.to_mut(); + + if &payload[..32] == KEEPALIVE { + // TODO - keepalive + continue; + } + + if &payload[..32] == &pk.0[..] { + // no self-sends + break; + } + + let mut dest = [0; 32]; + dest.copy_from_slice(&payload[..32]); + let dest = PubKey(Arc::new(dest)); + + payload[..32].copy_from_slice(&pk.0[..]); + + dest + }; + + if let Some(cslot) = weak.upgrade() { + let _ = cslot.send(&dest, payload).await; + } else { + break; + } + } +} diff --git a/rust/sbd-server/src/ip_deny.rs b/rust/sbd-server/src/ip_deny.rs new file mode 100644 index 0000000..8ff17e1 --- /dev/null +++ b/rust/sbd-server/src/ip_deny.rs @@ -0,0 +1,24 @@ +//! CURRENTLY A STUB!! + +use crate::*; + +pub struct IpDeny; + +impl IpDeny { + /// Construct a new filesystem-based ip deny list. + pub fn new(_config: Arc) -> Self { + Self + } + + /// Check if a given ip is blocked. + pub async fn is_blocked(&self, _ip: std::net::Ipv6Addr) -> Result { + // THIS IS A STUB!! + Ok(false) + } + + /// Block a given ip. + pub async fn block(&self, _ip: std::net::Ipv6Addr) -> Result<()> { + // THIS IS A STUB!! + Ok(()) + } +} diff --git a/rust/sbd-server/src/ip_rate.rs b/rust/sbd-server/src/ip_rate.rs new file mode 100644 index 0000000..0af3ccf --- /dev/null +++ b/rust/sbd-server/src/ip_rate.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +type Map = HashMap; + +#[derive(Clone)] +pub struct IpRate { + origin: tokio::time::Instant, + map: Arc>, + limit: u64, + burst: u64, +} + +impl IpRate { + pub fn new(limit: u64, burst: u64) -> Self { + Self { + origin: tokio::time::Instant::now(), + map: Arc::new(Mutex::new(HashMap::new())), + limit, + burst, + } + } + + /// Prune entries that have tracked backwards 10s or more. + /// The 10s just prevents hashtable thrashing if a connection + /// is using significantly less than its rate limit. + /// This is why the keepalive interval is 5 seconds and + /// connections are closed after 10 seconds. + pub fn prune(&self) { + let now = self.origin.elapsed().as_nanos() as u64; + self.map.lock().unwrap().retain(|_, cur| { + if now <= *cur { + true + } else { + // examples using seconds: + // now:100,cur:120 100-120=-20<10 true=keep + // now:100,cur:100 100-100=0<10 true=keep + // now:100,cur:80 100-80=20<10 false=prune + now - *cur < 10_000_000_000 + } + }); + } + + /// Return true if we are not over the rate limit. + pub fn is_ok(&self, ip: std::net::Ipv6Addr, bytes: usize) -> bool { + // multiply by our rate allowed per byte + let rate_add = bytes as u64 * self.limit; + + // get now + let now = self.origin.elapsed().as_nanos() as u64; + + // lock the map mutex + let mut lock = self.map.lock().unwrap(); + + // get the entry (default to now) + let e = lock.entry(ip).or_insert(now); + + // if we've already used time greater than now use that, + // otherwise consider we're starting from scratch + let cur = std::cmp::max(*e, now) + rate_add; + + // update the map with the current limit + *e = cur; + + // subtract now back out to see if we're greater than our burst + cur - now <= self.burst + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const ADDR1: std::net::Ipv6Addr = + std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1); + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn check_one_to_one() { + let rate = IpRate::new(1, 1); + + for _ in 0..10 { + // should always be ok when advancing with time + tokio::time::advance(std::time::Duration::from_nanos(1)).await; + assert!(rate.is_ok(ADDR1, 1)); + } + + // but one more without a time advance fails + assert!(!rate.is_ok(ADDR1, 1)); + + tokio::time::advance(std::time::Duration::from_nanos(1)).await; + + // make sure prune doesn't prune it yet + rate.prune(); + assert_eq!(1, rate.map.lock().unwrap().len()); + + tokio::time::advance(std::time::Duration::from_secs(10)).await; + + // make sure prune doesn't prune it yet + rate.prune(); + assert_eq!(1, rate.map.lock().unwrap().len()); + + // but one more should do it + tokio::time::advance(std::time::Duration::from_nanos(1)).await; + rate.prune(); + assert_eq!(0, rate.map.lock().unwrap().len()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn check_burst() { + let rate = IpRate::new(1, 5); + + for _ in 0..5 { + assert!(rate.is_ok(ADDR1, 1)); + } + + assert!(!rate.is_ok(ADDR1, 1)); + + tokio::time::advance(std::time::Duration::from_nanos(2)).await; + assert!(rate.is_ok(ADDR1, 1)); + + tokio::time::advance(std::time::Duration::from_secs(10)).await; + tokio::time::advance(std::time::Duration::from_nanos(4)).await; + + rate.prune(); + assert_eq!(1, rate.map.lock().unwrap().len()); + + tokio::time::advance(std::time::Duration::from_nanos(1)).await; + + rate.prune(); + assert_eq!(0, rate.map.lock().unwrap().len()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn check_limit_mult() { + let rate = IpRate::new(3, 13); + + assert!(rate.is_ok(ADDR1, 2)); + assert!(rate.is_ok(ADDR1, 2)); + assert!(!rate.is_ok(ADDR1, 2)); + + tokio::time::advance(std::time::Duration::from_secs(10)).await; + + assert!(rate.is_ok(ADDR1, 2)); + assert!(rate.is_ok(ADDR1, 2)); + assert!(!rate.is_ok(ADDR1, 2)); + } +} diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index d1610c2..d8cddc0 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -1,6 +1,9 @@ //! Sbd server library. #![deny(missing_docs)] +/// defined by the sbd spec +const MAX_MSG_BYTES: i32 = 16000; + use std::io::{Error, Result}; use std::sync::{Arc, Mutex, Weak}; @@ -10,6 +13,11 @@ pub use config::*; mod maybe_tls; use maybe_tls::*; +mod ip_deny; +mod ip_rate; + +mod cslot; + /// Websocket backend abstraction. pub mod ws { /// Payload. @@ -121,6 +129,27 @@ impl ClientMap { self.0.insert(pub_key, client_info); } + pub fn remove_ws( + &mut self, + pub_key: &PubKey, + subj_ws: &Arc>, + ) { + let should_drop = + if let Some(ClientInfo::Local { ws, .. }) = self.0.get(pub_key) { + if Arc::ptr_eq(subj_ws, ws) { + true + } else { + false + } + } else { + false + }; + + if should_drop { + self.0.remove(pub_key); + } + } + pub fn get_ws( &mut self, pub_key: &PubKey, @@ -134,10 +163,9 @@ impl ClientMap { /// SbdServer. pub struct SbdServer { - config: Arc, task_list: Vec>, bind_addrs: Vec, - client_map: Arc>, + _client_map: Arc>, } impl Drop for SbdServer { @@ -150,102 +178,192 @@ impl Drop for SbdServer { async fn check_accept_connection( config: Arc, + ip_deny: Arc, + ip_rate: Arc, tcp: MaybeTlsStream, addr: std::net::SocketAddr, weak_client_map: Weak>, ) { - const PROTO_VER: &[u8; 4] = b"sbd0"; - let limit_msg = config.limit_message_bytes.to_be_bytes(); - let limit_rate = config.limit_ip_byte_nanos.to_be_bytes(); - let raw_ip = match addr.ip() { std::net::IpAddr::V4(ip) => ip.to_ipv6_mapped(), std::net::IpAddr::V6(ip) => ip, }; + drop(addr); - // TODO if config.trusted_ip_header.is_none() do the ip check BEFORE upgrade + let mut calc_ip = raw_ip; - let (ws, pub_key, ip) = ws::WebSocket::upgrade(config, tcp).await.unwrap(); + let use_trusted_ip = config.trusted_ip_header.is_some(); - let ip: std::net::Ipv6Addr = if let Some(ip) = ip { ip } else { raw_ip }; + let (pub_key, client_info) = + match tokio::time::timeout(std::time::Duration::from_secs(10), async { + const PROTO_VER: &[u8; 4] = b"sbd0"; + let limit_rate = config.limit_ip_byte_nanos.to_be_bytes(); - // TODO if config.trusted_ip_header.is_some() do the ip check AFTER upgrade + if !use_trusted_ip { + // Do this check BEFORE handshake to avoid extra + // server process when capable. + // If we *are* behind a reverse proxy, we assume + // some amount of DDoS mitigation is happening there + // and thus we can accept a little more process overhead + if ip_deny.is_blocked(raw_ip).await.unwrap() { + return Err(Error::other("ip blocked")); + } - use rand::Rng; - let mut nonce = [0xdb; 32]; - rand::thread_rng().fill(&mut nonce[..]); + // Also precheck our rate limit, using up one byte + if !ip_rate.is_ok(raw_ip, 1) { + ip_deny.block(raw_ip).await.unwrap(); + return Err(Error::other("ip rate limited")); + } + } - let mut msg = Vec::with_capacity(4 + 4 + 4 + 32); - msg.extend_from_slice(&PROTO_VER[..]); - msg.extend_from_slice(&limit_msg[..]); - msg.extend_from_slice(&limit_rate[..]); - msg.extend_from_slice(&nonce[..]); + // TODO TLS upgrade - ws.send(Payload::Vec(msg)).await.unwrap(); + let (ws, pub_key, ip) = + ws::WebSocket::upgrade(config, tcp).await.unwrap(); - let sig = ws.recv().await.unwrap(); - if sig.len() != 64 { - return; - } - let mut sig_sized = [0; 64]; - sig_sized.copy_from_slice(sig.as_ref()); - if !pub_key.verify(&sig_sized, &nonce) { - return; - } + if let Some(ip) = ip { + calc_ip = ip; + } - let ws = Arc::new(ws); - let ws2 = ws.clone(); - let weak_client_map2 = weak_client_map.clone(); - let pub_key2 = pub_key.clone(); - let read_task = tokio::task::spawn(async move { - while let Ok(mut payload) = ws2.recv().await { - // TODO - rate limiting + if use_trusted_ip { + // if using a trusted ip, check block here. + // see note above before the handshakes. + if ip_deny.is_blocked(calc_ip).await.unwrap() { + return Err(Error::other("ip blocked")); + } - if payload.len() < 32 { - break; + // Also precheck our rate limit, using up one byte + if !ip_rate.is_ok(calc_ip, 1) { + ip_deny.block(calc_ip).await.unwrap(); + return Err(Error::other("ip rate limited")); + } } - const KEEPALIVE: &[u8; 32] = &[0; 32]; + use rand::Rng; + let mut nonce = [0xdb; 32]; + rand::thread_rng().fill(&mut nonce[..]); - let dest = { - let payload = payload.to_mut(); + let mut msg = Vec::with_capacity(4 + 4 + 32); + msg.extend_from_slice(&PROTO_VER[..]); + msg.extend_from_slice(&limit_rate[..]); + msg.extend_from_slice(&nonce[..]); - if &payload[..32] == KEEPALIVE { - // TODO - keepalive - continue; - } + ws.send(Payload::Vec(msg)).await.unwrap(); - let mut dest = [0; 32]; - dest.copy_from_slice(&payload[..32]); - let dest = PubKey(Arc::new(dest)); + let sig = ws.recv().await.unwrap(); - payload[..32].copy_from_slice(&pub_key2.0[..]); + // use up 64 bytes of rate + if !ip_rate.is_ok(calc_ip, 64) { + ip_deny.block(calc_ip).await.unwrap(); + return Err(Error::other("ip rate limited")); + } - dest - }; + if sig.len() != 64 { + return Err(Error::other("invalid sig len")); + } + let mut sig_sized = [0; 64]; + sig_sized.copy_from_slice(sig.as_ref()); + if !pub_key.verify(&sig_sized, &nonce) { + return Err(Error::other("invalid sig")); + } - if let Some(client_map) = weak_client_map2.upgrade() { - let ws = client_map.lock().unwrap().get_ws(&dest); - if let Some(ws) = ws { - if ws.send(payload).await.is_err() { - break; + let ws = Arc::new(ws); + + struct DoDrop { + pub_key: PubKey, + ws: Weak>, + client_map: Weak>, + } + + impl Drop for DoDrop { + fn drop(&mut self) { + if let Some(client_map) = self.client_map.upgrade() { + if let Some(ws) = self.ws.upgrade() { + client_map + .lock() + .unwrap() + .remove_ws(&self.pub_key, &ws); + } } } - } else { - break; } - } - // TODO delete us from the client map - }); + let do_drop = DoDrop { + pub_key: pub_key.clone(), + ws: Arc::downgrade(&ws), + client_map: weak_client_map.clone(), + }; + + let ws2 = ws.clone(); + let weak_client_map2 = weak_client_map.clone(); + let pub_key2 = pub_key.clone(); + let read_task = tokio::task::spawn(async move { + let _do_drop = do_drop; + + while let Ok(mut payload) = ws2.recv().await { + if !ip_rate.is_ok(calc_ip, payload.len()) { + ip_deny.block(calc_ip).await.unwrap(); + break; + } + + if payload.len() < 32 { + break; + } + + const KEEPALIVE: &[u8; 32] = &[0; 32]; + + let dest = { + let payload = payload.to_mut(); + + if &payload[..32] == KEEPALIVE { + // TODO - keepalive + continue; + } + + if &payload[..32] == &pub_key2.0[..] { + // no self-sends + break; + } + + let mut dest = [0; 32]; + dest.copy_from_slice(&payload[..32]); + let dest = PubKey(Arc::new(dest)); + + payload[..32].copy_from_slice(&pub_key2.0[..]); + + dest + }; + + if let Some(client_map) = weak_client_map2.upgrade() { + let ws = client_map.lock().unwrap().get_ws(&dest); + if let Some(ws) = ws { + if ws.send(payload).await.is_err() { + break; + } + } + } else { + break; + } + } + }); + + Ok(( + pub_key, + ClientInfo::Local { + ws, + ip: calc_ip, + read_task, + }, + )) + }) + .await + { + Ok(Ok(r)) => r, + _ => return, + }; if let Some(client_map) = weak_client_map.upgrade() { - client_map - .lock() - .unwrap() - .insert(pub_key, ClientInfo::Local { ws, ip, read_task }); - } else { - read_task.abort(); + client_map.lock().unwrap().insert(pub_key, client_info); } } @@ -256,6 +374,14 @@ impl SbdServer { let mut bind_addrs = Vec::new(); let client_map = Arc::new(Mutex::new(ClientMap::default())); + let ip_deny = Arc::new(ip_deny::IpDeny::new(config.clone())); + + let ip_rate = Arc::new(ip_rate::IpRate::new( + config.limit_ip_byte_nanos as u64, + config.limit_ip_byte_nanos as u64 + * config.limit_ip_byte_burst as u64, + )); + let weak_client_map = Arc::downgrade(&client_map); for bind in config.bind.iter() { let a: std::net::SocketAddr = bind.parse().map_err(Error::other)?; @@ -265,6 +391,8 @@ impl SbdServer { let config = config.clone(); let weak_client_map = weak_client_map.clone(); + let ip_deny = ip_deny.clone(); + let ip_rate = ip_rate.clone(); task_list.push(tokio::task::spawn(async move { loop { if let Ok((tcp, addr)) = tcp.accept().await { @@ -272,6 +400,8 @@ impl SbdServer { // just let this task die on its own time tokio::task::spawn(check_accept_connection( config.clone(), + ip_deny.clone(), + ip_rate.clone(), tcp, addr, weak_client_map.clone(), @@ -282,10 +412,9 @@ impl SbdServer { } Ok(Self { - config, task_list, bind_addrs, - client_map, + _client_map: client_map, }) } @@ -294,79 +423,3 @@ impl SbdServer { self.bind_addrs.as_slice() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test(flavor = "multi_thread")] - async fn sanity() { - let tmp = tempfile::tempdir().unwrap(); - let tmp_dir = tmp.path().to_owned(); - let rcgen::CertifiedKey { cert, key_pair } = - rcgen::generate_simple_self_signed(vec!["localhost".to_string()]) - .unwrap(); - let mut cert_path = tmp_dir.clone(); - cert_path.push("cert.pem"); - tokio::fs::write(&cert_path, cert.pem()).await.unwrap(); - let mut key_path = tmp_dir.clone(); - key_path.push("key.pem"); - tokio::fs::write(&key_path, key_pair.serialize_pem()) - .await - .unwrap(); - - let mut config = Config::default(); - config.cert_pem_file = Some(cert_path); - config.priv_key_pem_file = Some(key_path); - config.bind.push("127.0.0.1:0".into()); - println!("{config:?}"); - - let server = SbdServer::new(Arc::new(config)).await.unwrap(); - - let addr = server.bind_addrs()[0].clone(); - - println!("addr: {addr:?}"); - - let (client1, url1, pk1, mut rcv1) = - sbd_client::SbdClient::connect_config( - &format!("ws://{addr}"), - &sbd_client::DefaultCrypto::default(), - sbd_client::SbdClientConfig { - allow_plain_text: true, - ..Default::default() - }, - ) - .await - .unwrap(); - - println!("client url1: {url1}"); - - let (client2, url2, pk2, mut rcv2) = - sbd_client::SbdClient::connect_config( - &format!("ws://{addr}"), - &sbd_client::DefaultCrypto::default(), - sbd_client::SbdClientConfig { - allow_plain_text: true, - ..Default::default() - }, - ) - .await - .unwrap(); - - println!("client url2: {url2}"); - - client1.send(&pk2, b"hello").await.unwrap(); - - let res_data = rcv2.recv().await.unwrap(); - - assert_eq!(&pk1.0, res_data.pub_key_ref()); - assert_eq!(&b"hello"[..], res_data.message()); - - client2.send(&pk1, b"world").await.unwrap(); - - let res_data = rcv1.recv().await.unwrap(); - - assert_eq!(&pk2.0, res_data.pub_key_ref()); - assert_eq!(&b"world"[..], res_data.message()); - } -} diff --git a/rust/sbd-server/src/ws/ws_tungstenite.rs b/rust/sbd-server/src/ws/ws_tungstenite.rs index fefc073..5a47b29 100644 --- a/rust/sbd-server/src/ws/ws_tungstenite.rs +++ b/rust/sbd-server/src/ws/ws_tungstenite.rs @@ -33,7 +33,7 @@ where }; let mut trusted_ip = None; let mut ws = WebSocketConfig::default(); - ws.max_message_size = Some(config.limit_message_bytes as usize); + ws.max_message_size = Some(MAX_MSG_BYTES as usize); struct Cb(tokio::sync::oneshot::Sender); impl server::Callback for Cb { fn on_request( diff --git a/rust/sbd-server/tests/suite.rs b/rust/sbd-server/tests/suite.rs index b7887ef..ee77980 100644 --- a/rust/sbd-server/tests/suite.rs +++ b/rust/sbd-server/tests/suite.rs @@ -18,5 +18,10 @@ fn suite() { .unwrap(); println!("RUNNING the test suite {:?}", suite.path()); - assert!(suite.command().arg(server.path()).status().unwrap().success()); + assert!(suite + .command() + .arg(server.path()) + .status() + .unwrap() + .success()); } From 8497c1454d44814a82455536f21b8663d6f60196 Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 18 Apr 2024 14:29:06 -0600 Subject: [PATCH 02/33] checkpoint --- rust/sbd-client/src/lib.rs | 130 +++++-- .../src/bin/sbd-server-test-suite-bin.rs | 5 +- rust/sbd-server-test-suite/src/it.rs | 23 +- rust/sbd-server-test-suite/src/it/it_1.rs | 5 +- rust/sbd-server-test-suite/src/lib.rs | 6 +- rust/sbd-server/src/cmd.rs | 85 +++++ rust/sbd-server/src/cslot.rs | 278 ++++++++++----- rust/sbd-server/src/ip_deny.rs | 7 +- rust/sbd-server/src/ip_rate.rs | 101 ++++-- rust/sbd-server/src/lib.rs | 329 +++++------------- 10 files changed, 559 insertions(+), 410 deletions(-) create mode 100644 rust/sbd-server/src/cmd.rs diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 70fae21..dc75a8a 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -20,7 +20,7 @@ pub trait Crypto { fn pub_key(&self) -> &[u8; 32]; /// Sign the nonce. - fn sign(&self, nonce: &[u8; 32]) -> [u8; 64]; + fn sign(&self, nonce: &[u8]) -> [u8; 64]; } #[cfg(feature = "crypto")] @@ -42,9 +42,9 @@ mod default_crypto { &self.0 } - fn sign(&self, nonce: &[u8; 32]) -> [u8; 64] { + fn sign(&self, nonce: &[u8]) -> [u8; 64] { use ed25519_dalek::Signer; - self.1.sign(&nonce[..]).to_bytes() + self.1.sign(nonce).to_bytes() } } } @@ -64,6 +64,22 @@ impl std::fmt::Debug for PubKey { } } +const CMD_FLAG: &[u8; 28] = &[0; 28]; + +enum MsgType<'t> { + Msg { + #[allow(dead_code)] + pub_key: &'t [u8], + #[allow(dead_code)] + message: &'t [u8], + }, + LimitByteNanos(i32), + LimitIdleMillis(i32), + AuthReq(&'t [u8]), + Ready, + Unknown, +} + /// A message received from a remote. /// This is just a single buffer. The first 32 bytes are the public key /// of the sender. Any remaining bytes are the message. The buffer @@ -85,6 +101,47 @@ impl Msg { pub fn message(&self) -> &[u8] { &self.0[32..] } + + // -- private -- // + + fn parse(&self) -> Result> { + if self.0.len() < 32 { + return Err(Error::other("invalid message length")); + } + if &self.0[..28] == CMD_FLAG { + match &self.0[28..32] { + b"lbrt" => { + if self.0.len() != 32 + 4 { + return Err(Error::other("invalid lbrt length")); + } + Ok(MsgType::LimitByteNanos(i32::from_be_bytes( + self.0[32..].try_into().unwrap(), + ))) + } + b"lidl" => { + if self.0.len() != 32 + 4 { + return Err(Error::other("invalid lidl length")); + } + Ok(MsgType::LimitIdleMillis(i32::from_be_bytes( + self.0[32..].try_into().unwrap(), + ))) + } + b"areq" => { + if self.0.len() != 32 + 32 { + return Err(Error::other("invalid areq length")); + } + Ok(MsgType::AuthReq(&self.0[32..])) + } + b"srdy" => Ok(MsgType::Ready), + _ => Ok(MsgType::Unknown), + } + } else { + Ok(MsgType::Msg { + pub_key: &self.0[..32], + message: &self.0[32..], + }) + } + } } /// Handle to receive data from the sbd connection. @@ -169,40 +226,51 @@ impl SbdClient { .connect() .await?; - let handshake = recv.recv().await?; - if handshake.len() != 4 + 4 + 32 { - return Err(Error::other("invalid handshake")); - } - - let limit_rate = i32::from_be_bytes([ - handshake[4], - handshake[5], - handshake[6], - handshake[7], - ]); - - println!("rate: {limit_rate}"); - - let mut nonce = [0; 32]; - nonce.copy_from_slice(&handshake[8..]); + let mut limit_byte_nanos = 8000; - let sig = crypto.sign(&nonce); + loop { + match Msg(recv.recv().await?).parse()? { + MsgType::Msg { .. } => { + return Err(Error::other("invalid handshake")) + } + MsgType::LimitByteNanos(l) => limit_byte_nanos = l, + MsgType::LimitIdleMillis(_) => (), + MsgType::AuthReq(nonce) => { + let sig = crypto.sign(&nonce); + let mut auth_res = Vec::with_capacity(32 + 64); + auth_res.extend_from_slice(CMD_FLAG); + auth_res.extend_from_slice(b"ares"); + auth_res.extend_from_slice(&sig); + send.send(auth_res).await?; + } + MsgType::Ready => break, + MsgType::Unknown => (), + } + } - send.send(sig.to_vec()).await?; + println!("limit_byte_nanos: {limit_byte_nanos}"); let (recv_send, recv_recv) = tokio::sync::mpsc::channel(4); let read_task = tokio::task::spawn(async move { while let Ok(data) = recv.recv().await { - if data.len() < 32 { - break; - } - - if &data[..32] == &[0; 32] { - break; - } + let data = Msg(data); - if recv_send.send(Msg(data)).await.is_err() { - break; + match match data.parse() { + Ok(data) => data, + Err(_) => break, + } { + MsgType::Msg { .. } => { + if recv_send.send(data).await.is_err() { + break; + } + } + MsgType::LimitByteNanos(rate) => { + eprintln!("UPDATED RATE {rate}"); + } + MsgType::LimitIdleMillis(_) => todo!(), + MsgType::AuthReq(_) => break, + MsgType::Ready => (), + MsgType::Unknown => (), } } @@ -214,7 +282,7 @@ impl SbdClient { buf: std::collections::VecDeque::new(), out_buffer_size: config.out_buffer_size, origin: tokio::time::Instant::now(), - limit_rate: (limit_rate as f64 * 0.9) as u64, + limit_rate: (limit_byte_nanos as f64 * 0.9) as u64, next_send_at: 0, }; let send_buf = Arc::new(tokio::sync::Mutex::new(send_buf)); diff --git a/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs b/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs index 7fc7b78..a42beeb 100644 --- a/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs +++ b/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs @@ -2,7 +2,10 @@ async fn main() { let mut args = std::env::args_os(); args.next().unwrap(); - let result = sbd_server_test_suite::run(args.next().expect("Expected Sbd Server Suite Runner")).await; + let result = sbd_server_test_suite::run( + args.next().expect("Expected Sbd Server Suite Runner"), + ) + .await; println!("{result:#?}"); if !result.failed.is_empty() { panic!("TEST FAILED"); diff --git a/rust/sbd-server-test-suite/src/it.rs b/rust/sbd-server-test-suite/src/it.rs index 2d29dd1..8f825d7 100644 --- a/rust/sbd-server-test-suite/src/it.rs +++ b/rust/sbd-server-test-suite/src/it.rs @@ -1,5 +1,5 @@ use std::future::Future; -use std::io::{Result, Error}; +use std::io::{Error, Result}; use crate::Report; @@ -30,14 +30,27 @@ impl<'h> TestHelper<'h> { } /// expect a condition to be true - pub fn expect(&mut self, file: &'static str, line: u32, cond: bool, note: &'static str) { + pub fn expect( + &mut self, + file: &'static str, + line: u32, + cond: bool, + note: &'static str, + ) { if !cond { self.err_list.push(format!("{file}:{line}: failed: {note}")); } } /// connect a client - pub async fn connect_client(&self) -> Result<(sbd_client::SbdClient, String, sbd_client::PubKey, sbd_client::MsgRecv)> { + pub async fn connect_client( + &self, + ) -> Result<( + sbd_client::SbdClient, + String, + sbd_client::PubKey, + sbd_client::MsgRecv, + )> { for addr in self.addr_list.iter() { if let Ok(client) = sbd_client::SbdClient::connect_config( &format!("ws://{addr}"), @@ -46,7 +59,9 @@ impl<'h> TestHelper<'h> { allow_plain_text: true, ..Default::default() }, - ).await { + ) + .await + { return Ok(client); } } diff --git a/rust/sbd-server-test-suite/src/it/it_1.rs b/rust/sbd-server-test-suite/src/it/it_1.rs index e6ddba1..9d06bbc 100644 --- a/rust/sbd-server-test-suite/src/it/it_1.rs +++ b/rust/sbd-server-test-suite/src/it/it_1.rs @@ -13,10 +13,7 @@ impl It for It1 { helper.connect_client(), )?; - tokio::try_join!( - c1.send(&p2, b"hello"), - c2.send(&p1, b"world"), - )?; + tokio::try_join!(c1.send(&p2, b"hello"), c2.send(&p1, b"world"),)?; let (result1, result2) = tokio::try_join!( async { r1.recv().await.ok_or(Error::other("closed")) }, diff --git a/rust/sbd-server-test-suite/src/lib.rs b/rust/sbd-server-test-suite/src/lib.rs index 4720e8a..3af5bca 100644 --- a/rust/sbd-server-test-suite/src/lib.rs +++ b/rust/sbd-server-test-suite/src/lib.rs @@ -62,8 +62,7 @@ struct Server { impl Server { pub async fn spawn>(cmd: S) -> Result { let mut cmd = tokio::process::Command::new(cmd); - cmd - .kill_on_drop(true) + cmd.kill_on_drop(true) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()); @@ -71,7 +70,8 @@ impl Server { let mut child = cmd.spawn()?; let stdin = child.stdin.take().unwrap(); - let mut stdout = tokio::io::BufReader::new(child.stdout.take().unwrap()).lines(); + let mut stdout = + tokio::io::BufReader::new(child.stdout.take().unwrap()).lines(); if let Some(line) = stdout.next_line().await? { if line != "CMD:READY" { diff --git a/rust/sbd-server/src/cmd.rs b/rust/sbd-server/src/cmd.rs new file mode 100644 index 0000000..9f5daa3 --- /dev/null +++ b/rust/sbd-server/src/cmd.rs @@ -0,0 +1,85 @@ +use crate::*; + +const F_KEEPALIVE: &[u8] = b"keep"; +const F_LIMIT_BYTE_NANOS: &[u8] = b"lbrt"; +const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl"; +const F_AUTH_REQ: &[u8] = b"areq"; +const F_AUTH_RES: &[u8] = b"ares"; +//const F_READY: &[u8] = b"srdy"; + +/// Sbd commands. This enum only includes the types that clients send. +/// The class contains only methods for generating commands that can +/// be sent to the client. +pub enum SbdCmd<'c> { + Message(Payload<'c>), + Keepalive, + //LimitByteNanos(i32), + //LimitIdleMillis(i32), + //AuthReq([u8; 32]), + AuthRes([u8; 64]), + //Ready, + Unknown, +} + +const CMD_FLAG: &[u8; 28] = &[0; 28]; + +impl<'c> SbdCmd<'c> { + pub fn parse(payload: Payload<'c>) -> Result { + if payload.as_ref().len() < 32 { + return Err(Error::other("invalid payload length")); + } + if &payload.as_ref()[..28] == CMD_FLAG { + // only include the messages that clients should send + // mark everything else as Unknown + match &payload.as_ref()[28..32] { + F_KEEPALIVE => Ok(SbdCmd::Keepalive), + F_AUTH_RES => { + if payload.as_ref().len() != 32 + 64 { + return Err(Error::other("invalid auth res length")); + } + let mut sig = [0; 64]; + sig.copy_from_slice(&payload.as_ref()[32..]); + Ok(SbdCmd::AuthRes(sig)) + } + _ => Ok(SbdCmd::Unknown), + } + } else { + Ok(SbdCmd::Message(payload)) + } + } +} + +impl SbdCmd<'_> { + pub fn limit_byte_nanos(limit_byte_nanos: i32) -> Payload<'static> { + let mut out = Vec::with_capacity(32 + 4); + let n = limit_byte_nanos.to_be_bytes(); + out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(F_LIMIT_BYTE_NANOS); + out.extend_from_slice(&n[..]); + Payload::Vec(out) + } + + pub fn limit_idle_millis(limit_idle_millis: i32) -> Payload<'static> { + let mut out = Vec::with_capacity(32 + 4); + let n = limit_idle_millis.to_be_bytes(); + out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(F_LIMIT_IDLE_MILLIS); + out.extend_from_slice(&n[..]); + Payload::Vec(out) + } + + pub fn auth_req(nonce: &[u8; 32]) -> Payload<'static> { + let mut out = Vec::with_capacity(32 + 32); + out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(F_AUTH_REQ); + out.extend_from_slice(&nonce[..]); + Payload::Vec(out) + } + + pub fn ready() -> Payload<'static> { + Payload::Slice(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, b's', b'r', b'd', b'y', + ]) + } +} diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs index ddfe49e..51af430 100644 --- a/rust/sbd-server/src/cslot.rs +++ b/rust/sbd-server/src/cslot.rs @@ -6,8 +6,8 @@ //! time per connection. use super::*; -use std::sync::{Arc, Mutex, Weak}; use std::collections::HashMap; +use std::sync::{Arc, Mutex, Weak}; static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); @@ -28,6 +28,7 @@ struct SlotEntry { struct SlabEntry { uniq: u64, + handshake_complete: bool, weak_ws: Weak>, } @@ -48,7 +49,8 @@ impl Drop for CSlotInner { } } -struct WeakCSlot(Weak>); +#[derive(Clone)] +pub struct WeakCSlot(Weak>); impl WeakCSlot { pub fn upgrade(&self) -> Option { @@ -59,25 +61,20 @@ impl WeakCSlot { pub struct CSlot(Arc>); impl CSlot { - pub fn new( - count: usize, - ip_deny: Arc, - ip_rate: Arc, - ) -> Self { + pub fn new(config: Arc, ip_rate: Arc) -> Self { + let count = config.limit_clients as usize; Self(Arc::new_cyclic(|this| { let mut slots = Vec::with_capacity(count); let mut task_list = Vec::with_capacity(count); for _ in 0..count { let (send, recv) = tokio::sync::mpsc::unbounded_channel(); - slots.push(SlotEntry { - send, - }); - tokio::task::spawn(top_task( - ip_deny.clone(), + slots.push(SlotEntry { send }); + task_list.push(tokio::task::spawn(top_task( + config.clone(), ip_rate.clone(), WeakCSlot(this.clone()), recv, - )); + ))); } Mutex::new(CSlotInner { max_count: count, @@ -90,7 +87,11 @@ impl CSlot { })) } - pub fn remove(&self, uniq: u64, index: usize) { + pub fn weak(&self) -> WeakCSlot { + WeakCSlot(Arc::downgrade(&self.0)) + } + + fn remove(&self, uniq: u64, index: usize) { let mut lock = self.0.lock().unwrap(); match lock.slab.get(index) { @@ -111,16 +112,16 @@ impl CSlot { }); } - pub fn insert( + fn insert_and_get_rate_send_list( &self, ip: Arc, pk: PubKey, - ws: Arc> - ) -> Result { + ws: Arc>, + ) -> Option>)>> { let mut lock = self.0.lock().unwrap(); if lock.slab.len() >= lock.max_count { - return Err(Error::other("too many connections")); + return None; } let weak_ws = Arc::downgrade(&ws); @@ -130,54 +131,120 @@ impl CSlot { let index = lock.slab.insert(SlabEntry { uniq, weak_ws, + handshake_complete: false, }); lock.pk_to_index.insert(pk.clone(), index); - // TODO - should we block more than Vec::with_capacity(count) - // connections from the same IP so we avoid allocating - // here? Or set this to the max connection count value? + let rate_send_list = { + let list = { + // WARN - allocation here! + // Also, do we want to limit the max connections from same ip? - lock - .ip_to_index - .entry(ip.clone()) - .or_insert_with(|| Vec::with_capacity(1024)) - .push(index); + let e = lock + .ip_to_index + .entry(ip.clone()) + .or_insert_with(|| Vec::with_capacity(1024)); - // TODO - send rate updates to all clients on this ip + e.push(index); + + e.clone() + }; + + let mut rate_send_list = Vec::with_capacity(list.len()); + + for index in list.iter() { + if let Some(slab) = lock.slab.get(*index) { + rate_send_list.push(( + slab.uniq, + *index, + slab.weak_ws.clone(), + )); + } + } + + rate_send_list + }; let send = lock.slots.get(index).unwrap().send.clone(); - if send.send(TaskMsg::NewWs { uniq, index, ws, ip, pk }).is_err() { - return Err(Error::other("closed")); + let _ = send.send(TaskMsg::NewWs { + uniq, + index, + ws, + ip, + pk, + }); + + Some(rate_send_list) + } + + pub async fn insert( + &self, + ip: Arc, + pk: PubKey, + ws: Arc>, + limit_ip_byte_nanos: i32, + ) { + let rate_send_list = self.insert_and_get_rate_send_list(ip, pk, ws); + + if let Some(rate_send_list) = rate_send_list { + let mut rate = limit_ip_byte_nanos / rate_send_list.len() as i32; + if rate < 1 { + rate = 1; + } + + for (uniq, index, weak_ws) in rate_send_list { + if let Some(ws) = weak_ws.upgrade() { + if ws + .send(cmd::SbdCmd::limit_byte_nanos(rate)) + .await + .is_err() + { + self.remove(uniq, index); + } + } + } } + } - Ok(index) + fn mark_ready(&self, uniq: u64, index: usize) { + let mut lock = self.0.lock().unwrap(); + if let Some(slab) = lock.slab.get_mut(index) { + if slab.uniq == uniq { + slab.handshake_complete = true; + } + } } - pub async fn send( + fn get_sender( &self, pk: &PubKey, - payload: Payload<'_>, - ) -> Result<()> { - let (uniq, index, ws) = { - // XXX - DO NOT AWAIT IN THIS BLOCK - let lock = self.0.lock().unwrap(); - - let index = match lock.pk_to_index.get(&pk) { - None => return Err(Error::other("no such peer")), - Some(index) => *index, - }; + ) -> Result<(u64, usize, Arc>)> { + let lock = self.0.lock().unwrap(); - let slab = lock.slab.get(index).unwrap(); - let uniq = slab.uniq; - let ws = match slab.weak_ws.upgrade() { - None => return Err(Error::other("no such peer")), - Some(ws) => ws, - }; + let index = match lock.pk_to_index.get(&pk) { + None => return Err(Error::other("no such peer")), + Some(index) => *index, + }; - (uniq, index, ws) + let slab = lock.slab.get(index).unwrap(); + + if !slab.handshake_complete { + return Err(Error::other("no such peer")); + } + + let uniq = slab.uniq; + let ws = match slab.weak_ws.upgrade() { + None => return Err(Error::other("no such peer")), + Some(ws) => ws, }; + Ok((uniq, index, ws)) + } + + async fn send(&self, pk: &PubKey, payload: Payload<'_>) -> Result<()> { + let (uniq, index, ws) = self.get_sender(pk)?; + match ws.send(payload).await { Err(err) => { self.remove(uniq, index); @@ -189,88 +256,131 @@ impl CSlot { } async fn top_task( - ip_deny: Arc, + config: Arc, ip_rate: Arc, weak: WeakCSlot, mut recv: tokio::sync::mpsc::UnboundedReceiver, ) { while let Some(task_msg) = recv.recv().await { match task_msg { - TaskMsg::NewWs { uniq, index, ws, ip, pk } => { + TaskMsg::NewWs { + uniq, + index, + ws, + ip, + pk, + } => { tokio::select! { task_msg = recv.recv() => { match task_msg { - None => break, Some(TaskMsg::Close) => (), - _ => unreachable!(), + _ => break, } }, _ = ws_task( - &ip_deny, + &config, &ip_rate, &weak, - index, ws, ip, pk, + uniq, + index, ) => (), } if let Some(cslot) = weak.upgrade() { cslot.remove(uniq, index); } } - _ => unreachable!(), + _ => (), } } } async fn ws_task( - ip_deny: &ip_deny::IpDeny, + config: &Arc, ip_rate: &ip_rate::IpRate, - weak: &WeakCSlot, - index: usize, + weak_cslot: &WeakCSlot, ws: Arc>, ip: Arc, pk: PubKey, + uniq: u64, + index: usize, ) { - while let Ok(mut payload) = ws.recv().await { - if !ip_rate.is_ok(*ip, payload.len()) { - ip_deny.block(*ip).await.unwrap(); - break; + if tokio::time::timeout(std::time::Duration::from_secs(10), async { + use rand::Rng; + let mut nonce = [0xdb; 32]; + rand::thread_rng().fill(&mut nonce[..]); + + ws.send(cmd::SbdCmd::auth_req(&nonce)).await?; + + let auth_res = ws.recv().await?; + + if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await { + return Err(Error::other("ip rate limited")); } - if payload.len() < 32 { - break; + if let cmd::SbdCmd::AuthRes(sig) = cmd::SbdCmd::parse(auth_res)? { + if !pk.verify(&sig, &nonce) { + return Err(Error::other("invalid sig")); + } + } else { + return Err(Error::other("invalid auth response")); } - const KEEPALIVE: &[u8; 32] = &[0; 32]; + // NOTE: the byte_nanos limit is sent during the cslot insert - let dest = { - let payload = payload.to_mut(); + // TODO: ws.send(cmd::SbdCmd::limit_idle_millis(config.?).await?; - if &payload[..32] == KEEPALIVE { - // TODO - keepalive - continue; - } + if let Some(cslot) = weak_cslot.upgrade() { + cslot.mark_ready(uniq, index); + } else { + return Err(Error::other("closed")); + } - if &payload[..32] == &pk.0[..] { - // no self-sends - break; - } + ws.send(cmd::SbdCmd::ready()).await?; - let mut dest = [0; 32]; - dest.copy_from_slice(&payload[..32]); - let dest = PubKey(Arc::new(dest)); + Ok(()) + }) + .await + .is_err() + { + return; + } - payload[..32].copy_from_slice(&pk.0[..]); + while let Ok(payload) = ws.recv().await { + if !ip_rate.is_ok(&ip, payload.len()).await { + break; + } - dest + let cmd = match cmd::SbdCmd::parse(payload) { + Err(_) => break, + Ok(cmd) => cmd, }; - if let Some(cslot) = weak.upgrade() { - let _ = cslot.send(&dest, payload).await; - } else { - break; + match cmd { + cmd::SbdCmd::Keepalive => (), + cmd::SbdCmd::AuthRes(_) => break, + cmd::SbdCmd::Unknown => (), + cmd::SbdCmd::Message(mut payload) => { + let dest = { + let payload = payload.to_mut(); + + let mut dest = [0; 32]; + dest.copy_from_slice(&payload[..32]); + let dest = PubKey(Arc::new(dest)); + + payload[..32].copy_from_slice(&pk.0[..]); + + dest + }; + + if let Some(cslot) = weak_cslot.upgrade() { + let _ = cslot.send(&dest, payload).await; + } else { + break; + } + } } } } diff --git a/rust/sbd-server/src/ip_deny.rs b/rust/sbd-server/src/ip_deny.rs index 8ff17e1..7763216 100644 --- a/rust/sbd-server/src/ip_deny.rs +++ b/rust/sbd-server/src/ip_deny.rs @@ -11,14 +11,13 @@ impl IpDeny { } /// Check if a given ip is blocked. - pub async fn is_blocked(&self, _ip: std::net::Ipv6Addr) -> Result { + pub async fn is_blocked(&self, _ip: &Arc) -> bool { // THIS IS A STUB!! - Ok(false) + false } /// Block a given ip. - pub async fn block(&self, _ip: std::net::Ipv6Addr) -> Result<()> { + pub async fn block(&self, _ip: &Arc) { // THIS IS A STUB!! - Ok(()) } } diff --git a/rust/sbd-server/src/ip_rate.rs b/rust/sbd-server/src/ip_rate.rs index 0af3ccf..df7d851 100644 --- a/rust/sbd-server/src/ip_rate.rs +++ b/rust/sbd-server/src/ip_rate.rs @@ -1,23 +1,26 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -type Map = HashMap; +type Map = HashMap, u64>; -#[derive(Clone)] pub struct IpRate { origin: tokio::time::Instant, map: Arc>, limit: u64, burst: u64, + ip_deny: crate::ip_deny::IpDeny, } impl IpRate { - pub fn new(limit: u64, burst: u64) -> Self { + /// Construct a new IpRate limit instance. + pub fn new(config: Arc) -> Self { Self { origin: tokio::time::Instant::now(), map: Arc::new(Mutex::new(HashMap::new())), - limit, - burst, + limit: config.limit_ip_byte_nanos as u64, + burst: config.limit_ip_byte_burst as u64 + * config.limit_ip_byte_nanos as u64, + ip_deny: crate::ip_deny::IpDeny::new(config), } } @@ -41,29 +44,46 @@ impl IpRate { }); } + /// Return true if this ip is blocked. + pub async fn is_blocked(&self, ip: &Arc) -> bool { + self.ip_deny.is_blocked(ip).await + } + /// Return true if we are not over the rate limit. - pub fn is_ok(&self, ip: std::net::Ipv6Addr, bytes: usize) -> bool { + pub async fn is_ok( + &self, + ip: &Arc, + bytes: usize, + ) -> bool { // multiply by our rate allowed per byte let rate_add = bytes as u64 * self.limit; // get now let now = self.origin.elapsed().as_nanos() as u64; - // lock the map mutex - let mut lock = self.map.lock().unwrap(); + let is_ok = { + // lock the map mutex + let mut lock = self.map.lock().unwrap(); - // get the entry (default to now) - let e = lock.entry(ip).or_insert(now); + // get the entry (default to now) + let e = lock.entry(ip.clone()).or_insert(now); - // if we've already used time greater than now use that, - // otherwise consider we're starting from scratch - let cur = std::cmp::max(*e, now) + rate_add; + // if we've already used time greater than now use that, + // otherwise consider we're starting from scratch + let cur = std::cmp::max(*e, now) + rate_add; - // update the map with the current limit - *e = cur; + // update the map with the current limit + *e = cur; - // subtract now back out to see if we're greater than our burst - cur - now <= self.burst + // subtract now back out to see if we're greater than our burst + cur - now <= self.burst + }; + + if !is_ok { + self.ip_deny.block(ip).await; + } + + is_ok } } @@ -71,21 +91,32 @@ impl IpRate { mod tests { use super::*; - const ADDR1: std::net::Ipv6Addr = - std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1); + fn test_new(limit: u64, burst: u64) -> IpRate { + IpRate { + origin: tokio::time::Instant::now(), + map: Arc::new(Mutex::new(HashMap::new())), + limit, + burst, + ip_deny: crate::ip_deny::IpDeny::new(Arc::new( + crate::Config::default(), + )), + } + } #[tokio::test(flavor = "current_thread", start_paused = true)] async fn check_one_to_one() { - let rate = IpRate::new(1, 1); + let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1)); + + let rate = test_new(1, 1); for _ in 0..10 { // should always be ok when advancing with time tokio::time::advance(std::time::Duration::from_nanos(1)).await; - assert!(rate.is_ok(ADDR1, 1)); + assert!(rate.is_ok(&addr1, 1).await); } // but one more without a time advance fails - assert!(!rate.is_ok(ADDR1, 1)); + assert!(!rate.is_ok(&addr1, 1).await); tokio::time::advance(std::time::Duration::from_nanos(1)).await; @@ -107,16 +138,18 @@ mod tests { #[tokio::test(flavor = "current_thread", start_paused = true)] async fn check_burst() { - let rate = IpRate::new(1, 5); + let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1)); + + let rate = test_new(1, 5); for _ in 0..5 { - assert!(rate.is_ok(ADDR1, 1)); + assert!(rate.is_ok(&addr1, 1).await); } - assert!(!rate.is_ok(ADDR1, 1)); + assert!(!rate.is_ok(&addr1, 1).await); tokio::time::advance(std::time::Duration::from_nanos(2)).await; - assert!(rate.is_ok(ADDR1, 1)); + assert!(rate.is_ok(&addr1, 1).await); tokio::time::advance(std::time::Duration::from_secs(10)).await; tokio::time::advance(std::time::Duration::from_nanos(4)).await; @@ -132,16 +165,18 @@ mod tests { #[tokio::test(flavor = "current_thread", start_paused = true)] async fn check_limit_mult() { - let rate = IpRate::new(3, 13); + let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1)); + + let rate = test_new(3, 13); - assert!(rate.is_ok(ADDR1, 2)); - assert!(rate.is_ok(ADDR1, 2)); - assert!(!rate.is_ok(ADDR1, 2)); + assert!(rate.is_ok(&addr1, 2).await); + assert!(rate.is_ok(&addr1, 2).await); + assert!(!rate.is_ok(&addr1, 2).await); tokio::time::advance(std::time::Duration::from_secs(10)).await; - assert!(rate.is_ok(ADDR1, 2)); - assert!(rate.is_ok(ADDR1, 2)); - assert!(!rate.is_ok(ADDR1, 2)); + assert!(rate.is_ok(&addr1, 2).await); + assert!(rate.is_ok(&addr1, 2).await); + assert!(!rate.is_ok(&addr1, 2).await); } } diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index d8cddc0..039564d 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -5,7 +5,7 @@ const MAX_MSG_BYTES: i32 = 16000; use std::io::{Error, Result}; -use std::sync::{Arc, Mutex, Weak}; +use std::sync::Arc; mod config; pub use config::*; @@ -18,6 +18,8 @@ mod ip_rate; mod cslot; +mod cmd; + /// Websocket backend abstraction. pub mod ws { /// Payload. @@ -98,74 +100,11 @@ impl PubKey { } } -enum ClientInfo { - Local { - ws: Arc>, - ip: std::net::Ipv6Addr, - read_task: tokio::task::JoinHandle<()>, - }, // TODO - remote (back channel) clients -} - -impl Drop for ClientInfo { - fn drop(&mut self) { - match self { - Self::Local { read_task, .. } => { - read_task.abort(); - } - } - } -} - -struct ClientMap(std::collections::HashMap); - -impl Default for ClientMap { - fn default() -> Self { - Self(std::collections::HashMap::new()) - } -} - -impl ClientMap { - pub fn insert(&mut self, pub_key: PubKey, client_info: ClientInfo) { - self.0.insert(pub_key, client_info); - } - - pub fn remove_ws( - &mut self, - pub_key: &PubKey, - subj_ws: &Arc>, - ) { - let should_drop = - if let Some(ClientInfo::Local { ws, .. }) = self.0.get(pub_key) { - if Arc::ptr_eq(subj_ws, ws) { - true - } else { - false - } - } else { - false - }; - - if should_drop { - self.0.remove(pub_key); - } - } - - pub fn get_ws( - &mut self, - pub_key: &PubKey, - ) -> Option>> { - match self.0.get(pub_key) { - Some(ClientInfo::Local { ws, .. }) => Some(ws.clone()), - _ => None, - } - } -} - /// SbdServer. pub struct SbdServer { task_list: Vec>, bind_addrs: Vec, - _client_map: Arc>, + _cslot: cslot::CSlot, } impl Drop for SbdServer { @@ -177,194 +116,70 @@ impl Drop for SbdServer { } async fn check_accept_connection( + _connect_permit: tokio::sync::OwnedSemaphorePermit, config: Arc, - ip_deny: Arc, ip_rate: Arc, - tcp: MaybeTlsStream, + tcp: tokio::net::TcpStream, addr: std::net::SocketAddr, - weak_client_map: Weak>, + weak_cslot: cslot::WeakCSlot, ) { - let raw_ip = match addr.ip() { + let raw_ip = Arc::new(match addr.ip() { std::net::IpAddr::V4(ip) => ip.to_ipv6_mapped(), std::net::IpAddr::V6(ip) => ip, - }; - drop(addr); + }); - let mut calc_ip = raw_ip; + let mut calc_ip = raw_ip.clone(); let use_trusted_ip = config.trusted_ip_header.is_some(); - let (pub_key, client_info) = - match tokio::time::timeout(std::time::Duration::from_secs(10), async { - const PROTO_VER: &[u8; 4] = b"sbd0"; - let limit_rate = config.limit_ip_byte_nanos.to_be_bytes(); - - if !use_trusted_ip { - // Do this check BEFORE handshake to avoid extra - // server process when capable. - // If we *are* behind a reverse proxy, we assume - // some amount of DDoS mitigation is happening there - // and thus we can accept a little more process overhead - if ip_deny.is_blocked(raw_ip).await.unwrap() { - return Err(Error::other("ip blocked")); - } - - // Also precheck our rate limit, using up one byte - if !ip_rate.is_ok(raw_ip, 1) { - ip_deny.block(raw_ip).await.unwrap(); - return Err(Error::other("ip rate limited")); - } - } - - // TODO TLS upgrade - - let (ws, pub_key, ip) = - ws::WebSocket::upgrade(config, tcp).await.unwrap(); - - if let Some(ip) = ip { - calc_ip = ip; + let _ = tokio::time::timeout(std::time::Duration::from_secs(10), async { + if !use_trusted_ip { + // Do this check BEFORE handshake to avoid extra + // server process when capable. + // If we *are* behind a reverse proxy, we assume + // some amount of DDoS mitigation is happening there + // and thus we can accept a little more process overhead + if ip_rate.is_blocked(&raw_ip).await { + return; } - if use_trusted_ip { - // if using a trusted ip, check block here. - // see note above before the handshakes. - if ip_deny.is_blocked(calc_ip).await.unwrap() { - return Err(Error::other("ip blocked")); - } - - // Also precheck our rate limit, using up one byte - if !ip_rate.is_ok(calc_ip, 1) { - ip_deny.block(calc_ip).await.unwrap(); - return Err(Error::other("ip rate limited")); - } + // Also precheck our rate limit, using up one byte + if !ip_rate.is_ok(&raw_ip, 1).await { + return; } + } - use rand::Rng; - let mut nonce = [0xdb; 32]; - rand::thread_rng().fill(&mut nonce[..]); - - let mut msg = Vec::with_capacity(4 + 4 + 32); - msg.extend_from_slice(&PROTO_VER[..]); - msg.extend_from_slice(&limit_rate[..]); - msg.extend_from_slice(&nonce[..]); - - ws.send(Payload::Vec(msg)).await.unwrap(); - - let sig = ws.recv().await.unwrap(); - - // use up 64 bytes of rate - if !ip_rate.is_ok(calc_ip, 64) { - ip_deny.block(calc_ip).await.unwrap(); - return Err(Error::other("ip rate limited")); - } + // TODO TLS upgrade + let tcp = MaybeTlsStream::Tcp(tcp); - if sig.len() != 64 { - return Err(Error::other("invalid sig len")); - } - let mut sig_sized = [0; 64]; - sig_sized.copy_from_slice(sig.as_ref()); - if !pub_key.verify(&sig_sized, &nonce) { - return Err(Error::other("invalid sig")); - } + let (ws, pub_key, ip) = + ws::WebSocket::upgrade(config.clone(), tcp).await.unwrap(); + let ws = Arc::new(ws); - let ws = Arc::new(ws); + if let Some(ip) = ip { + calc_ip = Arc::new(ip); + } - struct DoDrop { - pub_key: PubKey, - ws: Weak>, - client_map: Weak>, + if use_trusted_ip { + // if using a trusted ip, check block here. + // see note above before the handshakes. + if ip_rate.is_blocked(&calc_ip).await { + return; } - impl Drop for DoDrop { - fn drop(&mut self) { - if let Some(client_map) = self.client_map.upgrade() { - if let Some(ws) = self.ws.upgrade() { - client_map - .lock() - .unwrap() - .remove_ws(&self.pub_key, &ws); - } - } - } + // Also precheck our rate limit, using up one byte + if !ip_rate.is_ok(&calc_ip, 1).await { + return; } + } - let do_drop = DoDrop { - pub_key: pub_key.clone(), - ws: Arc::downgrade(&ws), - client_map: weak_client_map.clone(), - }; - - let ws2 = ws.clone(); - let weak_client_map2 = weak_client_map.clone(); - let pub_key2 = pub_key.clone(); - let read_task = tokio::task::spawn(async move { - let _do_drop = do_drop; - - while let Ok(mut payload) = ws2.recv().await { - if !ip_rate.is_ok(calc_ip, payload.len()) { - ip_deny.block(calc_ip).await.unwrap(); - break; - } - - if payload.len() < 32 { - break; - } - - const KEEPALIVE: &[u8; 32] = &[0; 32]; - - let dest = { - let payload = payload.to_mut(); - - if &payload[..32] == KEEPALIVE { - // TODO - keepalive - continue; - } - - if &payload[..32] == &pub_key2.0[..] { - // no self-sends - break; - } - - let mut dest = [0; 32]; - dest.copy_from_slice(&payload[..32]); - let dest = PubKey(Arc::new(dest)); - - payload[..32].copy_from_slice(&pub_key2.0[..]); - - dest - }; - - if let Some(client_map) = weak_client_map2.upgrade() { - let ws = client_map.lock().unwrap().get_ws(&dest); - if let Some(ws) = ws { - if ws.send(payload).await.is_err() { - break; - } - } - } else { - break; - } - } - }); - - Ok(( - pub_key, - ClientInfo::Local { - ws, - ip: calc_ip, - read_task, - }, - )) - }) - .await - { - Ok(Ok(r)) => r, - _ => return, - }; - - if let Some(client_map) = weak_client_map.upgrade() { - client_map.lock().unwrap().insert(pub_key, client_info); - } + if let Some(cslot) = weak_cslot.upgrade() { + cslot + .insert(calc_ip, pub_key, ws, config.limit_ip_byte_nanos) + .await; + } + }) + .await; } impl SbdServer { @@ -372,39 +187,61 @@ impl SbdServer { pub async fn new(config: Arc) -> Result { let mut task_list = Vec::new(); let mut bind_addrs = Vec::new(); - let client_map = Arc::new(Mutex::new(ClientMap::default())); - let ip_deny = Arc::new(ip_deny::IpDeny::new(config.clone())); + let ip_rate = Arc::new(ip_rate::IpRate::new(config.clone())); + + { + let ip_rate = Arc::downgrade(&ip_rate); + task_list.push(tokio::task::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + if let Some(ip_rate) = ip_rate.upgrade() { + ip_rate.prune(); + } else { + break; + } + } + })); + } + + let cslot = cslot::CSlot::new(config.clone(), ip_rate.clone()); - let ip_rate = Arc::new(ip_rate::IpRate::new( - config.limit_ip_byte_nanos as u64, - config.limit_ip_byte_nanos as u64 - * config.limit_ip_byte_burst as u64, - )); + // limit the number of connections that can be "connecting" at a time. + // MAYBE make this configurable. + // read this as a prioritization of existing connections over incoming + let connect_limit = Arc::new(tokio::sync::Semaphore::new(1024)); - let weak_client_map = Arc::downgrade(&client_map); + let weak_cslot = cslot.weak(); for bind in config.bind.iter() { let a: std::net::SocketAddr = bind.parse().map_err(Error::other)?; let tcp = tokio::net::TcpListener::bind(a).await?; bind_addrs.push(tcp.local_addr()?); + let connect_limit = connect_limit.clone(); let config = config.clone(); - let weak_client_map = weak_client_map.clone(); - let ip_deny = ip_deny.clone(); + let weak_cslot = weak_cslot.clone(); let ip_rate = ip_rate.clone(); task_list.push(tokio::task::spawn(async move { loop { if let Ok((tcp, addr)) = tcp.accept().await { - let tcp = MaybeTlsStream::Tcp(tcp); + // Drop connections as fast as possible + // if we are overloaded on accepting connections. + let connect_permit = + match connect_limit.clone().try_acquire_owned() { + Ok(permit) => permit, + _ => continue, + }; + // just let this task die on its own time + // MAYBE preallocate these tasks like cslot tokio::task::spawn(check_accept_connection( + connect_permit, config.clone(), - ip_deny.clone(), ip_rate.clone(), tcp, addr, - weak_client_map.clone(), + weak_cslot.clone(), )); } } @@ -414,7 +251,7 @@ impl SbdServer { Ok(Self { task_list, bind_addrs, - _client_map: client_map, + _cslot: cslot, }) } From a7e5a103b5ba51647cc63a7fab13bc79dad5cd08 Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 18 Apr 2024 16:03:07 -0600 Subject: [PATCH 03/33] checkpoint --- Makefile | 16 +++++ rust/sbd-client/src/lib.rs | 56 ++++++++++----- rust/sbd-client/src/raw_client.rs | 16 +++-- rust/sbd-client/src/send_buf.rs | 40 +++++++++-- rust/sbd-server-test-suite/src/it/it_1.rs | 4 +- rust/sbd-server-test-suite/src/lib.rs | 22 +----- rust/sbd-server/src/config.rs | 21 +++++- rust/sbd-server/src/cslot.rs | 86 +++++++++++++---------- rust/sbd-server/src/lib.rs | 4 +- rust/sbd-server/src/ws/ws_tungstenite.rs | 6 +- 10 files changed, 177 insertions(+), 94 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5836337 --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +# sdb Makefile + +.PHONY: all test static + +SHELL = /usr/bin/env sh -eu + +all: test + +test: static + cargo build --all-targets + RUST_BACKTRACE=1 cargo test + +static: + cargo fmt -- --check + cargo clippy -- -Dwarnings + @if [ "${CI}x" != "x" ]; then git diff --exit-code; fi diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index dc75a8a..c4788c7 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -227,6 +227,7 @@ impl SbdClient { .await?; let mut limit_byte_nanos = 8000; + let mut limit_idle_millis = 10_000; loop { match Msg(recv.recv().await?).parse()? { @@ -234,9 +235,9 @@ impl SbdClient { return Err(Error::other("invalid handshake")) } MsgType::LimitByteNanos(l) => limit_byte_nanos = l, - MsgType::LimitIdleMillis(_) => (), + MsgType::LimitIdleMillis(l) => limit_idle_millis = l, MsgType::AuthReq(nonce) => { - let sig = crypto.sign(&nonce); + let sig = crypto.sign(nonce); let mut auth_res = Vec::with_capacity(32 + 64); auth_res.extend_from_slice(CMD_FLAG); auth_res.extend_from_slice(b"ares"); @@ -248,8 +249,17 @@ impl SbdClient { } } - println!("limit_byte_nanos: {limit_byte_nanos}"); + let send_buf = send_buf::SendBuf { + ws: send, + buf: std::collections::VecDeque::new(), + out_buffer_size: config.out_buffer_size, + origin: tokio::time::Instant::now(), + limit_rate: (limit_byte_nanos as f64 * 0.9) as u64, + next_send_at: 0, + }; + let send_buf = Arc::new(tokio::sync::Mutex::new(send_buf)); + let send_buf2 = send_buf.clone(); let (recv_send, recv_recv) = tokio::sync::mpsc::channel(4); let read_task = tokio::task::spawn(async move { while let Ok(data) = recv.recv().await { @@ -265,9 +275,12 @@ impl SbdClient { } } MsgType::LimitByteNanos(rate) => { - eprintln!("UPDATED RATE {rate}"); + send_buf2 + .lock() + .await + .new_rate_limit((rate as f64 * 0.9) as u64); } - MsgType::LimitIdleMillis(_) => todo!(), + MsgType::LimitIdleMillis(_) => break, MsgType::AuthReq(_) => break, MsgType::Ready => (), MsgType::Unknown => (), @@ -277,26 +290,37 @@ impl SbdClient { // TODO - shutdown }); - let send_buf = send_buf::SendBuf { - ws: send, - buf: std::collections::VecDeque::new(), - out_buffer_size: config.out_buffer_size, - origin: tokio::time::Instant::now(), - limit_rate: (limit_byte_nanos as f64 * 0.9) as u64, - next_send_at: 0, - }; - let send_buf = Arc::new(tokio::sync::Mutex::new(send_buf)); - let send_buf2 = send_buf.clone(); let write_task = tokio::task::spawn(async move { + let mut last_send = tokio::time::Instant::now(); loop { if let Some(dur) = send_buf2.lock().await.next_step_dur() { tokio::time::sleep(dur).await; } match send_buf2.lock().await.write_next_queued().await { Err(_) => break, - Ok(true) => (), + Ok(true) => { + last_send = tokio::time::Instant::now(); + } Ok(false) => { + if last_send.elapsed().as_millis() as u64 + > limit_idle_millis as u64 / 2 + { + let mut data = Vec::with_capacity(32); + data.extend_from_slice(CMD_FLAG); + data.extend_from_slice(b"keep"); + if send_buf2 + .lock() + .await + .ws + .send(data) + .await + .is_err() + { + break; + } + last_send = tokio::time::Instant::now(); + } tokio::time::sleep(std::time::Duration::from_millis( 10, )) diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index 5e99e7c..dfcc5a8 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -15,6 +15,7 @@ pub struct WsRawConnect { /// Setting this to `true` allows `ws://` scheme. pub allow_plain_text: bool, + #[allow(unused_variables)] /// Setting this to `true` disables certificate verification on `wss://` /// scheme. WARNING: this is a dangerous configuration and should not /// be used outside of testing (i.e. self-signed tls certificates). @@ -28,7 +29,7 @@ impl WsRawConnect { full_url, max_message_size, allow_plain_text, - danger_disable_certificate_check, + .. } = self; let scheme_ws = full_url.starts_with("ws://"); @@ -48,7 +49,7 @@ impl WsRawConnect { Some(host) => host.to_string(), None => return Err(Error::other("invalid url")), }; - let port = request.uri().port_u16().unwrap_or_else(|| { + let port = request.uri().port_u16().unwrap_or({ if scheme_ws { 80 } else { @@ -73,10 +74,11 @@ impl WsRawConnect { tokio_tungstenite::MaybeTlsStream::Rustls(tls) }; - let mut config = - tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default( - ); - config.max_message_size = Some(max_message_size); + let config = + tokio_tungstenite::tungstenite::protocol::WebSocketConfig { + max_message_size: Some(max_message_size), + ..Default::default() + }; let (ws, _res) = tokio_tungstenite::client_async_with_config( request, @@ -174,7 +176,7 @@ fn priv_system_tls() -> Arc { for cert in rustls_native_certs::load_native_certs() .expect("failed to load system tls certs") { - roots.add(cert.into()).expect("faild to add cert to root"); + roots.add(cert).expect("faild to add cert to root"); } Arc::new( diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index c38c6fe..5ee4842 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -12,6 +12,22 @@ pub struct SendBuf { } impl SendBuf { + /// We received a new rate limit from the server, update our records. + pub fn new_rate_limit(&mut self, limit: u64) { + if limit < self.limit_rate { + // rate limit updates are sent on a best effort, + // and there are network timing conditions to worry about. + // assume we accidentally sent a message while the new limit + // was in effect, and accout for that in a brute-force manner. + + let now = self.origin.elapsed().as_nanos() as u64; + + self.next_send_at = std::cmp::max(now, self.next_send_at) + + (MAX_MSG_SIZE as u64 * self.limit_rate); + } + self.limit_rate = limit; + } + /// If we need to wait before taking the next step, this /// returns how long. pub fn next_step_dur(&self) -> Option { @@ -36,10 +52,12 @@ impl SendBuf { if let Some((_, data)) = self.buf.pop_front() { let now = self.origin.elapsed().as_nanos() as u64; - let next_send_at = - self.next_send_at + (data.len() as u64 * self.limit_rate); - self.next_send_at = std::cmp::max(now, next_send_at); + + self.next_send_at = std::cmp::max(now, self.next_send_at) + + (data.len() as u64 * self.limit_rate); + self.ws.send(data).await?; + Ok(true) } else { Ok(false) @@ -48,7 +66,7 @@ impl SendBuf { /// If our buffer is over our buffer size, do the work to get it under. /// Then queue up data to be sent out. - /// Note, you'll need to explicitly call `process_next_step()` or + /// Note, you'll need to explicitly call `write_next_queued()` or /// make additional sends in order to get this queued data actually sent. pub async fn send(&mut self, pk: &PubKey, mut data: &[u8]) -> Result<()> { while !self.space_free() { @@ -58,6 +76,8 @@ impl SendBuf { self.write_next_queued().await?; } + self.check_set_prebuffer(); + // first try to put into existing blocks for (qpk, qdata) in self.buf.iter_mut() { if qpk == pk && qdata.len() < MAX_MSG_SIZE { @@ -78,7 +98,7 @@ impl SendBuf { let amt = std::cmp::min(data.len(), MAX_MSG_SIZE - init.len()); init.extend_from_slice(&data[..amt]); data = &data[amt..]; - self.buf.push_back((pk.clone(), init)); + self.buf.push_back((*pk, init)); } Ok(()) @@ -90,6 +110,16 @@ impl SendBuf { self.buf.iter().map(|(_, d)| d.len()).sum() } + /// If we have an empty out buffer, set some rate-limit as a hack + /// for waiting a little bit to see if more sends come in and can + /// be aggregated + fn check_set_prebuffer(&mut self) { + if self.buf.is_empty() { + let hack = self.origin.elapsed().as_nanos() as u64 + 10_000_000; // 10 millis in nanos + self.next_send_at = std::cmp::max(hack, self.next_send_at) + } + } + /// Returns `true` if our total buffer size < out_buffer_size fn space_free(&self) -> bool { self.len() < self.out_buffer_size diff --git a/rust/sbd-server-test-suite/src/it/it_1.rs b/rust/sbd-server-test-suite/src/it/it_1.rs index 9d06bbc..6aa57ab 100644 --- a/rust/sbd-server-test-suite/src/it/it_1.rs +++ b/rust/sbd-server-test-suite/src/it/it_1.rs @@ -20,9 +20,9 @@ impl It for It1 { async { r2.recv().await.ok_or(Error::other("closed")) }, )?; - expect!(helper, result1.pub_key_ref() == &p2.0, "r1 recv from p2"); + expect!(helper, result1.pub_key_ref() == p2.0, "r1 recv from p2"); expect!(helper, result1.message() == b"world", "r1 got 'world'"); - expect!(helper, result2.pub_key_ref() == &p1.0, "r2 recv from p1"); + expect!(helper, result2.pub_key_ref() == p1.0, "r2 recv from p1"); expect!(helper, result2.message() == b"hello", "r2 got 'hello'"); Ok(()) diff --git a/rust/sbd-server-test-suite/src/lib.rs b/rust/sbd-server-test-suite/src/lib.rs index 3af5bca..972871e 100644 --- a/rust/sbd-server-test-suite/src/lib.rs +++ b/rust/sbd-server-test-suite/src/lib.rs @@ -1,4 +1,7 @@ #![deny(missing_docs)] +// uhhh... clippy... +#![allow(clippy::manual_async_fn)] + //! Test suite for sbd server compliance. //! //! The command supplied to the run function must: @@ -31,26 +34,7 @@ pub async fn run>(cmd: S) -> Report { println!("GOT RUNNING ADDRS: {addrs:?}"); - let config = sbd_client::SbdClientConfig { - allow_plain_text: true, - ..Default::default() - }; - it::exec_all(&addrs).await - - /* - for addr in addrs { - if let Ok(client) = sbd_client::SbdClient::connect_config( - &format!("ws://{addr}"), - &sbd_client::DefaultCrypto::default(), - config, - ).await { - println!("client connect success"); - } - } - - todo!() - */ } struct Server { diff --git a/rust/sbd-server/src/config.rs b/rust/sbd-server/src/config.rs index 2f60eaa..3f8e672 100644 --- a/rust/sbd-server/src/config.rs +++ b/rust/sbd-server/src/config.rs @@ -2,7 +2,8 @@ const DEF_IP_DENY_DIR: &str = "."; const DEF_IP_DENY_S: i32 = 600; const DEF_LIMIT_CLIENTS: i32 = 32768; const DEF_LIMIT_IP_BYTE_NANOS: i32 = 8000; -const DEF_LIMIT_IP_BYTE_BURST: i32 = 32768; +const DEF_LIMIT_IP_BYTE_BURST: i32 = 16 * 16 * 1024; +const DEF_LIMIT_IDLE_MILLIS: i32 = 10_000; /// Configure and execute an SBD server. #[derive(clap::Parser, Debug)] @@ -87,14 +88,23 @@ pub struct Config { /// The default value of 8000 results in ~1 mbps being allowed. /// If the default of 32768 connections were all sending this amount /// at the same time, the server would need a ~33.6 gbps connection. + /// This value divided by the count of connections from an ip will + /// be sent down to the client for individual rate limit. #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_NANOS)] pub limit_ip_byte_nanos: i32, /// Allow IPs to burst by this byte count. /// If the max message size is 16K, this value must be at least 16K. - /// The default value provides 2 * 16K for an additional buffer. + /// The default value provides 16 * 16K to allow for multiple connections + /// from a single ip address. #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_BURST)] pub limit_ip_byte_burst: i32, + + /// How long in milliseconds connections can remain idle before being + /// closed. Clients must send either a message or a keepalive before + /// this time expires to keep the connection alive. + #[arg(long, default_value_t = DEF_LIMIT_IDLE_MILLIS)] + pub limit_idle_millis: i32, } impl Default for Config { @@ -115,10 +125,17 @@ impl Default for Config { limit_clients: DEF_LIMIT_CLIENTS, limit_ip_byte_nanos: DEF_LIMIT_IP_BYTE_NANOS, limit_ip_byte_burst: DEF_LIMIT_IP_BYTE_BURST, + limit_idle_millis: DEF_LIMIT_IDLE_MILLIS, } } } +impl Config { + pub(crate) fn idle_dur(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.limit_idle_millis as u64) + } +} + fn get_styles() -> clap::builder::Styles { clap::builder::Styles::styled() .usage( diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs index 51af430..bde6de2 100644 --- a/rust/sbd-server/src/cslot.rs +++ b/rust/sbd-server/src/cslot.rs @@ -112,6 +112,8 @@ impl CSlot { }); } + // oi clippy, this is super straight forward... + #[allow(clippy::type_complexity)] fn insert_and_get_rate_send_list( &self, ip: Arc, @@ -222,7 +224,7 @@ impl CSlot { ) -> Result<(u64, usize, Arc>)> { let lock = self.0.lock().unwrap(); - let index = match lock.pk_to_index.get(&pk) { + let index = match lock.pk_to_index.get(pk) { None => return Err(Error::other("no such peer")), Some(index) => *index, }; @@ -261,42 +263,45 @@ async fn top_task( weak: WeakCSlot, mut recv: tokio::sync::mpsc::UnboundedReceiver, ) { - while let Some(task_msg) = recv.recv().await { - match task_msg { - TaskMsg::NewWs { - uniq, - index, - ws, - ip, - pk, - } => { - tokio::select! { - task_msg = recv.recv() => { - match task_msg { - Some(TaskMsg::Close) => (), - _ => break, - } - }, - _ = ws_task( - &config, - &ip_rate, - &weak, - ws, - ip, - pk, - uniq, - index, - ) => (), - } - if let Some(cslot) = weak.upgrade() { - cslot.remove(uniq, index); - } + let mut item = recv.recv().await; + loop { + let uitem = match item { + None => break, + Some(uitem) => uitem, + }; + + item = if let TaskMsg::NewWs { + uniq, + index, + ws, + ip, + pk, + } = uitem + { + let i = tokio::select! { + i = recv.recv() => i, + _ = ws_task( + &config, + &ip_rate, + &weak, + ws, + ip, + pk, + uniq, + index, + ) => recv.recv().await, + }; + if let Some(cslot) = weak.upgrade() { + cslot.remove(uniq, index); } - _ => (), - } + i + } else { + recv.recv().await + }; } } +#[allow(clippy::too_many_arguments)] async fn ws_task( config: &Arc, ip_rate: &ip_rate::IpRate, @@ -307,7 +312,7 @@ async fn ws_task( uniq: u64, index: usize, ) { - if tokio::time::timeout(std::time::Duration::from_secs(10), async { + let auth_res = tokio::time::timeout(config.idle_dur(), async { use rand::Rng; let mut nonce = [0xdb; 32]; rand::thread_rng().fill(&mut nonce[..]); @@ -330,7 +335,8 @@ async fn ws_task( // NOTE: the byte_nanos limit is sent during the cslot insert - // TODO: ws.send(cmd::SbdCmd::limit_idle_millis(config.?).await?; + ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis)) + .await?; if let Some(cslot) = weak_cslot.upgrade() { cslot.mark_ready(uniq, index); @@ -342,13 +348,15 @@ async fn ws_task( Ok(()) }) - .await - .is_err() - { + .await; + + if auth_res.is_err() { return; } - while let Ok(payload) = ws.recv().await { + while let Ok(Ok(payload)) = + tokio::time::timeout(config.idle_dur(), ws.recv()).await + { if !ip_rate.is_ok(&ip, payload.len()).await { break; } diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index 039564d..e74be65 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -91,7 +91,7 @@ impl PubKey { /// Verify a signature with this pub key. pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool { use ed25519_dalek::Verifier; - if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&*self.0) { + if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) { k.verify(data, &ed25519_dalek::Signature::from_bytes(sig)) .is_ok() } else { @@ -132,7 +132,7 @@ async fn check_accept_connection( let use_trusted_ip = config.trusted_ip_header.is_some(); - let _ = tokio::time::timeout(std::time::Duration::from_secs(10), async { + let _ = tokio::time::timeout(config.idle_dur(), async { if !use_trusted_ip { // Do this check BEFORE handshake to avoid extra // server process when capable. diff --git a/rust/sbd-server/src/ws/ws_tungstenite.rs b/rust/sbd-server/src/ws/ws_tungstenite.rs index 5a47b29..897606a 100644 --- a/rust/sbd-server/src/ws/ws_tungstenite.rs +++ b/rust/sbd-server/src/ws/ws_tungstenite.rs @@ -32,8 +32,10 @@ where handshake::server, protocol::WebSocketConfig, }; let mut trusted_ip = None; - let mut ws = WebSocketConfig::default(); - ws.max_message_size = Some(MAX_MSG_BYTES as usize); + let ws = WebSocketConfig { + max_message_size: Some(MAX_MSG_BYTES as usize), + ..Default::default() + }; struct Cb(tokio::sync::oneshot::Sender); impl server::Callback for Cb { fn on_request( From 2946ab29514dd5f903f9173dd63a92a59b2cf703 Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 18 Apr 2024 16:07:25 -0600 Subject: [PATCH 04/33] ci --- .github/workflows/static.yml | 35 ++++++++++++++++++++++++++++++ .github/workflows/test.yml | 42 ++++++++++++++++++++++++++++++++++++ Makefile | 2 +- 3 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/static.yml create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml new file mode 100644 index 0000000..99918f7 --- /dev/null +++ b/.github/workflows/static.yml @@ -0,0 +1,35 @@ +name: Static Analysis +on: + push: + branches: + - main + pull_request: + branches: + - main +jobs: + static-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + ] + toolchain: [ + stable, + 1.75.0 + ] + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Rust Toolchain + run: | + rustup toolchain install ${{ matrix.toolchain }} --profile minimal --no-self-update + rustup default ${{ matrix.toolchain }} + + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + + - name: Make Static + run: make static diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..739eeb2 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,42 @@ +name: Test +on: + push: + branches: + - main + pull_request: + branches: + - main +jobs: + test: + name: Test + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-latest, + windows-latest, + ] + toolchain: [ + stable, + ] + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Rust Toolchain + run: | + rustup toolchain install ${{ matrix.toolchain }} --profile minimal --no-self-update + rustup default ${{ matrix.toolchain }} + + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + + - name: Cargo Build + run: cargo build --all-targets + + - name: Cargo Test + env: + RUST_BACKTRACE: 1 + run: cargo test -- --nocapture diff --git a/Makefile b/Makefile index 5836337..99c6362 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -# sdb Makefile +# sbd Makefile .PHONY: all test static From 1ff92810fbc02e7d3fdd44dfed4b6896ea67e5c2 Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 18 Apr 2024 16:10:47 -0600 Subject: [PATCH 05/33] ci --- .github/workflows/static.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index 99918f7..7c1c69b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -27,6 +27,8 @@ jobs: run: | rustup toolchain install ${{ matrix.toolchain }} --profile minimal --no-self-update rustup default ${{ matrix.toolchain }} + rustup component add rustfmt --toolchain ${{ matrix.toolchain }} + rustup component add clippy --toolchain ${{ matrix.toolchain }} - name: Rust Cache uses: Swatinem/rust-cache@v2 From 356655c6396166b14646317d3b0bd07ff9ec4da8 Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 18 Apr 2024 16:24:52 -0600 Subject: [PATCH 06/33] rename --- Cargo.lock | 19 +++++++++---------- Cargo.toml | 5 ++++- .../Cargo.toml | 4 ++-- .../src/bin/sbd-o-bahn-server-tester-bin.rs} | 2 +- .../src/it.rs | 0 .../src/it/it_1.rs | 0 .../src/lib.rs | 0 rust/sbd-server/Cargo.toml | 1 - rust/sbd-server/tests/suite.rs | 4 ++-- 9 files changed, 18 insertions(+), 17 deletions(-) rename rust/{sbd-server-test-suite => sbd-o-bahn-server-tester}/Cargo.toml (60%) rename rust/{sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs => sbd-o-bahn-server-tester/src/bin/sbd-o-bahn-server-tester-bin.rs} (86%) rename rust/{sbd-server-test-suite => sbd-o-bahn-server-tester}/src/it.rs (100%) rename rust/{sbd-server-test-suite => sbd-o-bahn-server-tester}/src/it/it_1.rs (100%) rename rust/{sbd-server-test-suite => sbd-o-bahn-server-tester}/src/lib.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index fe9c30f..42cc3d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -999,6 +999,15 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "sbd-o-bahn-server-tester" +version = "0.0.1-alpha" +dependencies = [ + "sbd-client", + "serde_json", + "tokio", +] + [[package]] name = "sbd-server" version = "0.0.1-alpha" @@ -1016,7 +1025,6 @@ dependencies = [ "hyper-util", "rand", "rcgen", - "sbd-client", "serde_json", "slab", "tempfile", @@ -1024,15 +1032,6 @@ dependencies = [ "tokio-tungstenite", ] -[[package]] -name = "sbd-server-test-suite" -version = "0.0.1-alpha" -dependencies = [ - "sbd-client", - "serde_json", - "tokio", -] - [[package]] name = "schannel" version = "0.1.23" diff --git a/Cargo.toml b/Cargo.toml index aba380d..6f8a30d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,11 @@ members = [ "rust/sbd-client", "rust/sbd-server", - "rust/sbd-server-test-suite", + "rust/sbd-o-bahn-server-tester", ] resolver = "2" [workspace.dependencies] +sbd-client = { version = "0.0.1-alpha", path = "rust/sbd-client" } +sbd-server = { version = "0.0.1-alpha", path = "rust/sbd-server" } +sbd-o-bahn-server-tester = { version = "0.0.1-alpha", path = "rust/sbd-o-bahn-server-tester" } diff --git a/rust/sbd-server-test-suite/Cargo.toml b/rust/sbd-o-bahn-server-tester/Cargo.toml similarity index 60% rename from rust/sbd-server-test-suite/Cargo.toml rename to rust/sbd-o-bahn-server-tester/Cargo.toml index 84e552a..5ac40a0 100644 --- a/rust/sbd-server-test-suite/Cargo.toml +++ b/rust/sbd-o-bahn-server-tester/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "sbd-server-test-suite" +name = "sbd-o-bahn-server-tester" version = "0.0.1-alpha" edition = "2021" [dependencies] -sbd-client = { version = "0.0.1-alpha", path = "../sbd-client" } +sbd-client = { workspace = true } serde_json = "1.0.116" tokio = { version = "1.37.0", features = [ "full" ] } diff --git a/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs b/rust/sbd-o-bahn-server-tester/src/bin/sbd-o-bahn-server-tester-bin.rs similarity index 86% rename from rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs rename to rust/sbd-o-bahn-server-tester/src/bin/sbd-o-bahn-server-tester-bin.rs index a42beeb..d0264ff 100644 --- a/rust/sbd-server-test-suite/src/bin/sbd-server-test-suite-bin.rs +++ b/rust/sbd-o-bahn-server-tester/src/bin/sbd-o-bahn-server-tester-bin.rs @@ -2,7 +2,7 @@ async fn main() { let mut args = std::env::args_os(); args.next().unwrap(); - let result = sbd_server_test_suite::run( + let result = sbd_o_bahn_server_tester::run( args.next().expect("Expected Sbd Server Suite Runner"), ) .await; diff --git a/rust/sbd-server-test-suite/src/it.rs b/rust/sbd-o-bahn-server-tester/src/it.rs similarity index 100% rename from rust/sbd-server-test-suite/src/it.rs rename to rust/sbd-o-bahn-server-tester/src/it.rs diff --git a/rust/sbd-server-test-suite/src/it/it_1.rs b/rust/sbd-o-bahn-server-tester/src/it/it_1.rs similarity index 100% rename from rust/sbd-server-test-suite/src/it/it_1.rs rename to rust/sbd-o-bahn-server-tester/src/it/it_1.rs diff --git a/rust/sbd-server-test-suite/src/lib.rs b/rust/sbd-o-bahn-server-tester/src/lib.rs similarity index 100% rename from rust/sbd-server-test-suite/src/lib.rs rename to rust/sbd-o-bahn-server-tester/src/lib.rs diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index a612c75..81350e9 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -26,7 +26,6 @@ hyper = { version = "1.2.0", features = ["http1", "server"], optional = true } [dev-dependencies] escargot = { version = "0.5.10", features = [ "print" ] } rcgen = "0.13.1" -sbd-client = { version = "0.0.1-alpha", path = "../sbd-client" } serde_json = "1.0.116" tempfile = "3.10.1" tokio = { version = "1.37.0", features = [ "test-util" ] } diff --git a/rust/sbd-server/tests/suite.rs b/rust/sbd-server/tests/suite.rs index ee77980..a65706c 100644 --- a/rust/sbd-server/tests/suite.rs +++ b/rust/sbd-server/tests/suite.rs @@ -10,8 +10,8 @@ fn suite() { println!("BUILDING sbd-server-test-suite IN RELEASE MODE"); let suite = escargot::CargoBuild::new() - .bin("sbd-server-test-suite-bin") - .manifest_path("../sbd-server-test-suite/Cargo.toml") + .bin("sbd-o-bahn-server-tester-bin") + .manifest_path("../sbd-o-bahn-server-tester/Cargo.toml") .release() .current_target() .run() From 4cfd9cc7190e631f835d02726ddc83c054f73eef Mon Sep 17 00:00:00 2001 From: neonphog Date: Fri, 19 Apr 2024 13:25:40 -0600 Subject: [PATCH 07/33] ratelimit --- Cargo.lock | 2 + rust/sbd-client/Cargo.toml | 1 + rust/sbd-client/src/lib.rs | 70 +++------- rust/sbd-client/src/raw_client.rs | 55 +++++++- rust/sbd-client/src/send_buf.rs | 71 ++++++++-- .../sbd-client/tests/reasonable-rate-limit.rs | 109 +++++++++++++++ rust/sbd-server/Cargo.toml | 1 + rust/sbd-server/src/config.rs | 26 ++-- rust/sbd-server/src/cslot.rs | 27 ++-- rust/sbd-server/src/ip_rate.rs | 4 +- rust/sbd-server/src/lib.rs | 8 +- rust/sbd-server/src/ws/ws_tungstenite.rs | 10 +- rust/sbd-server/tests/rate_limit_enforced.rs | 125 ++++++++++++++++++ 13 files changed, 420 insertions(+), 89 deletions(-) create mode 100644 rust/sbd-client/tests/reasonable-rate-limit.rs create mode 100644 rust/sbd-server/tests/rate_limit_enforced.rs diff --git a/Cargo.lock b/Cargo.lock index 42cc3d0..3a59285 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -993,6 +993,7 @@ dependencies = [ "rand", "rustls", "rustls-native-certs", + "sbd-server", "tokio", "tokio-rustls", "tokio-tungstenite", @@ -1025,6 +1026,7 @@ dependencies = [ "hyper-util", "rand", "rcgen", + "sbd-client", "serde_json", "slab", "tempfile", diff --git a/rust/sbd-client/Cargo.toml b/rust/sbd-client/Cargo.toml index c9c18b3..f086af7 100644 --- a/rust/sbd-client/Cargo.toml +++ b/rust/sbd-client/Cargo.toml @@ -22,6 +22,7 @@ webpki-roots = "0.26.1" [dev-dependencies] tokio = { version = "1.37.0", features = [ "full" ] } +sbd-server = { workspace = true } [features] default = [ "crypto" ] diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index c4788c7..d8c60af 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -226,37 +226,20 @@ impl SbdClient { .connect() .await?; - let mut limit_byte_nanos = 8000; - let mut limit_idle_millis = 10_000; - - loop { - match Msg(recv.recv().await?).parse()? { - MsgType::Msg { .. } => { - return Err(Error::other("invalid handshake")) - } - MsgType::LimitByteNanos(l) => limit_byte_nanos = l, - MsgType::LimitIdleMillis(l) => limit_idle_millis = l, - MsgType::AuthReq(nonce) => { - let sig = crypto.sign(nonce); - let mut auth_res = Vec::with_capacity(32 + 64); - auth_res.extend_from_slice(CMD_FLAG); - auth_res.extend_from_slice(b"ares"); - auth_res.extend_from_slice(&sig); - send.send(auth_res).await?; - } - MsgType::Ready => break, - MsgType::Unknown => (), - } - } - - let send_buf = send_buf::SendBuf { - ws: send, - buf: std::collections::VecDeque::new(), - out_buffer_size: config.out_buffer_size, - origin: tokio::time::Instant::now(), - limit_rate: (limit_byte_nanos as f64 * 0.9) as u64, - next_send_at: 0, - }; + let raw_client::Handshake { + limit_byte_nanos, + limit_idle_millis, + bytes_sent, + } = raw_client::Handshake::handshake(&mut send, &mut recv, crypto) + .await?; + + let send_buf = send_buf::SendBuf::new( + send, + config.out_buffer_size, + (limit_byte_nanos as f64 * 1.1) as u64, + std::time::Duration::from_millis((limit_idle_millis / 2) as u64), + bytes_sent, + ); let send_buf = Arc::new(tokio::sync::Mutex::new(send_buf)); let send_buf2 = send_buf.clone(); @@ -278,7 +261,7 @@ impl SbdClient { send_buf2 .lock() .await - .new_rate_limit((rate as f64 * 0.9) as u64); + .new_rate_limit((rate as f64 * 1.1) as u64); } MsgType::LimitIdleMillis(_) => break, MsgType::AuthReq(_) => break, @@ -292,35 +275,14 @@ impl SbdClient { let send_buf2 = send_buf.clone(); let write_task = tokio::task::spawn(async move { - let mut last_send = tokio::time::Instant::now(); loop { if let Some(dur) = send_buf2.lock().await.next_step_dur() { tokio::time::sleep(dur).await; } match send_buf2.lock().await.write_next_queued().await { Err(_) => break, - Ok(true) => { - last_send = tokio::time::Instant::now(); - } + Ok(true) => (), Ok(false) => { - if last_send.elapsed().as_millis() as u64 - > limit_idle_millis as u64 / 2 - { - let mut data = Vec::with_capacity(32); - data.extend_from_slice(CMD_FLAG); - data.extend_from_slice(b"keep"); - if send_buf2 - .lock() - .await - .ws - .send(data) - .await - .is_err() - { - break; - } - last_send = tokio::time::Instant::now(); - } tokio::time::sleep(std::time::Duration::from_millis( 10, )) diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index dfcc5a8..2c58a5b 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -1,7 +1,6 @@ //! `feature = "raw_client"` Raw websocket interaction types. -use std::io::{Error, Result}; -use std::sync::Arc; +use crate::*; /// Connection info for creating a raw websocket connection. pub struct WsRawConnect { @@ -148,6 +147,58 @@ impl WsRawRecv { } } +/// Process the standard sbd handshake from the client side. +pub struct Handshake { + /// limit_byte_nanos. + pub limit_byte_nanos: i32, + + /// limit_idle_millis. + pub limit_idle_millis: i32, + + /// bytes sent. + pub bytes_sent: usize, +} + +impl Handshake { + /// Process the standard sbd handshake from the client side. + pub async fn handshake( + send: &mut WsRawSend, + recv: &mut WsRawRecv, + crypto: &C, + ) -> Result { + let mut limit_byte_nanos = 8000; + let mut limit_idle_millis = 10_000; + let mut bytes_sent = 0; + + loop { + match Msg(recv.recv().await?).parse()? { + MsgType::Msg { .. } => { + return Err(Error::other("invalid handshake")) + } + MsgType::LimitByteNanos(l) => limit_byte_nanos = l, + MsgType::LimitIdleMillis(l) => limit_idle_millis = l, + MsgType::AuthReq(nonce) => { + let sig = crypto.sign(nonce); + let mut auth_res = Vec::with_capacity(32 + 64); + auth_res.extend_from_slice(CMD_FLAG); + auth_res.extend_from_slice(b"ares"); + auth_res.extend_from_slice(&sig); + send.send(auth_res).await?; + bytes_sent += 32 + 64; + } + MsgType::Ready => break, + MsgType::Unknown => (), + } + } + + Ok(Self { + limit_byte_nanos, + limit_idle_millis, + bytes_sent, + }) + } +} + fn priv_system_tls() -> Arc { let mut roots = rustls::RootCertStore::empty(); diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index 5ee4842..bafea90 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -8,10 +8,39 @@ pub struct SendBuf { pub out_buffer_size: usize, pub origin: tokio::time::Instant, pub limit_rate: u64, + pub idle_keepalive_nanos: u64, pub next_send_at: u64, + pub last_send: u64, } impl SendBuf { + /// construct a new send buf + pub fn new( + ws: raw_client::WsRawSend, + out_buffer_size: usize, + limit_rate: u64, + idle_keepalive: std::time::Duration, + pre_sent_bytes: usize, + ) -> Self { + let mut this = Self { + ws, + buf: VecDeque::default(), + out_buffer_size, + origin: tokio::time::Instant::now(), + limit_rate, + idle_keepalive_nanos: idle_keepalive.as_nanos() as u64, + next_send_at: 0, + last_send: 0, + }; + + let now = this.origin.elapsed().as_nanos() as u64; + + this.next_send_at = std::cmp::max(now, this.next_send_at) + + (pre_sent_bytes as u64 * this.limit_rate); + + this + } + /// We received a new rate limit from the server, update our records. pub fn new_rate_limit(&mut self, limit: u64) { if limit < self.limit_rate { @@ -32,8 +61,18 @@ impl SendBuf { /// returns how long. pub fn next_step_dur(&self) -> Option { let now = self.origin.elapsed().as_nanos() as u64; + + if now - self.last_send >= self.idle_keepalive_nanos { + // we need a keepalive now, don't wait + return None; + } + if now < self.next_send_at { - Some(std::time::Duration::from_nanos(self.next_send_at - now)) + let need_keepalive_at = + self.idle_keepalive_nanos - (now - self.last_send); + let nanos = + std::cmp::min(need_keepalive_at, self.next_send_at - now); + Some(std::time::Duration::from_nanos(nanos)) } else { None } @@ -44,19 +83,23 @@ impl SendBuf { /// out the next queued block on the low-level websocket. /// Returns true if it did something, false if it did not. pub async fn write_next_queued(&mut self) -> Result { - // check the dur again, just to avoid race conditions - // sending too much data at once + let now = self.origin.elapsed().as_nanos() as u64; + + // first check if we need to keepalive + if now - self.last_send >= self.idle_keepalive_nanos { + let mut data = Vec::with_capacity(32); + data.extend_from_slice(CMD_FLAG); + data.extend_from_slice(b"keep"); + self.raw_send(now, data).await?; + return Ok(true); + } + if self.next_step_dur().is_some() { return Ok(false); } if let Some((_, data)) = self.buf.pop_front() { - let now = self.origin.elapsed().as_nanos() as u64; - - self.next_send_at = std::cmp::max(now, self.next_send_at) - + (data.len() as u64 * self.limit_rate); - - self.ws.send(data).await?; + self.raw_send(now, data).await?; Ok(true) } else { @@ -106,6 +149,16 @@ impl SendBuf { // -- private -- // + async fn raw_send(&mut self, now: u64, data: Vec) -> Result<()> { + self.next_send_at = std::cmp::max(now, self.next_send_at) + + (data.len() as u64 * self.limit_rate); + + self.ws.send(data).await?; + self.last_send = now; + + Ok(()) + } + fn len(&self) -> usize { self.buf.iter().map(|(_, d)| d.len()).sum() } diff --git a/rust/sbd-client/tests/reasonable-rate-limit.rs b/rust/sbd-client/tests/reasonable-rate-limit.rs new file mode 100644 index 0000000..d9eef65 --- /dev/null +++ b/rust/sbd-client/tests/reasonable-rate-limit.rs @@ -0,0 +1,109 @@ +use sbd_client::*; +use sbd_server::*; +use std::sync::Arc; + +async fn get_client( + addrs: &[std::net::SocketAddr], +) -> (SbdClient, String, sbd_client::PubKey, MsgRecv) { + for addr in addrs { + if let Ok(r) = SbdClient::connect_config( + &format!("ws://{addr}"), + &DefaultCrypto::default(), + SbdClientConfig { + allow_plain_text: true, + ..Default::default() + }, + ) + .await + { + return r; + } + } + panic!() +} + +#[tokio::test(flavor = "multi_thread")] +async fn reasonable_rate_limit() { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 10, + limit_ip_kbps: 20, + limit_ip_byte_burst: 100000, + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let (mut c1, _, p1, mut r1) = get_client(server.bind_addrs()).await; + let (mut c2, _, p2, mut r2) = get_client(server.bind_addrs()).await; + + let (rate1, rate2) = + run(10, &mut c1, &p1, &mut r1, &mut c2, &p2, &mut r2).await; + + println!("got {rate1} bps and {rate2} bps"); + + // 20 kbps divided between 2 connections + // we should be in the range of 10000 bps + assert!(rate1 / 10000.0 > 0.5); + assert!(rate1 / 10000.0 < 1.5); + assert!(rate2 / 10000.0 > 0.5); + assert!(rate2 / 10000.0 < 1.5); +} + +const MSG: &[u8; 100] = &[0xdb; 100]; + +async fn run( + iters: usize, + c1: &mut SbdClient, + p1: &sbd_client::PubKey, + r1: &mut MsgRecv, + c2: &mut SbdClient, + p2: &sbd_client::PubKey, + r2: &mut MsgRecv, +) -> (f64, f64) { + let start = tokio::time::Instant::now(); + let mut rate1 = 0.0; + let mut rate2 = 0.0; + tokio::join!( + async { + for _ in 0..iters { + c1.send(&p2, MSG).await.unwrap(); + println!("c1 sent"); + } + }, + async { + for _ in 0..iters { + c2.send(&p1, MSG).await.unwrap(); + println!("c2 sent"); + } + }, + async { + let mut tot = 0; + loop { + let r = r1.recv().await.unwrap(); + assert_eq!(r.pub_key_ref(), &p2.0); + tot += r.message().len(); + println!("r1 got {} bytes", tot); + rate1 += (32 + r.message().len()) as f64; + if tot >= 100 * iters { + break; + } + } + }, + async { + let mut tot = 0; + loop { + let r = r2.recv().await.unwrap(); + assert_eq!(r.pub_key_ref(), &p1.0); + tot += r.message().len(); + println!("r2 got {} bytes", tot); + rate2 += (32 + r.message().len()) as f64; + if tot >= 100 * iters { + break; + } + } + }, + ); + let elapsed = start.elapsed().as_secs_f64(); + (rate1 / elapsed * 8.0, rate2 / elapsed * 8.0) +} diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index 81350e9..9bd6c69 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -26,6 +26,7 @@ hyper = { version = "1.2.0", features = ["http1", "server"], optional = true } [dev-dependencies] escargot = { version = "0.5.10", features = [ "print" ] } rcgen = "0.13.1" +sbd-client = { workspace = true, features = [ "raw_client" ] } serde_json = "1.0.116" tempfile = "3.10.1" tokio = { version = "1.37.0", features = [ "test-util" ] } diff --git a/rust/sbd-server/src/config.rs b/rust/sbd-server/src/config.rs index 3f8e672..03cb9df 100644 --- a/rust/sbd-server/src/config.rs +++ b/rust/sbd-server/src/config.rs @@ -1,7 +1,7 @@ const DEF_IP_DENY_DIR: &str = "."; const DEF_IP_DENY_S: i32 = 600; const DEF_LIMIT_CLIENTS: i32 = 32768; -const DEF_LIMIT_IP_BYTE_NANOS: i32 = 8000; +const DEF_LIMIT_IP_KBPS: i32 = 1000; const DEF_LIMIT_IP_BYTE_BURST: i32 = 16 * 16 * 1024; const DEF_LIMIT_IDLE_MILLIS: i32 = 10_000; @@ -84,19 +84,19 @@ pub struct Config { #[arg(long, default_value_t = DEF_LIMIT_CLIENTS)] pub limit_clients: i32, - /// How often in nanoseconds 1 byte is allowed to be sent from an IP. - /// The default value of 8000 results in ~1 mbps being allowed. + /// Rate limit connections to this kilobits per second. + /// The default value of 1000 obviously limits connections to 1 mbps. /// If the default of 32768 connections were all sending this amount - /// at the same time, the server would need a ~33.6 gbps connection. - /// This value divided by the count of connections from an ip will - /// be sent down to the client for individual rate limit. - #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_NANOS)] - pub limit_ip_byte_nanos: i32, + /// at the same time, the server would need a ~33 gbps connection. + /// The rate limit passed to clients will be divided by the number + /// of open connections for a given ip address. + #[arg(long, default_value_t = DEF_LIMIT_IP_KBPS)] + pub limit_ip_kbps: i32, /// Allow IPs to burst by this byte count. /// If the max message size is 16K, this value must be at least 16K. /// The default value provides 16 * 16K to allow for multiple connections - /// from a single ip address. + /// from a single ip address sending full messages at the same time. #[arg(long, default_value_t = DEF_LIMIT_IP_BYTE_BURST)] pub limit_ip_byte_burst: i32, @@ -123,7 +123,7 @@ impl Default for Config { back_open: Vec::new(), bind_prometheus: None, limit_clients: DEF_LIMIT_CLIENTS, - limit_ip_byte_nanos: DEF_LIMIT_IP_BYTE_NANOS, + limit_ip_kbps: DEF_LIMIT_IP_KBPS, limit_ip_byte_burst: DEF_LIMIT_IP_BYTE_BURST, limit_idle_millis: DEF_LIMIT_IDLE_MILLIS, } @@ -134,6 +134,12 @@ impl Config { pub(crate) fn idle_dur(&self) -> std::time::Duration { std::time::Duration::from_millis(self.limit_idle_millis as u64) } + + /// convert kbps into the nanosecond weight of each byte + /// (easier to rate limit with this value) + pub(crate) fn limit_ip_byte_nanos(&self) -> i32 { + 8_000_000 / self.limit_ip_kbps + } } fn get_styles() -> clap::builder::Styles { diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs index bde6de2..4ed10ae 100644 --- a/rust/sbd-server/src/cslot.rs +++ b/rust/sbd-server/src/cslot.rs @@ -190,10 +190,12 @@ impl CSlot { let rate_send_list = self.insert_and_get_rate_send_list(ip, pk, ws); if let Some(rate_send_list) = rate_send_list { - let mut rate = limit_ip_byte_nanos / rate_send_list.len() as i32; - if rate < 1 { - rate = 1; + let mut rate = + limit_ip_byte_nanos as u64 * rate_send_list.len() as u64; + if rate > i32::MAX as u64 { + rate = i32::MAX as u64; } + let rate = rate as i32; for (uniq, index, weak_ws) in rate_send_list { if let Some(ws) = weak_ws.upgrade() { @@ -278,23 +280,30 @@ async fn top_task( pk, } = uitem { - let i = tokio::select! { - i = recv.recv() => i, + let next_i = tokio::select! { + i = recv.recv() => Some(i), _ = ws_task( &config, &ip_rate, &weak, - ws, + &ws, ip, pk, uniq, index, - ) => recv.recv().await, + ) => None, }; + + ws.close().await; + drop(ws); if let Some(cslot) = weak.upgrade() { cslot.remove(uniq, index); } - i + + match next_i { + Some(i) => i, + None => recv.recv().await, + } } else { recv.recv().await }; @@ -306,7 +315,7 @@ async fn ws_task( config: &Arc, ip_rate: &ip_rate::IpRate, weak_cslot: &WeakCSlot, - ws: Arc>, + ws: &Arc>, ip: Arc, pk: PubKey, uniq: u64, diff --git a/rust/sbd-server/src/ip_rate.rs b/rust/sbd-server/src/ip_rate.rs index df7d851..640635e 100644 --- a/rust/sbd-server/src/ip_rate.rs +++ b/rust/sbd-server/src/ip_rate.rs @@ -17,9 +17,9 @@ impl IpRate { Self { origin: tokio::time::Instant::now(), map: Arc::new(Mutex::new(HashMap::new())), - limit: config.limit_ip_byte_nanos as u64, + limit: config.limit_ip_byte_nanos() as u64, burst: config.limit_ip_byte_burst as u64 - * config.limit_ip_byte_nanos as u64, + * config.limit_ip_byte_nanos() as u64, ip_deny: crate::ip_deny::IpDeny::new(config), } } diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index e74be65..2a73aa6 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -153,7 +153,11 @@ async fn check_accept_connection( let tcp = MaybeTlsStream::Tcp(tcp); let (ws, pub_key, ip) = - ws::WebSocket::upgrade(config.clone(), tcp).await.unwrap(); + match ws::WebSocket::upgrade(config.clone(), tcp).await { + Ok(r) => r, + Err(_) => return, + }; + let ws = Arc::new(ws); if let Some(ip) = ip { @@ -175,7 +179,7 @@ async fn check_accept_connection( if let Some(cslot) = weak_cslot.upgrade() { cslot - .insert(calc_ip, pub_key, ws, config.limit_ip_byte_nanos) + .insert(calc_ip, pub_key, ws, config.limit_ip_byte_nanos()) .await; } }) diff --git a/rust/sbd-server/src/ws/ws_tungstenite.rs b/rust/sbd-server/src/ws/ws_tungstenite.rs index 897606a..0c85199 100644 --- a/rust/sbd-server/src/ws/ws_tungstenite.rs +++ b/rust/sbd-server/src/ws/ws_tungstenite.rs @@ -127,9 +127,10 @@ where /// Send to the websocket. pub async fn send(&self, payload: Payload<'_>) -> Result<()> { - let mut write = self.write.lock().await; use futures::sink::SinkExt; use tokio_tungstenite::tungstenite::protocol::Message; + + let mut write = self.write.lock().await; let v = match payload { Payload::Slice(s) => s.to_vec(), Payload::SliceMut(s) => s.to_vec(), @@ -140,4 +141,11 @@ where write.flush().await.map_err(Error::other)?; Ok(()) } + + /// Close the websocket. + pub async fn close(&self) { + use futures::sink::SinkExt; + + let _ = self.write.lock().await.close().await; + } } diff --git a/rust/sbd-server/tests/rate_limit_enforced.rs b/rust/sbd-server/tests/rate_limit_enforced.rs new file mode 100644 index 0000000..7b08531 --- /dev/null +++ b/rust/sbd-server/tests/rate_limit_enforced.rs @@ -0,0 +1,125 @@ +use sbd_client::raw_client::*; +use sbd_client::*; +use sbd_server::*; +use std::sync::Arc; + +async fn get_client( + pk: &[u8; 32], + addrs: &[std::net::SocketAddr], +) -> (WsRawSend, WsRawRecv) { + use base64::Engine; + + for addr in addrs { + if let Ok(r) = (WsRawConnect { + full_url: format!( + "ws://{addr}/{}", + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(pk) + ), + max_message_size: 100, + allow_plain_text: true, + danger_disable_certificate_check: false, + }) + .connect() + .await + { + return r; + } + } + panic!() +} + +#[tokio::test(flavor = "multi_thread")] +async fn rate_limit_enforced() { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 10, + limit_ip_kbps: 1, + limit_ip_byte_burst: 1000, + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let c1 = DefaultCrypto::default(); + let (mut s1, mut r1) = get_client(c1.pub_key(), server.bind_addrs()).await; + let c2 = DefaultCrypto::default(); + let (mut s2, mut r2) = get_client(c2.pub_key(), server.bind_addrs()).await; + + Handshake::handshake(&mut s1, &mut r1, &c1).await.unwrap(); + Handshake::handshake(&mut s2, &mut r2, &c2).await.unwrap(); + + let mut msg = Vec::with_capacity(32 + 5); + msg.extend_from_slice(c2.pub_key()); + msg.extend_from_slice(b"hello"); + + let start_send_fast_s = Arc::new(tokio::sync::Barrier::new(2)); + let start_send_fast_r = start_send_fast_s.clone(); + + let mut send_slow_complete = false; + let mut send_fast_complete = false; + + tokio::select! { + _ = async { + loop { + if r1.recv().await.is_err() { + eprintln!("R1 RECV ERR"); + break; + } + } + } => (), + _ = async { + // should be able to send on the order of millis + for _ in 0..10 { + tokio::time::sleep(std::time::Duration::from_millis(1)).await; + if s1.send(msg.clone()).await.is_err() { + eprintln!("S1 SLOW SEND ERR"); + break; + } + } + + start_send_fast_s.wait().await; + + // but should get dropped if we start spamming + for _ in 0..100 { + if s1.send(msg.clone()).await.is_err() { + eprintln!("S1 FAST SEND ERR"); + break; + } + } + + // the receive side is what triggers this to exit + std::future::pending::<()>().await; + } => (), + _ = async { + for _ in 0..10 { + let r = match r2.recv().await { + Ok(r) => r, + Err(_) => { + eprintln!("R2 SLOW RECV ERR"); + break; + } + }; + assert_eq!(32 + 5, r.len()); + } + + send_slow_complete = true; + start_send_fast_r.wait().await; + + for _ in 0..100 { + let r = match r2.recv().await { + Ok(r) => r, + Err(_) => { + eprintln!("R2 FAST RECV ERR"); + break; + } + }; + assert_eq!(32 + 5, r.len()); + } + + send_fast_complete = true; + } => (), + } + + assert!(send_slow_complete); + assert!(!send_fast_complete); +} From a10639cc92ea7d84653c538cb7c60dd091160c47 Mon Sep 17 00:00:00 2001 From: David Braden Date: Fri, 19 Apr 2024 14:24:59 -0600 Subject: [PATCH 08/33] Create spec.md --- spec.md | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 spec.md diff --git a/spec.md b/spec.md new file mode 100644 index 0000000..baf9265 --- /dev/null +++ b/spec.md @@ -0,0 +1,123 @@ +# SBD Spec + +This SBD spec defines the protocol for SBD servers and clients to communicate. +SBD is a simple websocket-based message relay protocol. + +## 1. Websocket Stack + +The SBD protocol is built upon websockets. + +### 1.1. Websocket Configuration + +#### 1.1.1. Message and Frame Size + +The maximum SBD message size (including 32 byte header) is 16000 bytes. + +SBD clients and servers MAY set the max message size in the websocket library to 16000 to help enforce this. + +The maximum frame size MUST be set larger than 16000 so that sbd messages always fit in a single websocket frame. + +## 2. Cryptography + +### 2.1. TLS + +Generally over the WAN SBD servers SHOULD be available over TLS (wss://), and on a LAN without (ws://). + +### 2.2. Ed25519 + +Clients will be identified by ed25519 public key. Client sessions will be validated by ed25519 signature. + +## 3. Protocol + +### 3.1. Connect Path + +Clients MUST specify exactly 1 http path item on the websocket connection url. +This item must be the base64url encoded public key that this client will be identified by. +This public key MUST be unique to this new connection. + +### 3.2. Messages + +#### 3.2.1. Header + +SBD messages always contain first a 32 byte header. Messages less than 32 bytes are invalid. + +If the header starts with 28 zero bytes, the message is a "command". Otherwise, the message is a "forward". + +If the header is a "command" the next four literal ascii bytes are interpreted as the command type: + +- `lbrt` - limit byte nanos - 4 byte i32be limit - server sent +- `lidl` - limit idle millis - 4 byte i32be limit - server sent +- `areq` - authentication request - 32 byte nonce - server sent +- `ares` - authentication response - 64 byte signature - client sent +- `srdy` - server ready - no additional data - server sent +- `keep` - keepalive - no additional data - client sent + +If the header is a "forward" type, the 32 byte header is interpreted as the public key to forward the message to. +The remaining bytes in the message are the data to forward. + +#### 3.2.2. Forward + +When a client sends a "forward" message to the server, the first 32 bytes represent the +peer the message should be forwarded to. + +When a server sends a "forward" message to the client the message should be forwarded to +the first 32 bytes will be altered to represent the peer from which the message originated. + +#### 3.2.3. Flow + +- The server MUST send `areq` with a random nonce once for every new opened connection. + The server MAY send any limit messages before or after this `areq`, but it MUST come before the `srdy`. +- The client MUST respond with a signature over the nonce by the private key associated with the public key + sent in the url path segment websocket request +- If the signature is valid the server MUST send the `srdy` message. +- After receiving `srdy` the client MAY begin sending "forward" type messages. +- At any point, the server MAY send an updated `lbrt` message. +- At any point after `srdy`, the server MAY send "forward" type messages to the client. +- At any point after `srdy`, the client MAY send a `keep` message. + +```mermaid +sequenceDiagram + +S->>C: lbrt +S->>C: lidl +S->>C: areq +C->>S: ares +S->>C: srdy +C->>S: "forward" +S->>C: "forward" +S->>C: lbrt +C->>S: keep +``` + +#### 3.2.4. `lidl` and Keepalive + +The `lidl` "Limit Idle Millis" message is the millisecond count the server will keep a connection around +without having seen any messages sent from that client. If a client does not have any forward messages +to send within this time period and wishes to keep the connection open they SHOULD send a `keep` message +to maintain the connection. Note this keep message will be counted against rate limiting. + +#### 3.2.5. `lbrt` and rate limiting + +The `lbrt` "Limit Byte Nanos" message indicates the nanoseconds of rate limiting used up by a single byte sent. +This is intuitively backwards of a "limit" because higher values indicate you need to send data more slowly, +but this direction is easier to work with in code. That is, if `lbrt` is 1, you can send one byte every nanosecond. +A more reasonable `librt` value of 8000 means you can send 1 byte every 8000 nanoseconds. Or more reasonably, +if you send a 16000 byte message, you should wait 8000 * 16000 nanoseconds before sending the next message. + +A server MUST provide a "burst" grace window to account for message size. + +A server MAY track rate limiting by some metric other than individual connection. IP address, for example. +Then, if additional connections are established from the same other metric, all connections could be notified +of needing to send data more slowly. + +#### 3.2.6. Extensibility + +In order to make this protocol extensible without versioning, clients and servers MUST ignore unknown command types. +(With the exception that servers should still count the raw bytes in rate limiting.) + +### 3.3. Violations + +If a server receives an invalid message from a client it MUST immediately drop the connection with no closing frame. + +If a server receives a message that violates the rate limit, the connection MUST similarly be dropped with no closing frame. +The server MAY also block connections (perhaps by IP address) for an unspecified amount of time. From 69bccf919b481e924cb9fb46b3287ec06de5a5c2 Mon Sep 17 00:00:00 2001 From: David Braden Date: Fri, 19 Apr 2024 14:34:16 -0600 Subject: [PATCH 09/33] Create README.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..6bfeaaa --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +Simple websocket-based message relay servers and clients. + +- [Rust Reference Server](rust/sbd-server) +- [Rust Reference Client](rust/sbd-client) +- [Autobahn-Style Server Test Suite](rust/sbd-o-bahn-server-tester) +- [Protocol Spec](spec.md) From 7e6161b76f49ef78902c566eaed3ade7f0db1553 Mon Sep 17 00:00:00 2001 From: neonphog Date: Fri, 19 Apr 2024 16:04:47 -0600 Subject: [PATCH 10/33] close and benchmark --- Cargo.lock | 395 ++++++++++++++++++++++++++++++ Cargo.toml | 6 +- rust/sbd-bench/Cargo.toml | 17 ++ rust/sbd-bench/benches/thru.rs | 23 ++ rust/sbd-bench/src/lib.rs | 136 ++++++++++ rust/sbd-client/src/lib.rs | 9 +- rust/sbd-client/src/raw_client.rs | 6 + rust/sbd-client/src/send_buf.rs | 5 + rust/sbd-server/src/config.rs | 6 + rust/sbd-server/src/cslot.rs | 18 +- rust/sbd-server/src/ip_rate.rs | 7 + rust/sbd-server/src/lib.rs | 4 +- 12 files changed, 618 insertions(+), 14 deletions(-) create mode 100644 rust/sbd-bench/Cargo.toml create mode 100644 rust/sbd-bench/benches/thru.rs create mode 100644 rust/sbd-bench/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 3a59285..4847a1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,21 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.13" @@ -125,6 +140,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "byteorder" version = "1.5.0" @@ -137,6 +158,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.0.92" @@ -149,6 +176,33 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.4" @@ -227,6 +281,75 @@ dependencies = [ "libc", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -325,6 +448,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "either" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" + [[package]] name = "errno" version = "0.3.8" @@ -510,6 +639,16 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "heck" version = "0.5.0" @@ -614,12 +753,41 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "libc" version = "0.2.153" @@ -680,6 +848,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -705,6 +882,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + [[package]] name = "openssl-probe" version = "0.1.5" @@ -798,6 +981,34 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db23d408679286588f4d4644f965003d056e3dd5abcaaa938116871d7ce2fee7" +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -858,6 +1069,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "rcgen" version = "0.13.1" @@ -880,6 +1111,35 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + [[package]] name = "ring" version = "0.17.8" @@ -983,6 +1243,26 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "sbd-bench" +version = "0.0.1-alpha" +dependencies = [ + "base64 0.22.0", + "criterion", + "sbd-client", + "sbd-server", + "tokio", +] + [[package]] name = "sbd-client" version = "0.0.1-alpha" @@ -1280,6 +1560,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -1440,6 +1730,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -1455,6 +1755,70 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.1" @@ -1464,6 +1828,37 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 6f8a30d..281e6ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,14 @@ [workspace] members = [ + "rust/sbd-bench", "rust/sbd-client", - "rust/sbd-server", "rust/sbd-o-bahn-server-tester", + "rust/sbd-server", ] resolver = "2" [workspace.dependencies] +sbd-bench = { version = "0.0.1-alpha", path = "rust/sbd-bench" } sbd-client = { version = "0.0.1-alpha", path = "rust/sbd-client" } -sbd-server = { version = "0.0.1-alpha", path = "rust/sbd-server" } sbd-o-bahn-server-tester = { version = "0.0.1-alpha", path = "rust/sbd-o-bahn-server-tester" } +sbd-server = { version = "0.0.1-alpha", path = "rust/sbd-server" } diff --git a/rust/sbd-bench/Cargo.toml b/rust/sbd-bench/Cargo.toml new file mode 100644 index 0000000..6dba90f --- /dev/null +++ b/rust/sbd-bench/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "sbd-bench" +version = "0.0.1-alpha" +edition = "2021" + +[dependencies] +base64 = "0.22.0" +sbd-server = { workspace = true } +sbd-client = { workspace = true, features = [ "raw_client" ] } +tokio = { version = "1.37.0", features = [ "full" ] } + +[dev-dependencies] +criterion = { version = "0.5.1", features = [ "async_tokio" ] } + +[[bench]] +name = "thru" +harness = false diff --git a/rust/sbd-bench/benches/thru.rs b/rust/sbd-bench/benches/thru.rs new file mode 100644 index 0000000..65dc0f9 --- /dev/null +++ b/rust/sbd-bench/benches/thru.rs @@ -0,0 +1,23 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use sbd_bench::ThruBenchmark; +use std::sync::Arc; +use tokio::sync::Mutex; + +fn criterion_benchmark(c: &mut Criterion) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let test = Arc::new(Mutex::new(rt.block_on(ThruBenchmark::new()))); + let test = &test; + + c.bench_function("thru", |b| { + b.to_async(&rt).iter(|| async move { + test.lock().await.iter().await; + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/rust/sbd-bench/src/lib.rs b/rust/sbd-bench/src/lib.rs new file mode 100644 index 0000000..d41dd97 --- /dev/null +++ b/rust/sbd-bench/src/lib.rs @@ -0,0 +1,136 @@ +use sbd_client::raw_client::*; +use sbd_client::*; +use sbd_server::*; +use std::sync::Arc; + +pub struct ThruBenchmark { + _server: SbdServer, + c1: DefaultCrypto, + s1: WsRawSend, + r1: WsRawRecv, + c2: DefaultCrypto, + s2: WsRawSend, + r2: WsRawRecv, + v1: Option>, + v2: Option>, +} + +impl ThruBenchmark { + pub async fn new() -> Self { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 100, + disable_rate_limiting: true, + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let c1 = DefaultCrypto::default(); + let (mut s1, mut r1) = c(c1.pub_key(), server.bind_addrs()).await; + + let c2 = DefaultCrypto::default(); + let (mut s2, mut r2) = c(c2.pub_key(), server.bind_addrs()).await; + + Handshake::handshake(&mut s1, &mut r1, &c1).await.unwrap(); + Handshake::handshake(&mut s2, &mut r2, &c2).await.unwrap(); + + Self { + _server: server, + c1, + s1, + r1, + c2, + s2, + r2, + v1: None, + v2: None, + } + } + + pub async fn iter(&mut self) { + let Self { + c1, + s1, + r1, + c2, + s2, + r2, + v1, + v2, + .. + } = self; + + let mut b1 = v1.take().unwrap_or_else(|| vec![0xdb; 1000]); + let mut b2 = v2.take().unwrap_or_else(|| vec![0xca; 1000]); + + tokio::join!( + async { + b1[0..32].copy_from_slice(c2.pub_key()); + s1.send(b1).await.unwrap(); + }, + async { + b2[0..32].copy_from_slice(c1.pub_key()); + s2.send(b2).await.unwrap(); + }, + async { + let b2 = r1.recv().await.unwrap(); + assert_eq!(1000, b2.len()); + *v2 = Some(b2); + }, + async { + let b1 = r2.recv().await.unwrap(); + assert_eq!(1000, b1.len()); + *v1 = Some(b1); + }, + ); + } +} + +async fn c( + pk: &[u8; 32], + addrs: &[std::net::SocketAddr], +) -> (WsRawSend, WsRawRecv) { + use base64::Engine; + + for addr in addrs { + if let Ok(r) = (WsRawConnect { + full_url: format!( + "ws://{addr}/{}", + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(pk) + ), + max_message_size: 1000, + allow_plain_text: true, + danger_disable_certificate_check: false, + }) + .connect() + .await + { + return r; + } + } + panic!() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn thru_bench_test() { + let mut b = ThruBenchmark::new().await; + + // warmup + for _ in 0..10 { + b.iter().await; + } + + let start = tokio::time::Instant::now(); + for _ in 0..100 { + b.iter().await; + } + let elapsed = start.elapsed(); + + println!("{} nanos per iter", elapsed.as_nanos() / 100); + } +} diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index d8c60af..07c10d9 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -270,7 +270,7 @@ impl SbdClient { } } - // TODO - shutdown + send_buf2.lock().await.close().await; }); let send_buf2 = send_buf.clone(); @@ -291,7 +291,7 @@ impl SbdClient { } } - // TODO - shutdown + send_buf2.lock().await.close().await; }); let this = Self { @@ -308,6 +308,11 @@ impl SbdClient { )) } + /// Close the connection. + pub async fn close(&self) { + self.send_buf.lock().await.close().await; + } + /// Send a message to a peer. pub async fn send(&self, peer: &PubKey, data: &[u8]) -> Result<()> { self.send_buf.lock().await.send(peer, data).await diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index 2c58a5b..77e611c 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -117,6 +117,12 @@ impl WsRawSend { self.send.flush().await.map_err(Error::other)?; Ok(()) } + + /// Close the connection. + pub async fn close(&mut self) { + use futures::sink::SinkExt; + let _ = self.send.close().await; + } } /// The receive half of the websocket connection. diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index bafea90..3729a4e 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -41,6 +41,11 @@ impl SendBuf { this } + /// Close the connection. + pub async fn close(&mut self) { + self.ws.close().await; + } + /// We received a new rate limit from the server, update our records. pub fn new_rate_limit(&mut self, limit: u64) { if limit < self.limit_rate { diff --git a/rust/sbd-server/src/config.rs b/rust/sbd-server/src/config.rs index 03cb9df..f8c0621 100644 --- a/rust/sbd-server/src/config.rs +++ b/rust/sbd-server/src/config.rs @@ -84,6 +84,11 @@ pub struct Config { #[arg(long, default_value_t = DEF_LIMIT_CLIENTS)] pub limit_clients: i32, + /// If set, rate-limiting will be disabled on the server, + /// and clients will be informed they have an 8gbps rate limit. + #[arg(long)] + pub disable_rate_limiting: bool, + /// Rate limit connections to this kilobits per second. /// The default value of 1000 obviously limits connections to 1 mbps. /// If the default of 32768 connections were all sending this amount @@ -123,6 +128,7 @@ impl Default for Config { back_open: Vec::new(), bind_prometheus: None, limit_clients: DEF_LIMIT_CLIENTS, + disable_rate_limiting: false, limit_ip_kbps: DEF_LIMIT_IP_KBPS, limit_ip_byte_burst: DEF_LIMIT_IP_BYTE_BURST, limit_idle_millis: DEF_LIMIT_IDLE_MILLIS, diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs index 4ed10ae..214405a 100644 --- a/rust/sbd-server/src/cslot.rs +++ b/rust/sbd-server/src/cslot.rs @@ -182,20 +182,24 @@ impl CSlot { pub async fn insert( &self, + config: &Config, ip: Arc, pk: PubKey, ws: Arc>, - limit_ip_byte_nanos: i32, ) { let rate_send_list = self.insert_and_get_rate_send_list(ip, pk, ws); if let Some(rate_send_list) = rate_send_list { - let mut rate = - limit_ip_byte_nanos as u64 * rate_send_list.len() as u64; - if rate > i32::MAX as u64 { - rate = i32::MAX as u64; - } - let rate = rate as i32; + let rate = if config.disable_rate_limiting { + 1 + } else { + let mut rate = config.limit_ip_byte_nanos() as u64 + * rate_send_list.len() as u64; + if rate > i32::MAX as u64 { + rate = i32::MAX as u64; + } + rate as i32 + }; for (uniq, index, weak_ws) in rate_send_list { if let Some(ws) = weak_ws.upgrade() { diff --git a/rust/sbd-server/src/ip_rate.rs b/rust/sbd-server/src/ip_rate.rs index 640635e..8f4dfe2 100644 --- a/rust/sbd-server/src/ip_rate.rs +++ b/rust/sbd-server/src/ip_rate.rs @@ -6,6 +6,7 @@ type Map = HashMap, u64>; pub struct IpRate { origin: tokio::time::Instant, map: Arc>, + disabled: bool, limit: u64, burst: u64, ip_deny: crate::ip_deny::IpDeny, @@ -17,6 +18,7 @@ impl IpRate { Self { origin: tokio::time::Instant::now(), map: Arc::new(Mutex::new(HashMap::new())), + disabled: config.disable_rate_limiting, limit: config.limit_ip_byte_nanos() as u64, burst: config.limit_ip_byte_burst as u64 * config.limit_ip_byte_nanos() as u64, @@ -55,6 +57,10 @@ impl IpRate { ip: &Arc, bytes: usize, ) -> bool { + if self.disabled { + return true; + } + // multiply by our rate allowed per byte let rate_add = bytes as u64 * self.limit; @@ -95,6 +101,7 @@ mod tests { IpRate { origin: tokio::time::Instant::now(), map: Arc::new(Mutex::new(HashMap::new())), + disabled: false, limit, burst, ip_deny: crate::ip_deny::IpDeny::new(Arc::new( diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index 2a73aa6..e652aed 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -178,9 +178,7 @@ async fn check_accept_connection( } if let Some(cslot) = weak_cslot.upgrade() { - cslot - .insert(calc_ip, pub_key, ws, config.limit_ip_byte_nanos()) - .await; + cslot.insert(&config, calc_ip, pub_key, ws).await; } }) .await; From 2ad552fde4761db3b3a1bd12963599ed2756eb53 Mon Sep 17 00:00:00 2001 From: neonphog Date: Fri, 19 Apr 2024 21:46:23 -0600 Subject: [PATCH 11/33] client tester --- Cargo.lock | 17 ++ Cargo.toml | 1 + README.md | 1 + rust/sbd-client/Cargo.toml | 2 + .../examples/client-o-bahn-runner.rs | 103 +++++++++ rust/sbd-client/tests/suite.rs | 27 +++ rust/sbd-o-bahn-client-tester/Cargo.toml | 9 + .../src/bin/sbd-o-bahn-client-tester-bin.rs | 13 ++ rust/sbd-o-bahn-client-tester/src/it.rs | 110 ++++++++++ rust/sbd-o-bahn-client-tester/src/it/it_1.rs | 42 ++++ rust/sbd-o-bahn-client-tester/src/lib.rs | 197 ++++++++++++++++++ ...uite-runner.rs => server-o-bahn-runner.rs} | 0 rust/sbd-server/tests/suite.rs | 2 +- 13 files changed, 523 insertions(+), 1 deletion(-) create mode 100644 rust/sbd-client/examples/client-o-bahn-runner.rs create mode 100644 rust/sbd-client/tests/suite.rs create mode 100644 rust/sbd-o-bahn-client-tester/Cargo.toml create mode 100644 rust/sbd-o-bahn-client-tester/src/bin/sbd-o-bahn-client-tester-bin.rs create mode 100644 rust/sbd-o-bahn-client-tester/src/it.rs create mode 100644 rust/sbd-o-bahn-client-tester/src/it/it_1.rs create mode 100644 rust/sbd-o-bahn-client-tester/src/lib.rs rename rust/sbd-server/examples/{test-suite-runner.rs => server-o-bahn-runner.rs} (100%) diff --git a/Cargo.lock b/Cargo.lock index 4847a1f..8cf7fac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,6 +661,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.1.0" @@ -1269,7 +1275,9 @@ version = "0.0.1-alpha" dependencies = [ "base64 0.22.0", "ed25519-dalek", + "escargot", "futures", + "hex", "rand", "rustls", "rustls-native-certs", @@ -1280,6 +1288,15 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "sbd-o-bahn-client-tester" +version = "0.0.1-alpha" +dependencies = [ + "hex", + "sbd-server", + "tokio", +] + [[package]] name = "sbd-o-bahn-server-tester" version = "0.0.1-alpha" diff --git a/Cargo.toml b/Cargo.toml index 281e6ea..b900dbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "rust/sbd-bench", "rust/sbd-client", "rust/sbd-o-bahn-server-tester", + "rust/sbd-o-bahn-client-tester", "rust/sbd-server", ] resolver = "2" diff --git a/README.md b/README.md index 6bfeaaa..d50f4c5 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,5 @@ Simple websocket-based message relay servers and clients. - [Rust Reference Server](rust/sbd-server) - [Rust Reference Client](rust/sbd-client) - [Autobahn-Style Server Test Suite](rust/sbd-o-bahn-server-tester) +- [Autobahn-Style Client Test Suite](rust/sbd-o-bahn-client-tester) - [Protocol Spec](spec.md) diff --git a/rust/sbd-client/Cargo.toml b/rust/sbd-client/Cargo.toml index f086af7..c0b3f93 100644 --- a/rust/sbd-client/Cargo.toml +++ b/rust/sbd-client/Cargo.toml @@ -21,6 +21,8 @@ rand = { version = "0.8.5", optional = true } webpki-roots = "0.26.1" [dev-dependencies] +escargot = { version = "0.5.10", features = [ "print" ] } +hex = "0.4.3" tokio = { version = "1.37.0", features = [ "full" ] } sbd-server = { workspace = true } diff --git a/rust/sbd-client/examples/client-o-bahn-runner.rs b/rust/sbd-client/examples/client-o-bahn-runner.rs new file mode 100644 index 0000000..191b199 --- /dev/null +++ b/rust/sbd-client/examples/client-o-bahn-runner.rs @@ -0,0 +1,103 @@ +use sbd_client::*; +use std::collections::{HashMap, VecDeque}; + +enum ConCmd { + Close, + Send(PubKey, Vec), +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() { + println!("CMD/READY"); + + let mut lines = tokio::io::AsyncBufReadExt::lines( + tokio::io::BufReader::new(tokio::io::stdin()), + ); + + let mut con_map = HashMap::new(); + + while let Ok(Some(line)) = lines.next_line().await { + let mut parts = line.split("/").collect::>(); + if parts.pop_front().unwrap() != "CMD" { + panic!(); + } + match parts.pop_front().unwrap() { + "CONNECT" => { + let id: usize = parts.pop_front().unwrap().parse().unwrap(); + let (s, r) = tokio::sync::mpsc::unbounded_channel(); + con_map.insert(id, s); + tokio::task::spawn(spawn_con( + id, + r, + parts.into_iter().map(|s| s.to_string()).collect(), + )); + } + "SEND" => { + let id: usize = parts.pop_front().unwrap().parse().unwrap(); + let pk = hex::decode(parts.pop_front().unwrap()).unwrap(); + let msg = hex::decode(parts.pop_front().unwrap()).unwrap(); + if let Some(s) = con_map.get(&id) { + let _ = s.send(ConCmd::Send( + PubKey(pk.try_into().unwrap()), + msg, + )); + } + } + "CLOSE" => { + let id: usize = parts.pop_front().unwrap().parse().unwrap(); + if let Some(s) = con_map.get(&id) { + let _ = s.send(ConCmd::Close); + } + } + oth => panic!("unhandled: {oth}"), + } + } +} + +async fn connect(addrs: &[String]) -> (SbdClient, PubKey, MsgRecv) { + for addr in addrs { + if let Ok(c) = SbdClient::connect_config( + &format!("ws://{addr}"), + &DefaultCrypto::default(), + SbdClientConfig { + allow_plain_text: true, + ..Default::default() + }, + ) + .await + { + return (c.0, c.2, c.3); + } + } + panic!() +} + +async fn spawn_con( + id: usize, + mut r: tokio::sync::mpsc::UnboundedReceiver, + addrs: Vec, +) { + let (cli, pk, mut rcv) = connect(addrs.as_slice()).await; + tokio::task::spawn(async move { + while let Some(data) = rcv.recv().await { + println!( + "CMD/RECV/{id}/{}/{}", + hex::encode(data.pub_key_ref()), + hex::encode(data.message()), + ); + } + println!("CMD/CLOSE/{id}"); + }); + println!("CMD/CONNECT/{id}/{}", hex::encode(&pk.0)); + while let Some(cmd) = r.recv().await { + match cmd { + ConCmd::Close => break, + ConCmd::Send(dest, msg) => { + if cli.send(&dest, &msg).await.is_err() { + break; + } + } + } + } + cli.close().await; +} diff --git a/rust/sbd-client/tests/suite.rs b/rust/sbd-client/tests/suite.rs new file mode 100644 index 0000000..44e067c --- /dev/null +++ b/rust/sbd-client/tests/suite.rs @@ -0,0 +1,27 @@ +#[test] +fn suite() { + println!("BUILDING example test-suite-runner IN RELEASE MODE"); + let server = escargot::CargoBuild::new() + .example("client-o-bahn-runner") + .release() + .current_target() + .run() + .unwrap(); + + println!("BUILDING sbd-server-test-suite IN RELEASE MODE"); + let suite = escargot::CargoBuild::new() + .bin("sbd-o-bahn-client-tester-bin") + .manifest_path("../sbd-o-bahn-client-tester/Cargo.toml") + .release() + .current_target() + .run() + .unwrap(); + + println!("RUNNING the test suite {:?}", suite.path()); + assert!(suite + .command() + .arg(server.path()) + .status() + .unwrap() + .success()); +} diff --git a/rust/sbd-o-bahn-client-tester/Cargo.toml b/rust/sbd-o-bahn-client-tester/Cargo.toml new file mode 100644 index 0000000..3a44ac3 --- /dev/null +++ b/rust/sbd-o-bahn-client-tester/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "sbd-o-bahn-client-tester" +version = "0.0.1-alpha" +edition = "2021" + +[dependencies] +hex = "0.4.3" +sbd-server = { workspace = true } +tokio = { version = "1.37.0", features = [ "full" ] } diff --git a/rust/sbd-o-bahn-client-tester/src/bin/sbd-o-bahn-client-tester-bin.rs b/rust/sbd-o-bahn-client-tester/src/bin/sbd-o-bahn-client-tester-bin.rs new file mode 100644 index 0000000..5a03bad --- /dev/null +++ b/rust/sbd-o-bahn-client-tester/src/bin/sbd-o-bahn-client-tester-bin.rs @@ -0,0 +1,13 @@ +#[tokio::main(flavor = "multi_thread")] +async fn main() { + let mut args = std::env::args_os(); + args.next().unwrap(); + let result = sbd_o_bahn_client_tester::run( + args.next().expect("Expected Sbd Client Suite Runner"), + ) + .await; + println!("{result:#?}"); + if !result.failed.is_empty() { + panic!("TEST FAILED"); + } +} diff --git a/rust/sbd-o-bahn-client-tester/src/it.rs b/rust/sbd-o-bahn-client-tester/src/it.rs new file mode 100644 index 0000000..879d045 --- /dev/null +++ b/rust/sbd-o-bahn-client-tester/src/it.rs @@ -0,0 +1,110 @@ +use std::future::Future; +use std::io::{Error, Result}; + +use crate::{Client, Report}; + +macro_rules! expect { + ($h:ident, $cond:expr, $note:literal) => { + $h.expect(file!(), line!(), $cond, $note) + }; +} + +pub struct Conn<'h> { + client: &'h Client, + id: u64, + pk: [u8; 32], + r: tokio::sync::Mutex< + tokio::sync::mpsc::UnboundedReceiver<([u8; 32], Vec)>, + >, +} + +impl Conn<'_> { + pub fn pub_key(&self) -> &[u8; 32] { + &self.pk + } + + pub async fn recv(&self) -> Option<([u8; 32], Vec)> { + self.r.lock().await.recv().await + } + + pub async fn send(&self, pk: &[u8; 32], msg: &[u8]) { + self.client.send(self.id, pk, msg).await; + } +} + +/// Utilities for helping with the test. +pub struct TestHelper<'h> { + client: &'h Client, + err_list: Vec, + report: Report, +} + +impl<'h> TestHelper<'h> { + fn new(client: &'h Client) -> Self { + Self { + client, + err_list: Vec::new(), + report: Report::default(), + } + } + + fn into_report(self) -> Report { + self.report + } + + /// expect a condition to be true + pub fn expect( + &mut self, + file: &'static str, + line: u32, + cond: bool, + note: &'static str, + ) { + if !cond { + self.err_list.push(format!("{file}:{line}: failed: {note}")); + } + } + + pub async fn connect(&self, addrs: &[std::net::SocketAddr]) -> Conn<'h> { + let (id, pk, r) = self.client.connect(addrs).await; + Conn { + client: self.client, + id, + pk, + r: tokio::sync::Mutex::new(r), + } + } +} + +/// Test definition. +pub trait It { + const NAME: &'static str; + + fn exec(helper: &mut TestHelper) -> impl Future>; +} + +pub mod it_1; + +/// Execute the full test suite. +pub async fn exec_all(client: &mut Client) -> Report { + let mut helper = TestHelper::new(client); + + exec_one::(&mut helper).await; + + helper.into_report() +} + +async fn exec_one<'h, T: It>(helper: &mut TestHelper<'h>) { + helper.err_list.clear(); + match T::exec(helper).await { + Ok(_) => { + helper.report.passed.push(T::NAME.to_string()); + } + Err(err) => { + helper.err_list.push(err.to_string()); + let err = format!("{:?}", helper.err_list); + helper.err_list.clear(); + helper.report.failed.push((T::NAME.to_string(), err)); + } + } +} diff --git a/rust/sbd-o-bahn-client-tester/src/it/it_1.rs b/rust/sbd-o-bahn-client-tester/src/it/it_1.rs new file mode 100644 index 0000000..0002b84 --- /dev/null +++ b/rust/sbd-o-bahn-client-tester/src/it/it_1.rs @@ -0,0 +1,42 @@ +use super::*; +use sbd_server::*; +use std::sync::Arc; + +/// test 1 +pub struct It1; + +impl It for It1 { + const NAME: &'static str = "sanity"; + + fn exec(helper: &mut TestHelper) -> impl Future> { + async { + let server = SbdServer::new(Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 100, + ..Default::default() + })) + .await?; + + let (c1, c2) = tokio::join!( + helper.connect(server.bind_addrs()), + helper.connect(server.bind_addrs()), + ); + + tokio::join!( + c1.send(c2.pub_key(), b"hello"), + c2.send(c1.pub_key(), b"world"), + ); + + let (r1, r2) = tokio::join!(c1.recv(), c2.recv()); + let r1 = r1.ok_or(Error::other("closed"))?; + let r2 = r2.ok_or(Error::other("closed"))?; + + expect!(helper, &r1.0 == c2.pub_key(), "recv from r2"); + expect!(helper, r1.1 == b"world", "recv from r2"); + expect!(helper, &r2.0 == c1.pub_key(), "recv from r2"); + expect!(helper, r1.1 == b"hello", "recv from r2"); + + Ok(()) + } + } +} diff --git a/rust/sbd-o-bahn-client-tester/src/lib.rs b/rust/sbd-o-bahn-client-tester/src/lib.rs new file mode 100644 index 0000000..c3b9582 --- /dev/null +++ b/rust/sbd-o-bahn-client-tester/src/lib.rs @@ -0,0 +1,197 @@ +#![deny(missing_docs)] +// uhhh... clippy... +#![allow(clippy::manual_async_fn)] + +//! Test suite for sbd client compliance. +//! +//! The command supplied to the run function must: +//! - Print on stdout: `CMD/READY` when it is ready to receive commands. +//! - Listen on stdin for: +//! - `CMD/CONNECT/id/` where id is a numeric identifier, +//! addr-list is a slash separated list of addresses (ip:port) +//! e.g. `CMD/CONNECT/42/127.0.0.1:44556/[::1]:44557`. +//! - `CMD/SEND/id//` where msg-hex is hex encoded bytes to +//! be sent to the hex encoded pubkey. +//! - `CMD/CLOSE/id/` close the connection and print a response close +//! - Write to stdout: +//! - `CMD/CONNECT/` where pubkey is a hex encoded +//! pubkey of the client that was connected. +//! - `CMD/RECV/id//` where msg-hex is hex encoded bytes +//! received from the remote hex encoded pubkey peer. +//! - `CMD/CLOSE/id/` if the connection closes (including when +//! the listen close command was called). + +use std::collections::HashMap; +use std::io::Result; +use tokio::io::AsyncBufReadExt; + +mod it; + +/// Results of the test suite run. +#[derive(Debug, Default)] +pub struct Report { + /// The names of tests that pass. + pub passed: Vec, + + /// Failed tests: (Name, Notes). + pub failed: Vec<(String, String)>, +} + +/// Run the test suite. +pub async fn run>(cmd: S) -> Report { + let mut client = Client::spawn(cmd).await.unwrap(); + + it::exec_all(&mut client).await +} + +enum ReadTask { + Client( + u64, + #[allow(clippy::type_complexity)] + tokio::sync::oneshot::Sender<( + [u8; 32], + tokio::sync::mpsc::UnboundedReceiver<([u8; 32], Vec)>, + )>, + ), + Line(String), +} + +struct Client { + _child: tokio::process::Child, + stdin: tokio::sync::Mutex, + sender: tokio::sync::mpsc::UnboundedSender, +} + +impl Client { + pub async fn spawn>(cmd: S) -> Result { + let mut cmd = tokio::process::Command::new(cmd); + cmd.kill_on_drop(true) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()); + + println!("RUNNING {cmd:?}"); + let mut child = cmd.spawn()?; + let stdin = child.stdin.take().unwrap(); + + let mut stdout = + tokio::io::BufReader::new(child.stdout.take().unwrap()).lines(); + + if let Some(line) = stdout.next_line().await? { + if line != "CMD/READY" { + panic!("unexpected: {line}"); + } + } else { + panic!("no stdout"); + } + + let (s, mut r) = tokio::sync::mpsc::unbounded_channel(); + + { + let s = s.clone(); + tokio::task::spawn(async move { + while let Ok(Some(line)) = stdout.next_line().await { + if s.send(ReadTask::Line(line)).is_err() { + break; + } + } + }); + } + + tokio::task::spawn(async move { + let mut pre_connect_map = HashMap::new(); + let mut con_map = HashMap::new(); + while let Some(r) = r.recv().await { + match r { + ReadTask::Client(id, s) => { + pre_connect_map.insert(id, s); + } + ReadTask::Line(line) => { + let parts = line.split('/').collect::>(); + if parts[0] != "CMD" { + panic!(); + } + match parts[1] { + "CONNECT" => { + let id: u64 = parts[2].parse().unwrap(); + let pk = hex::decode(parts[3]).unwrap(); + if let Some(s) = pre_connect_map.remove(&id) { + let (ms, mr) = + tokio::sync::mpsc::unbounded_channel(); + con_map.insert(id, ms); + let _ = + s.send((pk.try_into().unwrap(), mr)); + } + } + "RECV" => { + let id: u64 = parts[2].parse().unwrap(); + let pk = hex::decode(parts[3]).unwrap(); + let msg = hex::decode(parts[4]).unwrap(); + if let Some(s) = con_map.get(&id) { + let _ = + s.send((pk.try_into().unwrap(), msg)); + } + } + "CLOSE" => { + let id: u64 = parts[2].parse().unwrap(); + con_map.remove(&id); + } + oth => panic!("unhandled: {oth}"), + } + } + } + } + }); + + println!("GOT CMD/READY"); + + Ok(Self { + _child: child, + stdin: tokio::sync::Mutex::new(stdin), + sender: s, + }) + } + + pub async fn connect( + &self, + addrs: &[std::net::SocketAddr], + ) -> ( + u64, + [u8; 32], + tokio::sync::mpsc::UnboundedReceiver<([u8; 32], Vec)>, + ) { + use tokio::io::AsyncWriteExt; + static ID: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(1); + let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let mut msg = format!("CMD/CONNECT/{id}"); + for addr in addrs { + msg.push_str(&format!("/{addr}")); + } + msg.push('\n'); + + let (s, r) = tokio::sync::oneshot::channel(); + self.sender.send(ReadTask::Client(id, s)).unwrap(); + + { + let mut lock = self.stdin.lock().await; + lock.write_all(&msg.into_bytes()).await.unwrap(); + lock.flush().await.unwrap(); + } + + let (pk, r) = r.await.unwrap(); + (id, pk, r) + } + + pub async fn send(&self, id: u64, pk: &[u8], msg: &[u8]) { + use tokio::io::AsyncWriteExt; + let msg = format!( + "CMD/SEND/{id}/{}/{}\n", + hex::encode(pk), + hex::encode(msg), + ) + .into_bytes(); + let mut lock = self.stdin.lock().await; + lock.write_all(&msg).await.unwrap(); + lock.flush().await.unwrap(); + } +} diff --git a/rust/sbd-server/examples/test-suite-runner.rs b/rust/sbd-server/examples/server-o-bahn-runner.rs similarity index 100% rename from rust/sbd-server/examples/test-suite-runner.rs rename to rust/sbd-server/examples/server-o-bahn-runner.rs diff --git a/rust/sbd-server/tests/suite.rs b/rust/sbd-server/tests/suite.rs index a65706c..a9704cd 100644 --- a/rust/sbd-server/tests/suite.rs +++ b/rust/sbd-server/tests/suite.rs @@ -2,7 +2,7 @@ fn suite() { println!("BUILDING example test-suite-runner IN RELEASE MODE"); let server = escargot::CargoBuild::new() - .example("test-suite-runner") + .example("server-o-bahn-runner") .release() .current_target() .run() From 5626cf174eb521f74184dff925638a01b2af4c9c Mon Sep 17 00:00:00 2001 From: neonphog Date: Sat, 20 Apr 2024 11:17:42 -0600 Subject: [PATCH 12/33] warmup on rate-limit test --- rust/sbd-client/tests/reasonable-rate-limit.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rust/sbd-client/tests/reasonable-rate-limit.rs b/rust/sbd-client/tests/reasonable-rate-limit.rs index d9eef65..627cebd 100644 --- a/rust/sbd-client/tests/reasonable-rate-limit.rs +++ b/rust/sbd-client/tests/reasonable-rate-limit.rs @@ -37,6 +37,9 @@ async fn reasonable_rate_limit() { let (mut c1, _, p1, mut r1) = get_client(server.bind_addrs()).await; let (mut c2, _, p2, mut r2) = get_client(server.bind_addrs()).await; + //warmup + run(2, &mut c1, &p1, &mut r1, &mut c2, &p2, &mut r2).await; + let (rate1, rate2) = run(10, &mut c1, &p1, &mut r1, &mut c2, &p2, &mut r2).await; From b27ee2c455ced9bf58dbdaca1ab73f8abe3a3436 Mon Sep 17 00:00:00 2001 From: neonphog Date: Sat, 20 Apr 2024 11:50:56 -0600 Subject: [PATCH 13/33] client turnover benchmark --- rust/sbd-bench/Cargo.toml | 4 + rust/sbd-bench/benches/c_turnover.rs | 23 ++++++ rust/sbd-bench/src/c_turnover.rs | 76 +++++++++++++++++ rust/sbd-bench/src/lib.rs | 117 ++------------------------- rust/sbd-bench/src/thru.rs | 112 +++++++++++++++++++++++++ rust/sbd-server/src/cslot.rs | 53 +++++++----- 6 files changed, 254 insertions(+), 131 deletions(-) create mode 100644 rust/sbd-bench/benches/c_turnover.rs create mode 100644 rust/sbd-bench/src/c_turnover.rs create mode 100644 rust/sbd-bench/src/thru.rs diff --git a/rust/sbd-bench/Cargo.toml b/rust/sbd-bench/Cargo.toml index 6dba90f..2e4d166 100644 --- a/rust/sbd-bench/Cargo.toml +++ b/rust/sbd-bench/Cargo.toml @@ -15,3 +15,7 @@ criterion = { version = "0.5.1", features = [ "async_tokio" ] } [[bench]] name = "thru" harness = false + +[[bench]] +name = "c_turnover" +harness = false diff --git a/rust/sbd-bench/benches/c_turnover.rs b/rust/sbd-bench/benches/c_turnover.rs new file mode 100644 index 0000000..2b281f9 --- /dev/null +++ b/rust/sbd-bench/benches/c_turnover.rs @@ -0,0 +1,23 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use sbd_bench::CTurnoverBenchmark; +use std::sync::Arc; +use tokio::sync::Mutex; + +fn criterion_benchmark(c: &mut Criterion) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let test = Arc::new(Mutex::new(rt.block_on(CTurnoverBenchmark::new()))); + let test = &test; + + c.bench_function("c_turnover", |b| { + b.to_async(&rt).iter(|| async move { + test.lock().await.iter().await; + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/rust/sbd-bench/src/c_turnover.rs b/rust/sbd-bench/src/c_turnover.rs new file mode 100644 index 0000000..213005e --- /dev/null +++ b/rust/sbd-bench/src/c_turnover.rs @@ -0,0 +1,76 @@ +use super::*; +use std::collections::VecDeque; + +pub struct CTurnoverBenchmark { + server: SbdServer, + house: VecDeque<(WsRawSend, WsRawRecv)>, +} + +impl CTurnoverBenchmark { + pub async fn new() -> Self { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 4, + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let mut this = Self { + server, + house: VecDeque::new(), + }; + + // make sure we have a full house even before the warmup + this.iter().await; + + this + } + + pub async fn iter(&mut self) { + // ensure full house + while self.try_connect().await.is_ok() {} + + // drop one + if let Some((mut s, r)) = self.house.pop_front() { + s.close().await; + drop(s); + drop(r); + } + + // this next one should succeed + self.try_connect().await.unwrap(); + } + + async fn try_connect(&mut self) -> std::io::Result<()> { + let c = DefaultCrypto::default(); + let (mut s, mut r) = + raw_connect(c.pub_key(), self.server.bind_addrs()).await?; + Handshake::handshake(&mut s, &mut r, &c).await?; + self.house.push_back((s, r)); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn c_turnover_bench_test() { + let mut b = CTurnoverBenchmark::new().await; + + // warmup + for _ in 0..10 { + b.iter().await; + } + + let start = tokio::time::Instant::now(); + for _ in 0..100 { + b.iter().await; + } + let elapsed = start.elapsed(); + + println!("{} nanos per iter", elapsed.as_nanos() / 100); + } +} diff --git a/rust/sbd-bench/src/lib.rs b/rust/sbd-bench/src/lib.rs index d41dd97..56152a6 100644 --- a/rust/sbd-bench/src/lib.rs +++ b/rust/sbd-bench/src/lib.rs @@ -3,94 +3,10 @@ use sbd_client::*; use sbd_server::*; use std::sync::Arc; -pub struct ThruBenchmark { - _server: SbdServer, - c1: DefaultCrypto, - s1: WsRawSend, - r1: WsRawRecv, - c2: DefaultCrypto, - s2: WsRawSend, - r2: WsRawRecv, - v1: Option>, - v2: Option>, -} - -impl ThruBenchmark { - pub async fn new() -> Self { - let config = Arc::new(Config { - bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], - limit_clients: 100, - disable_rate_limiting: true, - ..Default::default() - }); - - let server = SbdServer::new(config).await.unwrap(); - - let c1 = DefaultCrypto::default(); - let (mut s1, mut r1) = c(c1.pub_key(), server.bind_addrs()).await; - - let c2 = DefaultCrypto::default(); - let (mut s2, mut r2) = c(c2.pub_key(), server.bind_addrs()).await; - - Handshake::handshake(&mut s1, &mut r1, &c1).await.unwrap(); - Handshake::handshake(&mut s2, &mut r2, &c2).await.unwrap(); - - Self { - _server: server, - c1, - s1, - r1, - c2, - s2, - r2, - v1: None, - v2: None, - } - } - - pub async fn iter(&mut self) { - let Self { - c1, - s1, - r1, - c2, - s2, - r2, - v1, - v2, - .. - } = self; - - let mut b1 = v1.take().unwrap_or_else(|| vec![0xdb; 1000]); - let mut b2 = v2.take().unwrap_or_else(|| vec![0xca; 1000]); - - tokio::join!( - async { - b1[0..32].copy_from_slice(c2.pub_key()); - s1.send(b1).await.unwrap(); - }, - async { - b2[0..32].copy_from_slice(c1.pub_key()); - s2.send(b2).await.unwrap(); - }, - async { - let b2 = r1.recv().await.unwrap(); - assert_eq!(1000, b2.len()); - *v2 = Some(b2); - }, - async { - let b1 = r2.recv().await.unwrap(); - assert_eq!(1000, b1.len()); - *v1 = Some(b1); - }, - ); - } -} - -async fn c( +async fn raw_connect( pk: &[u8; 32], addrs: &[std::net::SocketAddr], -) -> (WsRawSend, WsRawRecv) { +) -> std::io::Result<(WsRawSend, WsRawRecv)> { use base64::Engine; for addr in addrs { @@ -106,31 +22,14 @@ async fn c( .connect() .await { - return r; + return Ok(r); } } - panic!() + Err(std::io::Error::other("could not connect")) } -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test(flavor = "multi_thread")] - async fn thru_bench_test() { - let mut b = ThruBenchmark::new().await; +mod thru; +pub use thru::*; - // warmup - for _ in 0..10 { - b.iter().await; - } - - let start = tokio::time::Instant::now(); - for _ in 0..100 { - b.iter().await; - } - let elapsed = start.elapsed(); - - println!("{} nanos per iter", elapsed.as_nanos() / 100); - } -} +mod c_turnover; +pub use c_turnover::*; diff --git a/rust/sbd-bench/src/thru.rs b/rust/sbd-bench/src/thru.rs new file mode 100644 index 0000000..1b9d997 --- /dev/null +++ b/rust/sbd-bench/src/thru.rs @@ -0,0 +1,112 @@ +use super::*; + +pub struct ThruBenchmark { + _server: SbdServer, + c1: DefaultCrypto, + s1: WsRawSend, + r1: WsRawRecv, + c2: DefaultCrypto, + s2: WsRawSend, + r2: WsRawRecv, + v1: Option>, + v2: Option>, +} + +impl ThruBenchmark { + pub async fn new() -> Self { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + limit_clients: 100, + disable_rate_limiting: true, + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let c1 = DefaultCrypto::default(); + let (mut s1, mut r1) = raw_connect(c1.pub_key(), server.bind_addrs()) + .await + .unwrap(); + + let c2 = DefaultCrypto::default(); + let (mut s2, mut r2) = raw_connect(c2.pub_key(), server.bind_addrs()) + .await + .unwrap(); + + Handshake::handshake(&mut s1, &mut r1, &c1).await.unwrap(); + Handshake::handshake(&mut s2, &mut r2, &c2).await.unwrap(); + + Self { + _server: server, + c1, + s1, + r1, + c2, + s2, + r2, + v1: None, + v2: None, + } + } + + pub async fn iter(&mut self) { + let Self { + c1, + s1, + r1, + c2, + s2, + r2, + v1, + v2, + .. + } = self; + + let mut b1 = v1.take().unwrap_or_else(|| vec![0xdb; 1000]); + let mut b2 = v2.take().unwrap_or_else(|| vec![0xca; 1000]); + + tokio::join!( + async { + b1[0..32].copy_from_slice(c2.pub_key()); + s1.send(b1).await.unwrap(); + }, + async { + b2[0..32].copy_from_slice(c1.pub_key()); + s2.send(b2).await.unwrap(); + }, + async { + let b2 = r1.recv().await.unwrap(); + assert_eq!(1000, b2.len()); + *v2 = Some(b2); + }, + async { + let b1 = r2.recv().await.unwrap(); + assert_eq!(1000, b1.len()); + *v1 = Some(b1); + }, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn thru_bench_test() { + let mut b = ThruBenchmark::new().await; + + // warmup + for _ in 0..10 { + b.iter().await; + } + + let start = tokio::time::Instant::now(); + for _ in 0..100 { + b.iter().await; + } + let elapsed = start.elapsed(); + + println!("{} nanos per iter", elapsed.as_nanos() / 100); + } +} diff --git a/rust/sbd-server/src/cslot.rs b/rust/sbd-server/src/cslot.rs index 214405a..9bacea6 100644 --- a/rust/sbd-server/src/cslot.rs +++ b/rust/sbd-server/src/cslot.rs @@ -119,11 +119,14 @@ impl CSlot { ip: Arc, pk: PubKey, ws: Arc>, - ) -> Option>)>> { + ) -> std::result::Result< + Vec<(u64, usize, Weak>)>, + Arc>, + > { let mut lock = self.0.lock().unwrap(); if lock.slab.len() >= lock.max_count { - return None; + return Err(ws); } let weak_ws = Arc::downgrade(&ws); @@ -177,7 +180,7 @@ impl CSlot { pk, }); - Some(rate_send_list) + Ok(rate_send_list) } pub async fn insert( @@ -189,29 +192,35 @@ impl CSlot { ) { let rate_send_list = self.insert_and_get_rate_send_list(ip, pk, ws); - if let Some(rate_send_list) = rate_send_list { - let rate = if config.disable_rate_limiting { - 1 - } else { - let mut rate = config.limit_ip_byte_nanos() as u64 - * rate_send_list.len() as u64; - if rate > i32::MAX as u64 { - rate = i32::MAX as u64; - } - rate as i32 - }; + match rate_send_list { + Ok(rate_send_list) => { + let rate = if config.disable_rate_limiting { + 1 + } else { + let mut rate = config.limit_ip_byte_nanos() as u64 + * rate_send_list.len() as u64; + if rate > i32::MAX as u64 { + rate = i32::MAX as u64; + } + rate as i32 + }; - for (uniq, index, weak_ws) in rate_send_list { - if let Some(ws) = weak_ws.upgrade() { - if ws - .send(cmd::SbdCmd::limit_byte_nanos(rate)) - .await - .is_err() - { - self.remove(uniq, index); + for (uniq, index, weak_ws) in rate_send_list { + if let Some(ws) = weak_ws.upgrade() { + if ws + .send(cmd::SbdCmd::limit_byte_nanos(rate)) + .await + .is_err() + { + self.remove(uniq, index); + } } } } + Err(ws) => { + ws.close().await; + drop(ws); + } } } From 6b5ad36d4a3c6ec54762348933dc7f9b46156053 Mon Sep 17 00:00:00 2001 From: neonphog Date: Sat, 20 Apr 2024 12:14:20 -0600 Subject: [PATCH 14/33] workspace deps --- Cargo.lock | 55 +------------------ Cargo.toml | 26 +++++++++ rust/sbd-bench/Cargo.toml | 6 +- rust/sbd-client/Cargo.toml | 26 ++++----- rust/sbd-o-bahn-client-tester/Cargo.toml | 4 +- rust/sbd-o-bahn-server-tester/Cargo.toml | 3 +- rust/sbd-o-bahn-server-tester/src/lib.rs | 11 ++-- rust/sbd-server/Cargo.toml | 37 ++++++------- .../examples/server-o-bahn-runner.rs | 16 +++--- 9 files changed, 77 insertions(+), 107 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8cf7fac..d9e5e11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,12 +113,6 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - [[package]] name = "bitflags" version = "1.3.2" @@ -250,12 +244,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - [[package]] name = "core-foundation" version = "0.9.4" @@ -374,7 +362,6 @@ dependencies = [ "platforms", "rustc_version", "subtle", - "zeroize", ] [[package]] @@ -394,16 +381,6 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" -[[package]] -name = "der" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" -dependencies = [ - "const-oid", - "zeroize", -] - [[package]] name = "deranged" version = "0.3.11" @@ -429,7 +406,6 @@ version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" dependencies = [ - "pkcs8", "signature", ] @@ -442,10 +418,8 @@ dependencies = [ "curve25519-dalek", "ed25519", "rand_core", - "serde", "sha2", "subtle", - "zeroize", ] [[package]] @@ -971,16 +945,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - [[package]] name = "platforms" version = "3.4.0" @@ -1191,9 +1155,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99008d7ad0bbbea527ec27bddbc0e432c5b87d8175178cee68d2eec9c4a1813c" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", "ring", @@ -1302,7 +1266,6 @@ name = "sbd-o-bahn-server-tester" version = "0.0.1-alpha" dependencies = [ "sbd-client", - "serde_json", "tokio", ] @@ -1324,7 +1287,6 @@ dependencies = [ "rand", "rcgen", "sbd-client", - "serde_json", "slab", "tempfile", "tokio", @@ -1442,9 +1404,6 @@ name = "signature" version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "rand_core", -] [[package]] name = "simdutf8" @@ -1483,16 +1442,6 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - [[package]] name = "strsim" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index b900dbb..1945c40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,33 @@ members = [ resolver = "2" [workspace.dependencies] +# workspace member deps sbd-bench = { version = "0.0.1-alpha", path = "rust/sbd-bench" } sbd-client = { version = "0.0.1-alpha", path = "rust/sbd-client" } +sbd-o-bahn-client-tester = { version = "0.0.1-alpha", path = "rust/sbd-o-bahn-client-tester" } sbd-o-bahn-server-tester = { version = "0.0.1-alpha", path = "rust/sbd-o-bahn-server-tester" } sbd-server = { version = "0.0.1-alpha", path = "rust/sbd-server" } +# crate deps +anstyle = "1.0.6" +base64 = "0.22.0" +bytes = "1.6.0" +clap = "4.5.4" +criterion = "0.5.1" +ed25519-dalek = { version = "2.1.1", default-features = false } +escargot = "0.5.10" +fastwebsockets = "0.7.1" +futures = "0.3.30" +hex = "0.4.3" +http-body-util = "0.1.0" +hyper = "1.2.0" +hyper-util = "0.1.3" +rand = "0.8.5" +rcgen = "0.13.1" +rustls = "0.22.4" +rustls-native-certs = "0.7.0" +slab = "0.4.9" +tempfile = "3.10.1" +tokio = { version = "1.37.0", default-features = false } +tokio-rustls = "0.25.0" +tokio-tungstenite = { version = "0.21.0", default-features = false } +webpki-roots = "0.26.1" diff --git a/rust/sbd-bench/Cargo.toml b/rust/sbd-bench/Cargo.toml index 2e4d166..d1f4f50 100644 --- a/rust/sbd-bench/Cargo.toml +++ b/rust/sbd-bench/Cargo.toml @@ -4,13 +4,13 @@ version = "0.0.1-alpha" edition = "2021" [dependencies] -base64 = "0.22.0" +base64 = { workspace = true } sbd-server = { workspace = true } sbd-client = { workspace = true, features = [ "raw_client" ] } -tokio = { version = "1.37.0", features = [ "full" ] } +tokio = { workspace = true, features = [ "full" ] } [dev-dependencies] -criterion = { version = "0.5.1", features = [ "async_tokio" ] } +criterion = { workspace = true, features = [ "async_tokio" ] } [[bench]] name = "thru" diff --git a/rust/sbd-client/Cargo.toml b/rust/sbd-client/Cargo.toml index c0b3f93..38fc168 100644 --- a/rust/sbd-client/Cargo.toml +++ b/rust/sbd-client/Cargo.toml @@ -4,26 +4,26 @@ version = "0.0.1-alpha" edition = "2021" [dependencies] -base64 = "0.22.0" -futures = "0.3.30" -rustls = "0.22.3" -rustls-native-certs = "0.7.0" -tokio = { version = "1.37.0", default-features = false, features = [ "io-util", "net", "sync", "time", "rt" ] } -tokio-rustls = "0.25.0" -tokio-tungstenite = { version = "0.21.0", default-features = false, features = [ "connect", "__rustls-tls" ] } +base64 = { workspace = true } +futures = { workspace = true } +rustls = { workspace = true } +rustls-native-certs = { workspace = true } +tokio = { workspace = true, default-features = false, features = [ "io-util", "net", "sync", "time", "rt" ] } +tokio-rustls = { workspace = true } +tokio-tungstenite = { workspace = true, default-features = false, features = [ "connect", "__rustls-tls" ] } # optional -ed25519-dalek = { version = "2.1.1", features = [ "rand_core" ], optional = true } -rand = { version = "0.8.5", optional = true } +ed25519-dalek = { workspace = true, features = [ "rand_core" ], optional = true } +rand = { workspace = true, optional = true } [target.'cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))'.dependencies] -webpki-roots = "0.26.1" +webpki-roots = { workspace = true } [dev-dependencies] -escargot = { version = "0.5.10", features = [ "print" ] } -hex = "0.4.3" -tokio = { version = "1.37.0", features = [ "full" ] } +escargot = { workspace = true, features = [ "print" ] } +hex = { workspace = true } +tokio = { workspace = true, features = [ "full" ] } sbd-server = { workspace = true } [features] diff --git a/rust/sbd-o-bahn-client-tester/Cargo.toml b/rust/sbd-o-bahn-client-tester/Cargo.toml index 3a44ac3..68584dd 100644 --- a/rust/sbd-o-bahn-client-tester/Cargo.toml +++ b/rust/sbd-o-bahn-client-tester/Cargo.toml @@ -4,6 +4,6 @@ version = "0.0.1-alpha" edition = "2021" [dependencies] -hex = "0.4.3" +hex = { workspace = true } sbd-server = { workspace = true } -tokio = { version = "1.37.0", features = [ "full" ] } +tokio = { workspace = true, features = [ "full" ] } diff --git a/rust/sbd-o-bahn-server-tester/Cargo.toml b/rust/sbd-o-bahn-server-tester/Cargo.toml index 5ac40a0..b241ba1 100644 --- a/rust/sbd-o-bahn-server-tester/Cargo.toml +++ b/rust/sbd-o-bahn-server-tester/Cargo.toml @@ -5,5 +5,4 @@ edition = "2021" [dependencies] sbd-client = { workspace = true } -serde_json = "1.0.116" -tokio = { version = "1.37.0", features = [ "full" ] } +tokio = { workspace = true, features = [ "full" ] } diff --git a/rust/sbd-o-bahn-server-tester/src/lib.rs b/rust/sbd-o-bahn-server-tester/src/lib.rs index 972871e..c769ae5 100644 --- a/rust/sbd-o-bahn-server-tester/src/lib.rs +++ b/rust/sbd-o-bahn-server-tester/src/lib.rs @@ -58,14 +58,14 @@ impl Server { tokio::io::BufReader::new(child.stdout.take().unwrap()).lines(); if let Some(line) = stdout.next_line().await? { - if line != "CMD:READY" { + if line != "CMD/READY" { panic!("unexpected: {line}"); } } else { panic!("no stdout"); } - println!("GOT CMD:READY"); + println!("GOT CMD/READY"); Ok(Self { _child: child, @@ -76,13 +76,12 @@ impl Server { pub async fn start(&mut self) -> Vec { use tokio::io::AsyncWriteExt; - self.stdin.write_all(b"CMD:START\n").await.unwrap(); + self.stdin.write_all(b"CMD/START\n").await.unwrap(); self.stdin.flush().await.unwrap(); let line = self.stdout.next_line().await.unwrap().unwrap(); - if !line.starts_with("CMD:START:") { + if !line.starts_with("CMD/START/") { panic!("unexpected: {line}"); } - let line = line.into_bytes(); - serde_json::from_slice(&line[10..]).unwrap() + line.split('/').skip(2).map(|s| s.to_string()).collect() } } diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index 9bd6c69..fa5cb19 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -4,32 +4,31 @@ version = "0.0.1-alpha" edition = "2021" [dependencies] -anstyle = "1.0.6" -base64 = "0.22.0" -bytes = "1.6.0" -clap = { version = "4.5.4", features = [ "color", "derive", "wrap_help" ] } -ed25519-dalek = { version = "2.1.1", default-features = false } -rand = "0.8.5" -slab = "0.4.9" -tokio = { version = "1.37.0", features = [ "full" ] } +anstyle = { workspace = true } +base64 = { workspace = true } +bytes = { workspace = true } +clap = { workspace = true, features = [ "color", "derive", "wrap_help" ] } +ed25519-dalek = { workspace = true, default-features = false } +rand = { workspace = true } +slab = { workspace = true } +tokio = { workspace = true, features = [ "full" ] } # feature tungstenite -futures = { version = "0.3.30", optional = true } -tokio-tungstenite = { version = "0.21.0", default-features = false, features = [ "handshake" ], optional = true } +futures = { workspace = true, optional = true } +tokio-tungstenite = { workspace = true, default-features = false, features = [ "handshake" ], optional = true } # feature fastwebsockets -fastwebsockets = { version = "0.7.1", features = [ "upgrade" ], optional = true } -http-body-util = { version = "0.1.0", optional = true } -hyper-util = { version = "0.1.3", features = [ "tokio" ], optional = true } -hyper = { version = "1.2.0", features = ["http1", "server"], optional = true } +fastwebsockets = { workspace = true, features = [ "upgrade" ], optional = true } +http-body-util = { workspace = true, optional = true } +hyper-util = { workspace = true, features = [ "tokio" ], optional = true } +hyper = { workspace = true, features = ["http1", "server"], optional = true } [dev-dependencies] -escargot = { version = "0.5.10", features = [ "print" ] } -rcgen = "0.13.1" +escargot = { workspace = true, features = [ "print" ] } +rcgen = { workspace = true } sbd-client = { workspace = true, features = [ "raw_client" ] } -serde_json = "1.0.116" -tempfile = "3.10.1" -tokio = { version = "1.37.0", features = [ "test-util" ] } +tempfile = { workspace = true } +tokio = { workspace = true, features = [ "test-util" ] } [features] default = [ "tungstenite" ] diff --git a/rust/sbd-server/examples/server-o-bahn-runner.rs b/rust/sbd-server/examples/server-o-bahn-runner.rs index c9c36d9..c3d7d2b 100644 --- a/rust/sbd-server/examples/server-o-bahn-runner.rs +++ b/rust/sbd-server/examples/server-o-bahn-runner.rs @@ -2,7 +2,7 @@ use std::sync::Arc; #[tokio::main(flavor = "multi_thread")] async fn main() { - println!("CMD:READY"); + println!("CMD/READY"); let mut lines = tokio::io::AsyncBufReadExt::lines( tokio::io::BufReader::new(tokio::io::stdin()), @@ -12,7 +12,7 @@ async fn main() { while let Ok(Some(line)) = lines.next_line().await { match line.as_str() { - "CMD:START" => { + "CMD/START" => { drop(server); let mut config = sbd_server::Config::default(); config.bind.push("127.0.0.1:0".to_string()); @@ -20,13 +20,11 @@ async fn main() { server = Some( sbd_server::SbdServer::new(Arc::new(config)).await.unwrap(), ); - println!( - "CMD:START:{}", - serde_json::to_string( - server.as_ref().unwrap().bind_addrs() - ) - .unwrap() - ); + let mut out = "CMD/START".to_string(); + for addr in server.as_ref().unwrap().bind_addrs() { + out.push_str(&format!("/{addr}")); + } + println!("{out}"); } oth => panic!("error, unexpected: {oth}"), } From 80a82eabffcf7417edbd5d7d9e5f12dae0531b75 Mon Sep 17 00:00:00 2001 From: neonphog Date: Sun, 21 Apr 2024 12:07:18 -0600 Subject: [PATCH 15/33] tls --- Cargo.lock | 3 + Cargo.toml | 1 + rust/sbd-client/src/lib.rs | 2 +- rust/sbd-client/src/raw_client.rs | 75 +++++++++++++++++++++--- rust/sbd-server/Cargo.toml | 3 + rust/sbd-server/src/lib.rs | 94 +++++++++++++++++++++++++++++-- rust/sbd-server/src/maybe_tls.rs | 48 ++++++++++++++++ spec.md | 6 +- 8 files changed, 216 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d9e5e11..bcc1e1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1286,10 +1286,13 @@ dependencies = [ "hyper-util", "rand", "rcgen", + "rustls", + "rustls-pemfile", "sbd-client", "slab", "tempfile", "tokio", + "tokio-rustls", "tokio-tungstenite", ] diff --git a/Cargo.toml b/Cargo.toml index 1945c40..05120cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" rcgen = "0.13.1" rustls = "0.22.4" rustls-native-certs = "0.7.0" +rustls-pemfile = "2.1.2" slab = "0.4.9" tempfile = "3.10.1" tokio = { version = "1.37.0", default-features = false } diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 07c10d9..b379e11 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -5,7 +5,7 @@ use std::io::{Error, Result}; use std::sync::Arc; /// defined by the sbd spec -const MAX_MSG_SIZE: usize = 16000; +const MAX_MSG_SIZE: usize = 20_000; #[cfg(feature = "raw_client")] pub mod raw_client; diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index 77e611c..ebe597e 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -28,7 +28,7 @@ impl WsRawConnect { full_url, max_message_size, allow_plain_text, - .. + danger_disable_certificate_check, } = self; let scheme_ws = full_url.starts_with("ws://"); @@ -62,7 +62,7 @@ impl WsRawConnect { let maybe_tls = if scheme_ws { tokio_tungstenite::MaybeTlsStream::Plain(tcp) } else { - let tls = priv_system_tls(); + let tls = priv_system_tls(danger_disable_certificate_check); let name = host .try_into() .unwrap_or_else(|_| "sbd".try_into().unwrap()); @@ -205,7 +205,9 @@ impl Handshake { } } -fn priv_system_tls() -> Arc { +fn priv_system_tls( + danger_disable_certificate_check: bool, +) -> Arc { let mut roots = rustls::RootCertStore::empty(); #[cfg(not(any( @@ -236,9 +238,66 @@ fn priv_system_tls() -> Arc { roots.add(cert).expect("faild to add cert to root"); } - Arc::new( - rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_no_client_auth(), - ) + if danger_disable_certificate_check { + let v = rustls::client::WebPkiServerVerifier::builder(Arc::new(roots)) + .build() + .unwrap(); + + Arc::new( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(V(v))) + .with_no_client_auth(), + ) + } else { + Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(), + ) + } +} + +#[derive(Debug)] +struct V(Arc); + +impl rustls::client::danger::ServerCertVerifier for V { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> std::result::Result< + rustls::client::danger::ServerCertVerified, + rustls::Error, + > { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result< + rustls::client::danger::HandshakeSignatureValid, + rustls::Error, + > { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result< + rustls::client::danger::HandshakeSignatureValid, + rustls::Error, + > { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + fn supported_verify_schemes(&self) -> Vec { + self.0.supported_verify_schemes() + } } diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index fa5cb19..6dce02a 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -10,8 +10,11 @@ bytes = { workspace = true } clap = { workspace = true, features = [ "color", "derive", "wrap_help" ] } ed25519-dalek = { workspace = true, default-features = false } rand = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } slab = { workspace = true } tokio = { workspace = true, features = [ "full" ] } +tokio-rustls = { workspace = true } # feature tungstenite futures = { workspace = true, optional = true } diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index e652aed..26ef9d6 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -2,7 +2,7 @@ #![deny(missing_docs)] /// defined by the sbd spec -const MAX_MSG_BYTES: i32 = 16000; +const MAX_MSG_BYTES: i32 = 20_000; use std::io::{Error, Result}; use std::sync::Arc; @@ -149,11 +149,19 @@ async fn check_accept_connection( } } - // TODO TLS upgrade - let tcp = MaybeTlsStream::Tcp(tcp); + let socket = if let (Some(cert), Some(pk)) = + (&config.cert_pem_file, &config.priv_key_pem_file) + { + match MaybeTlsStream::tls(cert, pk, tcp).await { + Err(_) => return, + Ok(tls) => tls, + } + } else { + MaybeTlsStream::Tcp(tcp) + }; let (ws, pub_key, ip) = - match ws::WebSocket::upgrade(config.clone(), tcp).await { + match ws::WebSocket::upgrade(config.clone(), socket).await { Ok(r) => r, Err(_) => return, }; @@ -262,3 +270,81 @@ impl SbdServer { self.bind_addrs.as_slice() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn tls_sanity() { + let tmp = tempfile::tempdir().unwrap(); + let tmp_dir = tmp.path().to_owned(); + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".to_string()]) + .unwrap(); + let mut cert_path = tmp_dir.clone(); + cert_path.push("cert.pem"); + tokio::fs::write(&cert_path, cert.pem()).await.unwrap(); + let mut key_path = tmp_dir.clone(); + key_path.push("key.pem"); + tokio::fs::write(&key_path, key_pair.serialize_pem()) + .await + .unwrap(); + + let mut config = Config::default(); + config.cert_pem_file = Some(cert_path); + config.priv_key_pem_file = Some(key_path); + config.bind.push("127.0.0.1:0".into()); + println!("{config:?}"); + + let server = SbdServer::new(Arc::new(config)).await.unwrap(); + + let addr = server.bind_addrs()[0].clone(); + + println!("addr: {addr:?}"); + + let (client1, url1, pk1, mut rcv1) = + sbd_client::SbdClient::connect_config( + &format!("wss://{addr}"), + &sbd_client::DefaultCrypto::default(), + sbd_client::SbdClientConfig { + allow_plain_text: true, + danger_disable_certificate_check: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + println!("client url1: {url1}"); + + let (client2, url2, pk2, mut rcv2) = + sbd_client::SbdClient::connect_config( + &format!("wss://{addr}"), + &sbd_client::DefaultCrypto::default(), + sbd_client::SbdClientConfig { + allow_plain_text: true, + danger_disable_certificate_check: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + println!("client url2: {url2}"); + + client1.send(&pk2, b"hello").await.unwrap(); + + let res_data = rcv2.recv().await.unwrap(); + + assert_eq!(&pk1.0, res_data.pub_key_ref()); + assert_eq!(&b"hello"[..], res_data.message()); + + client2.send(&pk1, b"world").await.unwrap(); + + let res_data = rcv1.recv().await.unwrap(); + + assert_eq!(&pk2.0, res_data.pub_key_ref()); + assert_eq!(&b"world"[..], res_data.message()); + } +} diff --git a/rust/sbd-server/src/maybe_tls.rs b/rust/sbd-server/src/maybe_tls.rs index bc3ae9d..bb0b33a 100644 --- a/rust/sbd-server/src/maybe_tls.rs +++ b/rust/sbd-server/src/maybe_tls.rs @@ -13,6 +13,50 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub enum MaybeTlsStream { /// Tcp. Tcp(tokio::net::TcpStream), + + /// Tls. + Tls(tokio_rustls::server::TlsStream), +} + +impl MaybeTlsStream { + pub async fn tls( + cert: &std::path::Path, + pk: &std::path::Path, + tcp: tokio::net::TcpStream, + ) -> std::io::Result { + use rustls_pemfile::Item::*; + + let cert = tokio::fs::read(cert).await?; + let pk = tokio::fs::read(pk).await?; + + let cert = match rustls_pemfile::read_one_from_slice(&cert) { + Ok(Some((X509Certificate(cert), _))) => cert, + _ => return Err(std::io::Error::other("error reading tls cert")), + }; + let pk = match rustls_pemfile::read_one_from_slice(&pk) { + Ok(Some((Pkcs1Key(pk), _))) => { + rustls::pki_types::PrivateKeyDer::Pkcs1(pk) + } + Ok(Some((Sec1Key(pk), _))) => { + rustls::pki_types::PrivateKeyDer::Sec1(pk) + } + Ok(Some((Pkcs8Key(pk), _))) => { + rustls::pki_types::PrivateKeyDer::Pkcs8(pk) + } + _ => return Err(std::io::Error::other("error reading priv key")), + }; + + let c = std::sync::Arc::new( + rustls::server::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], pk) + .map_err(std::io::Error::other)?, + ); + + let tls = tokio_rustls::TlsAcceptor::from(c).accept(tcp).await?; + + Ok(Self::Tls(tls)) + } } impl AsyncRead for MaybeTlsStream { @@ -23,6 +67,7 @@ impl AsyncRead for MaybeTlsStream { ) -> Poll> { match self.get_mut() { MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf), } } } @@ -35,6 +80,7 @@ impl AsyncWrite for MaybeTlsStream { ) -> Poll> { match self.get_mut() { MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf), } } @@ -44,6 +90,7 @@ impl AsyncWrite for MaybeTlsStream { ) -> Poll> { match self.get_mut() { MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx), } } @@ -53,6 +100,7 @@ impl AsyncWrite for MaybeTlsStream { ) -> Poll> { match self.get_mut() { MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx), } } } diff --git a/spec.md b/spec.md index baf9265..fd5f6a8 100644 --- a/spec.md +++ b/spec.md @@ -11,11 +11,11 @@ The SBD protocol is built upon websockets. #### 1.1.1. Message and Frame Size -The maximum SBD message size (including 32 byte header) is 16000 bytes. +The maximum SBD message size (including 32 byte header) is 20000 bytes. -SBD clients and servers MAY set the max message size in the websocket library to 16000 to help enforce this. +SBD clients and servers MAY set the max message size in the websocket library to 20000 to help enforce this. -The maximum frame size MUST be set larger than 16000 so that sbd messages always fit in a single websocket frame. +The maximum frame size MUST be set larger than 20000 so that sbd messages always fit in a single websocket frame. ## 2. Cryptography From 9351563f4ce1702f759ed4af78ace05bdd1249a6 Mon Sep 17 00:00:00 2001 From: neonphog Date: Mon, 22 Apr 2024 10:06:27 -0600 Subject: [PATCH 16/33] tweaks --- Cargo.toml | 3 ++ rust/sbd-client/src/send_buf.rs | 52 ++++++---------------- rust/sbd-server/Cargo.toml | 2 + rust/sbd-server/src/config.rs | 16 +++++++ rust/sbd-server/src/lib.rs | 17 +++++-- rust/sbd-server/src/maybe_tls.rs | 76 ++++++++++++++++++++++++-------- 6 files changed, 104 insertions(+), 62 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 05120cb..4bd1639 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ members = [ ] resolver = "2" +[profile.release] +panic = "abort" + [workspace.dependencies] # workspace member deps sbd-bench = { version = "0.0.1-alpha", path = "rust/sbd-bench" } diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index 3729a4e..c4b547f 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -4,7 +4,7 @@ use std::collections::VecDeque; pub struct SendBuf { pub ws: raw_client::WsRawSend, - pub buf: VecDeque<(PubKey, Vec)>, + pub buf: VecDeque>, pub out_buffer_size: usize, pub origin: tokio::time::Instant, pub limit_rate: u64, @@ -103,8 +103,8 @@ impl SendBuf { return Ok(false); } - if let Some((_, data)) = self.buf.pop_front() { - self.raw_send(now, data).await?; + if let Some(buf) = self.buf.pop_front() { + self.raw_send(now, buf).await?; Ok(true) } else { @@ -116,7 +116,11 @@ impl SendBuf { /// Then queue up data to be sent out. /// Note, you'll need to explicitly call `write_next_queued()` or /// make additional sends in order to get this queued data actually sent. - pub async fn send(&mut self, pk: &PubKey, mut data: &[u8]) -> Result<()> { + pub async fn send(&mut self, pk: &PubKey, data: &[u8]) -> Result<()> { + if data.len() > MAX_MSG_SIZE - 32 { + return Err(Error::other("message too large")); + } + while !self.space_free() { if let Some(dur) = self.next_step_dur() { tokio::time::sleep(dur).await; @@ -124,30 +128,10 @@ impl SendBuf { self.write_next_queued().await?; } - self.check_set_prebuffer(); - - // first try to put into existing blocks - for (qpk, qdata) in self.buf.iter_mut() { - if qpk == pk && qdata.len() < MAX_MSG_SIZE { - let amt = std::cmp::min(data.len(), MAX_MSG_SIZE - qdata.len()); - qdata.extend_from_slice(&data[..amt]); - data = &data[amt..]; - if data.is_empty() { - return Ok(()); - } - } - } - - // next, fill out new entries - while !data.is_empty() { - let mut init = Vec::with_capacity(MAX_MSG_SIZE); - init.extend_from_slice(&pk.0[..]); - - let amt = std::cmp::min(data.len(), MAX_MSG_SIZE - init.len()); - init.extend_from_slice(&data[..amt]); - data = &data[amt..]; - self.buf.push_back((*pk, init)); - } + let mut buf = Vec::with_capacity(32 + data.len()); + buf.extend_from_slice(&pk.0[..]); + buf.extend_from_slice(data); + self.buf.push_back(buf); Ok(()) } @@ -165,17 +149,7 @@ impl SendBuf { } fn len(&self) -> usize { - self.buf.iter().map(|(_, d)| d.len()).sum() - } - - /// If we have an empty out buffer, set some rate-limit as a hack - /// for waiting a little bit to see if more sends come in and can - /// be aggregated - fn check_set_prebuffer(&mut self) { - if self.buf.is_empty() { - let hack = self.origin.elapsed().as_nanos() as u64 + 10_000_000; // 10 millis in nanos - self.next_send_at = std::cmp::max(hack, self.next_send_at) - } + self.buf.iter().map(|b| b.len()).sum() } /// Returns `true` if our total buffer size < out_buffer_size diff --git a/rust/sbd-server/Cargo.toml b/rust/sbd-server/Cargo.toml index 6dce02a..9304c68 100644 --- a/rust/sbd-server/Cargo.toml +++ b/rust/sbd-server/Cargo.toml @@ -44,3 +44,5 @@ fastwebsockets = [ "dep:hyper-util", "dep:hyper", ] + +unstable = [] diff --git a/rust/sbd-server/src/config.rs b/rust/sbd-server/src/config.rs index f8c0621..4633552 100644 --- a/rust/sbd-server/src/config.rs +++ b/rust/sbd-server/src/config.rs @@ -1,4 +1,6 @@ +#[cfg(feature = "unstable")] const DEF_IP_DENY_DIR: &str = "."; +#[cfg(feature = "unstable")] const DEF_IP_DENY_S: i32 = 600; const DEF_LIMIT_CLIENTS: i32 = 32768; const DEF_LIMIT_IP_KBPS: i32 = 1000; @@ -31,6 +33,7 @@ pub struct Config { #[arg(long, verbatim_doc_comment)] pub bind: Vec, + #[cfg(feature = "unstable")] /// Watch this directory, and reload TLS certificates 10s after any /// files change within it. Must be an exact match to the parent directory /// of both `--cert-pem-file` and `--priv-key-pem-file`. @@ -42,23 +45,27 @@ pub struct Config { #[arg(long)] pub trusted_ip_header: Option, + #[cfg(feature = "unstable")] /// The directory in which to store the blocked ip addresses. /// Note v4 addresses will be mapped to v6 addresses per /// . #[arg(long, default_value = DEF_IP_DENY_DIR)] pub ip_deny_dir: std::path::PathBuf, + #[cfg(feature = "unstable")] /// How long to block ip addresses in seconds. Set to zero to block /// forever (or until the file is manually deleted). #[arg(long, default_value_t = DEF_IP_DENY_S)] pub ip_deny_s: i32, + #[cfg(feature = "unstable")] /// Bind to this backchannel interface and port. /// Can be specified more than once. /// Note, this should be a local only or virtual private interface. #[arg(long)] pub back_bind: Vec, + #[cfg(feature = "unstable")] /// Allow incoming backchannel connections only /// from the following explicit addresses. Note, this is expecting direct /// connections, not through a proxy, so only the raw TCP address will @@ -68,6 +75,7 @@ pub struct Config { #[arg(long)] pub back_allow_ip: Vec, + #[cfg(feature = "unstable")] /// Try to establish outgoing backchannel connections /// to the following ip+port addresses. /// Can be specified more than once. @@ -75,6 +83,7 @@ pub struct Config { #[arg(long)] pub back_open: Vec, + #[cfg(feature = "unstable")] /// Bind to this interface and port to provide prometheus metrics. /// Note, this should be a local only or virtual private interface. #[arg(long)] @@ -119,13 +128,20 @@ impl Default for Config { cert_pem_file: None, priv_key_pem_file: None, bind: Vec::new(), + #[cfg(feature = "unstable")] watch_reload_tls_dir: None, trusted_ip_header: None, + #[cfg(feature = "unstable")] ip_deny_dir: std::path::PathBuf::from(DEF_IP_DENY_DIR), + #[cfg(feature = "unstable")] ip_deny_s: DEF_IP_DENY_S, + #[cfg(feature = "unstable")] back_bind: Vec::new(), + #[cfg(feature = "unstable")] back_allow_ip: Vec::new(), + #[cfg(feature = "unstable")] back_open: Vec::new(), + #[cfg(feature = "unstable")] bind_prometheus: None, limit_clients: DEF_LIMIT_CLIENTS, disable_rate_limiting: false, diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index 26ef9d6..8b0abb3 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -118,6 +118,7 @@ impl Drop for SbdServer { async fn check_accept_connection( _connect_permit: tokio::sync::OwnedSemaphorePermit, config: Arc, + tls_config: Option>, ip_rate: Arc, tcp: tokio::net::TcpStream, addr: std::net::SocketAddr, @@ -149,10 +150,8 @@ async fn check_accept_connection( } } - let socket = if let (Some(cert), Some(pk)) = - (&config.cert_pem_file, &config.priv_key_pem_file) - { - match MaybeTlsStream::tls(cert, pk, tcp).await { + let socket = if let Some(tls_config) = &tls_config { + match MaybeTlsStream::tls(tls_config, tcp).await { Err(_) => return, Ok(tls) => tls, } @@ -195,6 +194,14 @@ async fn check_accept_connection( impl SbdServer { /// Construct a new running sbd server with the provided config. pub async fn new(config: Arc) -> Result { + let tls_config = if let (Some(cert), Some(pk)) = + (&config.cert_pem_file, &config.priv_key_pem_file) + { + Some(Arc::new(maybe_tls::TlsConfig::new(cert, pk).await?)) + } else { + None + }; + let mut task_list = Vec::new(); let mut bind_addrs = Vec::new(); @@ -228,6 +235,7 @@ impl SbdServer { let tcp = tokio::net::TcpListener::bind(a).await?; bind_addrs.push(tcp.local_addr()?); + let tls_config = tls_config.clone(); let connect_limit = connect_limit.clone(); let config = config.clone(); let weak_cslot = weak_cslot.clone(); @@ -248,6 +256,7 @@ impl SbdServer { tokio::task::spawn(check_accept_connection( connect_permit, config.clone(), + tls_config.clone(), ip_rate.clone(), tcp, addr, diff --git a/rust/sbd-server/src/maybe_tls.rs b/rust/sbd-server/src/maybe_tls.rs index bb0b33a..264f709 100644 --- a/rust/sbd-server/src/maybe_tls.rs +++ b/rust/sbd-server/src/maybe_tls.rs @@ -1,29 +1,47 @@ //! taken and altered from tokio_tungstenite -use std::{ - pin::Pin, - task::{Context, Poll}, -}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -/// A stream that might be protected with TLS. -#[non_exhaustive] -#[derive(Debug)] -pub enum MaybeTlsStream { - /// Tcp. - Tcp(tokio::net::TcpStream), - - /// Tls. - Tls(tokio_rustls::server::TlsStream), +pub struct TlsConfig { + cert: std::path::PathBuf, + pk: std::path::PathBuf, + config: Arc>>, } -impl MaybeTlsStream { - pub async fn tls( +impl TlsConfig { + pub async fn new( cert: &std::path::Path, pk: &std::path::Path, - tcp: tokio::net::TcpStream, ) -> std::io::Result { + let cert = cert.to_owned(); + let pk = pk.to_owned(); + let config = Self::load(&cert, &pk).await?; + Ok(Self { + cert, + pk, + config: Arc::new(Mutex::new(config)), + }) + } + + pub fn config(&self) -> Arc { + self.config.lock().unwrap().clone() + } + + #[allow(dead_code)] // watch reload tls + pub async fn reload(&self) -> std::io::Result<()> { + let new_config = Self::load(&self.cert, &self.pk).await?; + *self.config.lock().unwrap() = new_config; + Ok(()) + } + + async fn load( + cert: &std::path::Path, + pk: &std::path::Path, + ) -> std::io::Result> { use rustls_pemfile::Item::*; let cert = tokio::fs::read(cert).await?; @@ -46,14 +64,34 @@ impl MaybeTlsStream { _ => return Err(std::io::Error::other("error reading priv key")), }; - let c = std::sync::Arc::new( + Ok(std::sync::Arc::new( rustls::server::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert], pk) .map_err(std::io::Error::other)?, - ); + )) + } +} + +/// A stream that might be protected with TLS. +#[non_exhaustive] +#[derive(Debug)] +pub enum MaybeTlsStream { + /// Tcp. + Tcp(tokio::net::TcpStream), + + /// Tls. + Tls(tokio_rustls::server::TlsStream), +} + +impl MaybeTlsStream { + pub async fn tls( + tls_config: &TlsConfig, + tcp: tokio::net::TcpStream, + ) -> std::io::Result { + let config = tls_config.config(); - let tls = tokio_rustls::TlsAcceptor::from(c).accept(tcp).await?; + let tls = tokio_rustls::TlsAcceptor::from(config).accept(tcp).await?; Ok(Self::Tls(tls)) } From 161dba7763e75961c144d5375c2822414e1512d9 Mon Sep 17 00:00:00 2001 From: David Braden Date: Mon, 22 Apr 2024 13:46:20 -0600 Subject: [PATCH 17/33] Apply suggestions from code review Co-authored-by: ThetaSinner --- rust/sbd-client/src/lib.rs | 2 +- rust/sbd-client/src/raw_client.rs | 2 +- spec.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index b379e11..1258bc8 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -82,7 +82,7 @@ enum MsgType<'t> { /// A message received from a remote. /// This is just a single buffer. The first 32 bytes are the public key -/// of the sender. Any remaining bytes are the message. The buffer +/// of the sender, or 28 `0`s followed by a 4 byte command. Any remaining bytes are the message. The buffer /// contained in this type is guaranteed to be at least 32 bytes long. pub struct Msg(pub Vec); diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index ebe597e..d57fba5 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -235,7 +235,7 @@ fn priv_system_tls( for cert in rustls_native_certs::load_native_certs() .expect("failed to load system tls certs") { - roots.add(cert).expect("faild to add cert to root"); + roots.add(cert).expect("failed to add cert to root"); } if danger_disable_certificate_check { diff --git a/spec.md b/spec.md index fd5f6a8..255812a 100644 --- a/spec.md +++ b/spec.md @@ -65,7 +65,7 @@ the first 32 bytes will be altered to represent the peer from which the message #### 3.2.3. Flow -- The server MUST send `areq` with a random nonce once for every new opened connection. +- The server MUST send `areq` with a random nonce once for every newly opened connection. The server MAY send any limit messages before or after this `areq`, but it MUST come before the `srdy`. - The client MUST respond with a signature over the nonce by the private key associated with the public key sent in the url path segment websocket request From 216d1c660599e3ade4100db86cf2a6c4061f3373 Mon Sep 17 00:00:00 2001 From: neonphog Date: Mon, 22 Apr 2024 13:55:03 -0600 Subject: [PATCH 18/33] address code review comments --- rust/sbd-bench/src/c_turnover.rs | 4 +- rust/sbd-bench/src/thru.rs | 90 +++++++++++++++++--------------- spec.md | 4 +- 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/rust/sbd-bench/src/c_turnover.rs b/rust/sbd-bench/src/c_turnover.rs index 213005e..76334f7 100644 --- a/rust/sbd-bench/src/c_turnover.rs +++ b/rust/sbd-bench/src/c_turnover.rs @@ -66,11 +66,11 @@ mod tests { } let start = tokio::time::Instant::now(); - for _ in 0..100 { + for _ in 0..10 { b.iter().await; } let elapsed = start.elapsed(); - println!("{} nanos per iter", elapsed.as_nanos() / 100); + println!("{} nanos per iter", elapsed.as_nanos() / 10); } } diff --git a/rust/sbd-bench/src/thru.rs b/rust/sbd-bench/src/thru.rs index 1b9d997..ab51fae 100644 --- a/rust/sbd-bench/src/thru.rs +++ b/rust/sbd-bench/src/thru.rs @@ -2,14 +2,14 @@ use super::*; pub struct ThruBenchmark { _server: SbdServer, - c1: DefaultCrypto, - s1: WsRawSend, - r1: WsRawRecv, - c2: DefaultCrypto, - s2: WsRawSend, - r2: WsRawRecv, - v1: Option>, - v2: Option>, + crypto1: DefaultCrypto, + send1: WsRawSend, + recv1: WsRawRecv, + crypto2: DefaultCrypto, + send2: WsRawSend, + recv2: WsRawRecv, + message1: Option>, + message2: Option>, } impl ThruBenchmark { @@ -23,66 +23,72 @@ impl ThruBenchmark { let server = SbdServer::new(config).await.unwrap(); - let c1 = DefaultCrypto::default(); - let (mut s1, mut r1) = raw_connect(c1.pub_key(), server.bind_addrs()) + let crypto1 = DefaultCrypto::default(); + let (mut send1, mut recv1) = + raw_connect(crypto1.pub_key(), server.bind_addrs()) + .await + .unwrap(); + + let crypto2 = DefaultCrypto::default(); + let (mut send2, mut recv2) = + raw_connect(crypto2.pub_key(), server.bind_addrs()) + .await + .unwrap(); + + Handshake::handshake(&mut send1, &mut recv1, &crypto1) .await .unwrap(); - - let c2 = DefaultCrypto::default(); - let (mut s2, mut r2) = raw_connect(c2.pub_key(), server.bind_addrs()) + Handshake::handshake(&mut send2, &mut recv2, &crypto2) .await .unwrap(); - Handshake::handshake(&mut s1, &mut r1, &c1).await.unwrap(); - Handshake::handshake(&mut s2, &mut r2, &c2).await.unwrap(); - Self { _server: server, - c1, - s1, - r1, - c2, - s2, - r2, - v1: None, - v2: None, + crypto1, + send1, + recv1, + crypto2, + send2, + recv2, + message1: None, + message2: None, } } pub async fn iter(&mut self) { let Self { - c1, - s1, - r1, - c2, - s2, - r2, - v1, - v2, + crypto1, + send1, + recv1, + crypto2, + send2, + recv2, + message1, + message2, .. } = self; - let mut b1 = v1.take().unwrap_or_else(|| vec![0xdb; 1000]); - let mut b2 = v2.take().unwrap_or_else(|| vec![0xca; 1000]); + let mut b1 = message1.take().unwrap_or_else(|| vec![0xdb; 1000]); + let mut b2 = message2.take().unwrap_or_else(|| vec![0xca; 1000]); tokio::join!( async { - b1[0..32].copy_from_slice(c2.pub_key()); - s1.send(b1).await.unwrap(); + b1[0..32].copy_from_slice(crypto2.pub_key()); + send1.send(b1).await.unwrap(); }, async { - b2[0..32].copy_from_slice(c1.pub_key()); - s2.send(b2).await.unwrap(); + b2[0..32].copy_from_slice(crypto1.pub_key()); + send2.send(b2).await.unwrap(); }, async { - let b2 = r1.recv().await.unwrap(); + let b2 = recv1.recv().await.unwrap(); assert_eq!(1000, b2.len()); - *v2 = Some(b2); + *message2 = Some(b2); }, async { - let b1 = r2.recv().await.unwrap(); + let b1 = recv2.recv().await.unwrap(); assert_eq!(1000, b1.len()); - *v1 = Some(b1); + *message1 = Some(b1); }, ); } diff --git a/spec.md b/spec.md index 255812a..6868834 100644 --- a/spec.md +++ b/spec.md @@ -33,7 +33,7 @@ Clients will be identified by ed25519 public key. Client sessions will be valida Clients MUST specify exactly 1 http path item on the websocket connection url. This item must be the base64url encoded public key that this client will be identified by. -This public key MUST be unique to this new connection. +This public key SHOULD be unique to this new connection. ### 3.2. Messages @@ -65,6 +65,8 @@ the first 32 bytes will be altered to represent the peer from which the message #### 3.2.3. Flow +- If the server is in an overload state, it MAY drop incoming tcp connections + immediately with no response even before doing any TLS handshaking. - The server MUST send `areq` with a random nonce once for every newly opened connection. The server MAY send any limit messages before or after this `areq`, but it MUST come before the `srdy`. - The client MUST respond with a signature over the nonce by the private key associated with the public key From 481b229095775af05d9ba5cac143c0ce710444f5 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 10:24:08 -0600 Subject: [PATCH 19/33] address code review comments --- rust/sbd-client/src/lib.rs | 4 ++-- rust/sbd-client/src/raw_client.rs | 2 +- rust/sbd-client/src/send_buf.rs | 2 +- rust/sbd-server/src/cmd.rs | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 1258bc8..b00dabe 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -64,7 +64,7 @@ impl std::fmt::Debug for PubKey { } } -const CMD_FLAG: &[u8; 28] = &[0; 28]; +const CMD_PREFIX: &[u8; 28] = &[0; 28]; enum MsgType<'t> { Msg { @@ -108,7 +108,7 @@ impl Msg { if self.0.len() < 32 { return Err(Error::other("invalid message length")); } - if &self.0[..28] == CMD_FLAG { + if &self.0[..28] == CMD_PREFIX { match &self.0[28..32] { b"lbrt" => { if self.0.len() != 32 + 4 { diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index d57fba5..912dcd8 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -186,7 +186,7 @@ impl Handshake { MsgType::AuthReq(nonce) => { let sig = crypto.sign(nonce); let mut auth_res = Vec::with_capacity(32 + 64); - auth_res.extend_from_slice(CMD_FLAG); + auth_res.extend_from_slice(CMD_PREFIX); auth_res.extend_from_slice(b"ares"); auth_res.extend_from_slice(&sig); send.send(auth_res).await?; diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index c4b547f..b00077f 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -93,7 +93,7 @@ impl SendBuf { // first check if we need to keepalive if now - self.last_send >= self.idle_keepalive_nanos { let mut data = Vec::with_capacity(32); - data.extend_from_slice(CMD_FLAG); + data.extend_from_slice(CMD_PREFIX); data.extend_from_slice(b"keep"); self.raw_send(now, data).await?; return Ok(true); diff --git a/rust/sbd-server/src/cmd.rs b/rust/sbd-server/src/cmd.rs index 9f5daa3..9f7158f 100644 --- a/rust/sbd-server/src/cmd.rs +++ b/rust/sbd-server/src/cmd.rs @@ -21,14 +21,14 @@ pub enum SbdCmd<'c> { Unknown, } -const CMD_FLAG: &[u8; 28] = &[0; 28]; +const CMD_PREFIX: &[u8; 28] = &[0; 28]; impl<'c> SbdCmd<'c> { pub fn parse(payload: Payload<'c>) -> Result { if payload.as_ref().len() < 32 { return Err(Error::other("invalid payload length")); } - if &payload.as_ref()[..28] == CMD_FLAG { + if &payload.as_ref()[..28] == CMD_PREFIX { // only include the messages that clients should send // mark everything else as Unknown match &payload.as_ref()[28..32] { @@ -53,7 +53,7 @@ impl SbdCmd<'_> { pub fn limit_byte_nanos(limit_byte_nanos: i32) -> Payload<'static> { let mut out = Vec::with_capacity(32 + 4); let n = limit_byte_nanos.to_be_bytes(); - out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_LIMIT_BYTE_NANOS); out.extend_from_slice(&n[..]); Payload::Vec(out) @@ -62,7 +62,7 @@ impl SbdCmd<'_> { pub fn limit_idle_millis(limit_idle_millis: i32) -> Payload<'static> { let mut out = Vec::with_capacity(32 + 4); let n = limit_idle_millis.to_be_bytes(); - out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_LIMIT_IDLE_MILLIS); out.extend_from_slice(&n[..]); Payload::Vec(out) @@ -70,7 +70,7 @@ impl SbdCmd<'_> { pub fn auth_req(nonce: &[u8; 32]) -> Payload<'static> { let mut out = Vec::with_capacity(32 + 32); - out.extend_from_slice(CMD_FLAG); + out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_AUTH_REQ); out.extend_from_slice(&nonce[..]); Payload::Vec(out) From 5c74530f38da9d2b1f95d6cd7a4d46ca220917d5 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 10:34:44 -0600 Subject: [PATCH 20/33] address code review comment --- spec.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spec.md b/spec.md index 6868834..c342c83 100644 --- a/spec.md +++ b/spec.md @@ -112,6 +112,8 @@ A server MAY track rate limiting by some metric other than individual connection Then, if additional connections are established from the same other metric, all connections could be notified of needing to send data more slowly. +A client MAY wish to honor a slightly increased rate (e.g. lbrt * 1.1) to account for clock skew or network backlogs being dumped all at once. + #### 3.2.6. Extensibility In order to make this protocol extensible without versioning, clients and servers MUST ignore unknown command types. From 4cd0be194ed008120c7db66abdc83f068fd4d0b8 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 10:47:07 -0600 Subject: [PATCH 21/33] address code review comments --- rust/sbd-client/src/raw_client.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index 912dcd8..31e08a1 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -31,8 +31,10 @@ impl WsRawConnect { danger_disable_certificate_check, } = self; - let scheme_ws = full_url.starts_with("ws://"); - let scheme_wss = full_url.starts_with("wss://"); + let request = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(full_url).map_err(Error::other)?; + + let scheme_ws = request.uri().scheme_str() == Some("ws"); + let scheme_wss = request.uri().scheme_str() == Some("wss"); if !scheme_ws && !scheme_wss { return Err(Error::other("scheme must be ws:// or wss://")); @@ -42,8 +44,6 @@ impl WsRawConnect { return Err(Error::other("plain text scheme not allowed")); } - let request = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(full_url).map_err(Error::other)?; - let host = match request.uri().host() { Some(host) => host.to_string(), None => return Err(Error::other("invalid url")), From 7a1887d063470dc173bdd20902d764865356d5e3 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:07:11 -0600 Subject: [PATCH 22/33] address code review comments --- rust/sbd-server/src/cmd.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/rust/sbd-server/src/cmd.rs b/rust/sbd-server/src/cmd.rs index 9f7158f..379ed02 100644 --- a/rust/sbd-server/src/cmd.rs +++ b/rust/sbd-server/src/cmd.rs @@ -5,19 +5,23 @@ const F_LIMIT_BYTE_NANOS: &[u8] = b"lbrt"; const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl"; const F_AUTH_REQ: &[u8] = b"areq"; const F_AUTH_RES: &[u8] = b"ares"; -//const F_READY: &[u8] = b"srdy"; -/// Sbd commands. This enum only includes the types that clients send. -/// The class contains only methods for generating commands that can -/// be sent to the client. +/// Sbd commands. +/// Enum variants represent only the types that clients can send to the server: +/// - not-cmd Message(payload) +/// - `keep` Keepalive +/// - `ares` AuthRes(signature) +/// - other-cmd Unknown +/// Member functions represent only the types that the server can send to the +/// clients: +/// - `lbrt` limit_byte_nanos(i32) +/// - `lidl` limit_idle_millis(i32) +/// - `areq` auth_req(nonce) +/// - `srdy` ready() pub enum SbdCmd<'c> { Message(Payload<'c>), Keepalive, - //LimitByteNanos(i32), - //LimitIdleMillis(i32), - //AuthReq([u8; 32]), AuthRes([u8; 64]), - //Ready, Unknown, } From 185662835bf8586805dfb5f134c5f0a7aed3bed3 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:10:16 -0600 Subject: [PATCH 23/33] address code review comment --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d50f4c5..1672fd7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ Simple websocket-based message relay servers and clients. +SBD doesn't stand for anything. Imagine it however you like. Secure By Design... STUN By Default... Silent But Deadly... + - [Rust Reference Server](rust/sbd-server) - [Rust Reference Client](rust/sbd-client) - [Autobahn-Style Server Test Suite](rust/sbd-o-bahn-server-tester) From 8f9b50ede4cc68916fe383fd0a0e76296f0c1178 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:20:25 -0600 Subject: [PATCH 24/33] address code review commenst --- rust/sbd-bench/src/c_turnover.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rust/sbd-bench/src/c_turnover.rs b/rust/sbd-bench/src/c_turnover.rs index 76334f7..a359b59 100644 --- a/rust/sbd-bench/src/c_turnover.rs +++ b/rust/sbd-bench/src/c_turnover.rs @@ -1,6 +1,11 @@ use super::*; use std::collections::VecDeque; +/// This benchmark sets up a server that only accepts 4 clients. +/// Each iteration it: +/// - tries to connect new clients until it gets a connect error. +/// - drops one of the old clients so there is exactly 1 space free. +/// - connects 1 new client and panics if it errors connecting. pub struct CTurnoverBenchmark { server: SbdServer, house: VecDeque<(WsRawSend, WsRawRecv)>, From 58e4c1121c64ac29bd23918c160b47a8f3a6dedf Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:36:19 -0600 Subject: [PATCH 25/33] address code review comments --- rust/sbd-client/src/lib.rs | 53 ++++++++++++++++++++----------- rust/sbd-client/src/raw_client.rs | 4 +-- rust/sbd-client/src/send_buf.rs | 6 ++-- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index b00dabe..bd344a3 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -7,6 +7,18 @@ use std::sync::Arc; /// defined by the sbd spec const MAX_MSG_SIZE: usize = 20_000; +/// defined by ed25519 spec +const PK_SIZE: usize = 32; + +/// defined by ed25519 spec +const SIG_SIZE: usize = 64; + +/// sbd spec defines headers to be the same size as ed25519 pub keys +const HDR_SIZE: usize = PK_SIZE; + +/// defined by sbd spec +const NONCE_SIZE: usize = 32; + #[cfg(feature = "raw_client")] pub mod raw_client; #[cfg(not(feature = "raw_client"))] @@ -17,16 +29,18 @@ mod send_buf; /// Crypto to use. Note, the pair should be fresh for each new connection. pub trait Crypto { /// The pubkey. - fn pub_key(&self) -> &[u8; 32]; + fn pub_key(&self) -> &[u8; PK_SIZE]; /// Sign the nonce. - fn sign(&self, nonce: &[u8]) -> [u8; 64]; + fn sign(&self, nonce: &[u8]) -> [u8; SIG_SIZE]; } #[cfg(feature = "crypto")] mod default_crypto { + use super::*; + /// Default signer. Use a fresh one for every new connection. - pub struct DefaultCrypto([u8; 32], ed25519_dalek::SigningKey); + pub struct DefaultCrypto([u8; PK_SIZE], ed25519_dalek::SigningKey); impl Default for DefaultCrypto { fn default() -> Self { @@ -38,11 +52,11 @@ mod default_crypto { } impl super::Crypto for DefaultCrypto { - fn pub_key(&self) -> &[u8; 32] { + fn pub_key(&self) -> &[u8; PK_SIZE] { &self.0 } - fn sign(&self, nonce: &[u8]) -> [u8; 64] { + fn sign(&self, nonce: &[u8]) -> [u8; SIG_SIZE] { use ed25519_dalek::Signer; self.1.sign(nonce).to_bytes() } @@ -53,7 +67,7 @@ pub use default_crypto::*; /// Public key. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct PubKey(pub [u8; 32]); +pub struct PubKey(pub [u8; PK_SIZE]); impl std::fmt::Debug for PubKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -64,6 +78,7 @@ impl std::fmt::Debug for PubKey { } } +/// defined by sbd spec const CMD_PREFIX: &[u8; 28] = &[0; 28]; enum MsgType<'t> { @@ -89,56 +104,56 @@ pub struct Msg(pub Vec); impl Msg { /// Get a reference to the slice containing the pubkey data. pub fn pub_key_ref(&self) -> &[u8] { - &self.0[..32] + &self.0[..PK_SIZE] } /// Extract a pubkey from the message. pub fn pub_key(&self) -> PubKey { - PubKey(self.0[..32].try_into().unwrap()) + PubKey(self.0[..PK_SIZE].try_into().unwrap()) } /// Get a reference to the slice containing the message data. pub fn message(&self) -> &[u8] { - &self.0[32..] + &self.0[PK_SIZE..] } // -- private -- // fn parse(&self) -> Result> { - if self.0.len() < 32 { + if self.0.len() < PK_SIZE { return Err(Error::other("invalid message length")); } if &self.0[..28] == CMD_PREFIX { - match &self.0[28..32] { + match &self.0[28..HDR_SIZE] { b"lbrt" => { - if self.0.len() != 32 + 4 { + if self.0.len() != HDR_SIZE + 4 { return Err(Error::other("invalid lbrt length")); } Ok(MsgType::LimitByteNanos(i32::from_be_bytes( - self.0[32..].try_into().unwrap(), + self.0[PK_SIZE..].try_into().unwrap(), ))) } b"lidl" => { - if self.0.len() != 32 + 4 { + if self.0.len() != HDR_SIZE + 4 { return Err(Error::other("invalid lidl length")); } Ok(MsgType::LimitIdleMillis(i32::from_be_bytes( - self.0[32..].try_into().unwrap(), + self.0[HDR_SIZE..].try_into().unwrap(), ))) } b"areq" => { - if self.0.len() != 32 + 32 { + if self.0.len() != HDR_SIZE + NONCE_SIZE { return Err(Error::other("invalid areq length")); } - Ok(MsgType::AuthReq(&self.0[32..])) + Ok(MsgType::AuthReq(&self.0[HDR_SIZE..])) } b"srdy" => Ok(MsgType::Ready), _ => Ok(MsgType::Unknown), } } else { Ok(MsgType::Msg { - pub_key: &self.0[..32], - message: &self.0[32..], + pub_key: &self.0[..PK_SIZE], + message: &self.0[PK_SIZE..], }) } } diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index 31e08a1..85611ea 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -185,12 +185,12 @@ impl Handshake { MsgType::LimitIdleMillis(l) => limit_idle_millis = l, MsgType::AuthReq(nonce) => { let sig = crypto.sign(nonce); - let mut auth_res = Vec::with_capacity(32 + 64); + let mut auth_res = Vec::with_capacity(HDR_SIZE + SIG_SIZE); auth_res.extend_from_slice(CMD_PREFIX); auth_res.extend_from_slice(b"ares"); auth_res.extend_from_slice(&sig); send.send(auth_res).await?; - bytes_sent += 32 + 64; + bytes_sent += HDR_SIZE + SIG_SIZE; } MsgType::Ready => break, MsgType::Unknown => (), diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index b00077f..abdb175 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -92,7 +92,7 @@ impl SendBuf { // first check if we need to keepalive if now - self.last_send >= self.idle_keepalive_nanos { - let mut data = Vec::with_capacity(32); + let mut data = Vec::with_capacity(HDR_SIZE); data.extend_from_slice(CMD_PREFIX); data.extend_from_slice(b"keep"); self.raw_send(now, data).await?; @@ -117,7 +117,7 @@ impl SendBuf { /// Note, you'll need to explicitly call `write_next_queued()` or /// make additional sends in order to get this queued data actually sent. pub async fn send(&mut self, pk: &PubKey, data: &[u8]) -> Result<()> { - if data.len() > MAX_MSG_SIZE - 32 { + if data.len() > MAX_MSG_SIZE - PK_SIZE { return Err(Error::other("message too large")); } @@ -128,7 +128,7 @@ impl SendBuf { self.write_next_queued().await?; } - let mut buf = Vec::with_capacity(32 + data.len()); + let mut buf = Vec::with_capacity(PK_SIZE + data.len()); buf.extend_from_slice(&pk.0[..]); buf.extend_from_slice(data); self.buf.push_back(buf); From 52b890744c1f0bf2894108b5abddccd11e514b13 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:41:18 -0600 Subject: [PATCH 26/33] code review comment --- rust/sbd-client/src/lib.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index bd344a3..5456566 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -19,6 +19,11 @@ const HDR_SIZE: usize = PK_SIZE; /// defined by sbd spec const NONCE_SIZE: usize = 32; +const F_LIMIT_BYTE_NANOS: &[u8] = b"lbrt"; +const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl"; +const F_AUTH_REQ: &[u8] = b"areq"; +const F_READY: &[u8] = b"srdy"; + #[cfg(feature = "raw_client")] pub mod raw_client; #[cfg(not(feature = "raw_client"))] @@ -125,7 +130,7 @@ impl Msg { } if &self.0[..28] == CMD_PREFIX { match &self.0[28..HDR_SIZE] { - b"lbrt" => { + F_LIMIT_BYTE_NANOS => { if self.0.len() != HDR_SIZE + 4 { return Err(Error::other("invalid lbrt length")); } @@ -133,7 +138,7 @@ impl Msg { self.0[PK_SIZE..].try_into().unwrap(), ))) } - b"lidl" => { + F_LIMIT_IDLE_MILLIS => { if self.0.len() != HDR_SIZE + 4 { return Err(Error::other("invalid lidl length")); } @@ -141,13 +146,13 @@ impl Msg { self.0[HDR_SIZE..].try_into().unwrap(), ))) } - b"areq" => { + F_AUTH_REQ => { if self.0.len() != HDR_SIZE + NONCE_SIZE { return Err(Error::other("invalid areq length")); } Ok(MsgType::AuthReq(&self.0[HDR_SIZE..])) } - b"srdy" => Ok(MsgType::Ready), + F_READY => Ok(MsgType::Ready), _ => Ok(MsgType::Unknown), } } else { From 5c8561c43e0030ee5827c26c1c17fac1fa27d3af Mon Sep 17 00:00:00 2001 From: David Braden Date: Tue, 23 Apr 2024 11:47:05 -0600 Subject: [PATCH 27/33] Update rust/sbd-client/src/send_buf.rs Co-authored-by: Stefan Junker <1181362+steveej@users.noreply.github.com> --- rust/sbd-client/src/send_buf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index abdb175..edf62cb 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -73,10 +73,10 @@ impl SendBuf { } if now < self.next_send_at { - let need_keepalive_at = + let need_keepalive_in = self.idle_keepalive_nanos - (now - self.last_send); let nanos = - std::cmp::min(need_keepalive_at, self.next_send_at - now); + std::cmp::min(need_keepalive_in, self.next_send_at - now); Some(std::time::Duration::from_nanos(nanos)) } else { None From f1bad2fa0a69fa1623f15de298ea91d43a354d52 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 23 Apr 2024 11:51:13 -0600 Subject: [PATCH 28/33] code review comment --- rust/sbd-client/src/send_buf.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/sbd-client/src/send_buf.rs b/rust/sbd-client/src/send_buf.rs index edf62cb..11c4e2c 100644 --- a/rust/sbd-client/src/send_buf.rs +++ b/rust/sbd-client/src/send_buf.rs @@ -35,8 +35,7 @@ impl SendBuf { let now = this.origin.elapsed().as_nanos() as u64; - this.next_send_at = std::cmp::max(now, this.next_send_at) - + (pre_sent_bytes as u64 * this.limit_rate); + this.next_send_at = now + (pre_sent_bytes as u64 * this.limit_rate); this } From fade996672d496cede71fa835d7130eb94986eba Mon Sep 17 00:00:00 2001 From: neonphog Date: Wed, 24 Apr 2024 09:52:10 -0600 Subject: [PATCH 29/33] address code review comments --- rust/sbd-client/src/lib.rs | 15 +++++++++++---- rust/sbd-server/src/cmd.rs | 2 +- rust/sbd-server/src/lib.rs | 5 +++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 5456566..a9a5765 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -49,10 +49,17 @@ mod default_crypto { impl Default for DefaultCrypto { fn default() -> Self { - let k = - ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()); - let pk = k.verifying_key().to_bytes(); - Self(pk, k) + loop { + let k = ed25519_dalek::SigningKey::generate( + &mut rand::thread_rng(), + ); + let pk = k.verifying_key().to_bytes(); + if &pk[..28] == CMD_PREFIX { + continue; + } else { + return Self(pk, k); + } + } } } diff --git a/rust/sbd-server/src/cmd.rs b/rust/sbd-server/src/cmd.rs index 379ed02..6108f66 100644 --- a/rust/sbd-server/src/cmd.rs +++ b/rust/sbd-server/src/cmd.rs @@ -25,7 +25,7 @@ pub enum SbdCmd<'c> { Unknown, } -const CMD_PREFIX: &[u8; 28] = &[0; 28]; +pub(crate) const CMD_PREFIX: &[u8; 28] = &[0; 28]; impl<'c> SbdCmd<'c> { pub fn parse(payload: Payload<'c>) -> Result { diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index 8b0abb3..edc6cb9 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -165,6 +165,11 @@ async fn check_accept_connection( Err(_) => return, }; + // illegal pub key + if &pub_key.0[..28] == cmd::CMD_PREFIX { + return; + } + let ws = Arc::new(ws); if let Some(ip) = ip { From 60595d024cd59a0e451d8db04a0b1248ca551a86 Mon Sep 17 00:00:00 2001 From: neonphog Date: Wed, 24 Apr 2024 13:28:01 -0600 Subject: [PATCH 30/33] arc in pubkey --- rust/sbd-client/examples/client-o-bahn-runner.rs | 4 ++-- rust/sbd-client/src/lib.rs | 16 ++++++++++++---- rust/sbd-client/tests/reasonable-rate-limit.rs | 4 ++-- rust/sbd-o-bahn-server-tester/src/it/it_1.rs | 12 ++++++++++-- rust/sbd-server/src/lib.rs | 4 ++-- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/rust/sbd-client/examples/client-o-bahn-runner.rs b/rust/sbd-client/examples/client-o-bahn-runner.rs index 191b199..3c337de 100644 --- a/rust/sbd-client/examples/client-o-bahn-runner.rs +++ b/rust/sbd-client/examples/client-o-bahn-runner.rs @@ -38,7 +38,7 @@ async fn main() { let msg = hex::decode(parts.pop_front().unwrap()).unwrap(); if let Some(s) = con_map.get(&id) { let _ = s.send(ConCmd::Send( - PubKey(pk.try_into().unwrap()), + PubKey(std::sync::Arc::new(pk.try_into().unwrap())), msg, )); } @@ -88,7 +88,7 @@ async fn spawn_con( } println!("CMD/CLOSE/{id}"); }); - println!("CMD/CONNECT/{id}/{}", hex::encode(&pk.0)); + println!("CMD/CONNECT/{id}/{}", hex::encode(&pk[..])); while let Some(cmd) = r.recv().await { match cmd { ConCmd::Close => break, diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index a9a5765..6d1b5ba 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -78,8 +78,16 @@ mod default_crypto { pub use default_crypto::*; /// Public key. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct PubKey(pub [u8; PK_SIZE]); +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PubKey(pub Arc<[u8; PK_SIZE]>); + +impl std::ops::Deref for PubKey { + type Target = [u8; 32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} impl std::fmt::Debug for PubKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -121,7 +129,7 @@ impl Msg { /// Extract a pubkey from the message. pub fn pub_key(&self) -> PubKey { - PubKey(self.0[..PK_SIZE].try_into().unwrap()) + PubKey(Arc::new(self.0[..PK_SIZE].try_into().unwrap())) } /// Get a reference to the slice containing the message data. @@ -330,7 +338,7 @@ impl SbdClient { Ok(( this, full_url, - PubKey(*crypto.pub_key()), + PubKey(Arc::new(*crypto.pub_key())), MsgRecv(recv_recv), )) } diff --git a/rust/sbd-client/tests/reasonable-rate-limit.rs b/rust/sbd-client/tests/reasonable-rate-limit.rs index 627cebd..4a793fc 100644 --- a/rust/sbd-client/tests/reasonable-rate-limit.rs +++ b/rust/sbd-client/tests/reasonable-rate-limit.rs @@ -84,7 +84,7 @@ async fn run( let mut tot = 0; loop { let r = r1.recv().await.unwrap(); - assert_eq!(r.pub_key_ref(), &p2.0); + assert_eq!(r.pub_key_ref(), &p2[..]); tot += r.message().len(); println!("r1 got {} bytes", tot); rate1 += (32 + r.message().len()) as f64; @@ -97,7 +97,7 @@ async fn run( let mut tot = 0; loop { let r = r2.recv().await.unwrap(); - assert_eq!(r.pub_key_ref(), &p1.0); + assert_eq!(r.pub_key_ref(), &p1[..]); tot += r.message().len(); println!("r2 got {} bytes", tot); rate2 += (32 + r.message().len()) as f64; diff --git a/rust/sbd-o-bahn-server-tester/src/it/it_1.rs b/rust/sbd-o-bahn-server-tester/src/it/it_1.rs index 6aa57ab..666d8aa 100644 --- a/rust/sbd-o-bahn-server-tester/src/it/it_1.rs +++ b/rust/sbd-o-bahn-server-tester/src/it/it_1.rs @@ -20,9 +20,17 @@ impl It for It1 { async { r2.recv().await.ok_or(Error::other("closed")) }, )?; - expect!(helper, result1.pub_key_ref() == p2.0, "r1 recv from p2"); + expect!( + helper, + result1.pub_key_ref() == &p2[..], + "r1 recv from p2" + ); expect!(helper, result1.message() == b"world", "r1 got 'world'"); - expect!(helper, result2.pub_key_ref() == p1.0, "r2 recv from p1"); + expect!( + helper, + result2.pub_key_ref() == &p1[..], + "r2 recv from p1" + ); expect!(helper, result2.message() == b"hello", "r2 got 'hello'"); Ok(()) diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index edc6cb9..0d740ae 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -351,14 +351,14 @@ mod tests { let res_data = rcv2.recv().await.unwrap(); - assert_eq!(&pk1.0, res_data.pub_key_ref()); + assert_eq!(&pk1[..], res_data.pub_key_ref()); assert_eq!(&b"hello"[..], res_data.message()); client2.send(&pk1, b"world").await.unwrap(); let res_data = rcv1.recv().await.unwrap(); - assert_eq!(&pk2.0, res_data.pub_key_ref()); + assert_eq!(&pk2[..], res_data.pub_key_ref()); assert_eq!(&b"world"[..], res_data.message()); } } From 1f164a2d6549317d59b8c2ca91126fdcdbec0def Mon Sep 17 00:00:00 2001 From: neonphog Date: Thu, 2 May 2024 16:28:10 -0600 Subject: [PATCH 31/33] small api tweak and test --- .../examples/client-o-bahn-runner.rs | 3 +- rust/sbd-client/src/lib.rs | 30 ++++++--- rust/sbd-client/src/test.rs | 26 ++++++++ .../sbd-client/tests/reasonable-rate-limit.rs | 10 +-- rust/sbd-o-bahn-server-tester/src/it.rs | 7 +-- rust/sbd-o-bahn-server-tester/src/it/it_1.rs | 7 ++- rust/sbd-server/src/lib.rs | 62 +++++++++---------- 7 files changed, 89 insertions(+), 56 deletions(-) create mode 100644 rust/sbd-client/src/test.rs diff --git a/rust/sbd-client/examples/client-o-bahn-runner.rs b/rust/sbd-client/examples/client-o-bahn-runner.rs index 3c337de..3d61ce8 100644 --- a/rust/sbd-client/examples/client-o-bahn-runner.rs +++ b/rust/sbd-client/examples/client-o-bahn-runner.rs @@ -66,7 +66,8 @@ async fn connect(addrs: &[String]) -> (SbdClient, PubKey, MsgRecv) { ) .await { - return (c.0, c.2, c.3); + let pk = c.0.pub_key().clone(); + return (c.0, pk, c.1); } } panic!() diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 6d1b5ba..c453f75 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -217,6 +217,8 @@ impl Default for SbdClientConfig { /// SbdClient represents a single connection to a single sbd server /// through which we can communicate with any number of peers on that server. pub struct SbdClient { + url: String, + pub_key: PubKey, send_buf: Arc>, read_task: tokio::task::JoinHandle<()>, write_task: tokio::task::JoinHandle<()>, @@ -234,7 +236,7 @@ impl SbdClient { pub async fn connect( url: &str, crypto: &C, - ) -> Result<(Self, String, PubKey, MsgRecv)> { + ) -> Result<(Self, MsgRecv)> { Self::connect_config(url, crypto, SbdClientConfig::default()).await } @@ -243,7 +245,7 @@ impl SbdClient { url: &str, crypto: &C, config: SbdClientConfig, - ) -> Result<(Self, String, PubKey, MsgRecv)> { + ) -> Result<(Self, MsgRecv)> { use base64::Engine; let full_url = format!( "{url}/{}", @@ -329,18 +331,27 @@ impl SbdClient { send_buf2.lock().await.close().await; }); + let pub_key = PubKey(Arc::new(*crypto.pub_key())); + let this = Self { + url: full_url, + pub_key, send_buf, read_task, write_task, }; - Ok(( - this, - full_url, - PubKey(Arc::new(*crypto.pub_key())), - MsgRecv(recv_recv), - )) + Ok((this, MsgRecv(recv_recv))) + } + + /// The full url of this client. + pub fn url(&self) -> &str { + &self.url + } + + /// The pub key of this client. + pub fn pub_key(&self) -> &PubKey { + &self.pub_key } /// Close the connection. @@ -353,3 +364,6 @@ impl SbdClient { self.send_buf.lock().await.send(peer, data).await } } + +#[cfg(test)] +mod test; diff --git a/rust/sbd-client/src/test.rs b/rust/sbd-client/src/test.rs new file mode 100644 index 0000000..08e359c --- /dev/null +++ b/rust/sbd-client/src/test.rs @@ -0,0 +1,26 @@ +use crate::*; + +#[tokio::test] +async fn drop_sender() { + let config = Arc::new(sbd_server::Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + ..Default::default() + }); + + let server = sbd_server::SbdServer::new(config).await.unwrap(); + + let (s, mut r) = SbdClient::connect_config( + &format!("ws://{}", server.bind_addrs().get(0).unwrap()), + &DefaultCrypto::default(), + SbdClientConfig { + allow_plain_text: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + drop(s); + + assert!(r.recv().await.is_none()); +} diff --git a/rust/sbd-client/tests/reasonable-rate-limit.rs b/rust/sbd-client/tests/reasonable-rate-limit.rs index 4a793fc..9d3e46a 100644 --- a/rust/sbd-client/tests/reasonable-rate-limit.rs +++ b/rust/sbd-client/tests/reasonable-rate-limit.rs @@ -2,9 +2,7 @@ use sbd_client::*; use sbd_server::*; use std::sync::Arc; -async fn get_client( - addrs: &[std::net::SocketAddr], -) -> (SbdClient, String, sbd_client::PubKey, MsgRecv) { +async fn get_client(addrs: &[std::net::SocketAddr]) -> (SbdClient, MsgRecv) { for addr in addrs { if let Ok(r) = SbdClient::connect_config( &format!("ws://{addr}"), @@ -34,8 +32,10 @@ async fn reasonable_rate_limit() { let server = SbdServer::new(config).await.unwrap(); - let (mut c1, _, p1, mut r1) = get_client(server.bind_addrs()).await; - let (mut c2, _, p2, mut r2) = get_client(server.bind_addrs()).await; + let (mut c1, mut r1) = get_client(server.bind_addrs()).await; + let p1 = c1.pub_key().clone(); + let (mut c2, mut r2) = get_client(server.bind_addrs()).await; + let p2 = c2.pub_key().clone(); //warmup run(2, &mut c1, &p1, &mut r1, &mut c2, &p2, &mut r2).await; diff --git a/rust/sbd-o-bahn-server-tester/src/it.rs b/rust/sbd-o-bahn-server-tester/src/it.rs index 8f825d7..e618ff6 100644 --- a/rust/sbd-o-bahn-server-tester/src/it.rs +++ b/rust/sbd-o-bahn-server-tester/src/it.rs @@ -45,12 +45,7 @@ impl<'h> TestHelper<'h> { /// connect a client pub async fn connect_client( &self, - ) -> Result<( - sbd_client::SbdClient, - String, - sbd_client::PubKey, - sbd_client::MsgRecv, - )> { + ) -> Result<(sbd_client::SbdClient, sbd_client::MsgRecv)> { for addr in self.addr_list.iter() { if let Ok(client) = sbd_client::SbdClient::connect_config( &format!("ws://{addr}"), diff --git a/rust/sbd-o-bahn-server-tester/src/it/it_1.rs b/rust/sbd-o-bahn-server-tester/src/it/it_1.rs index 666d8aa..ad6937e 100644 --- a/rust/sbd-o-bahn-server-tester/src/it/it_1.rs +++ b/rust/sbd-o-bahn-server-tester/src/it/it_1.rs @@ -8,12 +8,15 @@ impl It for It1 { fn exec(helper: &mut TestHelper) -> impl Future> { async { - let ((c1, _u1, p1, mut r1), (c2, _u2, p2, mut r2)) = tokio::try_join!( + let ((c1, mut r1), (c2, mut r2)) = tokio::try_join!( helper.connect_client(), helper.connect_client(), )?; - tokio::try_join!(c1.send(&p2, b"hello"), c2.send(&p1, b"world"),)?; + let p1 = c1.pub_key().clone(); + let p2 = c2.pub_key().clone(); + + tokio::try_join!(c1.send(&p2, b"hello"), c2.send(&p1, b"world"))?; let (result1, result2) = tokio::try_join!( async { r1.recv().await.ok_or(Error::other("closed")) }, diff --git a/rust/sbd-server/src/lib.rs b/rust/sbd-server/src/lib.rs index 0d740ae..55d74af 100644 --- a/rust/sbd-server/src/lib.rs +++ b/rust/sbd-server/src/lib.rs @@ -317,48 +317,42 @@ mod tests { println!("addr: {addr:?}"); - let (client1, url1, pk1, mut rcv1) = - sbd_client::SbdClient::connect_config( - &format!("wss://{addr}"), - &sbd_client::DefaultCrypto::default(), - sbd_client::SbdClientConfig { - allow_plain_text: true, - danger_disable_certificate_check: true, - ..Default::default() - }, - ) - .await - .unwrap(); - - println!("client url1: {url1}"); - - let (client2, url2, pk2, mut rcv2) = - sbd_client::SbdClient::connect_config( - &format!("wss://{addr}"), - &sbd_client::DefaultCrypto::default(), - sbd_client::SbdClientConfig { - allow_plain_text: true, - danger_disable_certificate_check: true, - ..Default::default() - }, - ) - .await - .unwrap(); - - println!("client url2: {url2}"); - - client1.send(&pk2, b"hello").await.unwrap(); + let (client1, mut rcv1) = sbd_client::SbdClient::connect_config( + &format!("wss://{addr}"), + &sbd_client::DefaultCrypto::default(), + sbd_client::SbdClientConfig { + allow_plain_text: true, + danger_disable_certificate_check: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + let (client2, mut rcv2) = sbd_client::SbdClient::connect_config( + &format!("wss://{addr}"), + &sbd_client::DefaultCrypto::default(), + sbd_client::SbdClientConfig { + allow_plain_text: true, + danger_disable_certificate_check: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + client1.send(client2.pub_key(), b"hello").await.unwrap(); let res_data = rcv2.recv().await.unwrap(); - assert_eq!(&pk1[..], res_data.pub_key_ref()); + assert_eq!(&client1.pub_key()[..], res_data.pub_key_ref()); assert_eq!(&b"hello"[..], res_data.message()); - client2.send(&pk1, b"world").await.unwrap(); + client2.send(client1.pub_key(), b"world").await.unwrap(); let res_data = rcv1.recv().await.unwrap(); - assert_eq!(&pk2[..], res_data.pub_key_ref()); + assert_eq!(&client2.pub_key()[..], res_data.pub_key_ref()); assert_eq!(&b"world"[..], res_data.message()); } } From 93e0888eb3806f73697e9630fd07ac8465853fc7 Mon Sep 17 00:00:00 2001 From: David Braden Date: Tue, 7 May 2024 13:11:00 -0600 Subject: [PATCH 32/33] Apply suggestions from code review Co-authored-by: Stefan Junker <1181362+steveej@users.noreply.github.com> --- rust/sbd-server/src/ip_deny.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/sbd-server/src/ip_deny.rs b/rust/sbd-server/src/ip_deny.rs index 7763216..1905bdd 100644 --- a/rust/sbd-server/src/ip_deny.rs +++ b/rust/sbd-server/src/ip_deny.rs @@ -12,12 +12,12 @@ impl IpDeny { /// Check if a given ip is blocked. pub async fn is_blocked(&self, _ip: &Arc) -> bool { - // THIS IS A STUB!! + // TODO: THIS IS A STUB!! false } /// Block a given ip. pub async fn block(&self, _ip: &Arc) { - // THIS IS A STUB!! + // TODO: THIS IS A STUB!! } } From 4dc47e9e3f13615b7b48e1b1c0cd670ba6065142 Mon Sep 17 00:00:00 2001 From: neonphog Date: Tue, 7 May 2024 13:17:40 -0600 Subject: [PATCH 33/33] address code review comments --- rust/sbd-client/src/lib.rs | 6 +++--- rust/sbd-server/src/cmd.rs | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index c453f75..32711fe 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -19,6 +19,9 @@ const HDR_SIZE: usize = PK_SIZE; /// defined by sbd spec const NONCE_SIZE: usize = 32; +/// defined by sbd spec +const CMD_PREFIX: &[u8; 28] = &[0; 28]; + const F_LIMIT_BYTE_NANOS: &[u8] = b"lbrt"; const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl"; const F_AUTH_REQ: &[u8] = b"areq"; @@ -98,9 +101,6 @@ impl std::fmt::Debug for PubKey { } } -/// defined by sbd spec -const CMD_PREFIX: &[u8; 28] = &[0; 28]; - enum MsgType<'t> { Msg { #[allow(dead_code)] diff --git a/rust/sbd-server/src/cmd.rs b/rust/sbd-server/src/cmd.rs index 6108f66..51dbfcb 100644 --- a/rust/sbd-server/src/cmd.rs +++ b/rust/sbd-server/src/cmd.rs @@ -6,6 +6,21 @@ const F_LIMIT_IDLE_MILLIS: &[u8] = b"lidl"; const F_AUTH_REQ: &[u8] = b"areq"; const F_AUTH_RES: &[u8] = b"ares"; +/// defined by ed25519 spec +const PK_SIZE: usize = 32; + +/// defined by ed25519 spec +const SIG_SIZE: usize = 64; + +/// sbd spec defines headers to be the same size as ed25519 pub keys +const HDR_SIZE: usize = PK_SIZE; + +/// defined by sbd spec +const NONCE_SIZE: usize = 32; + +/// defined by sbd spec +pub(crate) const CMD_PREFIX: &[u8; 28] = &[0; 28]; + /// Sbd commands. /// Enum variants represent only the types that clients can send to the server: /// - not-cmd Message(payload) @@ -21,15 +36,13 @@ const F_AUTH_RES: &[u8] = b"ares"; pub enum SbdCmd<'c> { Message(Payload<'c>), Keepalive, - AuthRes([u8; 64]), + AuthRes([u8; SIG_SIZE]), Unknown, } -pub(crate) const CMD_PREFIX: &[u8; 28] = &[0; 28]; - impl<'c> SbdCmd<'c> { pub fn parse(payload: Payload<'c>) -> Result { - if payload.as_ref().len() < 32 { + if payload.as_ref().len() < HDR_SIZE { return Err(Error::other("invalid payload length")); } if &payload.as_ref()[..28] == CMD_PREFIX { @@ -38,11 +51,11 @@ impl<'c> SbdCmd<'c> { match &payload.as_ref()[28..32] { F_KEEPALIVE => Ok(SbdCmd::Keepalive), F_AUTH_RES => { - if payload.as_ref().len() != 32 + 64 { + if payload.as_ref().len() != HDR_SIZE + SIG_SIZE { return Err(Error::other("invalid auth res length")); } - let mut sig = [0; 64]; - sig.copy_from_slice(&payload.as_ref()[32..]); + let mut sig = [0; SIG_SIZE]; + sig.copy_from_slice(&payload.as_ref()[HDR_SIZE..]); Ok(SbdCmd::AuthRes(sig)) } _ => Ok(SbdCmd::Unknown), @@ -55,7 +68,7 @@ impl<'c> SbdCmd<'c> { impl SbdCmd<'_> { pub fn limit_byte_nanos(limit_byte_nanos: i32) -> Payload<'static> { - let mut out = Vec::with_capacity(32 + 4); + let mut out = Vec::with_capacity(HDR_SIZE + 4); let n = limit_byte_nanos.to_be_bytes(); out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_LIMIT_BYTE_NANOS); @@ -64,7 +77,7 @@ impl SbdCmd<'_> { } pub fn limit_idle_millis(limit_idle_millis: i32) -> Payload<'static> { - let mut out = Vec::with_capacity(32 + 4); + let mut out = Vec::with_capacity(HDR_SIZE + 4); let n = limit_idle_millis.to_be_bytes(); out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_LIMIT_IDLE_MILLIS); @@ -73,7 +86,7 @@ impl SbdCmd<'_> { } pub fn auth_req(nonce: &[u8; 32]) -> Payload<'static> { - let mut out = Vec::with_capacity(32 + 32); + let mut out = Vec::with_capacity(HDR_SIZE + NONCE_SIZE); out.extend_from_slice(CMD_PREFIX); out.extend_from_slice(F_AUTH_REQ); out.extend_from_slice(&nonce[..]);