diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index 5423668d..d08a7720 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -18,6 +18,7 @@ cid = { version = "0.10.1", features = ["serde-codec"] } http = "0.2.9" serde = { version = "1.0.160", features = ["derive"] } serde_bytes = "0.11.9" +tokio = { version = "1.33.0", features = ["sync"] } [dev-dependencies] futures = "0.3.28" diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 1956bde4..b2936809 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -7,6 +7,7 @@ use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcRe use http::{Method, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; use std::sync::{Arc, RwLock}; +use tokio::sync::{Mutex, Notify}; /// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) pub type Session = crate::com::atproto::server::create_session::Output; @@ -58,14 +59,31 @@ where { session: Arc>>, inner: T, + is_refreshing: Arc>, + notify: Arc, } impl RefreshWrapper where T: XrpcClient + Send + Sync, { + // Internal helper to refresh sessions + // - Wraps the actual implementation to ensure only one refresh is attempted at a time. async fn refresh_session(&self) { - // TODO: Wraps the actual implementation in a promise-guard to ensure only one refresh is attempted at a time. + { + let mut is_refreshing = self.is_refreshing.lock().await; + if *is_refreshing { + drop(is_refreshing); + return self.notify.notified().await; + } + *is_refreshing = true; + } + // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. + self.refresh_session_inner().await; + *self.is_refreshing.lock().await = false; + self.notify.notify_waiters(); + } + async fn refresh_session_inner(&self) { if let Ok(output) = self.call_refresh_session().await { let mut session = self .session @@ -83,8 +101,14 @@ where handle: output.handle, refresh_jwt: output.refresh_jwt, }); + } else { + self.session + .write() + .expect("write lock on session should not be poisoned") + .take(); } } + // same as `crate::client::com::atproto::server::Service::refresh_session()` async fn call_refresh_session( &self, ) -> Result< @@ -154,6 +178,7 @@ where E: DeserializeOwned + Send + Sync, { let result = self.inner.send_xrpc(request).await; + // handle session-refreshes as needed if Self::is_expired(&result) { self.refresh_session().await; self.inner.send_xrpc(request).await @@ -183,6 +208,8 @@ where session: Arc::clone(&session), inner: xrpc, }, + is_refreshing: Arc::new(Mutex::new(false)), + notify: Arc::new(Notify::new()), })); Self { api, session } } @@ -253,6 +280,7 @@ where #[cfg(test)] mod tests { use super::*; + use futures::future::join_all; use std::collections::HashMap; #[derive(Default)] @@ -464,54 +492,53 @@ mod tests { ); } - // TODO: fix this test - // #[tokio::test] - // async fn test_xrpc_get_session_with_duplicated_refresh() { - // let mut session = session(); - // session.access_jwt = String::from("expired"); - // let client = DummyClient { - // responses: DummyResponses { - // get_session: Some(crate::com::atproto::server::get_session::Output { - // did: session.did.clone(), - // email: session.email.clone(), - // email_confirmed: session.email_confirmed, - // handle: session.handle.clone(), - // }), - // ..Default::default() - // }, - // ..Default::default() - // }; - // let counts = Arc::clone(&client.counts); - // let agent = Arc::new(AtpAgent::new(client)); - // agent.session.write().unwrap().replace(session); - // let handles = (0..3).map(|_| { - // let agent = Arc::clone(&agent); - // tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) - // }); - // let results = join_all(handles).await; - // for result in &results { - // let output = result - // .as_ref() - // .expect("task should be successfully executed") - // .as_ref() - // .expect("get session should be succeeded"); - // assert_eq!(output.did, "did"); - // } - // assert_eq!( - // agent.get_session().map(|session| session.access_jwt), - // Some("access".into()) - // ); - // assert_eq!( - // counts - // .read() - // .expect("read lock on counts should not be poisoned") - // .clone(), - // HashMap::from_iter([ - // ("com.atproto.server.refreshSession".into(), 1), - // ("com.atproto.server.getSession".into(), 3) - // ]) - // ); - // } + #[tokio::test] + async fn test_xrpc_get_session_with_duplicated_refresh() { + let mut session = session(); + session.access_jwt = String::from("expired"); + let client = DummyClient { + responses: DummyResponses { + get_session: Some(crate::com::atproto::server::get_session::Output { + did: session.did.clone(), + email: session.email.clone(), + email_confirmed: session.email_confirmed, + handle: session.handle.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let counts = Arc::clone(&client.counts); + let agent = Arc::new(AtpAgent::new(client)); + agent.session.write().unwrap().replace(session); + let handles = (0..3).map(|_| { + let agent = Arc::clone(&agent); + tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) + }); + let results = join_all(handles).await; + for result in &results { + let output = result + .as_ref() + .expect("task should be successfully executed") + .as_ref() + .expect("get session should be succeeded"); + assert_eq!(output.did, "did"); + } + assert_eq!( + agent.get_session().map(|session| session.access_jwt), + Some("access".into()) + ); + assert_eq!( + counts + .read() + .expect("read lock on counts should not be poisoned") + .clone(), + HashMap::from_iter([ + ("com.atproto.server.refreshSession".into(), 1), + ("com.atproto.server.getSession".into(), 3) + ]) + ); + } #[tokio::test] async fn test_resume_session() {