diff --git a/russh/examples/client_exec_interactive.rs b/russh/examples/client_exec_interactive.rs index dcf977e7..c7477989 100644 --- a/russh/examples/client_exec_interactive.rs +++ b/russh/examples/client_exec_interactive.rs @@ -112,7 +112,13 @@ impl Session { // use publickey authentication, with or without certificate if openssh_cert.is_none() { let auth_res = session - .authenticate_publickey(user, Arc::new(key_pair)) + .authenticate_publickey( + user, + PrivateKeyWithHashAlg::new( + Arc::new(key_pair), + session.best_supported_rsa_hash().await?.flatten(), + ), + ) .await?; if !auth_res.success() { diff --git a/russh/examples/client_exec_simple.rs b/russh/examples/client_exec_simple.rs index 70d9c2cb..5405893c 100644 --- a/russh/examples/client_exec_simple.rs +++ b/russh/examples/client_exec_simple.rs @@ -99,7 +99,13 @@ impl Session { let mut session = client::connect(config, addrs, sh).await?; let auth_res = session - .authenticate_publickey(user, Arc::new(key_pair)) + .authenticate_publickey( + user, + PrivateKeyWithHashAlg::new( + Arc::new(key_pair), + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) .await?; if !auth_res.success() { diff --git a/russh/src/auth.rs b/russh/src/auth.rs index dd0146e3..0c1607b1 100644 --- a/russh/src/auth.rs +++ b/russh/src/auth.rs @@ -23,6 +23,7 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use crate::helpers::NameList; +use crate::keys::PrivateKeyWithHashAlg; use crate::CryptoVec; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -191,10 +192,7 @@ pub enum Method { password: String, }, PublicKey { - key: Arc, - /// None = based on server-sig-algs - /// Some(None) = SHA1 - hash_alg: Option>, + key: PrivateKeyWithHashAlg, }, OpenSshCertificate { key: Arc, diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index ef76b9f4..12d6f3ae 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -16,12 +16,11 @@ use std::cell::RefCell; use std::convert::TryInto; use std::ops::Deref; use std::str::FromStr; -use std::sync::Arc; use bytes::Bytes; use log::{debug, error, info, trace, warn}; use ssh_encoding::{Decode, Encode, Reader}; -use ssh_key::{Algorithm, HashAlg, PrivateKey}; +use ssh_key::Algorithm; use super::IncomingSshPacket; use crate::auth::AuthRequest; @@ -29,7 +28,6 @@ use crate::cert::PublicKeyOrCertificate; use crate::client::{Handler, Msg, Prompt, Reply, Session}; use crate::helpers::{map_err, sign_with_hash_alg, AlgorithmExt, EncodedExt, NameList}; use crate::keys::key::parse_public_key; -use crate::keys::PrivateKeyWithHashAlg; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; use crate::session::{Encrypted, EncryptedState, GlobalRequestResponse}; use crate::{ @@ -291,17 +289,15 @@ impl Session { fn handle_server_sig_algs_ext(&mut self, r: &mut impl Reader) -> Result<(), Error> { let algs = NameList::decode(r)?; debug!("* server-sig-algs"); - if let Some(ref mut enc) = self.common.encrypted { - enc.server_sig_algs = Some( - algs.0 - .iter() - .filter_map(|x| Algorithm::from_str(x).ok()) - .inspect(|x| { - debug!(" * {x:?}"); - }) - .collect::>(), - ); - } + self.server_sig_algs = Some( + algs.0 + .iter() + .filter_map(|x| Algorithm::from_str(x).ok()) + .inspect(|x| { + debug!(" * {x:?}"); + }) + .collect::>(), + ); Ok(()) } @@ -848,39 +844,6 @@ impl Session { } impl Encrypted { - fn pick_hash_alg_for_key( - &self, - key: Arc, - hash_alg_choice: Option>, - ) -> Result { - Ok(match hash_alg_choice { - Some(hash_alg) => PrivateKeyWithHashAlg::new(key.clone(), hash_alg)?, - None => { - if key.algorithm().is_rsa() { - PrivateKeyWithHashAlg::new(key.clone(), self.best_key_hash_alg())? - } else { - PrivateKeyWithHashAlg::new(key.clone(), None)? - } - } - }) - } - - fn best_key_hash_alg(&self) -> Option { - if let Some(ref ssa) = self.server_sig_algs { - let possible_algs = [ - Some(ssh_key::HashAlg::Sha512), - Some(ssh_key::HashAlg::Sha256), - None, - ]; - for alg in possible_algs.into_iter() { - if ssa.contains(&Algorithm::Rsa { hash: alg }) { - return alg; - } - } - } - None - } - fn write_auth_request( &mut self, user: &str, @@ -905,12 +868,7 @@ impl Encrypted { password.encode(&mut self.write)?; true } - auth::Method::PublicKey { - ref key, - ref hash_alg, - } => { - let key = self.pick_hash_alg_for_key(key.clone(), *hash_alg)?; - + auth::Method::PublicKey { ref key } => { user.encode(&mut self.write)?; "ssh-connection".encode(&mut self.write)?; "publickey".encode(&mut self.write)?; @@ -994,13 +952,12 @@ impl Encrypted { buffer: &mut CryptoVec, ) -> Result<(), crate::Error> { match method { - auth::Method::PublicKey { key, hash_alg } => { - let key = self.pick_hash_alg_for_key(key.clone(), *hash_alg)?; + auth::Method::PublicKey { key } => { let i0 = - self.client_make_to_sign(user, &PublicKeyOrCertificate::from(&key), buffer)?; + self.client_make_to_sign(user, &PublicKeyOrCertificate::from(key), buffer)?; // Extend with self-signature. - sign_with_hash_alg(&key, buffer)?.encode(&mut *buffer)?; + sign_with_hash_alg(key, buffer)?.encode(&mut *buffer)?; push_packet!(self.write, { #[allow(clippy::indexing_slicing)] // length checked diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index bd8f1dfd..9728a82a 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -47,7 +47,7 @@ use kex::ClientKex; use log::{debug, error, trace}; use russh_util::time::Instant; use ssh_encoding::Decode; -use ssh_key::{Certificate, HashAlg, PrivateKey, PublicKey}; +use ssh_key::{Algorithm, Certificate, HashAlg, PrivateKey, PublicKey}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::pin; use tokio::sync::mpsc::{ @@ -59,6 +59,7 @@ pub use crate::auth::AuthResult; use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSizeRef}; use crate::cipher::{self, clear, OpeningKey}; use crate::kex::{KexCause, KexProgress, SessionKexState}; +use crate::keys::PrivateKeyWithHashAlg; use crate::msg::{is_kex_msg, validate_server_msg_strict_kex}; use crate::session::{CommonSession, EncryptedState, GlobalRequestResponse, NewKeys}; use crate::ssh_read::SshRead; @@ -90,6 +91,7 @@ pub struct Session { inbound_channel_sender: Sender, inbound_channel_receiver: Receiver, open_global_requests: VecDeque, + server_sig_algs: Option>, } impl Drop for Session { @@ -180,6 +182,9 @@ pub enum Msg { }, Channel(ChannelId, ChannelMsg), Rekey, + GetServerSigAlgs { + reply_channel: oneshot::Sender>>, + }, } impl From<(ChannelId, ChannelMsg)> for Msg { @@ -359,44 +364,22 @@ impl Handle { } /// Perform public key-based SSH authentication. - /// This method will automatically select the best hash function - /// if the server supports the `server-sig-algs` protocol extension - /// and will fall back to SHA-1 otherwise. + /// + /// For RSA keys, you'll need to decide on which hash algorithm to use. + /// This is the difference between what is also known as + /// `ssh-rsa`, `rsa-sha2-256`, and `rsa-sha2-512` "keys" in OpenSSH. + /// You can use [Handle::best_supported_rsa_hash] to automatically + /// figure out the best hash algorithm for RSA keys. pub async fn authenticate_publickey>( &mut self, user: U, - key: Arc, + key: PrivateKeyWithHashAlg, ) -> Result { let user = user.into(); self.sender .send(Msg::Authenticate { user, - method: auth::Method::PublicKey { - key, - hash_alg: None, - }, - }) - .await - .map_err(|_| crate::Error::SendError)?; - self.wait_recv_reply().await - } - - /// Perform public key-based SSH authentication - /// with an explicit hash algorithm selection (for RSA keys). - pub async fn authenticate_publickey_with_hash>( - &mut self, - user: U, - key: Arc, - hash_alg: Option, - ) -> Result { - let user = user.into(); - self.sender - .send(Msg::Authenticate { - user, - method: auth::Method::PublicKey { - key, - hash_alg: Some(hash_alg), - }, + method: auth::Method::PublicKey { key }, }) .await .map_err(|_| crate::Error::SendError)?; @@ -507,6 +490,39 @@ impl Handle { } } + /// Returns the best RSA hash algorithm supported by the server, + /// as indicated by the `server-sig-algs` extension. + /// If the server does not support the extension, + /// `None` is returned. In this case you may still attempt an authentication + /// with `rsa-sha2-256` or `rsa-sha2-512` and hope for the best. + /// If the server supports the extension, but does not support `rsa-sha2-*`, + /// `Some(None)` is returned. + pub async fn best_supported_rsa_hash(&self) -> Result>, Error> { + let (sender, receiver) = oneshot::channel(); + + self.sender + .send(Msg::GetServerSigAlgs { + reply_channel: sender, + }) + .await + .map_err(|_| crate::Error::SendError)?; + + if let Some(ssa) = receiver.await.map_err(|_| Error::Inconsistent)? { + let possible_algs = [ + Some(ssh_key::HashAlg::Sha512), + Some(ssh_key::HashAlg::Sha256), + None, + ]; + for alg in possible_algs.into_iter() { + if ssa.contains(&Algorithm::Rsa { hash: alg }) { + return Ok(Some(alg)); + } + } + } + + Ok(None) + } + /// Request a session channel (the most basic type of /// channel). This function returns `Some(..)` immediately if the /// connection is authenticated, but the channel only becomes @@ -896,6 +912,7 @@ impl Session { pending_reads: Vec::new(), pending_len: 0, open_global_requests: VecDeque::new(), + server_sig_algs: None, } } @@ -1255,6 +1272,9 @@ impl Session { } Msg::Channel(id, ChannelMsg::Close) => self.close(id)?, Msg::Rekey => self.initiate_rekey()?, + Msg::GetServerSigAlgs { reply_channel } => { + let _ = reply_channel.send(self.server_sig_algs.clone()); + } msg => { // should be unreachable, since the receiver only gets // messages from methods implemented within russh diff --git a/russh/src/keys/agent/server.rs b/russh/src/keys/agent/server.rs index 50ab8f0b..bdcedbb1 100644 --- a/russh/src/keys/agent/server.rs +++ b/russh/src/keys/agent/server.rs @@ -341,7 +341,7 @@ impl, - hash_alg: Option, - ) -> Result { - if hash_alg.is_some() && !key.algorithm().is_rsa() { - return Err(crate::keys::Error::InvalidParameters); + mut hash_alg: Option, + ) -> Self { + if !key.algorithm().is_rsa() { + hash_alg = None; } - Ok(Self { key, hash_alg }) + Self { key, hash_alg } } pub fn algorithm(&self) -> Algorithm { diff --git a/russh/src/server/kex.rs b/russh/src/server/kex.rs index 881a62bc..4116ef44 100644 --- a/russh/src/server/kex.rs +++ b/russh/src/server/kex.rs @@ -272,8 +272,7 @@ impl ServerKex { // Hash signature debug!("signing with key {:?}", key); let signature = sign_with_hash_alg( - &PrivateKeyWithHashAlg::new(Arc::new(key.clone()), signature_hash_alg) - .map_err(Into::into)?, + &PrivateKeyWithHashAlg::new(Arc::new(key.clone()), signature_hash_alg), &hash, ) .map_err(Into::into)?; diff --git a/russh/src/session.rs b/russh/src/session.rs index 18b9b2fe..d25fa4be 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -21,7 +21,6 @@ use std::num::Wrapping; use byteorder::{BigEndian, ByteOrder}; use log::{debug, trace}; use ssh_encoding::Encode; -use ssh_key::Algorithm; use tokio::sync::oneshot; use crate::cipher::OpeningKey; @@ -53,7 +52,6 @@ pub(crate) struct Encrypted { pub client_compression: crate::compression::Compression, pub decompress: crate::compression::Decompress, pub rekey_wanted: bool, - pub server_sig_algs: Option>, } pub(crate) struct CommonSession { @@ -153,7 +151,6 @@ impl CommonSession { client_compression: newkeys.names.client_compression, decompress: crate::compression::Decompress::None, rekey_wanted: false, - server_sig_algs: None, }); self.remote_to_local = newkeys.cipher.remote_to_local; self.packet_writer diff --git a/russh/src/tests.rs b/russh/src/tests.rs index db054157..dbe6bbb2 100644 --- a/russh/src/tests.rs +++ b/russh/src/tests.rs @@ -9,6 +9,7 @@ mod compress { use std::sync::{Arc, Mutex}; use async_trait::async_trait; + use keys::PrivateKeyWithHashAlg; use log::debug; use rand_core::OsRng; use ssh_key::PrivateKey; @@ -52,7 +53,10 @@ mod compress { let authenticated = session .authenticate_publickey( std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), + PrivateKeyWithHashAlg::new( + Arc::new(client_key), + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), ) .await .unwrap() @@ -139,6 +143,7 @@ mod compress { mod channels { use async_trait::async_trait; + use keys::PrivateKeyWithHashAlg; use rand_core::OsRng; use server::Session; use ssh_key::PrivateKey; @@ -195,7 +200,7 @@ mod channels { let authenticated = session .authenticate_publickey( std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), + PrivateKeyWithHashAlg::new(Arc::new(client_key), None), ) .await .unwrap(); diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs index 597fb07d..5c48e1b7 100644 --- a/russh/tests/test_backpressure.rs +++ b/russh/tests/test_backpressure.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use futures::FutureExt; use rand::RngCore; use rand_core::OsRng; +use russh::keys::PrivateKeyWithHashAlg; use russh::server::{self, Auth, Msg, Server as _, Session}; use russh::{client, Channel, ChannelMsg}; use ssh_key::PrivateKey; @@ -40,7 +41,13 @@ async fn stream(addr: SocketAddr, data: &[u8], tx: watch::Sender<()>) -> Result< let mut session = russh::client::connect(config, addr, Client).await?; let channel = match session - .authenticate_publickey("user", key) + .authenticate_publickey( + "user", + PrivateKeyWithHashAlg::new( + key, + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) .await .map(|x| x.success()) { diff --git a/russh/tests/test_data_stream.rs b/russh/tests/test_data_stream.rs index 85b53938..45fe5c23 100644 --- a/russh/tests/test_data_stream.rs +++ b/russh/tests/test_data_stream.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use rand::RngCore; use rand_core::OsRng; +use russh::keys::PrivateKeyWithHashAlg; use russh::server::{self, Auth, Msg, Server as _, Session}; use russh::{client, Channel}; use ssh_key::PrivateKey; @@ -35,7 +36,13 @@ async fn stream(addr: SocketAddr, data: &[u8]) -> Result<(), anyhow::Error> { let mut session = russh::client::connect(config, addr, Client).await?; let mut channel = match session - .authenticate_publickey("user", key) + .authenticate_publickey( + "user", + PrivateKeyWithHashAlg::new( + key, + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) .await .map(|x| x.success()) {