diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index eff18b5b..03147540 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -21,3 +21,4 @@ serde_bytes = "0.11.9" [dev-dependencies] tokio = { version = "1.28.0", features = ["macros", "rt-multi-thread"] } +serde_json = "1.0.96" diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 76e344bf..63cdcf54 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -1,13 +1,12 @@ //! An ATP "Agent". //! Manages session token lifecycles and provides all XRPC methods. +use crate::client::AtpServiceClient; use async_trait::async_trait; +use atrium_xrpc::error::Error; use atrium_xrpc::{HttpClient, XrpcClient}; use http::{Request, Response}; -use std::error::Error; use std::sync::{Arc, RwLock}; -use crate::client::AtpServiceClient; - /// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) pub type Session = crate::com::atproto::server::create_session::Output; @@ -27,7 +26,7 @@ where async fn send( &self, req: Request>, - ) -> Result>, Box> { + ) -> Result>, Box> { HttpClient::send(&self.xrpc, req).await } } @@ -74,9 +73,168 @@ where let api = AtpServiceClient::new(Arc::new(base)); Self { api, session } } - pub fn set_session(&mut self, session: Session) { - if let Ok(mut lock) = self.session.write() { - *lock = Some(session); + pub fn get_session(&self) -> Option { + self.session.read().expect("read lock").clone() + } + pub async fn login( + &self, + identifier: &str, + password: &str, + ) -> Result> { + let result = self + .api + .com + .atproto + .server + .create_session(crate::com::atproto::server::create_session::Input { + identifier: identifier.into(), + password: password.into(), + }) + .await?; + self.session + .write() + .expect("write lock") + .replace(result.clone()); + Ok(result) + } + pub async fn resume_session( + &self, + session: Session, + ) -> Result<(), Error> { + self.session + .write() + .expect("write lock") + .replace(session.clone()); + match self.api.com.atproto.server.get_session().await { + Ok(result) => { + assert_eq!(result.did, session.did); + self.session.write().expect("write lock").replace(Session { + email: result.email, + handle: result.handle, + ..session + }); + Ok(()) + } + Err(err) => { + self.session.write().expect("write lock").take(); + Err(err) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct DummyClient { + session: Option, + } + + #[async_trait] + impl HttpClient for DummyClient { + async fn send( + &self, + _req: Request>, + ) -> Result>, Box> { + let builder = + Response::builder().header(http::header::CONTENT_TYPE, "application/json"); + if let Some(session) = &self.session { + Ok(builder + .status(http::StatusCode::OK) + .body(serde_json::to_vec(&session)?)?) + } else { + Ok(builder + .status(http::StatusCode::UNAUTHORIZED) + .body(serde_json::to_vec( + &atrium_xrpc::error::ErrorResponseBody { + error: Some(String::from("AuthenticationRequired")), + message: Some(String::from("Invalid identifier or password")), + }, + )?)?) + } + } + } + + impl XrpcClient for DummyClient { + fn host(&self) -> &str { + "http://localhost:8080" + } + } + + #[test] + fn new_agent() { + let agent = AtpAgent::new(atrium_xrpc::client::reqwest::ReqwestClient::new( + "http://localhost:8080".into(), + )); + assert_eq!(agent.get_session(), None); + } + + #[tokio::test] + async fn login() { + let session = Session { + access_jwt: "access".into(), + did: "did".into(), + email: None, + handle: "handle".into(), + refresh_jwt: "refresh".into(), + }; + // success + { + let client = DummyClient { + session: Some(session.clone()), + }; + let agent = AtpAgent::new(client); + agent.login("test", "pass").await.expect("failed to login"); + assert_eq!(agent.get_session(), Some(session)); + } + // failure with `createSession` error + { + let client = DummyClient { session: None }; + let agent = AtpAgent::new(client); + agent + .login("test", "bad") + .await + .expect_err("should failed to login"); + assert_eq!(agent.get_session(), None); + } + } + + #[tokio::test] + async fn resume_session() { + let session = Session { + access_jwt: "access".into(), + did: "did".into(), + email: None, + handle: "handle".into(), + refresh_jwt: "refresh".into(), + }; + // success + { + let client = DummyClient { + session: Some(session.clone()), + }; + let agent = AtpAgent::new(client); + assert_eq!(agent.get_session(), None); + agent + .resume_session(Session { + email: Some(String::from("test@example.com")), + ..session.clone() + }) + .await + .expect("failed to resume session"); + assert_eq!(agent.get_session(), Some(session.clone())); + } + // failure with `getSession` error + { + let client = DummyClient { session: None }; + let agent = AtpAgent::new(client); + assert_eq!(agent.get_session(), None); + agent + .resume_session(session) + .await + .expect_err("should failed to resume session"); + assert_eq!(agent.get_session(), None); } } }