Skip to content

Commit

Permalink
impr(proxy): Decouple ip_allowlist from the CancelClosure
Browse files Browse the repository at this point in the history
This PR removes the direct dependency of the IP allowlist from CancelClosure,
allowing for more scalable and flexible IP restrictions
and enabling the future use of Redis-based CancelMap storage.

Changes:
 - Introduce a new BackendAuth async trait that retrieves the IP allowlist
   through existing authentication methods;
 - Improve cancellation error handling by instrumtent() async
   cancel_sesion() rather than dropping it.
 - Set and store IP allowlist for SCRAM Proxy to consistently perform
   ip allowance check
  • Loading branch information
awarus committed Dec 19, 2024
1 parent 65042cb commit 335063b
Show file tree
Hide file tree
Showing 12 changed files with 535 additions and 67 deletions.
82 changes: 77 additions & 5 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use async_trait::async_trait;
use postgres_client::config::SslMode;
use pq_proto::BeMessage as Be;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};

use super::ComputeCredentialKeys;
use crate::auth::backend::{BackendAuth, ComputeUserInfo, GetAuthInfoError, MaybeOwned};
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::control_plane::{
self, client::cplane_proxy_v1, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
};
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{CacheOutcome, Metrics};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
use crate::{auth, compute, waiters};
Expand All @@ -29,8 +34,9 @@ pub(crate) enum ConsoleRedirectError {
}

#[derive(Debug)]
pub struct ConsoleRedirectBackend {
pub struct ConsoleRedirectBackend<'a> {
console_uri: reqwest::Url,
api: MaybeOwned<'a, cplane_proxy_v1::AuthControlPlaneClient>,
}

impl UserFacingError for ConsoleRedirectError {
Expand Down Expand Up @@ -71,9 +77,75 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}

impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
#[async_trait]
impl BackendAuth for ConsoleRedirectBackend<'_> {
async fn auth_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.api.caches.project_info.get_allowed_ips(normalized_ep) {
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Hit);
return Ok((allowed_ips, None));
}

Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);

let auth_info = self.api.do_get_auth_req(user_info, session_id).await?;

let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.api.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.api.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
}
Ok((
Cached::new_uncached(allowed_ips),
Some(Cached::new_uncached(auth_info.secret)),
))
}

async fn get_allowed_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> auth::Result<Vec<auth::IpPattern>> {
self.auth_ips(user_info, session_id)
.await
.map(|(ips, _)| ips.as_ref().clone())
.map_err(|e| e.into())
}
}

impl std::fmt::Debug for MaybeOwned<'_, cplane_proxy_v1::AuthControlPlaneClient> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MaybeOwned")
}
}

impl ConsoleRedirectBackend<'static> {
pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::AuthControlPlaneClient) -> Self {
Self {
console_uri,
api: auth::backend::MaybeOwned::Owned(api),
}
}

pub(crate) async fn authenticate(
Expand Down
75 changes: 66 additions & 9 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};

use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
use crate::auth::{
self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern,
};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
Expand Down Expand Up @@ -131,7 +133,7 @@ pub(crate) struct ComputeUserInfoNoEndpoint {
pub(crate) options: NeonOptions,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub(crate) struct ComputeUserInfo {
pub(crate) endpoint: EndpointId,
pub(crate) user: RoleName,
Expand Down Expand Up @@ -244,6 +246,21 @@ impl AuthenticationConfig {
}
}

#[async_trait::async_trait]
pub(crate) trait BackendAuth {
async fn auth_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError>;

async fn get_allowed_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> auth::Result<Vec<auth::IpPattern>>;
}

/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
Expand All @@ -256,7 +273,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
Expand Down Expand Up @@ -315,7 +332,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok(keys),
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
Expand Down Expand Up @@ -385,7 +402,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
Expand All @@ -394,7 +411,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);

let credentials = auth_quirks(
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
Expand All @@ -404,7 +421,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
endpoint_rate_limiter,
)
.await?;
Backend::ControlPlane(api, credentials)
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
Expand All @@ -413,7 +430,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {

// TODO: replace with some metric
info!("user successfully authenticated");
Ok(res)
res
}
}

Expand Down Expand Up @@ -441,6 +458,32 @@ impl Backend<'_, ComputeUserInfo> {
}
}

//ComputeUserInfoMaybeEndpoint
#[async_trait::async_trait]
impl BackendAuth for Backend<'_, ()> {
async fn auth_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::ControlPlane(api, ()) => api.get_allowed_ips(user_info, session_id).await,
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}

async fn get_allowed_ips(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
) -> auth::Result<Vec<auth::IpPattern>> {
self.auth_ips(user_info, session_id)
.await
.map(|(ips, _)| ips.as_ref().clone())
.map_err(|e| e.into())
}
}

#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
Expand Down Expand Up @@ -505,6 +548,20 @@ mod tests {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}

async fn get_allowed_ips(
&self,
_user_info: &super::ComputeUserInfo,
_session_id: &uuid::Uuid,
) -> Result<
(CachedAllowedIps, Option<CachedRoleSecret>),
control_plane::errors::GetAuthInfoError,
> {
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
))
}

async fn get_allowed_ips_and_secret(
&self,
_ctx: &RequestContext,
Expand Down Expand Up @@ -786,7 +843,7 @@ mod tests {
.await
.unwrap();

assert_eq!(creds.info.endpoint, "my-endpoint");
assert_eq!(creds.0.info.endpoint, "my-endpoint");

handle.await.unwrap();
}
Expand Down
34 changes: 31 additions & 3 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<Either<&'static auth::Backend<'static, ()>, &'static ConsoleRedirectBackend>> {
) -> anyhow::Result<
Either<&'static auth::Backend<'static, ()>, &'static ConsoleRedirectBackend<'static>>,
> {
match &args.auth_backend {
AuthBackendType::ControlPlaneV1 => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
Expand Down Expand Up @@ -735,9 +737,35 @@ fn build_auth_backend(
}

AuthBackendType::ConsoleRedirect => {
let url = args.uri.parse()?;
let backend = ConsoleRedirectBackend::new(url);
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;

info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");

let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));

let url = args.uri.clone().parse()?;
let ep_url: proxy::url::ApiUrl = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(ep_url, http::new_client());

let api = control_plane::client::cplane_proxy_v1::AuthControlPlaneClient::new(
endpoint,
args.control_plane_token.clone(),
caches,
);

let backend = ConsoleRedirectBackend::new(url, api);
let config = Box::leak(Box::new(backend));

Ok(Either::Right(config))
Expand Down
Loading

0 comments on commit 335063b

Please sign in to comment.