Skip to content

Commit

Permalink
Added a check for DNS updates
Browse files Browse the repository at this point in the history
  • Loading branch information
barshaul committed Aug 29, 2023
1 parent d7d96a7 commit b9c6b2e
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 23 deletions.
4 changes: 4 additions & 0 deletions redis-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ impl AioConnectionLike for MockRedisConnection {
fn get_db(&self) -> i64 {
0
}

fn get_ip(&self) -> Option<String> {
None
}
}

#[cfg(test)]
Expand Down
47 changes: 37 additions & 10 deletions redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use ::tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use ::tokio::net::lookup_host;
use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset};
use futures_util::future::select_ok;

use futures_util::{
future::FutureExt,
stream::{Stream, StreamExt},
Expand All @@ -36,6 +37,7 @@ pub struct Connection<C = Pin<Box<dyn AsyncStream + Send + Sync>>> {
// This flag is checked when attempting to send a command, and if it's raised, we attempt to
// exit the pubsub state before executing the new request.
pubsub: bool,
ip: Option<String>,
}

fn assert_sync<T: Sync>() {}
Expand All @@ -53,13 +55,15 @@ impl<C> Connection<C> {
decoder,
db,
pubsub,
ip,
} = self;
Connection {
con: f(con),
buf,
decoder,
db,
pubsub,
ip,
}
}
}
Expand All @@ -70,13 +74,18 @@ where
{
/// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object
/// and a `RedisConnectionInfo`
pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
pub async fn new(
connection_info: &RedisConnectionInfo,
con: C,
ip: Option<String>,
) -> RedisResult<Self> {
let mut rv = Connection {
con,
buf: Vec::new(),
decoder: combine::stream::Decoder::new(),
db: connection_info.db,
pubsub: false,
ip,
};
authenticate(connection_info, &mut rv).await?;
Ok(rv)
Expand Down Expand Up @@ -172,16 +181,16 @@ where
/// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object
/// and a `RedisConnectionInfo`
pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await
Connection::new(connection_info, async_std::AsyncStdWrapped::new(con), None).await
}
}

pub(crate) async fn connect<C>(connection_info: &ConnectionInfo) -> RedisResult<Connection<C>>
where
C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
{
let con = connect_simple::<C>(connection_info).await?;
Connection::new(&connection_info.redis, con).await
let (con, ip) = connect_simple::<C>(connection_info).await?;
Connection::new(&connection_info.redis, con, ip).await
}

impl<C> ConnectionLike for Connection<C>
Expand Down Expand Up @@ -254,6 +263,10 @@ where
fn get_db(&self) -> i64 {
self.db
}

fn get_ip(&self) -> Option<String> {
todo!()
}
}

/// Represents a `PubSub` connection.
Expand Down Expand Up @@ -383,11 +396,20 @@ pub(crate) async fn get_socket_addrs(

pub(crate) async fn connect_simple<T: RedisRuntime>(
connection_info: &ConnectionInfo,
) -> RedisResult<T> {
) -> RedisResult<(T, Option<String>)> {
Ok(match connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(socket_addrs.map(<T>::connect_tcp)).await?.0
select_ok(socket_addrs.map(|socket_addr| {
Box::pin(async move {
Ok::<_, RedisError>((
<T>::connect_tcp(socket_addr).await?,
Some(socket_addr.ip().to_string()),
))
})
}))
.await?
.0
}

#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
Expand All @@ -397,9 +419,14 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
insecure,
} => {
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(
socket_addrs.map(|socket_addr| <T>::connect_tcp_tls(host, socket_addr, insecure)),
)
select_ok(socket_addrs.map(|socket_addr| {
Box::pin(async move {
Ok::<_, RedisError>((
<T>::connect_tcp_tls(host, socket_addr, insecure).await?,
Some(socket_addr.ip().to_string()),
))
})
}))
.await?
.0
}
Expand All @@ -413,7 +440,7 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
}

#[cfg(unix)]
ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
ConnectionAddr::Unix(ref path) => (<T>::connect_unix(path).await?, None),

#[cfg(not(unix))]
ConnectionAddr::Unix(_) => {
Expand Down
4 changes: 4 additions & 0 deletions redis/src/aio/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,8 @@ impl ConnectionLike for ConnectionManager {
fn get_db(&self) -> i64 {
self.client.connection_info().redis.db
}

fn get_ip(&self) -> Option<String> {
None
}
}
4 changes: 4 additions & 0 deletions redis/src/aio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ pub trait ConnectionLike {
/// also might be incorrect if the connection like object is not
/// actually connected.
fn get_db(&self) -> i64;

/// Returns the connection's IP if it's a TCP connection and the connection is established,
/// otherwise returns None.
fn get_ip(&self) -> Option<String>;
}

async fn authenticate<C>(connection_info: &RedisConnectionInfo, con: &mut C) -> RedisResult<()>
Expand Down
15 changes: 11 additions & 4 deletions redis/src/aio/multiplexed_connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::ConnectionLike;
use crate::aio::authenticate;
use crate::cmd::Cmd;
use crate::connection::RedisConnectionInfo;
use crate::connection::ConnectionInfo;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::types::{RedisError, RedisFuture, RedisResult, Value};
Expand Down Expand Up @@ -311,6 +311,7 @@ where
pub struct MultiplexedConnection {
pipeline: Pipeline<Vec<u8>, Value, RedisError>,
db: i64,
ip: Option<String>,
}

impl Debug for MultiplexedConnection {
Expand All @@ -326,8 +327,9 @@ impl MultiplexedConnection {
/// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
/// and a `ConnectionInfo`
pub async fn new<C>(
connection_info: &RedisConnectionInfo,
connection_info: &ConnectionInfo,
stream: C,
ip: Option<String>,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
Expand All @@ -348,10 +350,11 @@ impl MultiplexedConnection {
let driver = boxed(driver);
let mut con = MultiplexedConnection {
pipeline,
db: connection_info.db,
db: connection_info.redis.db,
ip,
};
let driver = {
let auth = authenticate(connection_info, &mut con);
let auth = authenticate(&connection_info.redis, &mut con);
futures_util::pin_mut!(auth);

match futures_util::future::select(auth, driver).await {
Expand Down Expand Up @@ -419,4 +422,8 @@ impl ConnectionLike for MultiplexedConnection {
fn get_db(&self) -> i64 {
self.db
}

fn get_ip(&self) -> Option<String> {
self.ip.clone()
}
}
18 changes: 10 additions & 8 deletions redis/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Client {
/// Returns an async connection from the client.
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
pub async fn get_async_connection(&self) -> RedisResult<crate::aio::Connection> {
let con = match Runtime::locate() {
let (con, ip) = match Runtime::locate() {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => {
self.get_simple_async_connection::<crate::aio::tokio::Tokio>()
Expand All @@ -85,7 +85,7 @@ impl Client {
}
};

crate::aio::Connection::new(&self.connection_info.redis, con).await
crate::aio::Connection::new(&self.connection_info.redis, con, ip).await
}

/// Returns an async connection from the client.
Expand Down Expand Up @@ -268,19 +268,21 @@ impl Client {
where
T: crate::aio::RedisRuntime,
{
let con = self.get_simple_async_connection::<T>().await?;
crate::aio::MultiplexedConnection::new(&self.connection_info.redis, con).await
let (con, ip) = self.get_simple_async_connection::<T>().await?;
crate::aio::MultiplexedConnection::new(&self.connection_info, con, ip).await
}

async fn get_simple_async_connection<T>(
&self,
) -> RedisResult<Pin<Box<dyn crate::aio::AsyncStream + Send + Sync>>>
) -> RedisResult<(
Pin<Box<dyn crate::aio::AsyncStream + Send + Sync>>,
Option<String>,
)>
where
T: crate::aio::RedisRuntime,
{
Ok(crate::aio::connect_simple::<T>(&self.connection_info)
.await?
.boxed())
let (conn, ip) = crate::aio::connect_simple::<T>(&self.connection_info).await?;
Ok((conn.boxed(), ip))
}

#[cfg(feature = "connection-manager")]
Expand Down
33 changes: 32 additions & 1 deletion redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,13 +1110,40 @@ where
}
}

/// This function takes a node's address, examines if its host has encountered a DNS change, where the node's endpoint now leads to a different IP address.
/// If no socket addresses are discovered for the node's host address, or if it's a non-DNS address, it returns false.
/// In case the node's host address resolves to socket addresses and none of them match the current connection's IP, a DNS change is detected, resulting in a true return.
async fn is_dns_changed(addr: &str, curr_ip: &String) -> bool {
let (host, port) = match get_host_and_port_from_addr(addr) {
Some((host, port)) => (host, port),
None => return false,
};
let updated_addresses = match get_socket_addrs(host, port).await {
Ok(socket_addrs) => socket_addrs,
Err(_) => return false,
};
for socket_addr in updated_addresses {
if socket_addr.ip().to_string() == *curr_ip {
return false;
}
}
true
}

async fn get_or_create_conn(
addr: &str,
conn_option: Option<ConnectionFuture<C>>,
params: &ClusterParams,
) -> RedisResult<C> {
if let Some(conn) = conn_option {
let mut conn = conn.await;
if let Some(ip) = conn.get_ip() {
// Check for a DNS change
if Self::is_dns_changed(addr, &ip).await {
// A DNS change is detected, create a new connection
return connect_and_check(addr, params.clone()).await;
}
};
match check_connection(&mut conn, params.connection_timeout.into()).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params.clone()).await,
Expand Down Expand Up @@ -1289,7 +1316,12 @@ where
fn get_db(&self) -> i64 {
0
}

fn get_ip(&self) -> Option<String> {
None
}
}

/// Implements the process of connecting to a Redis server
/// and obtaining a connection handle.
pub trait Connect: Sized {
Expand Down Expand Up @@ -1338,7 +1370,6 @@ async fn check_connection<C>(conn: &mut C, timeout: futures_time::time::Duration
where
C: ConnectionLike + Send + 'static,
{
// TODO: Add a check to re-resolve DNS addresses to verify we that we have a connection to the right node
crate::cmd("PING")
.query_async::<_, String>(conn)
.timeout(timeout)
Expand Down
7 changes: 7 additions & 0 deletions redis/tests/support/mock_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static HANDLERS: Lazy<RwLock<HashMap<String, Handler>>> = Lazy::new(Default::def
pub struct MockConnection {
pub handler: Handler,
pub port: u16,
pub ip: String,
}

#[cfg(feature = "cluster-async")]
Expand All @@ -50,6 +51,7 @@ impl cluster_async::Connect for MockConnection {
.unwrap_or_else(|| panic!("Handler `{name}` were not installed"))
.clone(),
port,
ip: name.clone(),
}))
}
}
Expand All @@ -73,6 +75,7 @@ impl cluster::Connect for MockConnection {
.unwrap_or_else(|| panic!("Handler `{name}` were not installed"))
.clone(),
port,
ip: name.clone(),
})
}

Expand Down Expand Up @@ -225,6 +228,10 @@ impl aio::ConnectionLike for MockConnection {
fn get_db(&self) -> i64 {
0
}

fn get_ip(&self) -> Option<String> {
None
}
}

impl redis::ConnectionLike for MockConnection {
Expand Down
4 changes: 4 additions & 0 deletions redis/tests/test_cluster_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ impl ConnectionLike for ErrorConnection {
fn get_db(&self) -> i64 {
self.inner.get_db()
}

fn get_ip(&self) -> Option<String> {
None
}
}

#[test]
Expand Down

0 comments on commit b9c6b2e

Please sign in to comment.