Skip to content

Commit

Permalink
feat(api): Implement refresh_session wrapper (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan authored Nov 5, 2023
1 parent 78f7495 commit e4e7265
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 49 deletions.
1 change: 1 addition & 0 deletions atrium-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cid = { version = "0.10.1", features = ["serde-codec"] }
http = "0.2.9"
serde = { version = "1.0.160", features = ["derive"] }
serde_bytes = "0.11.9"
tokio = { version = "1.33.0", features = ["sync"] }

[dev-dependencies]
futures = "0.3.28"
Expand Down
125 changes: 76 additions & 49 deletions atrium-api/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcRe
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::{Arc, RwLock};
use tokio::sync::{Mutex, Notify};

/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output)
pub type Session = crate::com::atproto::server::create_session::Output;
Expand Down Expand Up @@ -58,14 +59,31 @@ where
{
session: Arc<RwLock<Option<Session>>>,
inner: T,
is_refreshing: Arc<Mutex<bool>>,
notify: Arc<Notify>,
}

impl<T> RefreshWrapper<T>
where
T: XrpcClient + Send + Sync,
{
// Internal helper to refresh sessions
// - Wraps the actual implementation to ensure only one refresh is attempted at a time.
async fn refresh_session(&self) {
// TODO: Wraps the actual implementation in a promise-guard to ensure only one refresh is attempted at a time.
{
let mut is_refreshing = self.is_refreshing.lock().await;
if *is_refreshing {
drop(is_refreshing);
return self.notify.notified().await;
}
*is_refreshing = true;
}
// TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`.
self.refresh_session_inner().await;
*self.is_refreshing.lock().await = false;
self.notify.notify_waiters();
}
async fn refresh_session_inner(&self) {
if let Ok(output) = self.call_refresh_session().await {
let mut session = self
.session
Expand All @@ -83,8 +101,14 @@ where
handle: output.handle,
refresh_jwt: output.refresh_jwt,
});
} else {
self.session
.write()
.expect("write lock on session should not be poisoned")
.take();
}
}
// same as `crate::client::com::atproto::server::Service::refresh_session()`
async fn call_refresh_session(
&self,
) -> Result<
Expand Down Expand Up @@ -154,6 +178,7 @@ where
E: DeserializeOwned + Send + Sync,
{
let result = self.inner.send_xrpc(request).await;
// handle session-refreshes as needed
if Self::is_expired(&result) {
self.refresh_session().await;
self.inner.send_xrpc(request).await
Expand Down Expand Up @@ -183,6 +208,8 @@ where
session: Arc::clone(&session),
inner: xrpc,
},
is_refreshing: Arc::new(Mutex::new(false)),
notify: Arc::new(Notify::new()),
}));
Self { api, session }
}
Expand Down Expand Up @@ -253,6 +280,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use futures::future::join_all;
use std::collections::HashMap;

#[derive(Default)]
Expand Down Expand Up @@ -464,54 +492,53 @@ mod tests {
);
}

// TODO: fix this test
// #[tokio::test]
// async fn test_xrpc_get_session_with_duplicated_refresh() {
// let mut session = session();
// session.access_jwt = String::from("expired");
// let client = DummyClient {
// responses: DummyResponses {
// get_session: Some(crate::com::atproto::server::get_session::Output {
// did: session.did.clone(),
// email: session.email.clone(),
// email_confirmed: session.email_confirmed,
// handle: session.handle.clone(),
// }),
// ..Default::default()
// },
// ..Default::default()
// };
// let counts = Arc::clone(&client.counts);
// let agent = Arc::new(AtpAgent::new(client));
// agent.session.write().unwrap().replace(session);
// let handles = (0..3).map(|_| {
// let agent = Arc::clone(&agent);
// tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
// });
// let results = join_all(handles).await;
// for result in &results {
// let output = result
// .as_ref()
// .expect("task should be successfully executed")
// .as_ref()
// .expect("get session should be succeeded");
// assert_eq!(output.did, "did");
// }
// assert_eq!(
// agent.get_session().map(|session| session.access_jwt),
// Some("access".into())
// );
// assert_eq!(
// counts
// .read()
// .expect("read lock on counts should not be poisoned")
// .clone(),
// HashMap::from_iter([
// ("com.atproto.server.refreshSession".into(), 1),
// ("com.atproto.server.getSession".into(), 3)
// ])
// );
// }
#[tokio::test]
async fn test_xrpc_get_session_with_duplicated_refresh() {
let mut session = session();
session.access_jwt = String::from("expired");
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client));
agent.session.write().unwrap().replace(session);
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
});
let results = join_all(handles).await;
for result in &results {
let output = result
.as_ref()
.expect("task should be successfully executed")
.as_ref()
.expect("get session should be succeeded");
assert_eq!(output.did, "did");
}
assert_eq!(
agent.get_session().map(|session| session.access_jwt),
Some("access".into())
);
assert_eq!(
counts
.read()
.expect("read lock on counts should not be poisoned")
.clone(),
HashMap::from_iter([
("com.atproto.server.refreshSession".into(), 1),
("com.atproto.server.getSession".into(), 3)
])
);
}

#[tokio::test]
async fn test_resume_session() {
Expand Down

0 comments on commit e4e7265

Please sign in to comment.