Skip to content

Commit

Permalink
Refactor selection algorithms and LB
Browse files Browse the repository at this point in the history
  • Loading branch information
hippalus committed Aug 19, 2024
1 parent ba60ede commit d4169eb
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 88 deletions.
2 changes: 1 addition & 1 deletion config/development.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ ca_path = ""
upstream_host = "localhost"
upstream_port = 1994
discovery_type = "dns"
discovery_refresh_interval = 300
discovery_refresh_interval = 60
load_balancer_selection = "round_robin"

8 changes: 4 additions & 4 deletions umay/src/app/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::io::Read;
use std::net::SocketAddr;
use std::time::Duration;
use std::{env, fs};
use tracing::{info, warn};
use tracing::{debug, warn};
use webpki::types::ServerName;

const CONFIG_BASE_PATH: &str = "config/";
Expand Down Expand Up @@ -82,8 +82,8 @@ impl ServiceConfig {
&self.discovery_type
}

pub fn discovery_refresh_interval(&self) -> u64 {
self.discovery_refresh_interval
pub fn discovery_refresh_interval(&self) -> Duration {
Duration::from_secs(self.discovery_refresh_interval)
}

pub fn load_balancer_selection(&self) -> &str {
Expand Down Expand Up @@ -142,7 +142,7 @@ impl AppConfig {

AppConfig::set_env_vars(&mut app_config)?;

info!("Configuration loaded successfully {:?}", app_config);
debug!("Configuration loaded successfully {:?}", app_config);
Ok(app_config)
}

Expand Down
21 changes: 10 additions & 11 deletions umay/src/app/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::app::config::{AppConfig, ServiceConfig};
use crate::app::metric::Metrics;
use crate::balance::discovery::{DnsDiscovery, LocalDiscovery, ServiceDiscovery};
use crate::balance::{Backends, LoadBalancer, Selector};
use crate::balance::selection::SelectionAlgorithm;
use crate::balance::{selection, Backends, LoadBalancer};
use crate::proxy::ProxyService;
use crate::tls;
use crate::tls::credentials::Store;
Expand Down Expand Up @@ -50,7 +51,7 @@ impl Server {
let listener = bind_listener(service_config.port()).await?;

info!("Listening on 0.0.0.0:{}", service_config.port());
self.start_load_balancer_refresh();
self.start_load_balancer_refresh(service_config.discovery_refresh_interval());

loop {
tokio::select! {
Expand Down Expand Up @@ -78,9 +79,8 @@ impl Server {
Ok(())
}

fn start_load_balancer_refresh(&self) {
fn start_load_balancer_refresh(&self, refresh_interval: Duration) {
let lb = self.proxy_service.load_balancer().clone();
let refresh_interval = Duration::from_secs(30);
tokio::spawn(async move {
lb.start_refresh_task(refresh_interval).await;
});
Expand All @@ -105,6 +105,7 @@ async fn bind_listener(port: u16) -> Result<TcpListener> {
.await
.context(format!("Failed to bind to address: {}", listen_addr))
}

fn initialize_tls_server(store: &Store) -> Result<Arc<tls::server::Server>> {
Ok(Arc::new(tls::server::Server::new(
store.server_name().to_owned(),
Expand Down Expand Up @@ -137,14 +138,12 @@ fn create_discovery(
}
}

fn create_selector(config: &ServiceConfig) -> Result<Selector> {
fn create_selector(config: &ServiceConfig) -> Result<Arc<dyn SelectionAlgorithm + Send + Sync>> {
match config.load_balancer_selection() {
"round_robin" => Ok(Selector::RoundRobin(Arc::new(tokio::sync::Mutex::new(0)))),
"random" => Ok(Selector::Random),
"least_connection" => Ok(Selector::LeastConnection(Arc::new(
tokio::sync::Mutex::new(Vec::new()),
))),
"consistent_hashing" => Ok(Selector::ConsistentHashing),
"random" => Ok(Arc::new(selection::Random)),
"round_robin" => Ok(Arc::new(selection::RoundRobin::default())),
"weighted_round_robin" => Ok(Arc::new(selection::WeightedRoundRobin::default())),
"least_connection" => Ok(Arc::new(selection::LeastConnections::default())),
_ => Err(anyhow!(
"Invalid load balancer selection: {}",
config.load_balancer_selection()
Expand Down
47 changes: 6 additions & 41 deletions umay/src/balance/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use crate::balance::discovery::ServiceDiscovery;
use crate::balance::selection::SelectionAlgorithm;
use anyhow::Result;
use arc_swap::ArcSwap;
use rand::prelude::IteratorRandom;
use std::collections::BTreeSet;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::error;

pub mod discovery;
pub mod selection;

#[derive(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
pub struct Backend {
Expand Down Expand Up @@ -54,23 +54,15 @@ impl Backends {
}
}

#[derive(Clone)]
pub enum Selector {
Random,
RoundRobin(Arc<Mutex<usize>>),
LeastConnection(Arc<Mutex<Vec<(Backend, usize)>>>),
ConsistentHashing,
}

pub struct LoadBalancer {
selector: Selector,
selection: Arc<dyn SelectionAlgorithm + Send + Sync>,
backends: Arc<Backends>,
}

impl LoadBalancer {
pub fn new(backends: Backends, selector: Selector) -> Self {
pub fn new(backends: Backends, selection: Arc<dyn SelectionAlgorithm>) -> Self {
Self {
selector,
selection,
backends: Arc::new(backends),
}
}
Expand All @@ -81,34 +73,7 @@ impl LoadBalancer {
return None;
}

match &self.selector {
Selector::Random => backends.iter().choose(&mut rand::thread_rng()).cloned(),
Selector::RoundRobin(counter) => {
let mut index = counter.lock().await;
*index = (*index + 1) % backends.len();
backends.iter().nth(*index).cloned()
}
Selector::LeastConnection(connections) => {
let mut conns = connections.lock().await;
conns.sort_by_key(|(_, count)| *count);
conns.first().map(|(backend, _)| backend.clone())
}
Selector::ConsistentHashing => {
if let Some(key) = key {
let hash = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
};
backends
.iter()
.min_by_key(|backend| backend.hash_key() ^ hash)
.cloned()
} else {
None
}
}
}
self.selection.select(&backends).await
}

pub async fn start_refresh_task(self: Arc<Self>, duration: Duration) {
Expand Down
149 changes: 149 additions & 0 deletions umay/src/balance/selection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use crate::balance::Backend;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use rand::Rng;
use std::collections::{BTreeMap, BTreeSet};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

#[async_trait]
pub trait SelectionAlgorithm: Send + Sync {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend>;
}

pub struct RoundRobin {
index: AtomicUsize,
}

impl Default for RoundRobin {
fn default() -> Self {
Self {
index: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl SelectionAlgorithm for RoundRobin {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend> {
let len = backends.len();
if len == 0 {
return None;
}
let index = self.index.fetch_add(1, Ordering::Relaxed) % len;
backends.iter().nth(index).cloned()
}
}
pub struct WeightedRoundRobin {
index: AtomicUsize,
}

impl Default for WeightedRoundRobin {
fn default() -> Self {
Self {
index: AtomicUsize::new(0),
}
}
}

#[async_trait]
impl SelectionAlgorithm for WeightedRoundRobin {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend> {
let total_weight: usize = backends.iter().map(|b| b.weight).sum();
if total_weight == 0 {
return None;
}
let mut index = self.index.fetch_add(1, Ordering::Relaxed) % total_weight;
for backend in backends.iter() {
if index < backend.weight {
return Some(backend.clone());
}
index -= backend.weight;
}
None
}
}

pub struct LeastConnections {
connections: ArcSwap<BTreeMap<SocketAddr, usize>>,
}

impl Default for LeastConnections {
fn default() -> Self {
Self {
connections: ArcSwap::from_pointee(BTreeMap::new()),
}
}
}

impl LeastConnections {
pub fn increment(&self, addr: &SocketAddr) {
self.connections.rcu(|connections| {
let mut new_connections = connections.as_ref().clone();
*new_connections.entry(*addr).or_insert(0) += 1;
new_connections
});
}

pub fn decrement(&self, addr: &SocketAddr) {
self.connections.rcu(|connections| {
let mut new_connections = connections.as_ref().clone();
if let Some(count) = new_connections.get_mut(addr) {
if *count > 0 {
*count -= 1;
}
}
new_connections
});
}
}

#[async_trait]
impl SelectionAlgorithm for LeastConnections {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend> {
let connections = self.connections.load();
backends
.iter()
.min_by_key(|b| connections.get(&b.addr).unwrap_or(&0))
.cloned()
}
}

#[derive(Default)]
pub struct Random;

#[async_trait]
impl SelectionAlgorithm for Random {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend> {
if backends.is_empty() {
return None;
}
let mut rng = rand::thread_rng();
let index = rng.gen_range(0..backends.len());
backends.iter().nth(index).cloned()
}
}

pub struct ConsistentHashing {
virtual_nodes: usize,
}

impl ConsistentHashing {
pub fn new(virtual_nodes: usize) -> Self {
ConsistentHashing { virtual_nodes }
}

fn hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}
}

#[async_trait]
impl SelectionAlgorithm for ConsistentHashing {
async fn select(&self, backends: &Arc<BTreeSet<Backend>>) -> Option<Backend> {
todo!()
}
}
Loading

0 comments on commit d4169eb

Please sign in to comment.