diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index b203526a2..1fde5eaf1 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -40,6 +40,7 @@ lazy_static = "1.4" base64 = "0.13" tokio-retry = "0.3" ring = "0.16" +rand = "0.8" [dev-dependencies] criterion = { version = "0.3", features = ["async_tokio"]} diff --git a/async-nats/src/connector.rs b/async-nats/src/connector.rs index 775cf8df3..7f178292a 100644 --- a/async-nats/src/connector.rs +++ b/async-nats/src/connector.rs @@ -28,6 +28,8 @@ use crate::ToServerAddrs; use crate::LANG; use crate::VERSION; use bytes::BytesMut; +use rand::seq::SliceRandom; +use rand::thread_rng; use std::cmp; use std::collections::HashMap; use std::io; @@ -99,8 +101,13 @@ impl Connector { pub(crate) async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), io::Error> { let mut error = None; + let server_addrs = { + let mut rng = thread_rng(); + let mut server_addrs: Vec = self.servers.clone().into_keys().collect(); + server_addrs.shuffle(&mut rng); + server_addrs + }; - let server_addrs: Vec = self.servers.keys().cloned().collect(); for server_addr in server_addrs { let server_attempts = self.servers.get_mut(&server_addr).unwrap(); let duration = if *server_attempts == 0 { @@ -111,7 +118,6 @@ impl Connector { cmp::min(Duration::from_millis(2_u64.saturating_pow(exp)), max) }; - *server_attempts += 1; sleep(duration).await; diff --git a/async-nats/tests/jwt_tests.rs b/async-nats/tests/jwt_tests.rs index a85bdc465..18463e87c 100644 --- a/async-nats/tests/jwt_tests.rs +++ b/async-nats/tests/jwt_tests.rs @@ -12,6 +12,7 @@ // limitations under the License. mod client { + use async_nats::Event; use futures::stream::StreamExt; use std::path::PathBuf; @@ -34,6 +35,47 @@ mod client { .await .expect("published"); } + + #[cfg(not(target_os = "windows"))] + #[tokio::test] + async fn jwt_lame_duck_reconnect() { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let c = nats_server::run_cluster("tests/configs/jwt.conf"); + + let (tx_recconect, mut rx_reconnect) = tokio::sync::mpsc::channel(10); + + let client = async_nats::ConnectOptions::with_credentials_file( + path.join("tests/configs/TestUser.creds"), + ) + .await + .unwrap() + .event_callback({ + let tx = tx_recconect.clone(); + move |event| { + let tx = tx.clone(); + async move { + if event == Event::Connected { + tx.send(()).await.unwrap(); + } + } + } + }) + .connect(c.client_url()) + .await + .unwrap(); + + let mut subscriber = client.subscribe("test".into()).await.unwrap(); + for i in 0..2 { + rx_reconnect.recv().await; + let mut subscribe = client.subscribe("test".into()).await.unwrap(); + client.publish("test".into(), "data".into()).await.unwrap(); + subscribe.next().await.unwrap(); + client.flush().await.unwrap(); + assert!(subscriber.next().await.is_some()); + nats_server::set_lame_duck_mode(&c.servers[i]); + } + } + #[tokio::test] async fn jwt_reconnect() { use async_nats::ServerAddr;