Skip to content

Commit

Permalink
proxy: introduce Acceptor and Connector traits
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneable committed Jan 2, 2025
1 parent 38c7a2a commit 9815d3e
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 19 deletions.
19 changes: 6 additions & 13 deletions proxy/src/bin/pg_sni_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::conn::{Acceptor, TokioTcpAcceptor};
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
Expand Down Expand Up @@ -122,7 +123,7 @@ async fn main() -> anyhow::Result<()> {
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting sni router on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?;

let cancellation_token = CancellationToken::new();

Expand All @@ -148,21 +149,17 @@ async fn main() -> anyhow::Result<()> {
match signal {}
}

async fn task_main(
async fn task_main<A: Acceptor>(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
acceptor: A,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;

let connections = tokio_util::task::task_tracker::TaskTracker::new();

while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;

Expand All @@ -172,10 +169,6 @@ async fn task_main(

connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;

info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
Expand All @@ -197,7 +190,7 @@ async fn task_main(
}

connections.close();
drop(listener);
drop(acceptor);

connections.wait().await;

Expand Down
97 changes: 97 additions & 0 deletions proxy/src/conn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::task::{Context, Poll};

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};

pub trait Acceptor {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}

fn accept(
&self,
) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> + Send;
}

pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}

fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
}

pub struct TokioTcpAcceptor {
listener: TcpListener,
tcp_nodelay: Option<bool>,
tcp_keepalive: Option<bool>,
}

impl TokioTcpAcceptor {
pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
Ok(Self {
listener,
tcp_nodelay: Some(true),
tcp_keepalive: None,
})
}
}

impl Acceptor for TokioTcpAcceptor {
type Connection = TcpStream;
type Error = io::Error;

fn accept(&self) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> {
async move {
let (stream, addr) = self.listener.accept().await?;

let socket = socket2::SockRef::from(&stream);
if let Some(nodelay) = self.tcp_nodelay {
socket.set_nodelay(nodelay)?;
}
if let Some(keepalive) = self.tcp_keepalive {
socket.set_keepalive(keepalive)?;
}

Ok((stream, addr))
}
}
}

pub struct TokioTcpConnector;

impl Connector for TokioTcpConnector {
type Connection = TcpStream;
type Error = io::Error;

fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> {
async move {
let socket = TcpStream::connect(addr).await?;
socket.set_nodelay(true)?;
Ok(socket)
}
}
}
9 changes: 3 additions & 6 deletions proxy/src/control_plane/mgmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};

use crate::conn::Acceptor;
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};

Expand All @@ -26,19 +27,15 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai

/// Management API listener task.
/// It spawns management response handlers needed for the console redirect auth flow.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
pub async fn task_main<A: Acceptor>(acceptor: A) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("mgmt has shut down");
}

loop {
let (socket, peer_addr) = listener.accept().await?;
let (socket, peer_addr) = acceptor.accept().await?;
info!("accepted connection from {peer_addr}");

socket
.set_nodelay(true)
.context("failed to set client socket option")?;

let span = info_span!("mgmt", peer = %peer_addr);

tokio::task::spawn(
Expand Down
1 change: 1 addition & 0 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod cancellation;
pub mod compute;
pub mod compute_ctl;
pub mod config;
pub mod conn;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;
Expand Down

0 comments on commit 9815d3e

Please sign in to comment.