Skip to content

Commit

Permalink
proxy: introduce Acceptor and Connector traits
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneable committed Jan 3, 2025
1 parent c08759f commit 4f88c4b
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 47 deletions.
6 changes: 3 additions & 3 deletions proxy/src/bin/local_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
Expand All @@ -36,7 +37,6 @@ project_build_tag!(BUILD_TAG);

use clap::Parser;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -166,8 +166,8 @@ async fn main() -> anyhow::Result<()> {
}
};

let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let metrics_listener = TokioTcpAcceptor::bind(args.metrics).await?;
let http_listener = TokioTcpAcceptor::bind(args.http).await?;
let shutdown = CancellationToken::new();

// todo: should scale with CU
Expand Down
17 changes: 5 additions & 12 deletions proxy/src/bin/pg_sni_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::conn::{Acceptor, TokioTcpAcceptor};
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
Expand Down Expand Up @@ -122,7 +123,7 @@ async fn main() -> anyhow::Result<()> {
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting sni router on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?;

let cancellation_token = CancellationToken::new();

Expand Down Expand Up @@ -152,17 +153,13 @@ async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;

let connections = tokio_util::task::task_tracker::TaskTracker::new();

while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;

Expand All @@ -172,10 +169,6 @@ async fn task_main(

connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;

info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
Expand All @@ -197,7 +190,7 @@ async fn task_main(
}

connections.close();
drop(listener);
drop(acceptor);

connections.wait().await;

Expand Down
10 changes: 5 additions & 5 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
Expand All @@ -27,7 +28,6 @@ use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -353,17 +353,17 @@ async fn main() -> anyhow::Result<()> {
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
let http_listener = TokioTcpAcceptor::bind(http_address).await?;

let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let mgmt_listener = TokioTcpAcceptor::bind(mgmt_address).await?;

let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");

Some(TcpListener::bind(proxy_address).await?)
Some(TokioTcpAcceptor::bind(proxy_address).await?)
} else {
None
};
Expand All @@ -373,7 +373,7 @@ async fn main() -> anyhow::Result<()> {
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
Some(TokioTcpAcceptor::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
Expand Down
221 changes: 221 additions & 0 deletions proxy/src/conn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use std::future::{poll_fn, Future};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};

use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};

pub trait Acceptor {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}

fn accept(
&self,
) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> + Send;
}

pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}

fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
}

pub struct TokioTcpAcceptor {
listener: TcpListener,
tcp_nodelay: Option<bool>,
tcp_keepalive: Option<bool>,
}

impl TokioTcpAcceptor {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
Ok(Self {
listener,
tcp_nodelay: Some(true),
tcp_keepalive: None,
})
}

pub fn into_std(self) -> io::Result<std::net::TcpListener> {
self.listener.into_std()
}
}

impl Acceptor for TokioTcpAcceptor {
type Connection = TcpStream;
type Error = io::Error;

fn accept(&self) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> {
async move {
let (stream, addr) = self.listener.accept().await?;

let socket = socket2::SockRef::from(&stream);
if let Some(nodelay) = self.tcp_nodelay {
socket.set_nodelay(nodelay)?;
}
if let Some(keepalive) = self.tcp_keepalive {
socket.set_keepalive(keepalive)?;
}

Ok((stream, addr))
}
}
}

pub struct TokioTcpConnector;

impl Connector for TokioTcpConnector {
type Connection = TcpStream;
type Error = io::Error;

fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> {
async move {
let socket = TcpStream::connect(addr).await?;
socket.set_nodelay(true)?;
Ok(socket)
}
}
}

pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static {}

impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Stream for T {}

pub trait AsyncRead {
fn readable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_read_ready(cx))
}

fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;

fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>>;
}

pub trait AsyncWrite {
fn writable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_write_ready(cx))
}

fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;

fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>;

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}

impl AsyncRead for tokio::net::TcpStream {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_read_ready(self, cx)
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read(Pin::new(&mut *self).get_mut(), buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}

fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read_vectored(Pin::new(&mut *self).get_mut(), bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}

impl AsyncWrite for tokio::net::TcpStream {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_write_ready(self, cx)
}

fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write(self, cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write_vectored(self, cx, bufs)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_flush(self, cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_shutdown(self, cx)
}
}
11 changes: 4 additions & 7 deletions proxy/src/console_redirect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tracing::{debug, error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
Expand All @@ -22,23 +23,19 @@ use crate::proxy::{
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
listener: tokio::net::TcpListener,
acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}

// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;

let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();

while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;

Expand Down Expand Up @@ -131,7 +128,7 @@ pub async fn task_main(

connections.close();
cancellations.close();
drop(listener);
drop(acceptor);

// Drain connections
connections.wait().await;
Expand Down
Loading

0 comments on commit 4f88c4b

Please sign in to comment.