From f78aa64be9afa2ab96c03f384b9fe78d71fd4942 Mon Sep 17 00:00:00 2001 From: Oliver Gould Date: Thu, 10 Oct 2024 14:18:31 +0000 Subject: [PATCH 1/4] tls: Avoid InsertParam parameter. We don't actually use InsertParam all that much--only in the TLS server (which is obviously why it was included here). This change removes the InsertParam in favor of using a tuple, generally reducing boilerplate. It turns out that the TLS stack already has a map_target to handle turning the tuple-target into a Tls type, so it shouldn't be needed. --- linkerd/app/outbound/src/tls.rs | 9 --------- linkerd/app/outbound/src/tls/logical/tests.rs | 9 --------- linkerd/tls/src/detect_sni.rs | 17 ++++++----------- 3 files changed, 6 insertions(+), 29 deletions(-) diff --git a/linkerd/app/outbound/src/tls.rs b/linkerd/app/outbound/src/tls.rs index 208f1cce6b..d0ecd468c5 100644 --- a/linkerd/app/outbound/src/tls.rs +++ b/linkerd/app/outbound/src/tls.rs @@ -135,15 +135,6 @@ impl svc::ExtractParam for DetectParams { } } -impl svc::InsertParam for DetectParams { - type Target = (ServerName, T); - - #[inline] - fn insert_param(&self, sni: ServerName, target: T) -> Self::Target { - (sni, target) - } -} - // === impl TlsMetrics === impl TlsMetrics { diff --git a/linkerd/app/outbound/src/tls/logical/tests.rs b/linkerd/app/outbound/src/tls/logical/tests.rs index f9bf76b142..34f22a90f9 100644 --- a/linkerd/app/outbound/src/tls/logical/tests.rs +++ b/linkerd/app/outbound/src/tls/logical/tests.rs @@ -128,15 +128,6 @@ impl svc::ExtractParam for DetectParams { } } -impl svc::InsertParam for DetectParams { - type Target = (tls::ServerName, T); - - #[inline] - fn insert_param(&self, sni: tls::ServerName, target: T) -> Self::Target { - (sni, target) - } -} - fn spawn_io( client_hello: Vec, ) -> ( diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs index 45747d63ed..41e66a0069 100644 --- a/linkerd/tls/src/detect_sni.rs +++ b/linkerd/tls/src/detect_sni.rs @@ -4,7 +4,7 @@ use crate::{ }; use linkerd_error::Error; use linkerd_io as io; -use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt}; +use linkerd_stack::{layer, ExtractParam, NewService, Service, ServiceExt}; use std::{ future::Future, pin::Pin, @@ -29,11 +29,10 @@ pub struct NewDetectSni { } #[derive(Clone, Debug)] -pub struct DetectSni { +pub struct DetectSni { target: T, inner: N, timeout: Timeout, - params: P, } impl NewDetectSni { @@ -54,7 +53,7 @@ where P: ExtractParam + Clone, N: Clone, { - type Service = DetectSni; + type Service = DetectSni; fn new_service(&self, target: T) -> Self::Service { let timeout = self.params.extract_param(&target); @@ -62,18 +61,15 @@ where target, timeout, inner: self.inner.clone(), - params: self.params.clone(), } } } -impl Service for DetectSni +impl Service for DetectSni where T: Clone + Send + Sync + 'static, - P: InsertParam + Clone + Send + Sync + 'static, - P::Target: Send + 'static, I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, - N: NewService + Clone + Send + 'static, + N: NewService<(ServerName, T), Service = S> + Clone + Send + 'static, S: Service> + Send, S::Error: Into, S::Future: Send, @@ -90,7 +86,6 @@ where fn call(&mut self, io: I) -> Self::Future { let target = self.target.clone(); let new_accept = self.inner.clone(); - let params = self.params.clone(); // Detect the SNI from a ClientHello (or timeout). let Timeout(timeout) = self.timeout; @@ -100,7 +95,7 @@ where let sni = sni.ok_or(NoSniFoundError)?; debug!("detected SNI: {:?}", sni); - let svc = new_accept.new_service(params.insert_param(sni, target)); + let svc = new_accept.new_service((sni, target)); svc.oneshot(io).await.map_err(Into::into) }) } From 67c0e849f8368473f722c6cdb0bffaecca5c8635 Mon Sep 17 00:00:00 2001 From: Oliver Gould Date: Thu, 10 Oct 2024 14:49:32 +0000 Subject: [PATCH 2/4] tls: Remove ExtractParam from detect_sni Similarly, we don't actually care about extracting a timeout from the target. Using an ExtractParam causes needless boilerplate. This change updates the stack module to simply take a timeout parameter at construction time. --- linkerd/app/outbound/src/tls.rs | 18 +------- linkerd/app/outbound/src/tls/logical/tests.rs | 13 ------ .../outbound/src/tls/logical/tests/basic.rs | 2 +- linkerd/tls/src/detect_sni.rs | 44 ++++++++++--------- linkerd/tls/src/lib.rs | 2 +- 5 files changed, 27 insertions(+), 52 deletions(-) diff --git a/linkerd/app/outbound/src/tls.rs b/linkerd/app/outbound/src/tls.rs index d0ecd468c5..0ee3c94a6f 100644 --- a/linkerd/app/outbound/src/tls.rs +++ b/linkerd/app/outbound/src/tls.rs @@ -7,7 +7,7 @@ use linkerd_app_core::{ core::Resolve, }, svc, - tls::{self, detect_sni::NewDetectSni, server::Timeout, ServerName}, + tls::{NewDetectSni, ServerName}, transport::addrs::*, Error, }; @@ -25,9 +25,6 @@ struct Tls { parent: T, } -#[derive(Clone)] -struct DetectParams(Timeout); - pub fn spawn_routes( mut route_rx: watch::Receiver, init: Routes, @@ -97,13 +94,11 @@ impl Outbound { .push_tls_concrete(resolve) .push_tls_logical() .map_stack(|config, _rt, stk| { - let detect_timeout = Timeout(config.proxy.detect_protocol_timeout); - stk.push_new_idle_cached(config.discovery_idle_timeout) // Use a dedicated target type to configure parameters for // the TLS stack. It also helps narrow the cache key. .push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent }) - .push(NewDetectSni::layer(DetectParams(detect_timeout))) + .push(NewDetectSni::layer(config.proxy.detect_protocol_timeout)) .arc_new_clone_tcp() }) } @@ -126,15 +121,6 @@ where } } -// === impl DetectParams === - -impl svc::ExtractParam for DetectParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - self.0 - } -} - // === impl TlsMetrics === impl TlsMetrics { diff --git a/linkerd/app/outbound/src/tls/logical/tests.rs b/linkerd/app/outbound/src/tls/logical/tests.rs index 34f22a90f9..f12aa283e5 100644 --- a/linkerd/app/outbound/src/tls/logical/tests.rs +++ b/linkerd/app/outbound/src/tls/logical/tests.rs @@ -3,7 +3,6 @@ use crate::test_util::*; use linkerd_app_core::{ io, svc::{self, NewService}, - tls, transport::addrs::*, Result, }; @@ -42,9 +41,6 @@ struct ConnectTcp { srvs: Arc>>, } -#[derive(Clone)] -struct DetectParams; - // === impl MockServer === impl MockServer { @@ -119,15 +115,6 @@ impl>> svc::Service for ConnectTcp { } } -// === impl DetectParams === - -impl svc::ExtractParam for DetectParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - tls::server::Timeout(Duration::from_secs(1)) - } -} - fn spawn_io( client_hello: Vec, ) -> ( diff --git a/linkerd/app/outbound/src/tls/logical/tests/basic.rs b/linkerd/app/outbound/src/tls/logical/tests/basic.rs index ad8e0364e7..4d14e34af2 100644 --- a/linkerd/app/outbound/src/tls/logical/tests/basic.rs +++ b/linkerd/app/outbound/src/tls/logical/tests/basic.rs @@ -35,7 +35,7 @@ async fn routes() { .map_stack(|config, _rt, stk| { stk.push_new_idle_cached(config.discovery_idle_timeout) .push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent }) - .push(NewDetectSni::layer(DetectParams)) + .push(NewDetectSni::layer(Duration::from_secs(1))) .arc_new_clone_tcp() }) .into_inner(); diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs index 41e66a0069..ceabbb25cb 100644 --- a/linkerd/tls/src/detect_sni.rs +++ b/linkerd/tls/src/detect_sni.rs @@ -1,10 +1,10 @@ use crate::{ - server::{detect_sni, DetectIo, Timeout}, + server::{detect_sni, DetectIo}, ServerName, }; use linkerd_error::Error; use linkerd_io as io; -use linkerd_stack::{layer, ExtractParam, NewService, Service, ServiceExt}; +use linkerd_stack::{layer, NewService, Service, ServiceExt}; use std::{ future::Future, pin::Pin, @@ -23,44 +23,47 @@ pub struct SniDetectionTimeoutError; pub struct NoSniFoundError; #[derive(Clone, Debug)] -pub struct NewDetectSni { - params: P, +pub struct NewDetectSni { inner: N, + timeout: time::Duration, } #[derive(Clone, Debug)] pub struct DetectSni { target: T, inner: N, - timeout: Timeout, + timeout: time::Duration, } -impl NewDetectSni { - pub fn new(params: P, inner: N) -> Self { - Self { inner, params } +impl NewDetectSni { + fn new(timeout: time::Duration, inner: N) -> Self { + Self { inner, timeout } } - pub fn layer(params: P) -> impl layer::Layer + Clone - where - P: Clone, - { - layer::mk(move |inner| Self::new(params.clone(), inner)) + pub fn layer(timeout: time::Duration) -> impl layer::Layer + Clone { + layer::mk(move |inner| Self::new(timeout, inner)) } } -impl NewService for NewDetectSni +impl NewService for NewDetectSni where - P: ExtractParam + Clone, N: Clone, { type Service = DetectSni; fn new_service(&self, target: T) -> Self::Service { - let timeout = self.params.extract_param(&target); - DetectSni { + DetectSni::new(self.timeout, target, self.inner.clone()) + } +} + +// === impl DetectSni === + +impl DetectSni { + fn new(timeout: time::Duration, target: T, inner: N) -> Self { + Self { target, + inner, timeout, - inner: self.inner.clone(), } } } @@ -88,13 +91,12 @@ where let new_accept = self.inner.clone(); // Detect the SNI from a ClientHello (or timeout). - let Timeout(timeout) = self.timeout; - let detect = time::timeout(timeout, detect_sni(io)); + let detect = time::timeout(self.timeout, detect_sni(io)); Box::pin(async move { let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; let sni = sni.ok_or(NoSniFoundError)?; - debug!("detected SNI: {:?}", sni); + debug!(?sni, "Detected TLS"); let svc = new_accept.new_service((sni, target)); svc.oneshot(io).await.map_err(Into::into) }) diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 308141e23f..7f61c069ae 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -7,7 +7,7 @@ pub mod server; pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, - detect_sni::NewDetectSni, + detect_sni::{DetectSni, NewDetectSni}, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; From f6202d661163d7fdf5692b3a30f2a93e1e3496b8 Mon Sep 17 00:00:00 2001 From: Oliver Gould Date: Thu, 10 Oct 2024 15:03:42 +0000 Subject: [PATCH 3/4] tls: Make the detect_tls module private We now only need to export the NewDetectSni type. The module reexport is not necessary. --- linkerd/tls/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 7f61c069ae..e1c148b353 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,7 +2,7 @@ #![forbid(unsafe_code)] pub mod client; -pub mod detect_sni; +mod detect_sni; pub mod server; pub use self::{ From ae88d2eecb7570b12161e4983c7890f56a7cdb88 Mon Sep 17 00:00:00 2001 From: Oliver Gould Date: Thu, 10 Oct 2024 15:20:17 +0000 Subject: [PATCH 4/4] tls: Reorganize NewDetectRequiredSni under the server module Because the DetectTls and DetectSni types are so similar -- and implemented in the context of a server inspecting a provided connection (and not a client establishing a TLS connection), this change reorganizes the module: * The DetectSni types are renamed to DetectRequiredSni to better reflect their purpose and difference from the DetectTls type. * The detect_sni module is renamed and moved to server::required_sni. This module is private and the relevant types are reexported from the server module. --- linkerd/app/outbound/src/tls.rs | 6 ++-- .../outbound/src/tls/logical/tests/basic.rs | 4 +-- linkerd/tls/src/lib.rs | 7 ++-- linkerd/tls/src/server.rs | 6 +++- .../{detect_sni.rs => server/required_sni.rs} | 34 +++++++++++++------ 5 files changed, 38 insertions(+), 19 deletions(-) rename linkerd/tls/src/{detect_sni.rs => server/required_sni.rs} (66%) diff --git a/linkerd/app/outbound/src/tls.rs b/linkerd/app/outbound/src/tls.rs index 0ee3c94a6f..4787665c08 100644 --- a/linkerd/app/outbound/src/tls.rs +++ b/linkerd/app/outbound/src/tls.rs @@ -7,7 +7,7 @@ use linkerd_app_core::{ core::Resolve, }, svc, - tls::{NewDetectSni, ServerName}, + tls::{NewDetectRequiredSni, ServerName}, transport::addrs::*, Error, }; @@ -98,7 +98,9 @@ impl Outbound { // Use a dedicated target type to configure parameters for // the TLS stack. It also helps narrow the cache key. .push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent }) - .push(NewDetectSni::layer(config.proxy.detect_protocol_timeout)) + .push(NewDetectRequiredSni::layer( + config.proxy.detect_protocol_timeout, + )) .arc_new_clone_tcp() }) } diff --git a/linkerd/app/outbound/src/tls/logical/tests/basic.rs b/linkerd/app/outbound/src/tls/logical/tests/basic.rs index 4d14e34af2..1eb1150e7c 100644 --- a/linkerd/app/outbound/src/tls/logical/tests/basic.rs +++ b/linkerd/app/outbound/src/tls/logical/tests/basic.rs @@ -2,7 +2,7 @@ use super::*; use crate::tls::Tls; use linkerd_app_core::{ svc::ServiceExt, - tls::{NewDetectSni, ServerName}, + tls::{NewDetectRequiredSni, ServerName}, trace, NameAddr, }; use linkerd_proxy_client_policy as client_policy; @@ -35,7 +35,7 @@ async fn routes() { .map_stack(|config, _rt, stk| { stk.push_new_idle_cached(config.discovery_idle_timeout) .push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent }) - .push(NewDetectSni::layer(Duration::from_secs(1))) + .push(NewDetectRequiredSni::layer(Duration::from_secs(1))) .arc_new_clone_tcp() }) .into_inner(); diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index e1c148b353..4d6b0f6136 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,13 +2,14 @@ #![forbid(unsafe_code)] pub mod client; -mod detect_sni; pub mod server; pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, - detect_sni::{DetectSni, NewDetectSni}, - server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, + server::{ + ClientId, ConditionalServerTls, NewDetectRequiredSni, NewDetectTls, NoServerTls, + NoSniFoundError, ServerTls, SniDetectionTimeoutError, + }, }; use linkerd_dns_name as dns; diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index 1c85c92ee6..fd61e10d97 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -1,4 +1,5 @@ mod client_hello; +mod required_sni; use crate::{NegotiatedProtocol, ServerName}; use bytes::BytesMut; @@ -18,6 +19,8 @@ use thiserror::Error; use tokio::time::{self, Duration}; use tracing::{debug, trace, warn}; +pub use self::required_sni::{NewDetectRequiredSni, NoSniFoundError, SniDetectionTimeoutError}; + /// Describes the authenticated identity of a remote client. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ClientId(pub id::Id); @@ -65,6 +68,7 @@ pub struct NewDetectTls { _local_identity: std::marker::PhantomData L>, } +/// A param type used to indicate the timeout after which detection should fail. #[derive(Copy, Clone, Debug)] pub struct Timeout(pub Duration); @@ -192,7 +196,7 @@ where } /// Peek or buffer the provided stream to determine an SNI value. -pub(crate) async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, { diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/server/required_sni.rs similarity index 66% rename from linkerd/tls/src/detect_sni.rs rename to linkerd/tls/src/server/required_sni.rs index ceabbb25cb..a20b5e8c31 100644 --- a/linkerd/tls/src/detect_sni.rs +++ b/linkerd/tls/src/server/required_sni.rs @@ -22,20 +22,32 @@ pub struct SniDetectionTimeoutError; #[error("Could not find SNI")] pub struct NoSniFoundError; +/// A NewService that instruments an inner stack with knowledge of the +/// connection's TLS ServerName (i.e. from an SNI header). +/// +/// This differs from the parent module's NewDetectTls in a a few ways: +/// +/// - It requires that all connections have an SNI. +/// - It assumes that these connections may not be terminated locally, so there +/// is no concept of a local server name. +/// - There are no special affordances for mutually authenticated TLS, so we +/// make no attempt to detect the client's identity. +/// - The detection timeout is fixed and cannot vary per target (for +/// convenience, to reduce needless boilerplate). #[derive(Clone, Debug)] -pub struct NewDetectSni { +pub struct NewDetectRequiredSni { inner: N, timeout: time::Duration, } #[derive(Clone, Debug)] -pub struct DetectSni { +pub struct DetectRequiredSni { target: T, inner: N, timeout: time::Duration, } -impl NewDetectSni { +impl NewDetectRequiredSni { fn new(timeout: time::Duration, inner: N) -> Self { Self { inner, timeout } } @@ -45,20 +57,20 @@ impl NewDetectSni { } } -impl NewService for NewDetectSni +impl NewService for NewDetectRequiredSni where N: Clone, { - type Service = DetectSni; + type Service = DetectRequiredSni; fn new_service(&self, target: T) -> Self::Service { - DetectSni::new(self.timeout, target, self.inner.clone()) + DetectRequiredSni::new(self.timeout, target, self.inner.clone()) } } // === impl DetectSni === -impl DetectSni { +impl DetectRequiredSni { fn new(timeout: time::Duration, target: T, inner: N) -> Self { Self { target, @@ -68,7 +80,7 @@ impl DetectSni { } } -impl Service for DetectSni +impl Service for DetectRequiredSni where T: Clone + Send + Sync + 'static, I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, @@ -93,10 +105,10 @@ where // Detect the SNI from a ClientHello (or timeout). let detect = time::timeout(self.timeout, detect_sni(io)); Box::pin(async move { - let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; - let sni = sni.ok_or(NoSniFoundError)?; - + let (res, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; + let sni = res.ok_or(NoSniFoundError)?; debug!(?sni, "Detected TLS"); + let svc = new_accept.new_service((sni, target)); svc.oneshot(io).await.map_err(Into::into) })