diff --git a/.gitignore b/.gitignore index cc2dd72c..e452ae3f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ Cargo.lock .DS_Store /*.yaml /generated +*.rd.yml diff --git a/examples/tproxy/setup.sh b/examples/tproxy/setup.sh index 74086406..e989010e 100755 --- a/examples/tproxy/setup.sh +++ b/examples/tproxy/setup.sh @@ -12,6 +12,7 @@ RD_DISABLE_IPV6="${RD_DISABLE_IPV6:=0}" RD_ENABLE_SELF="${RD_ENABLE_SELF:=0}" RD_EXCLUDE_IP="$RD_EXCLUDE_IP" RD_EXCLUDE_MAC="$RD_EXCLUDE_MAC" +RD_HIJACK_DNS="${RD_HIJACK_DNS:=1}" if [ "$(id -u)" != "0" ]; then echo "This script must be run as root" 1>&2 @@ -27,7 +28,7 @@ if [ "$RD_DISABLE_IPV6" != "1" ]; then ip -6 rule add fwmark $RD_FW_MARK table $RD_TABLE ip6tables -t mangle -N RD_OUTPUT - if [ "$RD_ENABLE_SELF" == "1" ]; then + if [ "$RD_ENABLE_SELF" = "1" ]; then ip6tables -t mangle -A RD_OUTPUT -d ::1/128 -j RETURN ip6tables -t mangle -A RD_OUTPUT -d fc00::/7 -j RETURN ip6tables -t mangle -A RD_OUTPUT -d fe80::/10 -j RETURN @@ -49,7 +50,9 @@ if [ "$RD_DISABLE_IPV6" != "1" ]; then ip6tables -t mangle -A RD_PREROUTING -m mac --mac-source $i -j RETURN 2>/dev/null || true done ip6tables -t mangle -A RD_PREROUTING -m mark --mark $RD_MARK -j RETURN - ip6tables -t mangle -A RD_PREROUTING -p udp --dport 53 -j TPROXY --on-port $RD_PORT6 --tproxy-mark $RD_FW_MARK + if [ "$RD_HIJACK_DNS" = "1" ]; then + ip6tables -t mangle -A RD_PREROUTING -p udp --dport 53 -j TPROXY --on-port $RD_PORT6 --tproxy-mark $RD_FW_MARK + fi ip6tables -t mangle -A RD_PREROUTING -d ::1/128 -j RETURN ip6tables -t mangle -A RD_PREROUTING -d fc00::/7 -j RETURN ip6tables -t mangle -A RD_PREROUTING -d fe80::/10 -j RETURN @@ -61,7 +64,7 @@ if [ "$RD_DISABLE_IPV6" != "1" ]; then fi iptables -t mangle -N RD_OUTPUT -if [ "$RD_ENABLE_SELF" == "1" ]; then +if [ "$RD_ENABLE_SELF" = "1" ]; then iptables -t mangle -A RD_OUTPUT -d 0/8 -j RETURN iptables -t mangle -A RD_OUTPUT -d 127/8 -j RETURN iptables -t mangle -A RD_OUTPUT -d 10/8 -j RETURN @@ -88,7 +91,9 @@ for i in $(echo $RD_EXCLUDE_MAC | tr "," "\n"); do iptables -t mangle -A RD_PREROUTING -m mac --mac-source $i -j RETURN 2>/dev/null || true done iptables -t mangle -A RD_PREROUTING -m mark --mark $RD_MARK -j RETURN -iptables -t mangle -A RD_PREROUTING -p udp --dport 53 -j TPROXY --on-port $RD_PORT --tproxy-mark $RD_FW_MARK +if [ "$RD_HIJACK_DNS" = "1" ]; then + iptables -t mangle -A RD_PREROUTING -p udp --dport 53 -j TPROXY --on-port $RD_PORT --tproxy-mark $RD_FW_MARK +fi iptables -t mangle -A RD_PREROUTING -d 0/8 -j RETURN iptables -t mangle -A RD_PREROUTING -d 127/8 -j RETURN iptables -t mangle -A RD_PREROUTING -d 10/8 -j RETURN diff --git a/protocol/raw/src/net.rs b/protocol/raw/src/net.rs index f529f55d..d7cc2b99 100644 --- a/protocol/raw/src/net.rs +++ b/protocol/raw/src/net.rs @@ -104,7 +104,7 @@ impl INet for RawNet { _ctx: &mut Context, addr: &Address, ) -> Result { - let udp = UdpSocketWrap(self.net.udp_bind(addr.to_socket_addr()?).await?); + let udp = UdpSocketWrap::new(self.net.udp_bind(addr.to_socket_addr()?).await?); Ok(udp.into_dyn()) } } diff --git a/protocol/raw/src/server.rs b/protocol/raw/src/server.rs index 5d3a010b..73ed2783 100644 --- a/protocol/raw/src/server.rs +++ b/protocol/raw/src/server.rs @@ -1,16 +1,24 @@ use std::{ + io, net::{SocketAddr, SocketAddrV4}, + pin::Pin, str::FromStr, - time::Duration, + task, }; -use futures::{future::ready, StreamExt}; -use lru_time_cache::LruCache; +use crate::{ + device, + gateway::{GatewayInterface, MapTable}, +}; +use futures::{ready, Sink, Stream, StreamExt}; use rd_interface::{ - async_trait, constant::UDP_BUFFER_SIZE, error::map_other, prelude::*, registry::ServerFactory, - Address, Context, Error, IServer, IntoAddress, Net, Result, + async_trait, error::map_other, prelude::*, registry::ServerFactory, Bytes, Context, Error, + IServer, IntoAddress, Net, Result, +}; +use rd_std::util::{ + connect_tcp, + forward_udp::{self, RawUdpSource}, }; -use rd_std::util::connect_tcp; use smoltcp::{ phy::{Checksum, ChecksumCapabilities, Medium}, wire::{ @@ -18,21 +26,12 @@ use smoltcp::{ Ipv4Address, Ipv4Packet, Ipv4Repr, UdpPacket, UdpRepr, }, }; -use tokio::{ - select, spawn, - sync::mpsc::{unbounded_channel, UnboundedSender as Sender}, - time::timeout, -}; +use tokio::{select, spawn}; use tokio_smoltcp::{ device::{FutureDevice, Interface, Packet}, BufferSize, NetConfig, RawSocket, TcpListener, }; -use crate::{ - device, - gateway::{GatewayInterface, MapTable}, -}; - #[rd_config] #[derive(Clone, Copy)] pub enum Layer { @@ -145,7 +144,9 @@ impl RawServer { }; let device = GatewayInterface::new( - dev.filter(move |p: &Packet| ready(filter_packet(p, ethernet_addr, ip_addr, layer))), + dev.filter(move |p: &Packet| { + std::future::ready(filter_packet(p, ethernet_addr, ip_addr, layer)) + }), lru_size, SocketAddrV4::new(addr.into(), 20000), layer, @@ -216,61 +217,107 @@ impl RawServer { } } async fn serve_udp(&self, raw: RawSocket) -> Result<()> { - let (send_raw, mut send_rx) = unbounded_channel::<(SocketAddr, SocketAddr, Vec)>(); + let source = Source::new(raw); - let mut buf = [0u8; UDP_BUFFER_SIZE]; - let mut nat = LruCache::::with_expiry_duration_and_capacity( - Duration::from_secs(30), - 128, - ); - let net = self.net.clone(); - - let recv = async { - loop { - let size = raw.recv(&mut buf).await?; - let (src, dst, payload) = match parse_udp(&buf[..size]) { - Ok(v) => v, - _ => break, - }; - - let udp = nat - .entry(src) - .or_insert_with(|| UdpTunnel::new(net.clone(), src, send_raw.clone())); - if let Err(e) = udp.send_to(payload, dst).await { - tracing::error!("Udp send_to {:?}", e); - nat.remove(&src); - } - } + forward_udp::forward_udp(source, self.net.clone()).await?; - Ok(()) as Result<()> - }; + Ok(()) + } +} - let send = async { - while let Some((src, dst, payload)) = send_rx.recv().await { - if let Some(ip_packet) = pack_udp(src, dst, &payload) { - if let Err(e) = raw.send(&ip_packet).await { - tracing::error!( - "Raw send error: {:?}, dropping udp size: {}", - e, - ip_packet.len() - ); - } - } else { - tracing::debug!("Unsupported src/dst"); - } - } - Ok(()) as Result<()> +struct Source { + raw: RawSocket, + recv_buf: Box<[u8]>, + send_buf: Option>, +} + +impl Source { + pub fn new(raw: RawSocket) -> Source { + Source { + raw, + recv_buf: Box::new([0u8; 65536]), + send_buf: None, + } + } +} + +impl Stream for Source { + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let Source { raw, recv_buf, .. } = &mut *self; + + let (from, to, data) = loop { + let size = ready!(raw.poll_recv(cx, recv_buf))?; + + match parse_udp(&recv_buf[..size]) { + Ok(v) => break v, + _ => {} + }; }; - select! { - r = send => r?, - r = recv => r?, + let data = Bytes::copy_from_slice(data); + + Some(Ok(forward_udp::UdpPacket { from, to, data })).into() + } +} + +impl Sink for Source { + type Error = io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + if self.send_buf.is_some() { + return self.poll_flush(cx); } + Ok(()).into() + } + + fn start_send( + mut self: Pin<&mut Self>, + forward_udp::UdpPacket { from, to, data }: forward_udp::UdpPacket, + ) -> Result<(), Self::Error> { + if let Some(ip_packet) = pack_udp(from, to, &data) { + self.send_buf = Some(ip_packet); + } else { + tracing::debug!("Unsupported src/dst"); + } Ok(()) } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let Source { raw, send_buf, .. } = &mut *self; + + match send_buf { + Some(buf) => { + ready!(raw.poll_send(cx, buf))?; + *send_buf = None; + } + None => {} + } + + Ok(()).into() + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.poll_flush(cx) + } } +impl RawUdpSource for Source {} + /// buf is a ip packet fn parse_udp(buf: &[u8]) -> smoltcp::Result<(SocketAddr, SocketAddr, &[u8])> { let ipv4 = Ipv4Packet::new_checked(buf)?; @@ -354,62 +401,3 @@ impl ServerFactory for RawServer { RawServer::new(net, config) } } - -struct UdpTunnel { - tx: Sender<(SocketAddr, Vec)>, -} - -impl UdpTunnel { - fn new( - net: Net, - src: SocketAddr, - send_raw: Sender<(SocketAddr, SocketAddr, Vec)>, - ) -> UdpTunnel { - let (tx, mut rx) = unbounded_channel::<(SocketAddr, Vec)>(); - tokio::spawn(async move { - let udp = timeout( - Duration::from_secs(5), - net.udp_bind( - &mut Context::from_socketaddr(src), - &Address::any_addr_port(&src), - ), - ) - .await - .map_err(map_other)??; - - let send = async { - while let Some((addr, packet)) = rx.recv().await { - udp.send_to(&packet, addr.into()).await?; - } - Ok(()) - }; - let recv = async { - let mut buf = [0u8; UDP_BUFFER_SIZE]; - loop { - let (size, addr) = udp.recv_from(&mut buf).await?; - - if send_raw.send((addr, src, buf[..size].to_vec())).is_err() { - break; - } - } - tracing::trace!("send_raw return error"); - Ok(()) - }; - - let r: Result<()> = select! { - r = send => r, - r = recv => r, - }; - - r - }); - UdpTunnel { tx } - } - /// return false if the send queue is full - async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> Result<()> { - match self.tx.send((addr, buf.to_vec())) { - Ok(_) => Ok(()), - Err(_) => Err(Error::Other("Other side closed".into())), - } - } -} diff --git a/protocol/raw/src/wrap.rs b/protocol/raw/src/wrap.rs index fea1caca..0dc71b2d 100644 --- a/protocol/raw/src/wrap.rs +++ b/protocol/raw/src/wrap.rs @@ -1,8 +1,9 @@ -use std::net::SocketAddr; +use std::{io, net::SocketAddr, pin::Pin, task}; +use futures::{ready, Sink, Stream}; use rd_interface::{ - async_trait, impl_async_read_write, Address, ITcpListener, ITcpStream, IUdpSocket, IntoDyn, - Result, + async_trait, constant::UDP_BUFFER_SIZE, impl_async_read_write, Bytes, BytesMut, ITcpListener, + ITcpStream, IUdpSocket, IntoDyn, Result, }; use tokio::sync::Mutex; use tokio_smoltcp::{TcpListener, TcpSocket, UdpSocket}; @@ -35,19 +36,93 @@ impl ITcpListener for TcpListenerWrap { } } -pub struct UdpSocketWrap(pub(crate) UdpSocket); +pub struct UdpSocketWrap { + inner: UdpSocket, + recv_buf: Box<[u8]>, + send_buf: Option<(Bytes, SocketAddr)>, +} -#[async_trait] -impl IUdpSocket for UdpSocketWrap { - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - Ok(self.0.recv_from(buf).await?) +impl UdpSocketWrap { + pub(crate) fn new(inner: UdpSocket) -> Self { + Self { + inner, + recv_buf: vec![0; UDP_BUFFER_SIZE].into_boxed_slice(), + send_buf: None, + } + } +} + +impl Stream for UdpSocketWrap { + type Item = io::Result<(BytesMut, SocketAddr)>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let UdpSocketWrap { + inner, + recv_buf: buf, + .. + } = &mut *self; + let (size, from) = ready!(inner.poll_recv_from(cx, buf))?; + let buf = BytesMut::from(&buf[..size]); + Some(Ok((buf, from))).into() + } +} + +impl Sink<(Bytes, SocketAddr)> for UdpSocketWrap { + type Error = io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + if self.send_buf.is_some() { + ready!(self.poll_flush(cx))?; + } + Ok(()).into() + } + + fn start_send(mut self: Pin<&mut Self>, item: (Bytes, SocketAddr)) -> Result<(), Self::Error> { + self.send_buf = Some(item); + + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let UdpSocketWrap { + inner, + send_buf: buf, + .. + } = &mut *self; + if let Some((buf, to)) = buf { + let size = ready!(inner.poll_send_to(cx, buf, *to))?; + if size != buf.len() { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to send all bytes", + )) + .into(); + } + } + Ok(()).into() } - async fn send_to(&self, buf: &[u8], addr: Address) -> Result { - Ok(self.0.send_to(buf, addr.to_socket_addr()?).await?) + fn poll_close( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + ready!(self.poll_flush(cx))?; + Ok(()).into() } +} +#[async_trait] +impl IUdpSocket for UdpSocketWrap { async fn local_addr(&self) -> Result { - Ok(self.0.local_addr()?) + Ok(self.inner.local_addr()?) } } diff --git a/protocol/ss/Cargo.toml b/protocol/ss/Cargo.toml index c3049a00..0f8a767b 100644 --- a/protocol/ss/Cargo.toml +++ b/protocol/ss/Cargo.toml @@ -8,11 +8,15 @@ edition = "2018" rd-interface = { path = "../../rabbit-digger/rd-interface/", version = "0.4" } rd-std = { path = "../../rabbit-digger/rd-std/", version = "0.1" } # rd-interface = "0.3" -shadowsocks = { version = "1.10.3", default-features = false, features = [ "stream-cipher", "aead-cipher-extra" ] } +shadowsocks = { version = "1.12.3", default-features = false, features = [ + "stream-cipher", + "aead-cipher-extra", +] } serde = "1.0" bytes = "1.0" tracing = "0.1.26" byte_string = "1.0" serde_json = "1.0" tokio = { version = "1.5.0", features = ["rt"] } -socks5-protocol = "0.3" +socks5-protocol = "0.3.5" +futures = "0.3" diff --git a/protocol/ss/src/client.rs b/protocol/ss/src/client.rs index 6fea6d0e..cba9a155 100644 --- a/protocol/ss/src/client.rs +++ b/protocol/ss/src/client.rs @@ -1,3 +1,5 @@ +use std::io; + use super::wrapper::{Cipher, WrapAddress, WrapSSTcp, WrapSSUdp}; use rd_interface::{ async_trait, prelude::*, registry::NetRef, Address, INet, IntoDyn, Net, Result, TcpStream, @@ -8,7 +10,6 @@ use shadowsocks::{ context::{Context, SharedContext}, ProxyClientStream, }; -use tokio::sync::OnceCell; #[rd_config] #[derive(Debug, Clone)] @@ -25,7 +26,7 @@ pub struct SSNetConfig { } pub struct SSNet { - context: OnceCell, + context: SharedContext, cfg: ServerConfig, addr: Address, udp: bool, @@ -35,7 +36,7 @@ pub struct SSNet { impl SSNet { pub fn new(config: SSNetConfig) -> SSNet { SSNet { - context: OnceCell::new(), + context: Context::new_shared(ServerType::Local), addr: config.server.clone(), cfg: ServerConfig::new( (config.server.host(), config.server.port()), @@ -46,12 +47,6 @@ impl SSNet { net: (*config.net).clone(), } } - pub async fn context(&self) -> SharedContext { - self.context - .get_or_init(|| async { Context::new_shared(ServerType::Local) }) - .await - .clone() - } } #[async_trait] @@ -63,7 +58,7 @@ impl INet for SSNet { ) -> Result { let stream = self.net.tcp_connect(ctx, &self.addr).await?; let client = ProxyClientStream::from_stream( - self.context().await, + self.context.clone(), stream, &self.cfg, WrapAddress(addr.clone()), @@ -75,8 +70,18 @@ impl INet for SSNet { if !self.udp { return Err(NOT_ENABLED); } + let server_addr = self + .net + .lookup_host(&self.addr) + .await? + .into_iter() + .next() + .ok_or(io::Error::new( + io::ErrorKind::AddrNotAvailable, + "Failed to lookup domain", + ))?; let socket = self.net.udp_bind(ctx, &addr.to_any_addr_port()?).await?; - let udp = WrapSSUdp::new(self.context().await, socket, &self.cfg); + let udp = WrapSSUdp::new(socket, &self.cfg, server_addr); Ok(udp.into_dyn()) } } diff --git a/protocol/ss/src/udp.rs b/protocol/ss/src/udp.rs index a0bf9fca..6f7a4547 100644 --- a/protocol/ss/src/udp.rs +++ b/protocol/ss/src/udp.rs @@ -25,42 +25,45 @@ use std::io::{self, Cursor, ErrorKind}; use byte_string::ByteStr; use bytes::{BufMut, BytesMut}; -use shadowsocks::{ - context::Context, - crypto::v1::{random_iv_or_salt, Cipher, CipherCategory, CipherKind}, - relay::socks5::Address, -}; +use shadowsocks::crypto::v1::{random_iv_or_salt, Cipher, CipherCategory, CipherKind}; +use socks5_protocol::{sync::FromIO, Address}; + +#[must_use] +fn write_to_buf<'a>(addr: &Address, buf: &'a mut BytesMut) -> io::Result<()> { + let mut writer = buf.writer(); + addr.write_to(&mut writer).map_err(|e| e.to_io_err())?; + + Ok(()) +} /// Encrypt payload into ShadowSocks UDP encrypted packet pub fn encrypt_payload( - context: &Context, method: CipherKind, key: &[u8], addr: &Address, payload: &[u8], dst: &mut BytesMut, -) { - match method.category() { +) -> io::Result<()> { + Ok(match method.category() { CipherCategory::None => { - dst.reserve(addr.serialized_len() + payload.len()); - addr.write_to_buf(dst); + dst.reserve(addr.serialized_len().map_err(|e| e.to_io_err())? + payload.len()); + write_to_buf(addr, dst)?; dst.put_slice(payload); } - CipherCategory::Stream => encrypt_payload_stream(context, method, key, addr, payload, dst), - CipherCategory::Aead => encrypt_payload_aead(context, method, key, addr, payload, dst), - } + CipherCategory::Stream => encrypt_payload_stream(method, key, addr, payload, dst)?, + CipherCategory::Aead => encrypt_payload_aead(method, key, addr, payload, dst)?, + }) } fn encrypt_payload_stream( - context: &Context, method: CipherKind, key: &[u8], addr: &Address, payload: &[u8], dst: &mut BytesMut, -) { +) -> io::Result<()> { let iv_len = method.iv_len(); - let addr_len = addr.serialized_len(); + let addr_len = addr.serialized_len().map_err(|e| e.to_io_err())?; // Packet = IV + ADDRESS + PAYLOAD dst.reserve(iv_len + addr_len + payload.len()); @@ -70,36 +73,30 @@ fn encrypt_payload_stream( let iv = &mut dst[..iv_len]; if iv_len > 0 { - loop { - random_iv_or_salt(iv); - if !context.check_nonce_and_set(iv) { - break; - } - } + random_iv_or_salt(iv); tracing::trace!("UDP packet generated stream iv {:?}", ByteStr::new(iv)); - } else { - context.check_nonce_and_set(iv); } let mut cipher = Cipher::new(method, key, &iv); - addr.write_to_buf(dst); + write_to_buf(addr, dst)?; dst.put_slice(payload); let m = &mut dst[iv_len..]; cipher.encrypt_packet(m); + + Ok(()) } fn encrypt_payload_aead( - context: &Context, method: CipherKind, key: &[u8], addr: &Address, payload: &[u8], dst: &mut BytesMut, -) { +) -> io::Result<()> { let salt_len = method.salt_len(); - let addr_len = addr.serialized_len(); + let addr_len = addr.serialized_len().map_err(|e| e.to_io_err())?; // Packet = IV + ADDRESS + PAYLOAD + TAG dst.reserve(salt_len + addr_len + payload.len() + method.tag_len()); @@ -109,21 +106,14 @@ fn encrypt_payload_aead( let salt = &mut dst[..salt_len]; if salt_len > 0 { - loop { - random_iv_or_salt(salt); - if !context.check_nonce_and_set(salt) { - break; - } - } + random_iv_or_salt(salt); tracing::trace!("UDP packet generated aead salt {:?}", ByteStr::new(salt)); - } else { - context.check_nonce_and_set(salt); } let mut cipher = Cipher::new(method, key, salt); - addr.write_to_buf(dst); + write_to_buf(addr, dst)?; dst.put_slice(payload); unsafe { @@ -132,11 +122,12 @@ fn encrypt_payload_aead( let m = &mut dst[salt_len..]; cipher.encrypt_packet(m); + + Ok(()) } /// Decrypt payload from ShadowSocks UDP encrypted packet -pub async fn decrypt_payload( - context: &Context, +pub fn decrypt_payload( method: CipherKind, key: &[u8], payload: &mut [u8], @@ -144,7 +135,7 @@ pub async fn decrypt_payload( match method.category() { CipherCategory::None => { let mut cur = Cursor::new(payload); - match Address::read_from(&mut cur).await { + match Address::read_from(&mut cur) { Ok(address) => { let pos = cur.position() as usize; let payload = cur.into_inner(); @@ -158,13 +149,12 @@ pub async fn decrypt_payload( } } } - CipherCategory::Stream => decrypt_payload_stream(context, method, key, payload).await, - CipherCategory::Aead => decrypt_payload_aead(context, method, key, payload).await, + CipherCategory::Stream => decrypt_payload_stream(method, key, payload), + CipherCategory::Aead => decrypt_payload_aead(method, key, payload), } } -async fn decrypt_payload_stream( - context: &Context, +fn decrypt_payload_stream( method: CipherKind, key: &[u8], payload: &mut [u8], @@ -178,17 +168,13 @@ async fn decrypt_payload_stream( } let (iv, data) = payload.split_at_mut(iv_len); - if context.check_nonce_and_set(iv) { - tracing::debug!("detected repeated iv {:?}", ByteStr::new(iv)); - return Err(io::Error::new(io::ErrorKind::Other, "detected repeated iv")); - } tracing::trace!("UDP packet got stream IV {:?}", ByteStr::new(iv)); let mut cipher = Cipher::new(method, key, iv); assert!(cipher.decrypt_packet(data)); - let (dn, addr) = parse_packet(data).await?; + let (dn, addr) = parse_packet(data)?; let data_start_idx = iv_len + dn; let data_length = payload.len() - data_start_idx; @@ -197,8 +183,7 @@ async fn decrypt_payload_stream( Ok((data_length, addr)) } -async fn decrypt_payload_aead( - context: &Context, +fn decrypt_payload_aead( method: CipherKind, key: &[u8], payload: &mut [u8], @@ -211,13 +196,6 @@ async fn decrypt_payload_aead( } let (salt, data) = payload.split_at_mut(salt_len); - if context.check_nonce_and_set(salt) { - tracing::debug!("detected repeated salt {:?}", ByteStr::new(salt)); - return Err(io::Error::new( - io::ErrorKind::Other, - "detected repeated salt", - )); - } tracing::trace!("UDP packet got AEAD salt {:?}", ByteStr::new(salt)); @@ -239,7 +217,7 @@ async fn decrypt_payload_aead( let data_len = data.len() - tag_len; let data = &mut data[..data_len]; - let (dn, addr) = parse_packet(data).await?; + let (dn, addr) = parse_packet(data)?; let data_length = data_len - dn; let data_start_idx = salt_len + dn; @@ -250,9 +228,9 @@ async fn decrypt_payload_aead( Ok((data_length, addr)) } -async fn parse_packet(buf: &[u8]) -> io::Result<(usize, Address)> { +fn parse_packet(buf: &[u8]) -> io::Result<(usize, Address)> { let mut cur = Cursor::new(buf); - match Address::read_from(&mut cur).await { + match Address::read_from(&mut cur) { Ok(address) => { let pos = cur.position() as usize; Ok((pos, address)) diff --git a/protocol/ss/src/wrapper.rs b/protocol/ss/src/wrapper.rs index c45e1419..c9712550 100644 --- a/protocol/ss/src/wrapper.rs +++ b/protocol/ss/src/wrapper.rs @@ -1,15 +1,17 @@ use crate::udp::{decrypt_payload, encrypt_payload}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; +use futures::{ready, Sink, SinkExt, StreamExt}; use rd_interface::{ async_trait, impl_async_read_write, prelude::*, Address as RDAddress, AsyncRead, AsyncWrite, - ITcpStream, IUdpSocket, ReadBuf, TcpStream, UdpSocket, NOT_IMPLEMENTED, + ITcpStream, IUdpSocket, ReadBuf, Stream, TcpStream, UdpSocket, NOT_IMPLEMENTED, }; use shadowsocks::{ context::SharedContext, crypto::v1::CipherKind, relay::{socks5::Address as SSAddress, tcprelay::crypto_io}, - ProxyClientStream, ServerAddr, ServerConfig, + ProxyClientStream, ServerConfig, }; +use socks5_protocol::Address as S5Addr; use std::{io, net::SocketAddr, pin::Pin, str::FromStr, task}; pub struct WrapAddress(pub RDAddress); @@ -149,24 +151,18 @@ impl ITcpStream for WrapSSTcp { } pub struct WrapSSUdp { - context: SharedContext, socket: UdpSocket, method: CipherKind, key: Box<[u8]>, - server_addr: RDAddress, + server_addr: SocketAddr, } impl WrapSSUdp { - pub fn new(context: SharedContext, socket: UdpSocket, svr_cfg: &ServerConfig) -> Self { + pub fn new(socket: UdpSocket, svr_cfg: &ServerConfig, server_addr: SocketAddr) -> Self { let key = svr_cfg.key().to_vec().into_boxed_slice(); let method = svr_cfg.method(); - let server_addr = match svr_cfg.addr().clone() { - ServerAddr::DomainName(d, p) => RDAddress::Domain(d, p), - ServerAddr::SocketAddr(s) => RDAddress::SocketAddr(s), - }; WrapSSUdp { - context, socket, method, key, @@ -175,58 +171,73 @@ impl WrapSSUdp { } } -#[async_trait] -impl IUdpSocket for WrapSSUdp { - async fn local_addr(&self) -> rd_interface::Result { - Err(NOT_IMPLEMENTED) - } +impl Stream for WrapSSUdp { + type Item = io::Result<(BytesMut, SocketAddr)>; - async fn recv_from(&self, recv_buf: &mut [u8]) -> rd_interface::Result<(usize, SocketAddr)> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { // Waiting for response from server SERVER -> CLIENT - let (recv_n, _target_addr) = self.socket.recv_from(recv_buf).await?; - let (n, addr) = decrypt_payload( - &self.context, - self.method, - &self.key, - &mut recv_buf[..recv_n], - ) - .await?; + let (mut recv_buf, _target_addr) = match ready!(self.socket.poll_next_unpin(cx)) { + Some(r) => r?, + None => return task::Poll::Ready(None), + }; + let (n, addr) = decrypt_payload(self.method, &self.key, &mut recv_buf[..])?; - Ok(( - n, + Some(Ok(( + recv_buf.split_to(n), match addr { - SSAddress::DomainNameAddress(_, _) => unreachable!("Udp recv_from domain name"), - SSAddress::SocketAddress(s) => s, + S5Addr::Domain(_, _) => unreachable!("Udp recv_from domain name"), + S5Addr::SocketAddr(s) => s, }, - )) + ))) + .into() + } +} + +impl Sink<(Bytes, SocketAddr)> for WrapSSUdp { + type Error = io::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.socket.poll_ready_unpin(cx) } - async fn send_to(&self, payload: &[u8], target: RDAddress) -> rd_interface::Result { + fn start_send( + mut self: Pin<&mut Self>, + (payload, target): (Bytes, SocketAddr), + ) -> Result<(), Self::Error> { let mut send_buf = BytesMut::new(); - let addr: SSAddress = WrapAddress(target).into(); - encrypt_payload( - &self.context, - self.method, - &self.key, - &addr, - payload, - &mut send_buf, - ); + let addr: S5Addr = target.into(); + encrypt_payload(self.method, &self.key, &addr, &payload, &mut send_buf)?; - let send_len = self - .socket - .send_to(&send_buf, self.server_addr.clone()) - .await?; + let server_addr = self.server_addr; + self.socket + .start_send_unpin((send_buf.freeze(), server_addr)) + } - if send_buf.len() != send_len { - tracing::warn!( - "UDP server client send {} bytes, but actually sent {} bytes", - send_buf.len(), - send_len - ); - } + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.socket.poll_flush_unpin(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.socket.poll_close_unpin(cx) + } +} - Ok(send_len) +#[async_trait] +impl IUdpSocket for WrapSSUdp { + async fn local_addr(&self) -> rd_interface::Result { + Err(NOT_IMPLEMENTED) } } diff --git a/protocol/trojan/Cargo.toml b/protocol/trojan/Cargo.toml index 5974be04..769e188c 100644 --- a/protocol/trojan/Cargo.toml +++ b/protocol/trojan/Cargo.toml @@ -14,15 +14,23 @@ socks5-protocol = "0.3.4" futures = "0.3" tokio = "1.0" tokio-tungstenite = "0.15.0" +tokio-util = { version = "0.6.6", features = ["codec", "net"] } +bytes = "1.1.0" -tokio-rustls = { version = "0.22.0", features = ["dangerous_configuration"], optional = true } +tokio-rustls = { version = "0.22.0", features = [ + "dangerous_configuration", +], optional = true } webpki-roots = { version = "0.21.1", optional = true } -openssl-crate = { package = "openssl", version = "0.10", features = ["vendored"], optional = true } +openssl-crate = { package = "openssl", version = "0.10", features = [ + "vendored", +], optional = true } tokio-openssl = { version = "0.6.1", optional = true } tokio-native-tls = { version = "0.3.0", optional = true } -native-tls-crate = { package = "native-tls", version = "0.2", features = ["vendored"], optional = true } +native-tls-crate = { package = "native-tls", version = "0.2", features = [ + "vendored", +], optional = true } [features] default = ["native-tls"] diff --git a/protocol/trojan/src/client/udp.rs b/protocol/trojan/src/client/udp.rs index 6d2d07de..4e2bb93f 100644 --- a/protocol/trojan/src/client/udp.rs +++ b/protocol/trojan/src/client/udp.rs @@ -1,90 +1,112 @@ -use std::{ - io::{Cursor, Write}, - mem::take, - net::SocketAddr, - sync::RwLock, -}; +use std::{io, net::SocketAddr}; -use super::ra2sa; use crate::stream::IOStream; -use rd_interface::{async_trait, Address as RDAddress, IUdpSocket, Result, NOT_IMPLEMENTED}; -use socks5_protocol::{sync::FromIO, Address as S5Addr}; -use tokio::{ - io::{self, split, AsyncReadExt, ReadHalf, WriteHalf}, - sync::Mutex, +use bytes::{Buf, BufMut}; +use rd_interface::{ + async_trait, impl_stream_sink, Bytes, BytesMut, IUdpSocket, Result, NOT_IMPLEMENTED, }; +use socks5_protocol::{sync::FromIO, Address as S5Addr}; +use tokio_util::codec::{Decoder, Encoder, Framed}; -pub(super) struct TrojanUdp { - read: Mutex>>, - write: Mutex>>, - head: RwLock>, -} +// limited by 2-bytes header +const UDP_MAX_SIZE: usize = 65535; -impl TrojanUdp { - pub fn new(stream: Box, head: Vec) -> Self { - let (read, write) = split(stream); - Self { - read: Mutex::new(read), - write: Mutex::new(write), - head: RwLock::new(head), - } - } +struct UdpCodec { + head: Vec, } -#[async_trait] -impl IUdpSocket for TrojanUdp { - async fn local_addr(&self) -> rd_interface::Result { - Err(NOT_IMPLEMENTED) - } +impl Encoder<(Bytes, SocketAddr)> for UdpCodec { + type Error = io::Error; + + fn encode(&mut self, item: (Bytes, SocketAddr), dst: &mut BytesMut) -> Result<(), Self::Error> { + if item.0.len() > UDP_MAX_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Frame of length {} is too large.", item.0.len()), + )); + } - async fn recv_from(&self, recv_buf: &mut [u8]) -> rd_interface::Result<(usize, SocketAddr)> { - let mut read = self.read.lock().await; - let address = S5Addr::read(&mut *read) - .await - .map_err(|e| e.to_io_err())? - .to_socket_addr() - .map_err(|e| e.to_io_err())?; - let length = read.read_u16().await? as usize; - let _crlf = read.read_u16().await?; - - let to_read = length.min(recv_buf.len()); - let rest = length - to_read; - read.read_exact(&mut recv_buf[..to_read]).await?; - if rest > 0 { - read.read_exact(&mut vec![0u8; rest]).await?; + let addr = S5Addr::from(item.1); + dst.reserve( + self.head.len() + addr.serialized_len().map_err(|e| e.to_io_err())? + item.0.len(), + ); + if self.head.len() > 0 { + dst.extend_from_slice(&self.head); + self.head = Vec::new(); } + let mut writer = dst.writer(); - Ok((to_read, address)) + addr.write_to(&mut writer).map_err(|e| e.to_io_err())?; + let dst = writer.into_inner(); + + dst.put_u16(item.0.len() as u16); + dst.extend_from_slice(&[0x0D, 0x0A]); + dst.extend_from_slice(&item.0); + + Ok(()) } +} - async fn send_to(&self, payload: &[u8], target: RDAddress) -> Result { - if payload.len() > 65535 { - return Err(io::Error::from(io::ErrorKind::InvalidData).into()); +fn copy_2(b: &[u8]) -> [u8; 2] { + let mut buf = [0u8; 2]; + buf.copy_from_slice(&b); + buf +} + +impl Decoder for UdpCodec { + type Item = (BytesMut, SocketAddr); + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.len() < 2 { + return Ok(None); } - let addr = ra2sa(target); - let buf = if self.head.read().unwrap().len() > 0 { - let mut head = self.head.write().unwrap(); - if head.len() > 0 { - take(&mut *head) - } else { - Vec::new() - } - } else { - Vec::new() + let head = copy_2(&src[0..2]); + let addr_size = match head[0] { + 1 => 7, + 3 => 1 + head[1] as usize + 2, + 4 => 19, + _ => return Err(io::ErrorKind::InvalidData.into()), }; + if src.len() < addr_size + 4 { + return Ok(None); + } + let length = u16::from_be_bytes(copy_2(&src[addr_size..addr_size + 2])) as usize; + if src.len() < addr_size + 4 + length { + return Ok(None); + } - let pos = buf.len() as u64; - let mut writer = Cursor::new(buf); - writer.set_position(pos); + let mut reader = src.reader(); + let address = S5Addr::read_from(&mut reader).map_err(|e| e.to_io_err())?; + let src = reader.into_inner(); - addr.write_to(&mut writer).map_err(|e| e.to_io_err())?; - writer.write_all(&u16::to_be_bytes(payload.len() as u16))?; - writer.write_all(b"\r\n")?; - writer.write_all(payload)?; - let buf = writer.into_inner(); + // Length and CrLf + src.get_u16(); + src.get_u16(); + + Ok(Some(( + src.split_to(length as usize), + address.to_socket_addr().map_err(|e| e.to_io_err())?, + ))) + } +} + +pub(super) struct TrojanUdp { + framed: Framed, UdpCodec>, +} + +impl TrojanUdp { + pub fn new(stream: Box, head: Vec) -> Self { + let framed = Framed::new(stream, UdpCodec { head }); + Self { framed } + } +} - io::AsyncWriteExt::write_all(&mut *self.write.lock().await, &buf).await?; +impl_stream_sink!(TrojanUdp, framed); - Ok(payload.len()) +#[async_trait] +impl IUdpSocket for TrojanUdp { + async fn local_addr(&self) -> rd_interface::Result { + Err(NOT_IMPLEMENTED) } } diff --git a/rabbit-digger b/rabbit-digger index a043dc26..918af189 160000 --- a/rabbit-digger +++ b/rabbit-digger @@ -1 +1 @@ -Subproject commit a043dc26b5f8ac264cb0d523ad5551e99f103a35 +Subproject commit 918af189989c4ed557e2a63d3191ad350f0e7dbf diff --git a/src/main.rs b/src/main.rs index 8529b8ce..5d9338cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,7 +137,7 @@ async fn main(args: Args) -> Result<()> { if std::env::var_os("RUST_LOG").is_none() { std::env::set_var( "RUST_LOG", - "rabbit_digger=debug,rabbit_digger_pro=debug,rd_std=debug,raw=debug", + "rabbit_digger=debug,rabbit_digger_pro=debug,rd_std=debug,raw=debug,ss=debug", ) } let tr = tracing_subscriber::registry();