Skip to content

Commit

Permalink
Add Agent and SessionManager
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Nov 8, 2024
1 parent 5b3d3e8 commit 71f8cff
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 21 deletions.
2 changes: 1 addition & 1 deletion atrium-api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
While `AtpServiceClient` can be used for simple XRPC calls, it is better to use `AtpAgent`, which has practical features such as session management.

```rust,no_run
use atrium_api::agent::{store::MemorySessionStore, AtpAgent};
use atrium_api::agent::atp_agent::{store::MemorySessionStore, AtpAgent};
use atrium_xrpc_client::reqwest::ReqwestClient;
#[tokio::main]
Expand Down
30 changes: 28 additions & 2 deletions atrium-api/src/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
mod atp_agent;
pub mod atp_agent;
#[cfg(feature = "bluesky")]
pub mod bluesky;
mod inner;
mod session_manager;

pub use atp_agent::{AtpAgent, CredentialSession};
use crate::{client::Service, types::string::Did};
pub use session_manager::SessionManager;
use std::sync::Arc;

/// Supported proxy targets.
#[cfg(feature = "bluesky")]
Expand All @@ -21,3 +25,25 @@ impl AsRef<str> for AtprotoServiceType {
}
}
}

pub struct Agent<M>
where
M: SessionManager + Send + Sync,
{
session_manager: Arc<inner::Wrapper<M>>,
pub api: Service<inner::Wrapper<M>>,
}

impl<M> Agent<M>
where
M: SessionManager + Send + Sync,
{
pub fn new(session_manager: M) -> Self {
let session_manager = Arc::new(inner::Wrapper::new(session_manager));
let api = Service::new(session_manager.clone());
Self { session_manager, api }
}
pub async fn did(&self) -> Option<Did> {
self.session_manager.did().await
}
}
164 changes: 146 additions & 18 deletions atrium-api/src/agent/atp_agent.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
//! Implementation of [`AtpAgent`] and definitions of [`SessionStore`] for it.
mod inner;
mod store;
pub mod store;

use self::store::AtpSessionStore;
use super::inner::Wrapper;
use super::{Agent, SessionManager};
use crate::{
client::Service,
client::{com::atproto::Service as AtprotoService, Service},
did_doc::DidDocument,
types::{string::Did, TryFromUnknown},
};
use atrium_xrpc::{Error, XrpcClient};
use std::{ops::Deref, sync::Arc};
use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use http::{Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt::Debug, ops::Deref, sync::Arc};

/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output)
pub type AtpSession = crate::com::atproto::server::create_session::Output;
Expand All @@ -22,7 +26,7 @@ where
{
store: Arc<inner::Store<S>>,
inner: Arc<inner::Client<S, T>>,
pub api: Service<inner::Client<S, T>>,
atproto_service: AtprotoService<inner::Client<S, T>>,
}

impl<S, T> CredentialSession<S, T>
Expand All @@ -36,7 +40,7 @@ where
Self {
store: Arc::clone(&store),
inner: Arc::clone(&inner),
api: Service::new(Arc::clone(&inner)),
atproto_service: AtprotoService::new(Arc::clone(&inner)),
}
}
/// Start a new session with this agent.
Expand All @@ -46,9 +50,7 @@ where
password: impl AsRef<str>,
) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
let result = self
.api
.com
.atproto
.atproto_service
.server
.create_session(
crate::com::atproto::server::create_session::InputData {
Expand All @@ -75,7 +77,7 @@ where
session: AtpSession,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.store.set_session(session.clone()).await;
let result = self.api.com.atproto.server.get_session().await;
let result = self.atproto_service.server.get_session().await;
match result {
Ok(output) => {
assert_eq!(output.data.did, session.data.did);
Expand Down Expand Up @@ -142,14 +144,74 @@ where
}
}

impl<S, T> HttpClient for CredentialSession<S, T>
where
S: AtpSessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.send_http(request).await
}
}

impl<S, T> XrpcClient for CredentialSession<S, T>
where
S: AtpSessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
self.inner.base_uri()
}
async fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> Result<OutputDataOrBytes<O>, Error<E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
self.inner.send_xrpc(request).await
}
}

impl<S, T> SessionManager for CredentialSession<S, T>
where
S: AtpSessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
async fn did(&self) -> Option<Did> {
self.store.get_session().await.map(|session| session.data.did)
}
}

/// An ATP "Agent".
/// Manages session token lifecycles and provides convenience methods.
///
/// This will be deprecated in the near future. Use [`Agent`] directly
/// with a [`CredentialSession`] instead:
/// ```
/// use atrium_api::agent::atp_agent::{store::MemorySessionStore, CredentialSession};
/// use atrium_api::agent::Agent;
/// use atrium_xrpc_client::reqwest::ReqwestClient;
///
/// let session = CredentialSession::new(
/// ReqwestClient::new("https://bsky.social"),
/// MemorySessionStore::default(),
/// );
/// let agent = Agent::new(session);
/// ```
pub struct AtpAgent<S, T>
where
S: AtpSessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
inner: CredentialSession<S, T>,
session_manager: Wrapper<CredentialSession<S, T>>,
inner: Agent<Wrapper<CredentialSession<S, T>>>,
}

impl<S, T> AtpAgent<S, T>
Expand All @@ -159,7 +221,62 @@ where
{
/// Create a new agent.
pub fn new(xrpc: T, store: S) -> Self {
Self { inner: CredentialSession::new(xrpc, store) }
let session_manager = Wrapper::new(CredentialSession::new(xrpc, store));
let inner = Agent::new(session_manager.clone());
Self { session_manager, inner }
}
/// Start a new session with this agent.
pub async fn login(
&self,
identifier: impl AsRef<str>,
password: impl AsRef<str>,
) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
self.session_manager.login(identifier, password).await
}
// /// Resume a pre-existing session with this agent.
pub async fn resume_session(
&self,
session: AtpSession,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.session_manager.resume_session(session).await
}
// /// Set the current endpoint.
pub fn configure_endpoint(&self, endpoint: String) {
self.session_manager.configure_endpoint(endpoint);
}
/// Configures the moderation services to be applied on requests.
pub fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
self.session_manager.configure_labelers_header(labeler_dids);
}
/// Configures the atproto-proxy header to be applied on requests.
pub fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
self.session_manager.configure_proxy_header(did, service_type);
}
/// Configures the atproto-proxy header to be applied on requests.
///
/// Returns a new client service with the proxy header configured.
pub fn api_with_proxy(
&self,
did: Did,
service_type: impl AsRef<str>,
) -> Service<inner::Client<S, T>> {
self.session_manager.api_with_proxy(did, service_type)
}
/// Get the current session.
pub async fn get_session(&self) -> Option<AtpSession> {
self.session_manager.get_session().await
}
/// Get the current endpoint.
pub async fn get_endpoint(&self) -> String {
self.session_manager.get_endpoint().await
}
/// Get the current labelers header.
pub async fn get_labelers_header(&self) -> Option<Vec<String>> {
self.session_manager.get_labelers_header().await
}
/// Get the current proxy header.
pub async fn get_proxy_header(&self) -> Option<String> {
self.session_manager.get_proxy_header().await
}
}

Expand All @@ -168,7 +285,7 @@ where
S: AtpSessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
type Target = CredentialSession<S, T>;
type Target = Agent<Wrapper<CredentialSession<S, T>>>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down Expand Up @@ -366,7 +483,7 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent.store.set_session(session_data.clone().into()).await;
agent.session_manager.store.set_session(session_data.clone().into()).await;
let output = agent
.api
.com
Expand Down Expand Up @@ -400,7 +517,7 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent.store.set_session(session_data.clone().into()).await;
agent.session_manager.store.set_session(session_data.clone().into()).await;
let output = agent
.api
.com
Expand All @@ -411,7 +528,7 @@ mod tests {
.expect("get session should be succeeded");
assert_eq!(output.did.as_str(), "did:web:example.com");
assert_eq!(
agent.store.get_session().await.map(|session| session.data.access_jwt),
agent.session_manager.store.get_session().await.map(|session| session.data.access_jwt),
Some("access".into())
);
}
Expand Down Expand Up @@ -439,7 +556,7 @@ mod tests {
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default()));
agent.store.set_session(session_data.clone().into()).await;
agent.session_manager.store.set_session(session_data.clone().into()).await;
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
Expand All @@ -454,7 +571,7 @@ mod tests {
assert_eq!(output.did.as_str(), "did:web:example.com");
}
assert_eq!(
agent.store.get_session().await.map(|session| session.data.access_jwt),
agent.session_manager.store.get_session().await.map(|session| session.data.access_jwt),
Some("access".into())
);
assert_eq!(
Expand Down Expand Up @@ -790,4 +907,15 @@ mod tests {
Some(String::from("did:plc:test1#atproto_labeler"))
);
}

#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_agent_did() {
let session_data = session_data();
let client = MockClient { responses: MockResponses::default(), ..Default::default() };
let agent = AtpAgent::new(client, MemorySessionStore::default());
assert_eq!(agent.did().await, None);
agent.session_manager.store.set_session(session_data.clone().into()).await;
assert_eq!(agent.did().await, Some(session_data.did));
}
}
Loading

0 comments on commit 71f8cff

Please sign in to comment.