Skip to content

Commit

Permalink
Refactoring code
Browse files Browse the repository at this point in the history
  • Loading branch information
oestradiol committed Sep 19, 2024
1 parent 5b04cab commit 078241f
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 129 deletions.
3 changes: 0 additions & 3 deletions atrium-api/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ mod tests {
use crate::com::atproto::server::create_session::OutputData;
use crate::did_doc::{DidDocument, Service, VerificationMethod};
use crate::types::TryIntoUnknown;
use async_trait::async_trait;
use atrium_xrpc::HttpClient;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use std::collections::HashMap;
Expand All @@ -189,8 +188,6 @@ mod tests {
headers: Arc<RwLock<Vec<HeaderMap<HeaderValue>>>>,
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HttpClient for MockClient {
async fn send_http(
&self,
Expand Down
11 changes: 0 additions & 11 deletions atrium-api/src/agent/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::{Session, SessionStore};
use crate::did_doc::DidDocument;
use crate::types::string::Did;
use crate::types::TryFromUnknown;
use async_trait::async_trait;
use atrium_xrpc::error::{Error, Result, XrpcErrorKind};
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use http::{Method, Request, Response, Uri};
Expand Down Expand Up @@ -51,8 +50,6 @@ impl<S, T> Clone for WrapperClient<S, T> {
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<S, T> HttpClient for WrapperClient<S, T>
where
S: Send + Sync,
Expand All @@ -67,8 +64,6 @@ where
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<S, T> XrpcClient for WrapperClient<S, T>
where
S: SessionStore + Send + Sync,
Expand Down Expand Up @@ -231,8 +226,6 @@ where
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<S, T> HttpClient for Client<S, T>
where
S: Send + Sync,
Expand All @@ -247,8 +240,6 @@ where
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<S, T> XrpcClient for Client<S, T>
where
S: SessionStore + Send + Sync,
Expand Down Expand Up @@ -321,8 +312,6 @@ impl<S> Store<S> {
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<S> SessionStore for Store<S>
where
S: SessionStore + Send + Sync,
Expand Down
11 changes: 5 additions & 6 deletions atrium-api/src/agent/store.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
mod memory;

use std::future::Future;

pub use self::memory::MemorySessionStore;
pub(crate) use super::Session;
use async_trait::async_trait;

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait SessionStore {
#[must_use]
async fn get_session(&self) -> Option<Session>;
fn get_session(&self) -> impl Future<Output = Option<Session>> + Send;
#[must_use]
async fn set_session(&self, session: Session);
fn set_session(&self, session: Session) -> impl Future<Output = ()> + Send;
#[must_use]
async fn clear_session(&self);
fn clear_session(&self) -> impl Future<Output = ()> + Send;
}
3 changes: 0 additions & 3 deletions atrium-api/src/agent/store/memory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::{Session, SessionStore};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;

Expand All @@ -8,8 +7,6 @@ pub struct MemorySessionStore {
session: Arc<RwLock<Option<Session>>>,
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl SessionStore for MemorySessionStore {
async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone()
Expand Down
2 changes: 0 additions & 2 deletions atrium-xrpc-client/src/isahc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![doc = "XrpcClient implementation for [isahc]"]
use async_trait::async_trait;
use atrium_xrpc::http::{Request, Response};
use atrium_xrpc::{HttpClient, XrpcClient};
use isahc::{AsyncReadResponseExt, HttpClient as Client};
Expand Down Expand Up @@ -52,7 +51,6 @@ impl IsahcClientBuilder {
}
}

#[async_trait]
impl HttpClient for IsahcClient {
async fn send_http(
&self,
Expand Down
3 changes: 0 additions & 3 deletions atrium-xrpc-client/src/reqwest.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![doc = "XrpcClient implementation for [reqwest]"]
use async_trait::async_trait;
use atrium_xrpc::http::{Request, Response};
use atrium_xrpc::{HttpClient, XrpcClient};
use reqwest::Client;
Expand Down Expand Up @@ -48,8 +47,6 @@ impl ReqwestClientBuilder {
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HttpClient for ReqwestClient {
async fn send_http(
&self,
Expand Down
3 changes: 0 additions & 3 deletions atrium-xrpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod tests {
use super::*;
use crate::error::{XrpcError, XrpcErrorKind};
use crate::{HttpClient, XrpcClient};
use async_trait::async_trait;
use http::{Request, Response};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;
Expand All @@ -24,8 +23,6 @@ mod tests {
body: Vec<u8>,
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HttpClient for DummyClient {
async fn send_http(
&self,
Expand Down
129 changes: 64 additions & 65 deletions atrium-xrpc/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ use crate::error::Error;
use crate::error::{XrpcError, XrpcErrorKind};
use crate::types::{Header, NSID_REFRESH_SESSION};
use crate::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest};
use async_trait::async_trait;
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use std::fmt::Debug;
use std::future::Future;

/// An abstract HTTP client.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait HttpClient {
/// Send an HTTP request and return the response.
async fn send_http(
fn send_http(
&self,
request: Request<Vec<u8>>,
) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>;
) -> impl Future<Output = core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
}

type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E>>;
Expand All @@ -24,87 +22,88 @@ type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E
///
/// [`send_xrpc()`](XrpcClient::send_xrpc) method has a default implementation,
/// which wraps the [`HttpClient::send_http()`]` method to handle input and output as an XRPC Request.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait XrpcClient: HttpClient {
/// The base URI of the XRPC server.
fn base_uri(&self) -> String;
/// Get the authentication token to use `Authorization` header.
#[allow(unused_variables)]
async fn authentication_token(&self, is_refresh: bool) -> Option<String> {
None
fn authentication_token(&self, is_refresh: bool) -> impl Future<Output = Option<String>> + Send {
async { None }
}
/// Get the `atproto-proxy` header.
async fn atproto_proxy_header(&self) -> Option<String> {
None
fn atproto_proxy_header(&self) -> impl Future<Output = Option<String>> + Send {
async { None }
}
/// Get the `atproto-accept-labelers` header.
async fn atproto_accept_labelers_header(&self) -> Option<Vec<String>> {
None
fn atproto_accept_labelers_header(&self) -> impl Future<Output = Option<Vec<String>>> + Send {
async { None }
}
/// Send an XRPC request and return the response.
async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> impl Future<Output = XrpcResult<O, E>> + Send
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
Self: Sync,
{
let mut uri = format!("{}/xrpc/{}", self.base_uri(), request.nsid);
// Query parameters
if let Some(p) = &request.parameters {
serde_html_form::to_string(p).map(|qs| {
uri += "?";
uri += &qs;
})?;
};
let mut builder = Request::builder().method(&request.method).uri(&uri);
// Headers
if let Some(encoding) = &request.encoding {
builder = builder.header(Header::ContentType, encoding);
}
if let Some(token) = self
.authentication_token(
request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION,
)
.await
{
builder = builder.header(Header::Authorization, format!("Bearer {}", token));
}
if let Some(proxy) = self.atproto_proxy_header().await {
builder = builder.header(Header::AtprotoProxy, proxy);
}
if let Some(accept_labelers) = self.atproto_accept_labelers_header().await {
builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", "));
}
// Body
let body = if let Some(input) = &request.input {
match input {
InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
InputDataOrBytes::Bytes(bytes) => bytes.clone(),
async {
let mut uri = format!("{}/xrpc/{}", self.base_uri(), request.nsid);
// Query parameters
if let Some(p) = &request.parameters {
serde_html_form::to_string(p).map(|qs| {
uri += "?";
uri += &qs;
})?;
};
let mut builder = Request::builder().method(&request.method).uri(&uri);
// Headers
if let Some(encoding) = &request.encoding {
builder = builder.header(Header::ContentType, encoding);
}
} else {
Vec::new()
};
// Send
let (parts, body) =
self.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts();
if parts.status.is_success() {
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| content_type.starts_with("application/json"))
if let Some(token) = self
.authentication_token(
request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION,
)
.await
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
builder = builder.header(Header::Authorization, format!("Bearer {}", token));
}
if let Some(proxy) = self.atproto_proxy_header().await {
builder = builder.header(Header::AtprotoProxy, proxy);
}
if let Some(accept_labelers) = self.atproto_accept_labelers_header().await {
builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", "));
}
// Body
let body = if let Some(input) = &request.input {
match input {
InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
InputDataOrBytes::Bytes(bytes) => bytes.clone(),
}
} else {
Vec::new()
};
// Send
let (parts, body) =
self.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts();
if parts.status.is_success() {
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| content_type.starts_with("application/json"))
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
} else {
Ok(OutputDataOrBytes::Bytes(body))
}
} else {
Ok(OutputDataOrBytes::Bytes(body))
Err(Error::XrpcResponse(XrpcError {
status: parts.status,
error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
}))
}
} else {
Err(Error::XrpcResponse(XrpcError {
status: parts.status,
error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
}))
}
}
}
2 changes: 0 additions & 2 deletions bsky-sdk/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,11 @@ where
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use atrium_api::agent::Session;

#[derive(Clone)]
struct NoopStore;

#[async_trait]
impl SessionStore for NoopStore {
async fn get_session(&self) -> Option<Session> {
unimplemented!()
Expand Down
2 changes: 0 additions & 2 deletions bsky-sdk/src/agent/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ impl Default for BskyAgentBuilder<ReqwestClient, MemorySessionStore> {
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use atrium_api::agent::Session;
use atrium_api::com::atproto::server::create_session::OutputData;

Expand All @@ -125,7 +124,6 @@ mod tests {

struct MockSessionStore;

#[async_trait]
impl SessionStore for MockSessionStore {
async fn get_session(&self) -> Option<Session> {
Some(session())
Expand Down
13 changes: 6 additions & 7 deletions bsky-sdk/src/agent/config.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Configuration for the [`BskyAgent`](super::BskyAgent).
mod file;

use std::future::Future;

use crate::error::{Error, Result};
use async_trait::async_trait;
use atrium_api::agent::Session;
pub use file::FileStore;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -46,20 +47,18 @@ impl Default for Config {
}

/// The trait for loading configuration data.
#[async_trait]
pub trait Loader {
/// Loads the configuration data.
async fn load(
fn load(
&self,
) -> core::result::Result<Config, Box<dyn std::error::Error + Send + Sync + 'static>>;
) -> impl Future<Output = core::result::Result<Config, Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
}

/// The trait for saving configuration data.
#[async_trait]
pub trait Saver {
/// Saves the configuration data.
async fn save(
fn save(
&self,
config: &Config,
) -> core::result::Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
) -> impl Future<Output = core::result::Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
}
Loading

0 comments on commit 078241f

Please sign in to comment.