diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index af32f585..360b21ea 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -10,9 +10,8 @@ repository = "https://github.com/sugyan/atrium" license = "MIT" keywords = ["atproto", "bluesky"] - [dependencies] -atrium-xrpc = "0.5.0" +atrium-xrpc = "0.7.0" async-trait = "0.1.68" cid = { version = "0.10.1", features = ["serde-codec"] } http = "0.2.9" diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 6c2991b2..a7bee0a1 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -6,8 +6,8 @@ use atrium_xrpc::error::{Error, XrpcErrorKind}; use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult}; use http::{Method, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; -use std::sync::{Arc, RwLock}; -use tokio::sync::{Mutex, Notify}; +use std::sync::Arc; +use tokio::sync::{Mutex, Notify, RwLock}; /// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) pub type Session = crate::com::atproto::server::create_session::Output; @@ -33,6 +33,7 @@ where } } +#[async_trait] impl XrpcClient for SessionAuthWrapper where T: XrpcClient + Send + Sync, @@ -40,15 +41,13 @@ where fn base_uri(&self) -> &str { self.inner.base_uri() } - fn auth(&self, is_refresh: bool) -> Option { - self.session.read().ok().and_then(|lock| { - lock.as_ref().map(|session| { - if is_refresh { - session.refresh_jwt.clone() - } else { - session.access_jwt.clone() - } - }) + async fn auth(&self, is_refresh: bool) -> Option { + self.session.read().await.as_ref().map(|session| { + if is_refresh { + session.refresh_jwt.clone() + } else { + session.access_jwt.clone() + } }) } } @@ -85,10 +84,7 @@ where } async fn refresh_session_inner(&self) { if let Ok(output) = self.call_refresh_session().await { - let mut session = self - .session - .write() - .expect("write lock on session should not be poisoned"); + let mut session = self.session.write().await; let did_doc = session.as_ref().and_then(|s| s.did_doc.clone()); let email = session.as_ref().and_then(|s| s.email.clone()); let email_confirmed = session.as_ref().and_then(|s| s.email_confirmed); @@ -102,10 +98,7 @@ where refresh_jwt: output.refresh_jwt, }); } else { - self.session - .write() - .expect("write lock on session should not be poisoned") - .take(); + self.session.write().await.take(); } } // same as `crate::client::com::atproto::server::Service::refresh_session()` @@ -167,8 +160,8 @@ where fn base_uri(&self) -> &str { self.inner.base_uri() } - fn auth(&self, is_refresh: bool) -> Option { - self.inner.auth(is_refresh) + async fn auth(&self, is_refresh: bool) -> Option { + self.inner.auth(is_refresh).await } async fn send_xrpc(&self, request: &XrpcRequest) -> XrpcResult where @@ -213,11 +206,8 @@ where })); Self { api, session } } - pub fn get_session(&self) -> Option { - self.session - .read() - .expect("read lock on session should not be poisoned") - .clone() + pub async fn get_session(&self) -> Option { + self.session.read().await.clone() } /// Start a new session with this agent. pub async fn login( @@ -235,10 +225,7 @@ where password: password.into(), }) .await?; - self.session - .write() - .expect("write lock on session should not be poisoned") - .replace(result.clone()); + self.session.write().await.replace(result.clone()); Ok(result) } /// Resume a pre-existing session with this agent. @@ -246,20 +233,12 @@ where &self, session: Session, ) -> Result<(), Error> { - self.session - .write() - .expect("write lock on session should not be poisoned") - .replace(session.clone()); + self.session.write().await.replace(session.clone()); let result = self.api.com.atproto.server.get_session().await; match result { Ok(output) => { assert_eq!(output.did, session.did); - if let Some(session) = self - .session - .write() - .expect("write lock on session should not be poisoned") - .as_mut() - { + if let Some(session) = self.session.write().await.as_mut() { session.email = output.email; session.email_confirmed = output.email_confirmed; session.handle = output.handle; @@ -267,10 +246,7 @@ where Ok(()) } Err(err) => { - self.session - .write() - .expect("write lock on session should not be poisoned") - .take(); + self.session.write().await.take(); Err(err) } } @@ -319,12 +295,7 @@ mod tests { } let mut body = Vec::new(); if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") { - *self - .counts - .write() - .expect("write lock on counts should not be poisoned") - .entry(nsid.into()) - .or_default() += 1; + *self.counts.write().await.entry(nsid.into()).or_default() += 1; match nsid { "com.atproto.server.createSession" => { if let Some(output) = &self.responses.create_session { @@ -387,10 +358,10 @@ mod tests { } } - #[test] - fn test_new() { + #[tokio::test] + async fn test_new() { let agent = AtpAgent::new(DummyClient::default()); - assert_eq!(agent.get_session(), None); + assert_eq!(agent.get_session().await, None); } #[tokio::test] @@ -412,7 +383,7 @@ mod tests { .login("test", "pass") .await .expect("login should be succeeded"); - assert_eq!(agent.get_session(), Some(session)); + assert_eq!(agent.get_session().await, Some(session)); } // failure with `createSession` error { @@ -427,7 +398,7 @@ mod tests { .login("test", "bad") .await .expect_err("login should be failed"); - assert_eq!(agent.get_session(), None); + assert_eq!(agent.get_session().await, None); } } @@ -448,7 +419,7 @@ mod tests { ..Default::default() }; let agent = AtpAgent::new(client); - agent.session.write().unwrap().replace(session); + agent.session.write().await.replace(session); let output = agent .api .com @@ -478,7 +449,7 @@ mod tests { ..Default::default() }; let agent = AtpAgent::new(client); - agent.session.write().unwrap().replace(session); + agent.session.write().await.replace(session); let output = agent .api .com @@ -489,7 +460,7 @@ mod tests { .expect("get session should be succeeded"); assert_eq!(output.did, "did"); assert_eq!( - agent.get_session().map(|session| session.access_jwt), + agent.get_session().await.map(|session| session.access_jwt), Some("access".into()) ); } @@ -513,7 +484,7 @@ mod tests { }; let counts = Arc::clone(&client.counts); let agent = Arc::new(AtpAgent::new(client)); - agent.session.write().unwrap().replace(session); + agent.session.write().await.replace(session); let handles = (0..3).map(|_| { let agent = Arc::clone(&agent); tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) @@ -528,14 +499,11 @@ mod tests { assert_eq!(output.did, "did"); } assert_eq!( - agent.get_session().map(|session| session.access_jwt), + agent.get_session().await.map(|session| session.access_jwt), Some("access".into()) ); assert_eq!( - counts - .read() - .expect("read lock on counts should not be poisoned") - .clone(), + counts.read().await.clone(), HashMap::from_iter([ ("com.atproto.server.refreshSession".into(), 1), ("com.atproto.server.getSession".into(), 3) @@ -562,7 +530,7 @@ mod tests { ..Default::default() }; let agent = AtpAgent::new(client); - assert_eq!(agent.get_session(), None); + assert_eq!(agent.get_session().await, None); agent .resume_session(Session { email: Some(String::from("test@example.com")), @@ -570,7 +538,7 @@ mod tests { }) .await .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session(), Some(session.clone())); + assert_eq!(agent.get_session().await, Some(session.clone())); } // failure with `getSession` error { @@ -581,12 +549,12 @@ mod tests { ..Default::default() }; let agent = AtpAgent::new(client); - assert_eq!(agent.get_session(), None); + assert_eq!(agent.get_session().await, None); agent .resume_session(session) .await .expect_err("resume_session should be failed"); - assert_eq!(agent.get_session(), None); + assert_eq!(agent.get_session().await, None); } } @@ -614,6 +582,6 @@ mod tests { }) .await .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session(), Some(session)); + assert_eq!(agent.get_session().await, Some(session)); } }