Skip to content

Commit

Permalink
CR changes: Add async method to DisconnectNotifier trait, styling and…
Browse files Browse the repository at this point in the history
… other cleanups
  • Loading branch information
ikolomi committed Sep 3, 2024
1 parent 73ff308 commit 24c19dd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 53 deletions.
5 changes: 5 additions & 0 deletions redis/src/aio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::net::SocketAddr;
#[cfg(unix)]
use std::path::Path;
use std::pin::Pin;
use std::time::Duration;

/// Enables the async_std compatibility
#[cfg(feature = "async-std-comp")]
Expand Down Expand Up @@ -91,10 +92,14 @@ pub trait ConnectionLike {
}

/// Implements ability to notify about disconnection events
#[async_trait]
pub trait DisconnectNotifier: Send + Sync {
/// Notify about disconnect event
fn notify_disconnect(&mut self);

/// Wait for disconnect event with timeout
async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration);

/// Intended to be used with Box
fn clone_box(&self) -> Box<dyn DisconnectNotifier>;
}
Expand Down
117 changes: 64 additions & 53 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ use backoff_std_async::future::retry;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use backoff_std_async::{Error as BackoffError, ExponentialBackoff};

#[cfg(feature = "tokio-comp")]
use async_trait::async_trait;
#[cfg(feature = "tokio-comp")]
use backoff_tokio::future::retry;
#[cfg(feature = "tokio-comp")]
Expand Down Expand Up @@ -379,20 +381,37 @@ where
#[cfg(feature = "tokio-comp")]
#[derive(Clone)]
struct TokioDisconnectNotifier {
pub disconnect_notifier: Arc<Notify>,
disconnect_notifier: Arc<Notify>,
}

#[cfg(feature = "tokio-comp")]
#[async_trait]
impl DisconnectNotifier for TokioDisconnectNotifier {
fn notify_disconnect(&mut self) {
self.disconnect_notifier.notify_one();
}

async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) {
let _ = timeout(*max_wait, async {
self.disconnect_notifier.notified().await;
})
.await;
}

fn clone_box(&self) -> Box<dyn DisconnectNotifier> {
Box::new(self.clone())
}
}

#[cfg(feature = "tokio-comp")]
impl TokioDisconnectNotifier {
fn new() -> TokioDisconnectNotifier {
TokioDisconnectNotifier {
disconnect_notifier: Arc::new(Notify::new()),
}
}
}

type ConnectionMap<C> = connections_container::ConnectionsMap<ConnectionFuture<C>>;
type ConnectionsContainer<C> =
self::connections_container::ConnectionsContainer<ConnectionFuture<C>>;
Expand All @@ -406,8 +425,6 @@ pub(crate) struct InnerCore<C> {
subscriptions_by_address: RwLock<HashMap<String, PubSubSubscriptionInfo>>,
unassigned_subscriptions: RwLock<PubSubSubscriptionInfo>,
glide_connection_options: GlideConnectionOptions,
#[cfg(feature = "tokio-comp")]
tokio_notify: Arc<Notify>,
}

pub(crate) type Core<C> = Arc<InnerCore<C>>;
Expand Down Expand Up @@ -990,27 +1007,24 @@ where
cluster_params: ClusterParams,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> RedisResult<Disposable<Self>> {
#[cfg(feature = "tokio-comp")]
let tokio_notify = Arc::new(Notify::new());

let disconnect_notifier = {
#[cfg(feature = "tokio-comp")]
{
Some::<Box<dyn DisconnectNotifier>>(Box::new(TokioDisconnectNotifier {
disconnect_notifier: tokio_notify.clone(),
}))
Some::<Box<dyn DisconnectNotifier>>(Box::new(TokioDisconnectNotifier::new()))
}
#[cfg(not(feature = "tokio-comp"))]
None
};

let glide_connection_options = GlideConnectionOptions {
push_sender,
disconnect_notifier,
};

let connections = Self::create_initial_connections(
initial_nodes,
&cluster_params,
GlideConnectionOptions {
push_sender: push_sender.clone(),
disconnect_notifier: disconnect_notifier.clone(),
},
glide_connection_options.clone(),
)
.await?;

Expand All @@ -1035,12 +1049,7 @@ where
},
),
subscriptions_by_address: RwLock::new(Default::default()),
glide_connection_options: GlideConnectionOptions {
push_sender: push_sender.clone(),
disconnect_notifier: disconnect_notifier.clone(),
},
#[cfg(feature = "tokio-comp")]
tokio_notify,
glide_connection_options,
});
let mut connection = ClusterConnInner {
inner,
Expand Down Expand Up @@ -1227,40 +1236,40 @@ where
// In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server.
async fn validate_all_user_connections(inner: Arc<InnerCore<C>>) {
let mut all_valid_conns = HashMap::new();
let mut all_nodes_with_slots = HashSet::new();
// prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts
{
let mut nodes_to_delete = Vec::new();
let connections_container = inner.conn_lock.read().await;

connections_container
.slot_map
.addresses_for_all_nodes()
.iter()
.for_each(|addr| {
all_nodes_with_slots.insert(String::from(*addr));
});
let mut nodes_to_delete = Vec::new();
let connections_container = inner.conn_lock.read().await;

connections_container
.all_node_connections()
.for_each(|(addr, con)| {
if all_nodes_with_slots.contains(&addr) {
all_valid_conns.insert(addr.clone(), con.clone());
} else {
nodes_to_delete.push(addr.clone());
}
});
let all_nodes_with_slots: HashSet<String> = connections_container
.slot_map
.addresses_for_all_nodes()
.iter()
.map(|addr| String::from(*addr))
.collect();

connections_container
.all_node_connections()
.for_each(|(addr, con)| {
if all_nodes_with_slots.contains(&addr) {
all_valid_conns.insert(addr.clone(), con.clone());
} else {
nodes_to_delete.push(addr.clone());
}
});

for addr in &nodes_to_delete {
connections_container.remove_node(addr);
}
for addr in &nodes_to_delete {
connections_container.remove_node(addr);
}

drop(connections_container);

// identify nodes with closed connection
let mut addrs_to_refresh = Vec::new();
for (addr, con_fut) in &all_valid_conns {
let con = con_fut.clone().await;
// connection object might be present despite the transport being closed
if con.is_closed() {
// transport is closed, need to refresh
addrs_to_refresh.push(addr.clone());
}
}
Expand Down Expand Up @@ -1289,7 +1298,7 @@ where
inner: Arc<InnerCore<C>>,
addresses: Vec<String>,
conn_type: RefreshConnectionType,
try_existing_node: bool,
check_existing_conn: bool,
) {
info!("Started refreshing connections to {:?}", addresses);
let connections_container = inner.conn_lock.read().await;
Expand All @@ -1301,10 +1310,10 @@ where
.fold(
&*connections_container,
|connections_container, address| async move {
let node_option = if try_existing_node {
let node_option = if check_existing_conn {
connections_container.remove_node(&address)
} else {
Option::None
None
};

// override subscriptions for this connection
Expand Down Expand Up @@ -1541,13 +1550,15 @@ where

async fn connections_validation_task(inner: Arc<InnerCore<C>>, interval_duration: Duration) {
loop {
#[cfg(feature = "tokio-comp")]
let _ = timeout(interval_duration, async {
inner.tokio_notify.notified().await;
})
.await;
#[cfg(not(feature = "tokio-comp"))]
let _ = boxed_sleep(interval_duration).await;
if let Some(disconnect_notifier) =
inner.glide_connection_options.disconnect_notifier.clone()
{
disconnect_notifier
.wait_for_disconnect_with_timeout(&interval_duration)
.await;
} else {
let _ = boxed_sleep(interval_duration).await;
}

Self::validate_all_user_connections(inner.clone()).await;
}
Expand Down

0 comments on commit 24c19dd

Please sign in to comment.