Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
oestradiol committed Sep 14, 2024
1 parent 145d6df commit f2bb5a6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 25 deletions.
11 changes: 9 additions & 2 deletions atrium-streams-client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ type StreamKind = WebSocketStream<MaybeTlsStream<TcpStream>>;
impl<P: Serialize + Send + Sync> EventStreamClient<<StreamKind as Stream>::Item, Error>
for WssClient<P>
{
async fn connect(&self, mut uri: String) -> Result<impl Stream<Item = <StreamKind as Stream>::Item>, Error> {
async fn connect(
&self,
mut uri: String,
) -> Result<impl Stream<Item = <StreamKind as Stream>::Item>, Error> {
let Self { params } = self;

// Query parameters
Expand Down Expand Up @@ -76,7 +79,11 @@ fn get_host(uri: &str) -> Result<(Uri, Box<str>), Error> {
/// Generate a request for the given URI and host.
/// It sets the necessary headers for a WebSocket connection,
/// plus the client's `AtprotoProxy` and `AtprotoAcceptLabelers` headers.
async fn gen_request<P: Serialize + Send + Sync>(client: &WssClient<P>, uri: &Uri, host: &str) -> Result<Request<()>, Error> {
async fn gen_request<P: Serialize + Send + Sync>(
client: &WssClient<P>,
uri: &Uri,
host: &str,
) -> Result<Request<()>, Error> {
let mut request = Request::builder()
.uri(uri)
.method("GET")
Expand Down
55 changes: 33 additions & 22 deletions atrium-streams-client/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ use std::net::{Ipv4Addr, SocketAddr};
use atrium_streams::{atrium_api::com::atproto::sync::subscribe_repos, client::EventStreamClient};
use atrium_xrpc::http::{header::SEC_WEBSOCKET_KEY, HeaderMap, HeaderValue};
use futures::{SinkExt, StreamExt};
use tokio::{net::{TcpListener, TcpStream}, runtime::Runtime};
use tokio_tungstenite::{tungstenite::{handshake::server::{ErrorResponse, Request, Response}, Message}, WebSocketStream};
use tokio::{
net::{TcpListener, TcpStream},
runtime::Runtime,
};
use tokio_tungstenite::{
tungstenite::{
handshake::server::{ErrorResponse, Request, Response},
Message,
},
WebSocketStream,
};

use crate::WssClient;

Expand Down Expand Up @@ -35,15 +44,16 @@ fn client() {
Runtime::new().unwrap().block_on(fut);
}

async fn wss_client(uri: &str) -> (WssClient<subscribe_repos::ParametersData>, HeaderMap<HeaderValue>) {
let params = subscribe_repos::ParametersData {
cursor: None,
};
async fn wss_client(
uri: &str,
) -> (
WssClient<subscribe_repos::ParametersData>,
HeaderMap<HeaderValue>,
) {
let params = subscribe_repos::ParametersData { cursor: None };

let client = WssClient::builder().params(params).build();

let client = WssClient::builder()
.params(params)
.build();

let (uri, host) = get_host(uri).unwrap();
let req = gen_request(&client, &uri, &host).await.unwrap();
let headers = req.headers();
Expand All @@ -58,7 +68,7 @@ async fn mock_wss_server() -> (WebSocketStream<TcpStream>, HeaderMap, String) {
.await
.expect("Failed to bind to port!");

let headers: HeaderMap;
let headers: HeaderMap;
let route: String;
let (stream, _) = listener.accept().await.unwrap();
let (headers_, route_, stream) = extract_headers(stream).await;
Expand All @@ -68,21 +78,22 @@ async fn mock_wss_server() -> (WebSocketStream<TcpStream>, HeaderMap, String) {
(stream, headers, route)
}

async fn extract_headers(raw_stream: TcpStream) -> (HeaderMap<HeaderValue>, String, WebSocketStream<TcpStream>) {
async fn extract_headers(
raw_stream: TcpStream,
) -> (HeaderMap<HeaderValue>, String, WebSocketStream<TcpStream>) {
let mut headers: Option<HeaderMap<HeaderValue>> = None;
let mut route: Option<String> = None;

let copy_headers_callback = |request: &Request, response: Response| -> Result<Response, ErrorResponse> {
headers = Some(request.headers().clone());
route = Some(request.uri().path().to_owned());
Ok(response)
};
let copy_headers_callback =
|request: &Request, response: Response| -> Result<Response, ErrorResponse> {
headers = Some(request.headers().clone());
route = Some(request.uri().path().to_owned());
Ok(response)
};

let stream = tokio_tungstenite::accept_hdr_async(
raw_stream,
copy_headers_callback,
).await
let stream = tokio_tungstenite::accept_hdr_async(raw_stream, copy_headers_callback)
.await
.expect("Error during the websocket handshake occurred");

(headers.unwrap(), route.unwrap(), stream)
}
}
2 changes: 1 addition & 1 deletion atrium-streams/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub trait EventStreamClient<ConnectionPayload, ConnectionError> {
/// [`Result<M, E>`]
fn connect(
&self,
uri: String
uri: String,
) -> impl Future<Output = Result<impl Stream<Item = ConnectionPayload>, ConnectionError>> + Send;

/// Get the `atproto-proxy` header.
Expand Down

0 comments on commit f2bb5a6

Please sign in to comment.