diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 3009c99b3..16ba80e50 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -28,6 +28,17 @@ mod connections_logic; pub mod testing { pub use super::connections_logic::*; } +use crate::{ + cluster_slotmap::SlotMap, + cluster_topology::SLOT_SIZE, + cmd, + commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, + FromRedisValue, InfoDict, +}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use async_std::task::{spawn, JoinHandle}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use futures::executor::block_on; use std::{ collections::{HashMap, HashSet}, fmt, io, mem, @@ -40,14 +51,8 @@ use std::{ task::{self, Poll}, time::SystemTime, }; - -use crate::{ - cluster_slotmap::SlotMap, - cluster_topology::SLOT_SIZE, - cmd, - commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, - FromRedisValue, InfoDict, -}; +#[cfg(feature = "tokio-comp")] +use tokio::task::JoinHandle; use crate::{ aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, @@ -88,7 +93,6 @@ use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; use dispose::{Disposable, Dispose}; use futures::{future::BoxFuture, prelude::*, ready}; use pin_project_lite::pin_project; -use std::sync::atomic::AtomicBool; use tokio::sync::{ mpsc, oneshot::{self, Receiver}, @@ -393,13 +397,18 @@ pub(crate) struct ClusterConnInner { #[allow(clippy::complexity)] in_flight_requests: stream::FuturesUnordered>>>, refresh_error: Option, - // A flag indicating the connection's closure and the requirement to shut down all related tasks. - shutdown_flag: Arc, + // Handler of the periodic check task. + periodic_checks_handler: Option>, } impl Dispose for ClusterConnInner { fn dispose(self) { - self.shutdown_flag.store(true, Ordering::Relaxed); + if let Some(handle) = self.periodic_checks_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } } } @@ -913,13 +922,12 @@ where ), subscriptions_by_address: RwLock::new(Default::default()), }); - let shutdown_flag = Arc::new(AtomicBool::new(false)); - let connection = ClusterConnInner { + let mut connection = ClusterConnInner { inner, in_flight_requests: Default::default(), refresh_error: None, state: ConnectionState::PollComplete, - shutdown_flag: shutdown_flag.clone(), + periodic_checks_handler: None, }; Self::refresh_slots_and_subscriptions_with_retries( connection.inner.clone(), @@ -928,15 +936,16 @@ where .await?; if let Some(duration) = topology_checks_interval { - let periodic_task = ClusterConnInner::periodic_topology_check( - connection.inner.clone(), - duration, - shutdown_flag, - ); + let periodic_task = + ClusterConnInner::periodic_topology_check(connection.inner.clone(), duration); #[cfg(feature = "tokio-comp")] - tokio::spawn(periodic_task); + { + connection.periodic_checks_handler = Some(tokio::spawn(periodic_task)); + } #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - AsyncStd::spawn(periodic_task); + { + connection.periodic_checks_handler = Some(spawn(periodic_task)); + } } Ok(Disposable::new(connection)) @@ -1307,15 +1316,8 @@ where topology_changed } - async fn periodic_topology_check( - inner: Arc>, - interval_duration: Duration, - shutdown_flag: Arc, - ) { + async fn periodic_topology_check(inner: Arc>, interval_duration: Duration) { loop { - if shutdown_flag.load(Ordering::Relaxed) { - return; - } let _ = boxed_sleep(interval_duration).await; let topology_changed = Self::check_topology_and_refresh_if_diff(inner.clone(), &RefreshPolicy::Throttable)