diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 232dac6a..047cc2b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,8 @@ jobs: test: strategy: matrix: - os: [ubuntu-22.04, macos-13] + # os: [ubuntu-22.04, macos-13] linux not supported yet + os: [macos-13] runs-on: ${{ matrix.os }} steps: diff --git a/Cargo.Bazel.lock b/Cargo.Bazel.lock index 0d39be91..e9fded9f 100644 --- a/Cargo.Bazel.lock +++ b/Cargo.Bazel.lock @@ -1,5 +1,5 @@ { - "checksum": "c276c2878d903a25eb90b45923963b54d7fb82e6a7001b774fcbda205a49198a", + "checksum": "4c7d42924b2314ce069523e5915170a57d8303ede8bae1fcc0f5cbcfcb256e1a", "crates": { "addr2line 0.20.0": { "name": "addr2line", @@ -14904,6 +14904,7 @@ "common": [ "codec", "default", + "io", "net", "tracing" ], diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index e29370f2..983a9f7c 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -8,7 +8,7 @@ default = ["shadowsocks"] [dependencies] tokio = { version = "1", features = ["full"] } -tokio-util = { version = "0.7", features = ["net", "codec"] } +tokio-util = { version = "0.7", features = ["net", "codec", "io"] } tokio-rustls = "0.23.4" thiserror = "1.0" async-trait = "0.1" diff --git a/clash_lib/src/app/dispatcher.rs b/clash_lib/src/app/dispatcher.rs index f0ed7c96..698ada57 100644 --- a/clash_lib/src/app/dispatcher.rs +++ b/clash_lib/src/app/dispatcher.rs @@ -69,13 +69,16 @@ impl Dispatcher { sess, up, down ); } - Err(err) => { - warn!("connection {} closed with error {}", sess, err); - lhs.shutdown() - .await - .map(|x| debug!("local shutdown: {:?}", x)) - .ok(); - } + Err(err) => match err.kind() { + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe => { + debug!("connection {} closed with error {}", sess, err); + } + _ => { + warn!("connection {} closed with error {}", sess, err); + } + }, } } Err(err) => { diff --git a/clash_lib/src/proxy/vmess/vmess_impl/aead.rs b/clash_lib/src/proxy/vmess/vmess_impl/aead.rs deleted file mode 100644 index c676346a..00000000 --- a/clash_lib/src/proxy/vmess/vmess_impl/aead.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::pin::Pin; - -use aes_gcm::Aes128Gcm; -use bytes::{BufMut, Bytes, BytesMut}; -use chacha20poly1305::ChaCha20Poly1305; -use futures::{pin_mut, ready, Future}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; - -use super::MAX_CHUNK_SIZE; - -use crate::common::crypto::AeadCipherHelper; - -use super::CHUNK_SIZE; - -pub enum VmessSecurity { - Aes128Gcm(Aes128Gcm), - ChaCha20Poly1305(ChaCha20Poly1305), -} - -impl VmessSecurity { - #[inline(always)] - pub fn overhead_len(&self) -> usize { - 16 - } - #[inline(always)] - pub fn nonce_len(&self) -> usize { - 12 - } -} - -pub(crate) struct AeadReader { - buf: BytesMut, - pos: usize, - security: VmessSecurity, - nonce: [u8; 32], - iv: Bytes, - count: u16, - size_holder: [u8; 2], -} - -impl AeadReader { - pub fn new(iv: &[u8], security: VmessSecurity) -> Self { - Self { - buf: BytesMut::new(), - pos: 0, - security, - nonce: [0u8; 32], - iv: Bytes::copy_from_slice(iv), - count: 0, - size_holder: [0; 2], - } - } - - pub fn poll_read( - self: std::pin::Pin<&mut Self>, - inner: &mut R, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> - where - R: AsyncRead + Unpin, - { - let Self { - buf: inner_buf, - pos, - size_holder, - count, - nonce, - iv, - security, - .. - } = self.get_mut(); - - if !inner_buf.is_empty() { - let n = std::cmp::min(buf.remaining(), inner_buf.len() - *pos); - buf.put_slice(&inner_buf[*pos..*pos + n]); - *pos += n; - - if *pos == inner_buf.len() { - inner_buf.clear(); - *pos = 0; - } - - return std::task::Poll::Ready(Ok(())); - } else { - assert!(*pos == 0, "chunk reader bad state"); - - let fut = inner.read_exact(&mut size_holder[..]); - pin_mut!(fut); - ready!(fut.poll(cx))?; - - let size = u16::from_be_bytes(*size_holder) as usize; - if size > MAX_CHUNK_SIZE { - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!( - "chunk size too large. max: {}, got: {}", - MAX_CHUNK_SIZE, size - ), - ))); - } - - inner_buf.resize(size, 0); - let fut = inner.read_exact(&mut inner_buf[..]); - pin_mut!(fut); - ready!(fut.poll(cx))?; - - nonce[..2].copy_from_slice(&count.to_be_bytes()); - nonce[2..12].copy_from_slice(&iv[2..12]); - *count += 1; - - let nonce = &nonce[..security.nonce_len()]; - match security { - VmessSecurity::Aes128Gcm(cipher) => { - let dec = - cipher.decrypt_in_place_with_slice(nonce.into(), &[], &mut inner_buf[..]); - if dec.is_err() { - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - dec.unwrap_err().to_string(), - ))); - } - } - VmessSecurity::ChaCha20Poly1305(cipher) => { - let dec = - cipher.decrypt_in_place_with_slice(nonce.into(), &[], &mut inner_buf[..]); - if dec.is_err() { - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - dec.unwrap_err().to_string(), - ))); - } - } - } - - let real_len = size - security.overhead_len(); - inner_buf.truncate(real_len); - - let n: usize = std::cmp::min(buf.remaining(), inner_buf.len()); - buf.put_slice(&inner_buf[..n]); - *pos += n; - - if *pos == inner_buf.len() { - inner_buf.clear(); - *pos = 0; - } - - return std::task::Poll::Ready(Ok(())); - } - } -} - -pub(crate) struct AeadWriter { - buf: BytesMut, - security: VmessSecurity, - nonce: [u8; 32], - iv: Bytes, - count: u16, -} - -impl AeadWriter { - pub fn new(iv: &[u8], security: VmessSecurity) -> Self { - Self { - buf: BytesMut::new(), - security, - nonce: [0u8; 32], - iv: Bytes::copy_from_slice(iv), - count: 0, - } - } - - pub fn poll_write( - self: Pin<&mut Self>, - inner: &mut W, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> - where - W: AsyncWrite + Unpin, - { - let mut remaining = buf.len(); - let mut sent = 0; - - let Self { - buf: inner_buf, - security, - nonce, - count, - iv, - .. - } = self.get_mut(); - - let mut pin = Pin::new(inner); - - while remaining > 0 { - let payload_size = std::cmp::min(remaining, CHUNK_SIZE - security.overhead_len()); - - inner_buf.reserve(2 + payload_size + security.overhead_len()); - inner_buf.put_u16((payload_size + security.overhead_len()) as u16); - inner_buf.put_slice(&buf[sent..sent + payload_size]); - inner_buf.extend_from_slice(vec![0u8; security.overhead_len()].as_ref()); - - nonce[..2].copy_from_slice(&count.to_be_bytes()); - nonce[2..12].copy_from_slice(&iv[2..12]); - - *count += 1; - - let nonce_len = security.nonce_len(); - match security { - VmessSecurity::Aes128Gcm(cipher) => { - cipher.encrypt_in_place_with_slice( - nonce[..nonce_len].into(), - &[], - &mut inner_buf[2..], - ); - } - VmessSecurity::ChaCha20Poly1305(cipher) => { - cipher.encrypt_in_place_with_slice( - nonce[..nonce_len].into(), - &[], - &mut inner_buf[2..], - ); - } - } - - ready!(pin - .as_mut() - .poll_write(cx, &inner_buf[..2 + payload_size + security.overhead_len()]))?; - inner_buf.clear(); - - sent += payload_size; - remaining -= payload_size; - } - - std::task::Poll::Ready(Ok(buf.len())) - } -} diff --git a/clash_lib/src/proxy/vmess/vmess_impl/cipher.rs b/clash_lib/src/proxy/vmess/vmess_impl/cipher.rs new file mode 100644 index 00000000..0f2776da --- /dev/null +++ b/clash_lib/src/proxy/vmess/vmess_impl/cipher.rs @@ -0,0 +1,97 @@ +use aes_gcm::Aes128Gcm; +use bytes::Bytes; +use chacha20poly1305::ChaCha20Poly1305; + +use crate::common::crypto::AeadCipherHelper; + +pub enum VmessSecurity { + Aes128Gcm(Aes128Gcm), + ChaCha20Poly1305(ChaCha20Poly1305), +} + +impl VmessSecurity { + #[inline(always)] + pub fn overhead_len(&self) -> usize { + 16 + } + #[inline(always)] + pub fn nonce_len(&self) -> usize { + 12 + } +} + +pub(crate) struct AeadCipher { + pub security: VmessSecurity, + nonce: [u8; 32], + iv: Bytes, + count: u16, +} + +impl AeadCipher { + pub fn new(iv: &[u8], security: VmessSecurity) -> Self { + Self { + security, + nonce: [0u8; 32], + iv: Bytes::copy_from_slice(iv), + count: 0, + } + } + + pub fn decrypt_inplace(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let mut nonce = self.nonce; + let security = &self.security; + let iv = &self.iv; + let count = &mut self.count; + + nonce[..2].copy_from_slice(&count.to_be_bytes()); + nonce[2..12].copy_from_slice(&iv[2..12]); + *count += 1; + + let nonce = &nonce[..security.nonce_len()]; + match security { + VmessSecurity::Aes128Gcm(cipher) => { + let dec = cipher.decrypt_in_place_with_slice(nonce.into(), &[], &mut buf[..]); + if dec.is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + dec.unwrap_err().to_string(), + )); + } + } + VmessSecurity::ChaCha20Poly1305(cipher) => { + let dec = cipher.decrypt_in_place_with_slice(nonce.into(), &[], &mut buf[..]); + if dec.is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + dec.unwrap_err().to_string(), + )); + } + } + } + + Ok(()) + } + + pub fn encrypt_inplace(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let mut nonce = self.nonce; + let security = &self.security; + let iv = &self.iv; + let count = &mut self.count; + + nonce[..2].copy_from_slice(&count.to_be_bytes()); + nonce[2..12].copy_from_slice(&iv[2..12]); + *count += 1; + + let nonce = &nonce[..security.nonce_len()]; + match security { + VmessSecurity::Aes128Gcm(cipher) => { + cipher.encrypt_in_place_with_slice(nonce.into(), &[], &mut buf[..]); + } + VmessSecurity::ChaCha20Poly1305(cipher) => { + cipher.encrypt_in_place_with_slice(nonce.into(), &[], &mut buf[..]); + } + } + + Ok(()) + } +} diff --git a/clash_lib/src/proxy/vmess/vmess_impl/mod.rs b/clash_lib/src/proxy/vmess/vmess_impl/mod.rs index a3d16812..38d7b508 100644 --- a/clash_lib/src/proxy/vmess/vmess_impl/mod.rs +++ b/clash_lib/src/proxy/vmess/vmess_impl/mod.rs @@ -1,12 +1,11 @@ -mod aead; mod chunk; +mod cipher; mod client; mod header; //pub mod http; mod datagram; mod kdf; mod stream; -mod tls; mod user; pub(crate) const VERSION: u8 = 1; diff --git a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs index 1e13298c..187f4726 100644 --- a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs +++ b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs @@ -1,10 +1,10 @@ -use std::{fmt::Debug, pin::Pin, task::Poll, time::SystemTime}; +use std::{fmt::Debug, mem::MaybeUninit, pin::Pin, task::Poll, time::SystemTime}; use aes_gcm::Aes128Gcm; use bytes::{BufMut, BytesMut}; use chacha20poly1305::ChaCha20Poly1305; use futures::{pin_mut, ready, Future}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tracing::debug; use crate::{ @@ -13,12 +13,12 @@ use crate::{ errors::map_io_error, utils, }, + proxy::vmess::vmess_impl::MAX_CHUNK_SIZE, session::SocksAddr, }; use super::{ - aead::{AeadReader, AeadWriter, VmessSecurity}, - chunk::{ChunkReader, ChunkWriter}, + cipher::{AeadCipher, VmessSecurity}, header, kdf::{ self, KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_IV, KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_KEY, @@ -29,22 +29,10 @@ use super::{ SECURITY_CHACHA20_POLY1305, SECURITY_NONE, VERSION, }; -pub(crate) enum VmessReader { - None(ChunkReader), - Aes128Gcm(AeadReader), - ChaCha20Poly1305(AeadReader), -} - -pub(crate) enum VmessWriter { - None(ChunkWriter), - Aes128Gcm(AeadWriter), - ChaCha20Poly1305(AeadWriter), -} - pub struct VmessStream { stream: S, - reader: VmessReader, - writer: VmessWriter, + aead_read_cipher: Option, + aead_write_cipher: Option, dst: SocksAddr, id: ID, req_body_iv: Vec, @@ -55,7 +43,13 @@ pub struct VmessStream { security: u8, is_aead: bool, is_udp: bool, - handshake_done: bool, + + read_state: ReadState, + read_pos: usize, + read_buf: BytesMut, + + write_state: WriteState, + write_buf: BytesMut, } impl Debug for VmessStream { @@ -68,9 +62,76 @@ impl Debug for VmessStream { } } +enum ReadState { + AeadWaitingHeaderSize, + AeadWaitingHeader(usize), + StreamWaitingLength, + StreamWaitingData(usize), + StreamFlushingData(usize), +} + +enum WriteState { + BuildingData, + FlushingData(usize, (usize, usize)), +} + +pub trait ReadExt { + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll>; + fn get_data(&self) -> &[u8]; +} + +impl ReadExt for VmessStream { + // Read exactly `size` bytes into `read_buf`, starting from position 0. + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll> { + self.read_buf.reserve(size); + unsafe { self.read_buf.set_len(size) } + debug!( + "poll read exact: {}, read_pos: {}, buf: {}", + size, + self.read_pos, + self.read_buf.len() + ); + loop { + if self.read_pos < size { + let dst = unsafe { + &mut *((&mut self.read_buf[self.read_pos..size]) as *mut _ + as *mut [MaybeUninit]) + }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?; + assert_eq!(ptr, buf.filled().as_ptr()); + if buf.filled().is_empty() { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected eof", + ))); + } + self.read_pos += buf.filled().len(); + } else { + assert!(self.read_pos == size); + self.read_pos = 0; + return Poll::Ready(Ok(())); + } + } + } + + fn get_data(&self) -> &[u8] { + self.read_buf.as_ref() + } +} + impl VmessStream where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + Sync, { pub(crate) async fn new( stream: S, @@ -98,22 +159,16 @@ where ) }; - let (reader, writer) = match security { - &SECURITY_NONE => ( - VmessReader::None(ChunkReader::new()), - VmessWriter::None(ChunkWriter::new()), - ), + let (aead_read_cipher, aead_write_cipher) = match security { + &SECURITY_NONE => (None, None), &SECURITY_AES_128_GCM => { let write_cipher = VmessSecurity::Aes128Gcm(Aes128Gcm::new_with_slice(&req_body_key)); - let writer = AeadWriter::new(&req_body_iv, write_cipher); + let write_cipher = AeadCipher::new(&req_body_iv, write_cipher); let reader_cipher = VmessSecurity::Aes128Gcm(Aes128Gcm::new_with_slice(&resp_body_key)); - let reader = AeadReader::new(&resp_body_iv, reader_cipher); - ( - VmessReader::Aes128Gcm(reader), - VmessWriter::Aes128Gcm(writer), - ) + let read_cipher = AeadCipher::new(&resp_body_iv, reader_cipher); + (Some(read_cipher), Some(write_cipher)) } &SECURITY_CHACHA20_POLY1305 => { let mut key = [0u8; 32]; @@ -123,7 +178,7 @@ where key[16..].copy_from_slice(&tmp); let write_cipher = VmessSecurity::ChaCha20Poly1305(ChaCha20Poly1305::new_with_slice(&key)); - let writer = AeadWriter::new(&req_body_iv, write_cipher); + let write_cipher = AeadCipher::new(&req_body_iv, write_cipher); let tmp = utils::md5(&req_body_key); key.copy_from_slice(&tmp); @@ -131,12 +186,9 @@ where key[16..].copy_from_slice(&tmp); let reader_cipher = VmessSecurity::ChaCha20Poly1305(ChaCha20Poly1305::new_with_slice(&key)); - let reader = AeadReader::new(&resp_body_iv, reader_cipher); + let read_cipher = AeadCipher::new(&resp_body_iv, reader_cipher); - ( - VmessReader::ChaCha20Poly1305(reader), - VmessWriter::ChaCha20Poly1305(writer), - ) + (Some(read_cipher), Some(write_cipher)) } _ => { return Err(std::io::Error::new( @@ -148,8 +200,8 @@ where let mut stream = Self { stream, - reader, - writer, + aead_read_cipher, + aead_write_cipher, dst: dst.to_owned(), id: id.to_owned(), req_body_iv, @@ -160,7 +212,13 @@ where security: *security, is_aead, is_udp, - handshake_done: false, + + read_state: ReadState::AeadWaitingHeaderSize, + read_pos: 0, + read_buf: BytesMut::new(), + + write_state: WriteState::BuildingData, + write_buf: BytesMut::new(), }; stream.send_handshake_request().await?; @@ -171,114 +229,7 @@ where impl VmessStream where - S: AsyncRead + Unpin, -{ - async fn recv_handshake_response(&mut self) -> std::io::Result<()> { - let Self { - ref mut stream, - ref is_aead, - ref resp_body_key, - ref resp_body_iv, - ref resp_v, - .. - } = self; - - debug!("recv handshake response"); - let mut buf = Vec::new(); - - if !is_aead { - buf.resize(4, 0); - stream.read_exact(buf.as_mut()).await?; - crypto::aes_cfb_decrypt(resp_body_key, resp_body_iv, &mut buf).map_err(map_io_error)?; - } else { - let aead_response_header_length_encryption_key = - &kdf::vmess_kdf_1_one_shot(resp_body_key, KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_KEY) - [..16]; - let aead_response_header_length_encryption_iv = - &kdf::vmess_kdf_1_one_shot(resp_body_iv, KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_IV) - [..12]; - - debug!("recv handshake response header length"); - let mut hdr_len_buf = [0u8; 18]; - stream.read_exact(&mut hdr_len_buf).await?; - debug!( - "recv handshake response header length: {:?}", - hdr_len_buf.as_slice() - ); - - let decrypted_response_header_len = crypto::aes_gcm_open( - aead_response_header_length_encryption_key, - aead_response_header_length_encryption_iv, - hdr_len_buf.as_slice(), - None, - ) - .map_err(map_io_error)?; - - if decrypted_response_header_len.len() < 2 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid response header length", - )) - .into(); - } - - debug!( - "recv handshake response header length: {:?}", - decrypted_response_header_len - ); - - let decrypted_header_len = - u16::from_be_bytes(decrypted_response_header_len[..2].try_into().unwrap()); - let aead_response_header_payload_encryption_key = &kdf::vmess_kdf_1_one_shot( - resp_body_key, - KDF_SALT_CONST_AEAD_RESP_HEADER_PAYLOAD_KEY, - )[..16]; - let aead_response_header_payload_encryption_iv = &kdf::vmess_kdf_1_one_shot( - resp_body_iv, - KDF_SALT_CONST_AEAD_RESP_HEADER_PAYLOAD_IV, - )[..12]; - - debug!("recv handshake response header"); - let mut hdr_buff = vec![0; decrypted_header_len as usize + 16]; - stream.read_exact(&mut hdr_buff).await?; - - buf = crypto::aes_gcm_open( - &aead_response_header_payload_encryption_key, - &aead_response_header_payload_encryption_iv, - hdr_buff.as_slice(), - None, - ) - .map_err(map_io_error)?; - - if buf.len() < 4 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid response", - )); - } - } - - if buf[0] != *resp_v { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid response", - )); - } - - if buf[2] != 0 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid response", - )); - } - - Ok(()) - } -} - -impl VmessStream -where - S: AsyncWrite + Unpin, + S: AsyncWrite + Unpin + Send + Sync, { async fn send_handshake_request(&mut self) -> std::io::Result<()> { let Self { @@ -375,60 +326,255 @@ where impl AsyncRead for VmessStream where - S: AsyncRead + Unpin, + S: AsyncRead + Unpin + Send + Sync, { fn poll_read( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { debug!("poll read with aead"); - let this = self.get_mut(); - - if !this.handshake_done { - debug!("doing handshake"); - let fut = this.recv_handshake_response(); - pin_mut!(fut); - ready!(fut.poll(cx))?; + loop { + match self.read_state { + ReadState::AeadWaitingHeaderSize => { + debug!("recv handshake response header"); + let this = &mut *self; + let resp_body_key = this.resp_body_key.clone(); // TODO: get rid of clone + let resp_body_iv = this.resp_body_iv.clone(); + let resp_v = this.resp_v; + + if !this.is_aead { + ready!(this.poll_read_exact(cx, 4))?; + let mut buf = this.read_buf.split().freeze().to_vec(); + crypto::aes_cfb_decrypt(&resp_body_key, &resp_body_iv, &mut buf) + .map_err(map_io_error)?; + if buf[0] != resp_v { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response", + ))); + } + + if buf[2] != 0 { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response", + ))); + } + + this.read_state = ReadState::StreamWaitingLength; + } else { + debug!("recv handshake response header length"); + ready!(this.poll_read_exact(cx, 18))?; + + let aead_response_header_length_encryption_key = &kdf::vmess_kdf_1_one_shot( + &resp_body_key, + KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_KEY, + )[..16]; + let aead_response_header_length_encryption_iv = &kdf::vmess_kdf_1_one_shot( + &resp_body_iv, + KDF_SALT_CONST_AEAD_RESP_HEADER_LEN_IV, + )[..12]; + + let decrypted_response_header_len = crypto::aes_gcm_open( + aead_response_header_length_encryption_key, + aead_response_header_length_encryption_iv, + this.read_buf.split().as_ref(), + None, + ) + .map_err(map_io_error)?; + + if decrypted_response_header_len.len() < 2 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response header length", + )) + .into(); + } + + this.read_state = ReadState::AeadWaitingHeader(u16::from_be_bytes( + decrypted_response_header_len[..2].try_into().unwrap(), + ) + as usize); + } + } + + ReadState::AeadWaitingHeader(header_size) => { + debug!("recv handshake header body: {}", header_size); + + let this = &mut *self; + ready!(this.poll_read_exact(cx, header_size + 16))?; + + let resp_body_key = this.resp_body_key.clone(); + let resp_body_iv = this.resp_body_iv.clone(); + + let aead_response_header_payload_encryption_key = &kdf::vmess_kdf_1_one_shot( + &resp_body_key, + KDF_SALT_CONST_AEAD_RESP_HEADER_PAYLOAD_KEY, + )[..16]; + let aead_response_header_payload_encryption_iv = &kdf::vmess_kdf_1_one_shot( + &resp_body_iv, + KDF_SALT_CONST_AEAD_RESP_HEADER_PAYLOAD_IV, + )[..12]; + + let buf = crypto::aes_gcm_open( + &aead_response_header_payload_encryption_key, + &aead_response_header_payload_encryption_iv, + this.read_buf.split().as_ref(), + None, + ) + .map_err(map_io_error)?; + + if buf.len() < 4 { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response - header too short", + ))); + } + + if buf[0] != this.resp_v { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response - version mismatch", + ))); + } + + if buf[2] != 0 { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response - dynamic port not supported", + ))); + } + + this.read_state = ReadState::StreamWaitingLength; + } + + ReadState::StreamWaitingLength => { + debug!("recv stream length"); + let this = &mut *self; + ready!(this.poll_read_exact(cx, 2))?; + let len = u16::from_be_bytes(this.read_buf.split().as_ref().try_into().unwrap()) + as usize; + + if len > MAX_CHUNK_SIZE { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid response - chunk size too large", + ))); + } + + this.read_state = ReadState::StreamWaitingData(len); + } + + ReadState::StreamWaitingData(size) => { + debug!("recv stream data: {}", size); + let this = &mut *self; + ready!(this.poll_read_exact(cx, size))?; + + if let Some(ref mut cipher) = this.aead_read_cipher { + cipher.decrypt_inplace(&mut this.read_buf)?; + let data_len = size - cipher.security.overhead_len(); + this.read_buf.truncate(data_len); + this.read_state = ReadState::StreamFlushingData(data_len); + } else { + this.read_state = ReadState::StreamFlushingData(size); + } + } + + ReadState::StreamFlushingData(size) => { + debug!("flush stream data: {}", size); + let to_read = std::cmp::min(buf.remaining(), size); + let payload = self.read_buf.split_to(to_read); + buf.put_slice(&payload); + if to_read < size { + // there're unread data, continues in next poll + self.read_state = ReadState::StreamFlushingData(size - to_read); + } else { + // all data consumed, ready to read next chunk + self.read_state = ReadState::StreamWaitingLength; + } + + return Poll::Ready(Ok(())); + } + } } - - this.handshake_done = true; - debug!("handshake done"); - - let stream = &mut this.stream; - let reader = &mut this.reader; - - return match reader { - VmessReader::None(r) => Pin::new(r).poll_read(stream, cx, buf), - VmessReader::Aes128Gcm(r) => Pin::new(r).poll_read(stream, cx, buf), - VmessReader::ChaCha20Poly1305(r) => Pin::new(r).poll_read(stream, cx, buf), - }; } } impl AsyncWrite for VmessStream where - S: AsyncWrite + Unpin, + S: AsyncWrite + Unpin + Send + Sync, { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - let Self { - ref mut stream, - ref mut writer, - .. - } = self.get_mut(); - - debug!("poll write with aead"); - - return match writer { - VmessWriter::None(w) => Pin::new(w).poll_write(stream, cx, buf), - VmessWriter::Aes128Gcm(w) => Pin::new(w).poll_write(stream, cx, buf), - VmessWriter::ChaCha20Poly1305(w) => Pin::new(w).poll_write(stream, cx, buf), - }; + loop { + match self.write_state { + WriteState::BuildingData => { + let this = &mut *self; + let mut overhead_len = 0; + if let Some(ref mut cipher) = this.aead_write_cipher { + overhead_len = cipher.security.overhead_len(); + } + + let max_payload_size = MAX_CHUNK_SIZE - overhead_len; + let consume_len = std::cmp::min(buf.len(), max_payload_size); + let payload_len = consume_len + overhead_len; + + let size_bytes = 2; + this.write_buf.reserve(size_bytes + payload_len); + this.write_buf.put_u16(payload_len as u16); + + let mut piece2 = this.write_buf.split_off(size_bytes); + + piece2.put_slice(&buf[..consume_len]); + if let Some(ref mut cipher) = this.aead_write_cipher { + piece2 + .extend_from_slice(vec![0u8; cipher.security.overhead_len()].as_ref()); + cipher.encrypt_inplace(&mut piece2)?; + } + + this.write_buf.unsplit(piece2); + + // ready to write data + self.write_state = + WriteState::FlushingData(consume_len, (this.write_buf.len(), 0)); + } + + // consumed is the consumed plaintext length we're going to return to caller. + // total is total length of the ciphertext data chunk we're going to write to remote. + // written is the number of ciphertext bytes were written. + WriteState::FlushingData(consumed, (total, written)) => { + let this = &mut *self; + + // There would be trouble if the caller change the buf upon pending, but I + // believe that's not a usual use case. + let nw = ready!(tokio_util::io::poll_write_buf( + Pin::new(&mut this.stream), + cx, + &mut this.write_buf + ))?; + if nw == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "failed to write whole data", + )) + .into(); + } + + if written + nw >= total { + // data chunk written, go to next chunk + this.write_state = WriteState::BuildingData; + return Poll::Ready(Ok(consumed)); + } + + this.write_state = WriteState::FlushingData(consumed, (total, written + nw)); + } + } + } } fn poll_flush(