From 40e70a7308a37b352ba9d9088897e64db98e11cf Mon Sep 17 00:00:00 2001 From: honhimW Date: Wed, 30 Oct 2024 15:50:46 +0800 Subject: [PATCH] feat: ssh tunnel --- examples/ssh_tunnel.rs | 18 ++++---- src/redis_opt.rs | 35 ++++++--------- src/ssh_tunnel.rs | 96 ++++++++++++++++++++++++++++++------------ 3 files changed, 94 insertions(+), 55 deletions(-) diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 5e33ae8..27dd709 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -5,14 +5,11 @@ use crate::common::client::build_pool; use anyhow::Error; use anyhow::Result; use async_trait::async_trait; +use ratisui::ssh_tunnel::SshTunnel; use redis::cmd; -use russh::client::{Config, Handler}; +use russh::client::Handler; use russh::keys::key; -use std::net::{Ipv4Addr, SocketAddrV4}; -use std::sync::Arc; -use tokio::net::TcpListener; -use ratisui::ssh_tunnel; -use ratisui::ssh_tunnel::SshTunnel; +use std::ops::Deref; const SSH_HOST: &str = "10.37.1.133"; const SSH_PORT: u16 = 22; @@ -21,6 +18,8 @@ const SSH_PASSWORD: &str = "123"; const REDIS_HOST: &str = "redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com"; const REDIS_PORT: u16 = 16430; +const REDIS_USER: Some(String) = Some(String::from("default")); +const REDIS_PASSWORD: Some(String) = Some(String::from("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy")); const LOCAL_HOST: &str = "127.0.0.1"; @@ -51,18 +50,21 @@ async fn main() -> Result<()> { let pool = build_pool(common::client::Config { host: addr.ip().to_string(), port: addr.port(), - username: Some(String::from("default")), - password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), + username: REDIS_USER.deref().clone(), + password: REDIS_PASSWORD.deref().clone(), ..Default::default() })?; let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 1); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 3); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 3); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); ssh_tunnel.close().await?; assert!(!ssh_tunnel.is_connected()); diff --git a/src/redis_opt.rs b/src/redis_opt.rs index a09f5bd..06aadd7 100644 --- a/src/redis_opt.rs +++ b/src/redis_opt.rs @@ -6,20 +6,18 @@ use anyhow::{anyhow, Context, Error, Result}; use crossbeam_channel::Sender; use deadpool_redis::redis::cmd; use deadpool_redis::{Pool, Runtime}; +use futures::future::join_all; use futures::StreamExt; -use log::{debug, error, info}; +use log::{info}; use once_cell::sync::Lazy; -use redis::cluster::ClusterClient; use redis::ConnectionAddr::{Tcp, TcpTls}; use redis::{AsyncCommands, AsyncIter, Client, Cmd, ConnectionAddr, ConnectionInfo, ConnectionLike, FromRedisValue, RedisConnectionInfo, ScanOptions, ToRedisArgs, Value, VerbatimFormat}; use std::collections::HashMap; use std::future::Future; use std::ops::DerefMut; -use std::sync::{Arc, RwLock}; +use std::sync::RwLock; use std::task::Poll; use std::time::{Duration, Instant}; -use futures::future::join_all; -use tokio::join; use tokio::time::interval; #[macro_export] @@ -374,22 +372,17 @@ impl RedisOperations { let (id, node_holder) = result?; let host; let port; - if let Some(ref ssh_tunnel) = node_holder.ssh_tunnel { - host = ssh_tunnel.host.clone(); - port = ssh_tunnel.port; - } else { - match &node_holder.pool.manager().client.get_connection_info().addr { - Tcp(h, p) => { - host = h.clone(); - port = *p; - } - TcpTls { host: h, port: p, .. } => { - host = h.clone(); - port = *p; - } - _ => { - return Err(anyhow!("Not supported connection type")) - } + match &node_holder.pool.manager().client.get_connection_info().addr { + Tcp(h, p) => { + host = h.clone(); + port = *p; + } + TcpTls { host: h, port: p, .. } => { + host = h.clone(); + port = *p; + } + _ => { + return Err(anyhow!("Not supported connection type")) } } node_holders.insert(id, node_holder); diff --git a/src/ssh_tunnel.rs b/src/ssh_tunnel.rs index f2a0cc5..d24a4a5 100644 --- a/src/ssh_tunnel.rs +++ b/src/ssh_tunnel.rs @@ -1,9 +1,9 @@ use anyhow::{Error, Result}; use async_trait::async_trait; use log::{error, info, warn}; -use russh::client::{Config, Handler, Msg}; +use russh::client::{Config, Handler}; use russh::keys::key; -use russh::{Channel, Disconnect}; +use russh::Disconnect; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::Arc; use tokio::io::AsyncWriteExt; @@ -24,7 +24,6 @@ pub struct SshTunnel { } impl SshTunnel { - pub fn new(host: String, port: u16, username: String, password: String, forwarding_host: String, forwarding_port: u16) -> Self { let (tx, rx) = tokio::sync::watch::channel::(1); Self { @@ -34,7 +33,8 @@ impl SshTunnel { password, forwarding_host, forwarding_port, - tx, rx, + tx, + rx, is_connected: false, } } @@ -45,36 +45,44 @@ impl SshTunnel { format!("{}:{}", self.host, self.port), IHandler {}, ).await?; - ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?; let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; let addr = listener.local_addr()?; + let forwarding_host = self.forwarding_host.clone(); + let forwarding_port = self.forwarding_port as u32; - let channel = ssh_client.channel_open_direct_tcpip( - self.forwarding_host.clone(), - self.forwarding_port as u32, - Ipv4Addr::LOCALHOST.to_string(), - addr.port() as u32, - ).await?; - - let mut remote_stream = channel.into_stream(); - let mut rx_clone = self.rx.clone(); + let rx_clone = self.rx.clone(); tokio::spawn(async move { - if let Ok((mut local_stream, _)) = listener.accept().await { - select! { - result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => { - if let Err(e) = result { - error!("Error during bidirectional copy: {}", e); + loop { + let mut rx_clone_clone = rx_clone.clone(); + if let Ok((mut local_stream, _)) = listener.accept().await { + let channel = ssh_client.channel_open_direct_tcpip( + forwarding_host.clone(), + forwarding_port, + Ipv4Addr::LOCALHOST.to_string(), + addr.port() as u32, + ).await?; + let mut remote_stream = channel.into_stream(); + tokio::spawn(async move { + select! { + result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => { + if let Err(e) = result { + error!("Error during bidirectional copy: {}", e); + } + warn!("Bidirectional copy stopped"); + } + _ = rx_clone_clone.changed() => { + info!("Received close signal"); + } } - warn!("Bidirectional copy stopped"); - } - _ = rx_clone.changed() => { - info!("Received close signal"); - } + let _ = remote_stream.shutdown().await; + }); + } + if rx_clone.has_changed()? { + ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?; + break; } } - ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?; - remote_stream.shutdown().await?; drop(listener); info!("Stream closed"); Ok::<(), Error>(()) @@ -90,6 +98,7 @@ impl SshTunnel { Ok(()) } + #[allow(unused)] pub fn is_connected(&self) -> bool { self.is_connected } @@ -103,4 +112,39 @@ impl Handler for IHandler { async fn check_server_key(&mut self, _: &key::PublicKey) -> Result { Ok(true) } +} + +#[cfg(test)] +mod test { + use anyhow::Result; + use std::net::{Ipv4Addr, SocketAddrV4}; + use std::time::{Duration, Instant}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tcp_listener() -> Result<()> { + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; + let addr = listener.local_addr()?; + + tokio::spawn(async move { + let now = Instant::now(); + loop { + if let Ok((mut stream, _)) = listener.accept().await { + println!("{:?} {:?}", now.elapsed(), stream); + } else { + println!("No connection"); + } + } + }); + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(()) + } } \ No newline at end of file