From a8c7f9dd648648705137d4cb7088cd1e187dd8f4 Mon Sep 17 00:00:00 2001 From: SabrinaJewson Date: Sat, 30 Jul 2022 18:48:24 +0100 Subject: [PATCH] Fix ALPN --- Cargo.toml | 2 +- src/stream.rs | 27 +++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f8601f..40b2406 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ futures-util = { version = "0.3", default-features = false } bytes = "1.0" hyper-tls = { version = "0.5.0", optional = true } tokio-native-tls = { version = "0.3.0", optional = true } -native-tls = { version = "0.2", optional = true } +native-tls = { version = "0.2", optional = true, features = ["alpn"] } openssl = { version = "0.10", optional = true } tokio-openssl = { version = "0.6", optional = true } tokio-rustls = { version = "0.22", optional = true } diff --git a/src/stream.rs b/src/stream.rs index 4f45be6..b06d515 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -95,18 +95,37 @@ impl AsyncWrite for ProxyStream { impl Connection for ProxyStream { fn connected(&self) -> Connected { - match self { + let mut is_h2 = false; + + let connected = match self { ProxyStream::NoProxy(s) => s.connected(), ProxyStream::Regular(s) => s.connected().proxy(true), #[cfg(feature = "tls")] - ProxyStream::Secured(s) => s.get_ref().get_ref().get_ref().connected().proxy(true), + ProxyStream::Secured(s) => { + let stream = s.get_ref(); + is_h2 = stream.negotiated_alpn().ok().flatten().as_deref() == Some(b"h2"); + stream.get_ref().get_ref().connected().proxy(true) + } #[cfg(feature = "rustls-base")] - ProxyStream::Secured(s) => s.get_ref().0.connected().proxy(true), + ProxyStream::Secured(s) => { + let (underlying, tls) = s.get_ref(); + is_h2 = tls.alpn_protocol() == Some(b"h2"); + underlying.connected().proxy(true) + } #[cfg(feature = "openssl-tls")] - ProxyStream::Secured(s) => s.get_ref().connected().proxy(true), + ProxyStream::Secured(s) => { + is_h2 = s.ssl().selected_alpn_protocol() == Some(b"h2"); + s.get_ref().connected().proxy(true) + } + }; + + if is_h2 { + connected.negotiated_h2() + } else { + connected } } }