diff --git a/atrium-api/src/agent/inner.rs b/atrium-api/src/agent/inner.rs index da1e9876..f3bf2e66 100644 --- a/atrium-api/src/agent/inner.rs +++ b/atrium-api/src/agent/inner.rs @@ -1,13 +1,17 @@ use super::{Session, SessionStore}; use crate::did_doc::DidDocument; -use crate::types::string::Did; -use crate::types::TryFromUnknown; -use atrium_xrpc::error::{Error, Result, XrpcErrorKind}; -use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest}; +use crate::types::{string::Did, TryFromUnknown}; +use atrium_xrpc::{ + error::{Error, Result, XrpcErrorKind}, + types::AuthorizationToken, + HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, +}; use http::{Method, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; -use std::fmt::Debug; -use std::sync::{Arc, RwLock}; +use std::{ + fmt::Debug, + sync::{Arc, RwLock}, +}; use tokio::sync::{Mutex, Notify}; struct WrapperClient { @@ -72,13 +76,13 @@ where fn base_uri(&self) -> String { self.store.get_endpoint() } - async fn authentication_token(&self, is_refresh: bool) -> Option { + async fn authorization_token(&self, is_refresh: bool) -> Option { self.store.get_session().await.map(|session| { - if is_refresh { + AuthorizationToken::Bearer(if is_refresh { session.data.refresh_jwt } else { session.data.access_jwt - } + }) }) } async fn atproto_proxy_header(&self) -> Option { diff --git a/atrium-xrpc/src/traits.rs b/atrium-xrpc/src/traits.rs index f04e3176..13d65df9 100644 --- a/atrium-xrpc/src/traits.rs +++ b/atrium-xrpc/src/traits.rs @@ -1,11 +1,9 @@ -use crate::error::Error; -use crate::error::{XrpcError, XrpcErrorKind}; -use crate::types::{Header, NSID_REFRESH_SESSION}; +use crate::error::{Error, XrpcError, XrpcErrorKind}; +use crate::types::{AuthorizationToken, Header, NSID_REFRESH_SESSION}; use crate::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest}; use http::{Method, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; -use std::fmt::Debug; -use std::future::Future; +use std::{fmt::Debug, future::Future}; /// An abstract HTTP client. #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] @@ -32,9 +30,12 @@ type XrpcResult = core::result::Result, self::Error String; - /// Get the authentication token to use `Authorization` header. + /// Get the authorization token to use `Authorization` header. #[allow(unused_variables)] - fn authentication_token(&self, is_refresh: bool) -> impl Future> { + fn authorization_token( + &self, + is_refresh: bool, + ) -> impl Future> { async { None } } /// Get the `atproto-proxy` header. @@ -102,12 +103,10 @@ where builder = builder.header(Header::ContentType, encoding); } if let Some(token) = client - .authentication_token( - request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION, - ) + .authorization_token(request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION) .await { - builder = builder.header(Header::Authorization, format!("Bearer {}", token)); + builder = builder.header(Header::Authorization, token); } if let Some(proxy) = client.atproto_proxy_header().await { builder = builder.header(Header::AtprotoProxy, proxy); diff --git a/atrium-xrpc/src/types.rs b/atrium-xrpc/src/types.rs index 0fb23863..e4a29e52 100644 --- a/atrium-xrpc/src/types.rs +++ b/atrium-xrpc/src/types.rs @@ -1,9 +1,25 @@ -use http::header::{AUTHORIZATION, CONTENT_TYPE}; -use http::{HeaderName, Method}; +use http::header::{HeaderName, HeaderValue, InvalidHeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use http::Method; use serde::{de::DeserializeOwned, Serialize}; pub(crate) const NSID_REFRESH_SESSION: &str = "com.atproto.server.refreshSession"; +pub enum AuthorizationToken { + Bearer(String), + Dpop(String), +} + +impl TryFrom for HeaderValue { + type Error = InvalidHeaderValue; + + fn try_from(token: AuthorizationToken) -> Result { + HeaderValue::from_str(&match token { + AuthorizationToken::Bearer(t) => format!("Bearer {t}"), + AuthorizationToken::Dpop(t) => format!("DPoP {t}"), + }) + } +} + /// HTTP headers which can be used in XPRC requests. pub enum Header { ContentType,