Skip to content

Commit

Permalink
Merge pull request #184 from amazon-contributing/fast_reconnect_cand
Browse files Browse the repository at this point in the history
Introduce a fast reconnect process for async cluster connections.
  • Loading branch information
ikolomi authored Sep 4, 2024
2 parents 2d7200f + 24c19dd commit 426bb99
Show file tree
Hide file tree
Showing 23 changed files with 1,083 additions and 712 deletions.
4 changes: 4 additions & 0 deletions redis-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ impl AioConnectionLike for MockRedisConnection {
fn get_db(&self) -> i64 {
0
}

fn is_closed(&self) -> bool {
false
}
}

#[cfg(test)]
Expand Down
6 changes: 4 additions & 2 deletions redis/examples/async-await.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#![allow(unknown_lints, dependency_on_unit_never_type_fallback)]
use redis::AsyncCommands;
use redis::{AsyncCommands, GlideConnectionOptions};

#[tokio::main]
async fn main() -> redis::RedisResult<()> {
let client = redis::Client::open("redis://127.0.0.1/").unwrap();
let mut con = client.get_multiplexed_async_connection(None).await?;
let mut con = client
.get_multiplexed_async_connection(GlideConnectionOptions::default())
.await?;

con.set("key1", b"foo").await?;

Expand Down
10 changes: 9 additions & 1 deletion redis/examples/async-connection-loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::time::Duration;

use futures::future;
use redis::aio::ConnectionLike;
use redis::GlideConnectionOptions;
use redis::RedisResult;
use tokio::time::interval;

Expand Down Expand Up @@ -80,7 +81,14 @@ async fn main() -> RedisResult<()> {

let client = redis::Client::open("redis://127.0.0.1/").unwrap();
match mode {
Mode::Default => run_multi(client.get_multiplexed_tokio_connection(None).await?).await?,
Mode::Default => {
run_multi(
client
.get_multiplexed_tokio_connection(GlideConnectionOptions::default())
.await?,
)
.await?
}
Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?,
#[allow(deprecated)]
Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?,
Expand Down
7 changes: 5 additions & 2 deletions redis/examples/async-multiplexed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(unknown_lints, dependency_on_unit_never_type_fallback)]
use futures::prelude::*;
use redis::{aio::MultiplexedConnection, RedisResult};
use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult};

async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> {
let mut con = con.clone();
Expand Down Expand Up @@ -34,7 +34,10 @@ async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> {
async fn main() {
let client = redis::Client::open("redis://127.0.0.1/").unwrap();

let con = client.get_multiplexed_tokio_connection(None).await.unwrap();
let con = client
.get_multiplexed_tokio_connection(GlideConnectionOptions::default())
.await
.unwrap();

let cmds = (0..100).map(|i| test_cmd(&con, i));
let result = future::try_join_all(cmds).await.unwrap();
Expand Down
6 changes: 4 additions & 2 deletions redis/examples/async-pub-sub.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![allow(unknown_lints, dependency_on_unit_never_type_fallback)]
use futures_util::StreamExt as _;
use redis::AsyncCommands;
use redis::{AsyncCommands, GlideConnectionOptions};

#[tokio::main]
async fn main() -> redis::RedisResult<()> {
let client = redis::Client::open("redis://127.0.0.1/").unwrap();
let mut publish_conn = client.get_multiplexed_async_connection(None).await?;
let mut publish_conn = client
.get_multiplexed_async_connection(GlideConnectionOptions::default())
.await?;
let mut pubsub_conn = client.get_async_pubsub().await?;

pubsub_conn.subscribe("wavephone").await?;
Expand Down
6 changes: 4 additions & 2 deletions redis/examples/async-scan.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![allow(unknown_lints, dependency_on_unit_never_type_fallback)]
use futures::stream::StreamExt;
use redis::{AsyncCommands, AsyncIter};
use redis::{AsyncCommands, AsyncIter, GlideConnectionOptions};

#[tokio::main]
async fn main() -> redis::RedisResult<()> {
let client = redis::Client::open("redis://127.0.0.1/").unwrap();
let mut con = client.get_multiplexed_async_connection(None).await?;
let mut con = client
.get_multiplexed_async_connection(GlideConnectionOptions::default())
.await?;

con.set("async-key1", b"foo").await?;
con.set("async-key2", b"foo").await?;
Expand Down
5 changes: 5 additions & 0 deletions redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ where
fn get_db(&self) -> i64 {
self.db
}

fn is_closed(&self) -> bool {
// always false for AsyncRead + AsyncWrite (cant do better)
false
}
}

/// Represents a `PubSub` connection.
Expand Down
8 changes: 7 additions & 1 deletion redis/src/aio/connection_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::RedisFuture;
use crate::client::GlideConnectionOptions;
use crate::cmd::Cmd;
use crate::push_manager::PushManager;
use crate::types::{RedisError, RedisResult, Value};
Expand Down Expand Up @@ -195,7 +196,7 @@ impl ConnectionManager {
client.get_multiplexed_async_connection_with_timeouts(
response_timeout,
connection_timeout,
None,
GlideConnectionOptions::default(),
)
})
.await
Expand Down Expand Up @@ -301,4 +302,9 @@ impl ConnectionLike for ConnectionManager {
fn get_db(&self) -> i64 {
self.client.connection_info().redis.db
}

fn is_closed(&self) -> bool {
// always return false due to automatic reconnect
false
}
}
23 changes: 23 additions & 0 deletions redis/src/aio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::net::SocketAddr;
#[cfg(unix)]
use std::path::Path;
use std::pin::Pin;
use std::time::Duration;

/// Enables the async_std compatibility
#[cfg(feature = "async-std-comp")]
Expand Down Expand Up @@ -85,6 +86,28 @@ pub trait ConnectionLike {
/// also might be incorrect if the connection like object is not
/// actually connected.
fn get_db(&self) -> i64;

/// Returns the state of the connection
fn is_closed(&self) -> bool;
}

/// Implements ability to notify about disconnection events
#[async_trait]
pub trait DisconnectNotifier: Send + Sync {
/// Notify about disconnect event
fn notify_disconnect(&mut self);

/// Wait for disconnect event with timeout
async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration);

/// Intended to be used with Box
fn clone_box(&self) -> Box<dyn DisconnectNotifier>;
}

impl Clone for Box<dyn DisconnectNotifier> {
fn clone(&self) -> Box<dyn DisconnectNotifier> {
self.clone_box()
}
}

// Initial setup for every connection.
Expand Down
71 changes: 51 additions & 20 deletions redis/src/aio/multiplexed_connection.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use super::{ConnectionLike, Runtime};
use crate::aio::setup_connection;
use crate::aio::DisconnectNotifier;
use crate::client::GlideConnectionOptions;
use crate::cmd::Cmd;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::push_manager::PushManager;
use crate::types::{RedisError, RedisFuture, RedisResult, Value};
use crate::{cmd, ConnectionInfo, ProtocolVersion, PushInfo, PushKind};
use crate::{cmd, ConnectionInfo, ProtocolVersion, PushKind};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
Expand All @@ -23,6 +25,7 @@ use std::fmt;
use std::fmt::Debug;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{self, Poll};
use std::time::Duration;
Expand Down Expand Up @@ -73,19 +76,11 @@ struct PipelineMessage<S> {
/// items being output by the `Stream` (the number is specified at time of sending). With the
/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream`
/// and `Sink`.
#[derive(Clone)]
struct Pipeline<SinkItem> {
sender: mpsc::Sender<PipelineMessage<SinkItem>>,

push_manager: Arc<ArcSwap<PushManager>>,
}

impl<SinkItem> Clone for Pipeline<SinkItem> {
fn clone(&self) -> Self {
Pipeline {
sender: self.sender.clone(),
push_manager: self.push_manager.clone(),
}
}
is_stream_closed: Arc<AtomicBool>,
}

impl<SinkItem> Debug for Pipeline<SinkItem>
Expand All @@ -104,14 +99,21 @@ pin_project! {
in_flight: VecDeque<InFlight>,
error: Option<RedisError>,
push_manager: Arc<ArcSwap<PushManager>>,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
is_stream_closed: Arc<AtomicBool>,
}
}

impl<T> PipelineSink<T>
where
T: Stream<Item = RedisResult<Value>> + 'static,
{
fn new<SinkItem>(sink_stream: T, push_manager: Arc<ArcSwap<PushManager>>) -> Self
fn new<SinkItem>(
sink_stream: T,
push_manager: Arc<ArcSwap<PushManager>>,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
is_stream_closed: Arc<AtomicBool>,
) -> Self
where
T: Sink<SinkItem, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
Expand All @@ -120,6 +122,8 @@ where
in_flight: VecDeque::new(),
error: None,
push_manager,
disconnect_notifier,
is_stream_closed,
}
}

Expand All @@ -130,7 +134,15 @@ where
Some(result) => result,
// The redis response stream is not going to produce any more items so we `Err`
// to break out of the `forward` combinator and stop handling requests
None => return Poll::Ready(Err(())),
None => {
// this is the right place to notify about the passive TCP disconnect
// In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect
if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier {
disconnect_notifier.notify_disconnect();
}
self.is_stream_closed.store(true, Ordering::Relaxed);
return Poll::Ready(Err(()));
}
};
self.as_mut().send_result(item);
}
Expand Down Expand Up @@ -296,7 +308,10 @@ impl<SinkItem> Pipeline<SinkItem>
where
SinkItem: Send + 'static,
{
fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
fn new<T>(
sink_stream: T,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
T: Send + 'static,
Expand All @@ -308,7 +323,13 @@ where
let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
let push_manager: Arc<ArcSwap<PushManager>> =
Arc::new(ArcSwap::new(Arc::new(PushManager::default())));
let sink = PipelineSink::new::<SinkItem>(sink_stream, push_manager.clone());
let is_stream_closed = Arc::new(AtomicBool::new(false));
let sink = PipelineSink::new::<SinkItem>(
sink_stream,
push_manager.clone(),
disconnect_notifier,
is_stream_closed.clone(),
);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(sink)
Expand All @@ -317,6 +338,7 @@ where
Pipeline {
sender,
push_manager,
is_stream_closed,
},
f,
)
Expand Down Expand Up @@ -363,6 +385,10 @@ where
async fn set_push_manager(&mut self, push_manager: PushManager) {
self.push_manager.store(Arc::new(push_manager));
}

pub fn is_closed(&self) -> bool {
self.is_stream_closed.load(Ordering::Relaxed)
}
}

/// A connection object which can be cloned, allowing requests to be be sent concurrently
Expand Down Expand Up @@ -391,7 +417,7 @@ impl MultiplexedConnection {
pub async fn new<C>(
connection_info: &ConnectionInfo,
stream: C,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
glide_connection_options: GlideConnectionOptions,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
Expand All @@ -400,7 +426,7 @@ impl MultiplexedConnection {
connection_info,
stream,
std::time::Duration::MAX,
push_sender,
glide_connection_options,
)
.await
}
Expand All @@ -411,7 +437,7 @@ impl MultiplexedConnection {
connection_info: &ConnectionInfo,
stream: C,
response_timeout: std::time::Duration,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
glide_connection_options: GlideConnectionOptions,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
Expand All @@ -429,10 +455,11 @@ impl MultiplexedConnection {
let codec = ValueCodec::default()
.framed(stream)
.and_then(|msg| async move { msg });
let (mut pipeline, driver) = Pipeline::new(codec);
let (mut pipeline, driver) =
Pipeline::new(codec, glide_connection_options.disconnect_notifier);
let driver = boxed(driver);
let pm = PushManager::default();
if let Some(sender) = push_sender {
if let Some(sender) = glide_connection_options.push_sender {
pm.replace_sender(sender);
}

Expand Down Expand Up @@ -560,6 +587,10 @@ impl ConnectionLike for MultiplexedConnection {
fn get_db(&self) -> i64 {
self.db
}

fn is_closed(&self) -> bool {
self.pipeline.is_closed()
}
}
impl MultiplexedConnection {
/// Subscribes to a new channel.
Expand Down
Loading

0 comments on commit 426bb99

Please sign in to comment.