Skip to content

Commit

Permalink
replace checker task according to servers number
Browse files Browse the repository at this point in the history
- optimize memory consumption
  • Loading branch information
zonyitoo committed Oct 6, 2021
1 parent 62d1445 commit eff3853
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 82 deletions.
2 changes: 1 addition & 1 deletion crates/shadowsocks-service/src/local/dns/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ fn should_forward_by_query(context: &ServiceContext, balancer: &PingBalancer, qu
//
// This happens normally because VPN or TUN device receives DNS queries from local servers' plugins
// https://github.com/shadowsocks/shadowsocks-android/issues/2722
for server in balancer.servers().as_ref() {
for server in balancer.servers() {
let svr_cfg = server.server_config();
if let ServerAddr::DomainName(ref dn, ..) = svr_cfg.addr() {
// Convert domain name to `Name`
Expand Down
176 changes: 95 additions & 81 deletions crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use shadowsocks::{
},
ServerConfig,
};
use spin::Mutex as SpinMutex;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
task::JoinHandle,
Expand Down Expand Up @@ -146,7 +147,7 @@ impl PingBalancerBuilder {
let (best_tcp_idx, best_udp_idx) = PingBalancerBuilder::find_best_idx(&self.servers, self.mode);

let balancer_context = PingBalancerContext {
servers: ArcSwap::new(Arc::new(self.servers)),
servers: self.servers,
best_tcp_idx: AtomicUsize::new(best_tcp_idx),
best_udp_idx: AtomicUsize::new(best_udp_idx),
context: self.context,
Expand All @@ -166,15 +167,15 @@ impl PingBalancerBuilder {

PingBalancer {
inner: Arc::new(PingBalancerInner {
context: shared_context,
checker_abortable,
context: ArcSwap::new(shared_context),
checker_abortable: SpinMutex::new(checker_abortable),
}),
}
}
}

struct PingBalancerContext {
servers: ArcSwap<Vec<Arc<ServerIdent>>>,
servers: Vec<Arc<ServerIdent>>,
best_tcp_idx: AtomicUsize,
best_udp_idx: AtomicUsize,
context: Arc<ServiceContext>,
Expand All @@ -185,62 +186,17 @@ struct PingBalancerContext {

impl PingBalancerContext {
fn best_tcp_server(&self) -> Arc<ServerIdent> {
let servers = self.servers.load();

// A guard if reset_servers is running
let mut best_tcp_idx = self.best_tcp_idx.load(Ordering::Relaxed);
if best_tcp_idx >= servers.len() {
best_tcp_idx = 0;
}

servers[best_tcp_idx].clone()
self.servers[self.best_tcp_idx.load(Ordering::Relaxed)].clone()
}

fn best_udp_server(&self) -> Arc<ServerIdent> {
let servers = self.servers.load();

// A guard if reset_servers is running
let mut best_udp_idx = self.best_udp_idx.load(Ordering::Relaxed);
if best_udp_idx >= servers.len() {
best_udp_idx = 0;
}

servers[best_udp_idx].clone()
self.servers[self.best_udp_idx.load(Ordering::Relaxed)].clone()
}
}

impl PingBalancerContext {
pub async fn reset_servers(&self, servers: Vec<ServerConfig>) {
let max_server_rtt = self.max_server_rtt;
let servers = servers
.into_iter()
.map(|s| Arc::new(ServerIdent::new(s, max_server_rtt)))
.collect();
self.reset_servers_ident(servers).await
}

pub async fn reset_servers_ident(&self, servers: Vec<Arc<ServerIdent>>) {
assert!(!servers.is_empty(), "servers shouldn't be empty");

// Restore best_tcp_idx and best_udp_idx
let (best_tcp_idx, best_udp_idx) = PingBalancerBuilder::find_best_idx(&servers, self.mode);

// Create a new Arc servers
let new_servers = Arc::new(servers);

// Replace into context and then run the checker task once
self.servers.store(new_servers);
self.best_tcp_idx.store(best_tcp_idx, Ordering::Release);
self.best_udp_idx.store(best_udp_idx, Ordering::Release);

self.check_once(true).await;
}

async fn init_score(&self) {
assert!(
!self.servers.load().is_empty(),
"check PingBalancer without any servers"
);
assert!(!self.servers.is_empty(), "check PingBalancer without any servers");

self.check_once(true).await;
}
Expand All @@ -253,11 +209,11 @@ impl PingBalancerContext {
svr_cfg.mode().enable_udp() && svr_cfg.weight().udp_weight() > 0.0
}

fn probing_required(&self, servers: &[Arc<ServerIdent>]) -> bool {
fn probing_required(&self) -> bool {
let mut tcp_count = 0;
let mut udp_count = 0;

for server in servers.iter() {
for server in self.servers.iter() {
let svr_cfg = server.server_config();
if self.mode.enable_tcp() && PingBalancerContext::check_server_tcp_enabled(svr_cfg) {
tcp_count += 1;
Expand All @@ -271,20 +227,23 @@ impl PingBalancerContext {
}

async fn checker_task(self: Arc<Self>) {
assert!(
!self.servers.load().is_empty(),
"check PingBalancer without any servers"
);
assert!(!self.servers.is_empty(), "check PingBalancer without any servers");

self.checker_task_real().await
if !self.probing_required() {
self.checker_task_dummy().await
} else {
self.checker_task_real().await
}
}

/// Dummy task that will do nothing if there only have one server in the balancer
async fn checker_task_dummy(self: Arc<Self>) {
future::pending().await
}

/// Check each servers' score and update the best server's index
async fn check_once(&self, first_run: bool) {
let servers = self.servers.load();
if !self.probing_required(&*servers) {
return;
}
let servers = &self.servers;

let mut vfut_tcp = Vec::with_capacity(servers.len());
let mut vfut_udp = Vec::with_capacity(servers.len());
Expand Down Expand Up @@ -414,13 +373,13 @@ impl PingBalancerContext {
}

struct PingBalancerInner {
context: Arc<PingBalancerContext>,
checker_abortable: JoinHandle<()>,
context: ArcSwap<PingBalancerContext>,
checker_abortable: SpinMutex<JoinHandle<()>>,
}

impl Drop for PingBalancerInner {
fn drop(&mut self) {
self.checker_abortable.abort();
self.checker_abortable.lock().abort();
trace!("ping balancer stopped");
}
}
Expand All @@ -434,42 +393,82 @@ pub struct PingBalancer {
impl PingBalancer {
/// Get service context
pub fn context(&self) -> Arc<ServiceContext> {
self.inner.context.context.clone()
}

/// Get reference of the service context
pub fn context_ref(&self) -> &ServiceContext {
self.inner.context.context.as_ref()
let context = self.inner.context.load();
context.context.clone()
}

/// Pick the best TCP server
pub fn best_tcp_server(&self) -> Arc<ServerIdent> {
self.inner.context.best_tcp_server()
let context = self.inner.context.load();
context.best_tcp_server()
}

/// Pick the best UDP server
pub fn best_udp_server(&self) -> Arc<ServerIdent> {
self.inner.context.best_udp_server()
let context = self.inner.context.load();
context.best_udp_server()
}

/// Get the server list
pub fn servers(&self) -> Arc<Vec<Arc<ServerIdent>>> {
self.inner.context.servers.load().clone()
pub fn servers<'a>(&'a self) -> PingServerIter<'a> {
let context = self.inner.context.load();
let servers: &Vec<Arc<ServerIdent>> = unsafe { &*(&context.servers as *const _) };
PingServerIter {
context: context.clone(),
iter: servers.iter(),
}
}

/// Reset servers in load balancer. Designed for auto-reloading configuration file.
#[inline]
pub async fn reset_servers(&self, servers: Vec<ServerConfig>) {
self.inner.context.reset_servers(servers).await
let old_context = self.inner.context.load();

let servers = servers
.into_iter()
.map(|s| Arc::new(ServerIdent::new(s, old_context.max_server_rtt)))
.collect::<Vec<Arc<ServerIdent>>>();
let (best_tcp_idx, best_udp_idx) = PingBalancerBuilder::find_best_idx(&servers, old_context.mode);

// Stop the update task and create a new Context.
let context = PingBalancerContext {
servers,
best_tcp_idx: AtomicUsize::new(best_tcp_idx),
best_udp_idx: AtomicUsize::new(best_udp_idx),
context: old_context.context.clone(),
mode: old_context.mode,
max_server_rtt: old_context.max_server_rtt,
check_interval: old_context.check_interval,
};

context.init_score().await;

let shared_context = Arc::new(context);

let checker_abortable = {
let shared_context = shared_context.clone();
tokio::spawn(async move { shared_context.checker_task().await })
};

{
// Stop the previous task and replace with the new task
let mut abortable = self.inner.checker_abortable.lock();
abortable.abort();
*abortable = checker_abortable;
}

// Replace with the new context
self.inner.context.store(shared_context);
}
}

impl Debug for PingBalancer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let context = self.inner.context.load();

f.debug_struct("PingBalancer")
.field("servers", &self.inner.context.servers)
.field("best_tcp_idx", &self.inner.context.best_tcp_idx.load(Ordering::Relaxed))
.field("best_udp_idx", &self.inner.context.best_udp_idx.load(Ordering::Relaxed))
.field("servers", &context.servers)
.field("best_tcp_idx", &context.best_tcp_idx.load(Ordering::Relaxed))
.field("best_udp_idx", &context.best_udp_idx.load(Ordering::Relaxed))
.finish()
}
}
Expand Down Expand Up @@ -706,3 +705,18 @@ impl Display for ServerConfigFormatter<'_> {
}
}
}

/// Server Iterator
pub struct PingServerIter<'a> {
#[allow(dead_code)]
context: Arc<PingBalancerContext>,
iter: std::slice::Iter<'a, Arc<ServerIdent>>,
}

impl<'a> Iterator for PingServerIter<'a> {
type Item = &'a ServerIdent;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(AsRef::as_ref)
}
}

0 comments on commit eff3853

Please sign in to comment.