-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Update xprc, use tokio::sync::RwLock for agent (#76)
- Loading branch information
Showing
2 changed files
with
38 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,8 @@ use atrium_xrpc::error::{Error, XrpcErrorKind}; | |
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult}; | ||
use http::{Method, Request, Response}; | ||
use serde::{de::DeserializeOwned, Serialize}; | ||
use std::sync::{Arc, RwLock}; | ||
use tokio::sync::{Mutex, Notify}; | ||
use std::sync::Arc; | ||
use tokio::sync::{Mutex, Notify, RwLock}; | ||
|
||
/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) | ||
pub type Session = crate::com::atproto::server::create_session::Output; | ||
|
@@ -33,22 +33,21 @@ where | |
} | ||
} | ||
|
||
#[async_trait] | ||
impl<T> XrpcClient for SessionAuthWrapper<T> | ||
where | ||
T: XrpcClient + Send + Sync, | ||
{ | ||
fn base_uri(&self) -> &str { | ||
self.inner.base_uri() | ||
} | ||
fn auth(&self, is_refresh: bool) -> Option<String> { | ||
self.session.read().ok().and_then(|lock| { | ||
lock.as_ref().map(|session| { | ||
if is_refresh { | ||
session.refresh_jwt.clone() | ||
} else { | ||
session.access_jwt.clone() | ||
} | ||
}) | ||
async fn auth(&self, is_refresh: bool) -> Option<String> { | ||
self.session.read().await.as_ref().map(|session| { | ||
if is_refresh { | ||
session.refresh_jwt.clone() | ||
} else { | ||
session.access_jwt.clone() | ||
} | ||
}) | ||
} | ||
} | ||
|
@@ -85,10 +84,7 @@ where | |
} | ||
async fn refresh_session_inner(&self) { | ||
if let Ok(output) = self.call_refresh_session().await { | ||
let mut session = self | ||
.session | ||
.write() | ||
.expect("write lock on session should not be poisoned"); | ||
let mut session = self.session.write().await; | ||
let did_doc = session.as_ref().and_then(|s| s.did_doc.clone()); | ||
let email = session.as_ref().and_then(|s| s.email.clone()); | ||
let email_confirmed = session.as_ref().and_then(|s| s.email_confirmed); | ||
|
@@ -102,10 +98,7 @@ where | |
refresh_jwt: output.refresh_jwt, | ||
}); | ||
} else { | ||
self.session | ||
.write() | ||
.expect("write lock on session should not be poisoned") | ||
.take(); | ||
self.session.write().await.take(); | ||
} | ||
} | ||
// same as `crate::client::com::atproto::server::Service::refresh_session()` | ||
|
@@ -167,8 +160,8 @@ where | |
fn base_uri(&self) -> &str { | ||
self.inner.base_uri() | ||
} | ||
fn auth(&self, is_refresh: bool) -> Option<String> { | ||
self.inner.auth(is_refresh) | ||
async fn auth(&self, is_refresh: bool) -> Option<String> { | ||
self.inner.auth(is_refresh).await | ||
} | ||
async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E> | ||
where | ||
|
@@ -213,11 +206,8 @@ where | |
})); | ||
Self { api, session } | ||
} | ||
pub fn get_session(&self) -> Option<Session> { | ||
self.session | ||
.read() | ||
.expect("read lock on session should not be poisoned") | ||
.clone() | ||
pub async fn get_session(&self) -> Option<Session> { | ||
self.session.read().await.clone() | ||
} | ||
/// Start a new session with this agent. | ||
pub async fn login( | ||
|
@@ -235,42 +225,28 @@ where | |
password: password.into(), | ||
}) | ||
.await?; | ||
self.session | ||
.write() | ||
.expect("write lock on session should not be poisoned") | ||
.replace(result.clone()); | ||
self.session.write().await.replace(result.clone()); | ||
Ok(result) | ||
} | ||
/// Resume a pre-existing session with this agent. | ||
pub async fn resume_session( | ||
&self, | ||
session: Session, | ||
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> { | ||
self.session | ||
.write() | ||
.expect("write lock on session should not be poisoned") | ||
.replace(session.clone()); | ||
self.session.write().await.replace(session.clone()); | ||
let result = self.api.com.atproto.server.get_session().await; | ||
match result { | ||
Ok(output) => { | ||
assert_eq!(output.did, session.did); | ||
if let Some(session) = self | ||
.session | ||
.write() | ||
.expect("write lock on session should not be poisoned") | ||
.as_mut() | ||
{ | ||
if let Some(session) = self.session.write().await.as_mut() { | ||
session.email = output.email; | ||
session.email_confirmed = output.email_confirmed; | ||
session.handle = output.handle; | ||
} | ||
Ok(()) | ||
} | ||
Err(err) => { | ||
self.session | ||
.write() | ||
.expect("write lock on session should not be poisoned") | ||
.take(); | ||
self.session.write().await.take(); | ||
Err(err) | ||
} | ||
} | ||
|
@@ -319,12 +295,7 @@ mod tests { | |
} | ||
let mut body = Vec::new(); | ||
if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") { | ||
*self | ||
.counts | ||
.write() | ||
.expect("write lock on counts should not be poisoned") | ||
.entry(nsid.into()) | ||
.or_default() += 1; | ||
*self.counts.write().await.entry(nsid.into()).or_default() += 1; | ||
match nsid { | ||
"com.atproto.server.createSession" => { | ||
if let Some(output) = &self.responses.create_session { | ||
|
@@ -387,10 +358,10 @@ mod tests { | |
} | ||
} | ||
|
||
#[test] | ||
fn test_new() { | ||
#[tokio::test] | ||
async fn test_new() { | ||
let agent = AtpAgent::new(DummyClient::default()); | ||
assert_eq!(agent.get_session(), None); | ||
assert_eq!(agent.get_session().await, None); | ||
} | ||
|
||
#[tokio::test] | ||
|
@@ -412,7 +383,7 @@ mod tests { | |
.login("test", "pass") | ||
.await | ||
.expect("login should be succeeded"); | ||
assert_eq!(agent.get_session(), Some(session)); | ||
assert_eq!(agent.get_session().await, Some(session)); | ||
} | ||
// failure with `createSession` error | ||
{ | ||
|
@@ -427,7 +398,7 @@ mod tests { | |
.login("test", "bad") | ||
.await | ||
.expect_err("login should be failed"); | ||
assert_eq!(agent.get_session(), None); | ||
assert_eq!(agent.get_session().await, None); | ||
} | ||
} | ||
|
||
|
@@ -448,7 +419,7 @@ mod tests { | |
..Default::default() | ||
}; | ||
let agent = AtpAgent::new(client); | ||
agent.session.write().unwrap().replace(session); | ||
agent.session.write().await.replace(session); | ||
let output = agent | ||
.api | ||
.com | ||
|
@@ -478,7 +449,7 @@ mod tests { | |
..Default::default() | ||
}; | ||
let agent = AtpAgent::new(client); | ||
agent.session.write().unwrap().replace(session); | ||
agent.session.write().await.replace(session); | ||
let output = agent | ||
.api | ||
.com | ||
|
@@ -489,7 +460,7 @@ mod tests { | |
.expect("get session should be succeeded"); | ||
assert_eq!(output.did, "did"); | ||
assert_eq!( | ||
agent.get_session().map(|session| session.access_jwt), | ||
agent.get_session().await.map(|session| session.access_jwt), | ||
Some("access".into()) | ||
); | ||
} | ||
|
@@ -513,7 +484,7 @@ mod tests { | |
}; | ||
let counts = Arc::clone(&client.counts); | ||
let agent = Arc::new(AtpAgent::new(client)); | ||
agent.session.write().unwrap().replace(session); | ||
agent.session.write().await.replace(session); | ||
let handles = (0..3).map(|_| { | ||
let agent = Arc::clone(&agent); | ||
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) | ||
|
@@ -528,14 +499,11 @@ mod tests { | |
assert_eq!(output.did, "did"); | ||
} | ||
assert_eq!( | ||
agent.get_session().map(|session| session.access_jwt), | ||
agent.get_session().await.map(|session| session.access_jwt), | ||
Some("access".into()) | ||
); | ||
assert_eq!( | ||
counts | ||
.read() | ||
.expect("read lock on counts should not be poisoned") | ||
.clone(), | ||
counts.read().await.clone(), | ||
HashMap::from_iter([ | ||
("com.atproto.server.refreshSession".into(), 1), | ||
("com.atproto.server.getSession".into(), 3) | ||
|
@@ -562,15 +530,15 @@ mod tests { | |
..Default::default() | ||
}; | ||
let agent = AtpAgent::new(client); | ||
assert_eq!(agent.get_session(), None); | ||
assert_eq!(agent.get_session().await, None); | ||
agent | ||
.resume_session(Session { | ||
email: Some(String::from("[email protected]")), | ||
..session.clone() | ||
}) | ||
.await | ||
.expect("resume_session should be succeeded"); | ||
assert_eq!(agent.get_session(), Some(session.clone())); | ||
assert_eq!(agent.get_session().await, Some(session.clone())); | ||
} | ||
// failure with `getSession` error | ||
{ | ||
|
@@ -581,12 +549,12 @@ mod tests { | |
..Default::default() | ||
}; | ||
let agent = AtpAgent::new(client); | ||
assert_eq!(agent.get_session(), None); | ||
assert_eq!(agent.get_session().await, None); | ||
agent | ||
.resume_session(session) | ||
.await | ||
.expect_err("resume_session should be failed"); | ||
assert_eq!(agent.get_session(), None); | ||
assert_eq!(agent.get_session().await, None); | ||
} | ||
} | ||
|
||
|
@@ -614,6 +582,6 @@ mod tests { | |
}) | ||
.await | ||
.expect("resume_session should be succeeded"); | ||
assert_eq!(agent.get_session(), Some(session)); | ||
assert_eq!(agent.get_session().await, Some(session)); | ||
} | ||
} |