Skip to content

Commit

Permalink
feat: Update xprc, use tokio::sync::RwLock for agent (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan authored Nov 13, 2023
1 parent bb15656 commit 3218eb0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 71 deletions.
3 changes: 1 addition & 2 deletions atrium-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ repository = "https://github.com/sugyan/atrium"
license = "MIT"
keywords = ["atproto", "bluesky"]


[dependencies]
atrium-xrpc = "0.5.0"
atrium-xrpc = "0.7.0"
async-trait = "0.1.68"
cid = { version = "0.10.1", features = ["serde-codec"] }
http = "0.2.9"
Expand Down
106 changes: 37 additions & 69 deletions atrium-api/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use atrium_xrpc::error::{Error, XrpcErrorKind};
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult};
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::{Arc, RwLock};
use tokio::sync::{Mutex, Notify};
use std::sync::Arc;
use tokio::sync::{Mutex, Notify, RwLock};

/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output)
pub type Session = crate::com::atproto::server::create_session::Output;
Expand All @@ -33,22 +33,21 @@ where
}
}

#[async_trait]
impl<T> XrpcClient for SessionAuthWrapper<T>
where
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> &str {
self.inner.base_uri()
}
fn auth(&self, is_refresh: bool) -> Option<String> {
self.session.read().ok().and_then(|lock| {
lock.as_ref().map(|session| {
if is_refresh {
session.refresh_jwt.clone()
} else {
session.access_jwt.clone()
}
})
async fn auth(&self, is_refresh: bool) -> Option<String> {
self.session.read().await.as_ref().map(|session| {
if is_refresh {
session.refresh_jwt.clone()
} else {
session.access_jwt.clone()
}
})
}
}
Expand Down Expand Up @@ -85,10 +84,7 @@ where
}
async fn refresh_session_inner(&self) {
if let Ok(output) = self.call_refresh_session().await {
let mut session = self
.session
.write()
.expect("write lock on session should not be poisoned");
let mut session = self.session.write().await;
let did_doc = session.as_ref().and_then(|s| s.did_doc.clone());
let email = session.as_ref().and_then(|s| s.email.clone());
let email_confirmed = session.as_ref().and_then(|s| s.email_confirmed);
Expand All @@ -102,10 +98,7 @@ where
refresh_jwt: output.refresh_jwt,
});
} else {
self.session
.write()
.expect("write lock on session should not be poisoned")
.take();
self.session.write().await.take();
}
}
// same as `crate::client::com::atproto::server::Service::refresh_session()`
Expand Down Expand Up @@ -167,8 +160,8 @@ where
fn base_uri(&self) -> &str {
self.inner.base_uri()
}
fn auth(&self, is_refresh: bool) -> Option<String> {
self.inner.auth(is_refresh)
async fn auth(&self, is_refresh: bool) -> Option<String> {
self.inner.auth(is_refresh).await
}
async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
where
Expand Down Expand Up @@ -213,11 +206,8 @@ where
}));
Self { api, session }
}
pub fn get_session(&self) -> Option<Session> {
self.session
.read()
.expect("read lock on session should not be poisoned")
.clone()
pub async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone()
}
/// Start a new session with this agent.
pub async fn login(
Expand All @@ -235,42 +225,28 @@ where
password: password.into(),
})
.await?;
self.session
.write()
.expect("write lock on session should not be poisoned")
.replace(result.clone());
self.session.write().await.replace(result.clone());
Ok(result)
}
/// Resume a pre-existing session with this agent.
pub async fn resume_session(
&self,
session: Session,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.session
.write()
.expect("write lock on session should not be poisoned")
.replace(session.clone());
self.session.write().await.replace(session.clone());
let result = self.api.com.atproto.server.get_session().await;
match result {
Ok(output) => {
assert_eq!(output.did, session.did);
if let Some(session) = self
.session
.write()
.expect("write lock on session should not be poisoned")
.as_mut()
{
if let Some(session) = self.session.write().await.as_mut() {
session.email = output.email;
session.email_confirmed = output.email_confirmed;
session.handle = output.handle;
}
Ok(())
}
Err(err) => {
self.session
.write()
.expect("write lock on session should not be poisoned")
.take();
self.session.write().await.take();
Err(err)
}
}
Expand Down Expand Up @@ -319,12 +295,7 @@ mod tests {
}
let mut body = Vec::new();
if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") {
*self
.counts
.write()
.expect("write lock on counts should not be poisoned")
.entry(nsid.into())
.or_default() += 1;
*self.counts.write().await.entry(nsid.into()).or_default() += 1;
match nsid {
"com.atproto.server.createSession" => {
if let Some(output) = &self.responses.create_session {
Expand Down Expand Up @@ -387,10 +358,10 @@ mod tests {
}
}

#[test]
fn test_new() {
#[tokio::test]
async fn test_new() {
let agent = AtpAgent::new(DummyClient::default());
assert_eq!(agent.get_session(), None);
assert_eq!(agent.get_session().await, None);
}

#[tokio::test]
Expand All @@ -412,7 +383,7 @@ mod tests {
.login("test", "pass")
.await
.expect("login should be succeeded");
assert_eq!(agent.get_session(), Some(session));
assert_eq!(agent.get_session().await, Some(session));
}
// failure with `createSession` error
{
Expand All @@ -427,7 +398,7 @@ mod tests {
.login("test", "bad")
.await
.expect_err("login should be failed");
assert_eq!(agent.get_session(), None);
assert_eq!(agent.get_session().await, None);
}
}

Expand All @@ -448,7 +419,7 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client);
agent.session.write().unwrap().replace(session);
agent.session.write().await.replace(session);
let output = agent
.api
.com
Expand Down Expand Up @@ -478,7 +449,7 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client);
agent.session.write().unwrap().replace(session);
agent.session.write().await.replace(session);
let output = agent
.api
.com
Expand All @@ -489,7 +460,7 @@ mod tests {
.expect("get session should be succeeded");
assert_eq!(output.did, "did");
assert_eq!(
agent.get_session().map(|session| session.access_jwt),
agent.get_session().await.map(|session| session.access_jwt),
Some("access".into())
);
}
Expand All @@ -513,7 +484,7 @@ mod tests {
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client));
agent.session.write().unwrap().replace(session);
agent.session.write().await.replace(session);
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
Expand All @@ -528,14 +499,11 @@ mod tests {
assert_eq!(output.did, "did");
}
assert_eq!(
agent.get_session().map(|session| session.access_jwt),
agent.get_session().await.map(|session| session.access_jwt),
Some("access".into())
);
assert_eq!(
counts
.read()
.expect("read lock on counts should not be poisoned")
.clone(),
counts.read().await.clone(),
HashMap::from_iter([
("com.atproto.server.refreshSession".into(), 1),
("com.atproto.server.getSession".into(), 3)
Expand All @@ -562,15 +530,15 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client);
assert_eq!(agent.get_session(), None);
assert_eq!(agent.get_session().await, None);
agent
.resume_session(Session {
email: Some(String::from("[email protected]")),
..session.clone()
})
.await
.expect("resume_session should be succeeded");
assert_eq!(agent.get_session(), Some(session.clone()));
assert_eq!(agent.get_session().await, Some(session.clone()));
}
// failure with `getSession` error
{
Expand All @@ -581,12 +549,12 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client);
assert_eq!(agent.get_session(), None);
assert_eq!(agent.get_session().await, None);
agent
.resume_session(session)
.await
.expect_err("resume_session should be failed");
assert_eq!(agent.get_session(), None);
assert_eq!(agent.get_session().await, None);
}
}

Expand Down Expand Up @@ -614,6 +582,6 @@ mod tests {
})
.await
.expect("resume_session should be succeeded");
assert_eq!(agent.get_session(), Some(session));
assert_eq!(agent.get_session().await, Some(session));
}
}

0 comments on commit 3218eb0

Please sign in to comment.