From 6746921cf65441e91ad6f8b383282eda2ea150da Mon Sep 17 00:00:00 2001 From: iHsin Date: Sun, 24 Mar 2024 17:53:22 +0800 Subject: [PATCH] refactor: use Arc --- clash_lib/src/proxy/tuic/handle_stream.rs | 25 ++++++++++++++--------- clash_lib/src/proxy/tuic/handle_task.rs | 6 +++--- clash_lib/src/proxy/tuic/mod.rs | 7 +++---- clash_lib/src/proxy/tuic/types.rs | 8 ++++---- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/clash_lib/src/proxy/tuic/handle_stream.rs b/clash_lib/src/proxy/tuic/handle_stream.rs index 3d38ee58b..ee0d5c531 100644 --- a/clash_lib/src/proxy/tuic/handle_stream.rs +++ b/clash_lib/src/proxy/tuic/handle_stream.rs @@ -1,4 +1,5 @@ use std::sync::atomic::Ordering; +use std::sync::Arc; use bytes::Bytes; use quinn::{RecvStream, SendStream, VarInt}; @@ -46,7 +47,7 @@ impl TuicConnection { Ok(self.conn.read_datagram().await?) } - pub async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { + pub async fn handle_uni_stream(self: Arc, recv: RecvStream, _reg: Register) { tracing::debug!("[relay] incoming unidirectional stream"); let res = match self.inner.accept_uni_stream(recv).await { @@ -66,20 +67,24 @@ impl TuicConnection { } } - pub async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: Register) { + pub async fn handle_bi_stream( + self: Arc, + send: SendStream, + recv: RecvStream, + _reg: Register, + ) { tracing::debug!("[relay] incoming bidirectional stream"); - let res = match self.inner.accept_bi_stream(send, recv).await { - Err(err) => Err::<(), _>(anyhow!(err)), - _ => unreachable!(), // already filtered in `tuic_quinn` + + let err = match self.inner.accept_bi_stream(send, recv).await { + Err(err) => anyhow!(err), + _ => anyhow!("A client shouldn't receive bi stream"), }; - if let Err(err) = res { - tracing::warn!("[relay] incoming bidirectional stream error: {err}"); - } + tracing::warn!("[relay] incoming bidirectional stream error: {err}"); } - pub async fn handle_datagram(self, dg: Bytes) { + pub async fn handle_datagram(self: Arc, dg: Bytes) { tracing::debug!("[relay] incoming datagram"); let res = match self.inner.accept_datagram(dg) { @@ -91,7 +96,7 @@ impl TuicConnection { } UdpRelayMode::Quic => Err(anyhow!("wrong packet source")), }, - _ => unreachable!(), // already filtered in `tuic_quinn` + _ => Err(anyhow!("Datagram shouldn't receive any data expect UDP packet")), }; if let Err(err) = res { diff --git a/clash_lib/src/proxy/tuic/handle_task.rs b/clash_lib/src/proxy/tuic/handle_task.rs index adb590997..340dc027f 100644 --- a/clash_lib/src/proxy/tuic/handle_task.rs +++ b/clash_lib/src/proxy/tuic/handle_task.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use bytes::Bytes; use quinn::ZeroRttAccepted; @@ -13,7 +13,7 @@ use crate::session::SocksAddr as ClashSocksAddr; use super::types::{TuicConnection, UdpRelayMode}; impl TuicConnection { - pub async fn tuic_auth(self, zero_rtt_accepted: Option) { + pub async fn tuic_auth(self: Arc, zero_rtt_accepted: Option) { if let Some(zero_rtt_accepted) = zero_rtt_accepted { tracing::debug!("[auth] waiting for connection to be fully established"); zero_rtt_accepted.await; @@ -156,7 +156,7 @@ impl TuicConnection { /// Tasks triggered by timer /// Won't return unless occurs error pub async fn cyclical_tasks( - self, + self: Arc, heartbeat_interval: Duration, gc_interval: Duration, gc_lifetime: Duration, diff --git a/clash_lib/src/proxy/tuic/mod.rs b/clash_lib/src/proxy/tuic/mod.rs index 7389b2adb..03b8f4531 100644 --- a/clash_lib/src/proxy/tuic/mod.rs +++ b/clash_lib/src/proxy/tuic/mod.rs @@ -79,7 +79,7 @@ pub struct HandlerOptions { pub struct Handler { opts: HandlerOptions, ep: TuicEndpoint, - conn: AsyncMutex>, + conn: AsyncMutex>>, next_assoc_id: AtomicU16, } @@ -188,7 +188,7 @@ impl Handler { next_assoc_id: AtomicU16::new(0), })) } - async fn get_conn(&self) -> Result { + async fn get_conn(&self) -> Result> { let fut = async { let mut guard = self.conn.lock().await; if guard.is_none() { @@ -202,7 +202,6 @@ impl Handler { } else { conn }; - // TODO TuicConnection is huge, is it necessary to clone it? If it is, should we use Arc ? *guard = Some(conn.clone()); Ok(conn) }; @@ -248,7 +247,7 @@ struct TuicDatagramOutbound { impl TuicDatagramOutbound { pub fn new( assoc_id: u16, - conn: TuicConnection, + conn: Arc, local_addr: ClashSocksAddr, ) -> AnyOutboundDatagram { // TODO not sure about the size of buffer diff --git a/clash_lib/src/proxy/tuic/types.rs b/clash_lib/src/proxy/tuic/types.rs index d8052771f..5c73b6841 100644 --- a/clash_lib/src/proxy/tuic/types.rs +++ b/clash_lib/src/proxy/tuic/types.rs @@ -28,7 +28,7 @@ pub struct TuicEndpoint { pub gc_lifetime: Duration, } impl TuicEndpoint { - pub async fn connect(&self) -> Result { + pub async fn connect(&self) -> Result> { let mut last_err = None; for addr in self.server.resolve().await? { @@ -121,7 +121,7 @@ impl TuicConnection { heartbeat: Duration, gc_interval: Duration, gc_lifetime: Duration, - ) -> Self { + ) -> Arc { let conn = Self { conn: conn.clone(), inner: InnerConnection::::new(conn), @@ -135,7 +135,7 @@ impl TuicConnection { max_concurrent_bi_streams: Arc::new(AtomicU32::new(32)), udp_sessions: Arc::new(AsyncRwLock::new(HashMap::new())), }; - + let conn = Arc::new(conn); tokio::spawn( conn.clone() .init(zero_rtt_accepted, heartbeat, gc_interval, gc_lifetime), @@ -144,7 +144,7 @@ impl TuicConnection { conn } async fn init( - self, + self: Arc, zero_rtt_accepted: Option, heartbeat: Duration, gc_interval: Duration,