Skip to content

Commit

Permalink
feat: support fwmark for tuic outbound
Browse files Browse the repository at this point in the history
  • Loading branch information
Itsusinn committed Mar 30, 2024
1 parent 3bad303 commit e926959
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 62 deletions.
2 changes: 2 additions & 0 deletions clash_lib/src/config/internal/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ pub struct OutboundTuic {
pub gc_lifetime: Option<u64>,
pub send_window: Option<u64>,
pub receive_window: Option<u64>,
/// fwmark
pub mark: Option<u32>,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
Expand Down
6 changes: 5 additions & 1 deletion clash_lib/src/proxy/converters/tuic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::time::Duration;
use std::{
sync::{atomic::AtomicU32, Arc},
time::Duration,
};

use quinn::VarInt;

Expand Down Expand Up @@ -54,6 +57,7 @@ impl TryFrom<&OutboundTuic> for AnyOutboundHandler {
send_window: s.send_window.unwrap_or(8 * 1024 * 1024 * 2),
receive_window: VarInt::from_u64(s.receive_window.unwrap_or(8 * 1024 * 1024))
.unwrap_or(VarInt::MAX),
mark: Arc::new(AtomicU32::new(s.mark.unwrap_or(0))),
})
}
}
30 changes: 22 additions & 8 deletions clash_lib/src/proxy/tuic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use anyhow::Result;
use axum::async_trait;
use quinn::{EndpointConfig, TokioRuntime};
use std::net::SocketAddr;
use std::sync::atomic::AtomicU32;
use std::{
net::{Ipv4Addr, Ipv6Addr, UdpSocket},
sync::{
Expand Down Expand Up @@ -68,6 +69,7 @@ pub struct HandlerOptions {
pub gc_lifetime: Duration,
pub send_window: u64,
pub receive_window: VarInt,
pub mark: Arc<AtomicU32>,

/// not used
pub ip: Option<String>,
Expand Down Expand Up @@ -168,6 +170,7 @@ impl Handler {
socket,
Arc::new(TokioRuntime),
)?;

endpoint.set_default_client_config(quinn_config);
let endpoint = TuicEndpoint {
ep: endpoint,
Expand All @@ -179,6 +182,7 @@ impl Handler {
heartbeat: opts.heartbeat_interval,
gc_interval: opts.gc_interval,
gc_lifetime: opts.gc_lifetime,
mark: opts.mark.clone(),
};
Ok(Arc::new(Self {
opts,
Expand All @@ -187,17 +191,27 @@ impl Handler {
next_assoc_id: AtomicU16::new(0),
}))
}
async fn get_conn(&self) -> Result<Arc<TuicConnection>> {
async fn get_conn(
&self,
resolver: &ThreadSafeDNSResolver,
mark: Option<u32>,
) -> Result<Arc<TuicConnection>> {
let mark = mark.unwrap_or(self.opts.mark.load(Ordering::Relaxed));
let mut rebind = false;
// if mark not match the one current used, then rebind
if mark != self.opts.mark.swap(mark, Ordering::Relaxed) {
rebind = true;
}
let fut = async {
let mut guard = self.conn.lock().await;
if guard.is_none() {
// init
*guard = Some(self.ep.connect().await?);
*guard = Some(self.ep.connect(resolver, rebind).await?);
}
let conn = guard.take().unwrap();
let conn = if conn.check_open().is_err() {
let conn = if conn.check_open().is_err() || rebind {
// reconnect
self.ep.connect().await?
self.ep.connect(resolver, rebind).await?
} else {
conn
};
Expand All @@ -210,9 +224,9 @@ impl Handler {
async fn do_connect_stream(
&self,
sess: &Session,
_resolver: ThreadSafeDNSResolver,
resolver: ThreadSafeDNSResolver,
) -> Result<BoxedChainedStream> {
let conn = self.get_conn().await?;
let conn = self.get_conn(&resolver, sess.packet_mark).await?;
let dest = sess.destination.clone().into_tuic();
let tuic_tcp = conn.connect_tcp(dest).await?.compat();

Expand All @@ -224,9 +238,9 @@ impl Handler {
async fn do_connect_datagram(
&self,
sess: &Session,
_resolver: ThreadSafeDNSResolver,
resolver: ThreadSafeDNSResolver,
) -> Result<BoxedChainedDatagram> {
let conn = self.get_conn().await?;
let conn = self.get_conn(&resolver, sess.packet_mark).await?;

let assos_id = self.next_assoc_id.fetch_add(1, Ordering::Relaxed);
let quic_udp = TuicDatagramOutbound::new(assos_id, conn, sess.source.into());
Expand Down
118 changes: 65 additions & 53 deletions clash_lib/src/proxy/tuic/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::app::dns::ThreadSafeDNSResolver;
use crate::proxy::utils::StdSocketExt;
use crate::session::SocksAddr as ClashSocksAddr;
use anyhow::Result;
use quinn::Connection as QuinnConnection;
use quinn::{Endpoint as QuinnEndpoint, ZeroRttAccepted};
use register_count::Counter;
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
str::FromStr,
Expand All @@ -26,63 +29,71 @@ pub struct TuicEndpoint {
pub heartbeat: Duration,
pub gc_interval: Duration,
pub gc_lifetime: Duration,
pub mark: Arc<AtomicU32>,
}
impl TuicEndpoint {
pub async fn connect(&self) -> Result<Arc<TuicConnection>> {
let mut last_err = None;
pub async fn connect(&self, resolver: &ThreadSafeDNSResolver, rebind: bool) -> Result<Arc<TuicConnection>> {
let remote_addr = self.server.resolve(resolver).await?;
let connect_to = async {
let match_ipv4 = remote_addr.is_ipv4()
&& self
.ep
.local_addr()
.map_or(false, |local_addr| local_addr.is_ipv4());
let match_ipv6 = remote_addr.is_ipv6()
&& self
.ep
.local_addr()
.map_or(false, |local_addr| local_addr.is_ipv6());

for addr in self.server.resolve().await? {
let connect_to = async {
let match_ipv4 =
addr.is_ipv4() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv4());
let match_ipv6 =
addr.is_ipv6() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv6());

if !match_ipv4 && !match_ipv6 {
let bind_addr = if addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};

self.ep
.rebind(UdpSocket::bind(bind_addr).map_err(|err| {
anyhow!("failed to create endpoint UDP socket {}", err)
})?)
.map_err(|err| anyhow!("failed to rebind endpoint UDP socket {}", err))?;
}

tracing::trace!("Connect to {} {}", addr, self.server.server_name());
let conn = self.ep.connect(addr, self.server.server_name())?;
let (conn, zero_rtt_accepted) = if self.zero_rtt_handshake {
match conn.into_0rtt() {
Ok((conn, zero_rtt_accepted)) => (conn, Some(zero_rtt_accepted)),
Err(conn) => (conn.await?, None),
}
// if client and server don't match each other or forced to rebind, then rebind local socket
if (!match_ipv4 && !match_ipv6) || rebind {
let bind_addr = if remote_addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
(conn.await?, None)
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};
let socket = UdpSocket::bind(bind_addr)
.map_err(|err| anyhow!("failed to bind local socket: {}", err))?;
let mark = self.mark.load(Ordering::Relaxed);
// ignore mark == 0, just for convenient
if mark != 0 {
socket.set_mark(mark)?;
}
self.ep
.rebind(socket)
.map_err(|err| anyhow!("failed to rebind endpoint UDP socket {}", err))?;
}

Ok((conn, zero_rtt_accepted))
tracing::trace!("Connect to {} {}", remote_addr, self.server.server_name());
let conn = self.ep.connect(remote_addr, self.server.server_name())?;
let (conn, zero_rtt_accepted) = if self.zero_rtt_handshake {
match conn.into_0rtt() {
Ok((conn, zero_rtt_accepted)) => (conn, Some(zero_rtt_accepted)),
Err(conn) => (conn.await?, None),
}
} else {
(conn.await?, None)
};

match connect_to.await {
Ok((conn, zero_rtt_accepted)) => {
return Ok(TuicConnection::new(
conn,
zero_rtt_accepted,
self.udp_relay_mode,
self.uuid,
self.password.clone(),
self.heartbeat,
self.gc_interval,
self.gc_lifetime,
));
}
Err(err) => last_err = Some(err),
Ok((conn, zero_rtt_accepted))
};

match connect_to.await {
Ok((conn, zero_rtt_accepted)) => {
return Ok(TuicConnection::new(
conn,
zero_rtt_accepted,
self.udp_relay_mode,
self.uuid,
self.password.clone(),
self.heartbeat,
self.gc_interval,
self.gc_lifetime,
));
}
Err(err) => Err(err),
}
Err(last_err.unwrap_or(anyhow!("dns resolve")))
}
}

Expand Down Expand Up @@ -194,15 +205,16 @@ impl ServerAddr {
pub fn server_name(&self) -> &str {
&self.domain
}
// TODO change to clash dns?
pub async fn resolve(&self) -> Result<impl Iterator<Item = SocketAddr>> {

pub async fn resolve(&self, resolver: &ThreadSafeDNSResolver) -> Result<SocketAddr> {
if let Some(ip) = self.ip {
Ok(vec![SocketAddr::from((ip, self.port))].into_iter())
Ok(SocketAddr::from((ip, self.port)))
} else {
Ok(tokio::net::lookup_host((self.domain.as_str(), self.port))
let ip = resolver
.resolve(self.domain.as_str(), false)
.await?
.collect::<Vec<_>>()
.into_iter())
.ok_or(anyhow!("Resolve failed: unknown hostname"))?;
Ok(SocketAddr::from((ip, self.port)))
}
}
}
Expand Down

0 comments on commit e926959

Please sign in to comment.