Skip to content

Commit

Permalink
feat: ssh tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
honhimW committed Oct 30, 2024
1 parent 6ff465b commit 40e70a7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 55 deletions.
18 changes: 10 additions & 8 deletions examples/ssh_tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ use crate::common::client::build_pool;
use anyhow::Error;
use anyhow::Result;
use async_trait::async_trait;
use ratisui::ssh_tunnel::SshTunnel;
use redis::cmd;
use russh::client::{Config, Handler};
use russh::client::Handler;
use russh::keys::key;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::Arc;
use tokio::net::TcpListener;
use ratisui::ssh_tunnel;
use ratisui::ssh_tunnel::SshTunnel;
use std::ops::Deref;

const SSH_HOST: &str = "10.37.1.133";
const SSH_PORT: u16 = 22;
Expand All @@ -21,6 +18,8 @@ const SSH_PASSWORD: &str = "123";

const REDIS_HOST: &str = "redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com";
const REDIS_PORT: u16 = 16430;
const REDIS_USER: Some(String) = Some(String::from("default"));
const REDIS_PASSWORD: Some(String) = Some(String::from("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy"));

const LOCAL_HOST: &str = "127.0.0.1";

Expand Down Expand Up @@ -51,18 +50,21 @@ async fn main() -> Result<()> {
let pool = build_pool(common::client::Config {
host: addr.ip().to_string(),
port: addr.port(),
username: Some(String::from("default")),
password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()),
username: REDIS_USER.deref().clone(),
password: REDIS_PASSWORD.deref().clone(),
..Default::default()
})?;
let mut connection = pool.get().await?;
let pong: String = cmd("PING").query_async(&mut connection).await?;
assert_eq!(pool.status().size, 1);
assert!("PONG".eq_ignore_ascii_case(pong.as_str()));
let mut connection = pool.get().await?;
let pong: String = cmd("PING").query_async(&mut connection).await?;
assert_eq!(pool.status().size, 3);
assert!("PONG".eq_ignore_ascii_case(pong.as_str()));
let mut connection = pool.get().await?;
let pong: String = cmd("PING").query_async(&mut connection).await?;
assert_eq!(pool.status().size, 3);
assert!("PONG".eq_ignore_ascii_case(pong.as_str()));
ssh_tunnel.close().await?;
assert!(!ssh_tunnel.is_connected());
Expand Down
35 changes: 14 additions & 21 deletions src/redis_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@ use anyhow::{anyhow, Context, Error, Result};
use crossbeam_channel::Sender;
use deadpool_redis::redis::cmd;
use deadpool_redis::{Pool, Runtime};
use futures::future::join_all;
use futures::StreamExt;
use log::{debug, error, info};
use log::{info};
use once_cell::sync::Lazy;
use redis::cluster::ClusterClient;
use redis::ConnectionAddr::{Tcp, TcpTls};
use redis::{AsyncCommands, AsyncIter, Client, Cmd, ConnectionAddr, ConnectionInfo, ConnectionLike, FromRedisValue, RedisConnectionInfo, ScanOptions, ToRedisArgs, Value, VerbatimFormat};
use std::collections::HashMap;
use std::future::Future;
use std::ops::DerefMut;
use std::sync::{Arc, RwLock};
use std::sync::RwLock;
use std::task::Poll;
use std::time::{Duration, Instant};
use futures::future::join_all;
use tokio::join;
use tokio::time::interval;

#[macro_export]
Expand Down Expand Up @@ -374,22 +372,17 @@ impl RedisOperations {
let (id, node_holder) = result?;
let host;
let port;
if let Some(ref ssh_tunnel) = node_holder.ssh_tunnel {
host = ssh_tunnel.host.clone();
port = ssh_tunnel.port;
} else {
match &node_holder.pool.manager().client.get_connection_info().addr {
Tcp(h, p) => {
host = h.clone();
port = *p;
}
TcpTls { host: h, port: p, .. } => {
host = h.clone();
port = *p;
}
_ => {
return Err(anyhow!("Not supported connection type"))
}
match &node_holder.pool.manager().client.get_connection_info().addr {
Tcp(h, p) => {
host = h.clone();
port = *p;
}
TcpTls { host: h, port: p, .. } => {
host = h.clone();
port = *p;
}
_ => {
return Err(anyhow!("Not supported connection type"))
}
}
node_holders.insert(id, node_holder);
Expand Down
96 changes: 70 additions & 26 deletions src/ssh_tunnel.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use anyhow::{Error, Result};
use async_trait::async_trait;
use log::{error, info, warn};
use russh::client::{Config, Handler, Msg};
use russh::client::{Config, Handler};
use russh::keys::key;
use russh::{Channel, Disconnect};
use russh::Disconnect;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
Expand All @@ -24,7 +24,6 @@ pub struct SshTunnel {
}

impl SshTunnel {

pub fn new(host: String, port: u16, username: String, password: String, forwarding_host: String, forwarding_port: u16) -> Self {
let (tx, rx) = tokio::sync::watch::channel::<u8>(1);
Self {
Expand All @@ -34,7 +33,8 @@ impl SshTunnel {
password,
forwarding_host,
forwarding_port,
tx, rx,
tx,
rx,
is_connected: false,
}
}
Expand All @@ -45,36 +45,44 @@ impl SshTunnel {
format!("{}:{}", self.host, self.port),
IHandler {},
).await?;

ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?;
let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
let addr = listener.local_addr()?;
let forwarding_host = self.forwarding_host.clone();
let forwarding_port = self.forwarding_port as u32;

let channel = ssh_client.channel_open_direct_tcpip(
self.forwarding_host.clone(),
self.forwarding_port as u32,
Ipv4Addr::LOCALHOST.to_string(),
addr.port() as u32,
).await?;

let mut remote_stream = channel.into_stream();
let mut rx_clone = self.rx.clone();
let rx_clone = self.rx.clone();
tokio::spawn(async move {
if let Ok((mut local_stream, _)) = listener.accept().await {
select! {
result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => {
if let Err(e) = result {
error!("Error during bidirectional copy: {}", e);
loop {
let mut rx_clone_clone = rx_clone.clone();
if let Ok((mut local_stream, _)) = listener.accept().await {
let channel = ssh_client.channel_open_direct_tcpip(
forwarding_host.clone(),
forwarding_port,
Ipv4Addr::LOCALHOST.to_string(),
addr.port() as u32,
).await?;
let mut remote_stream = channel.into_stream();
tokio::spawn(async move {
select! {
result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => {
if let Err(e) = result {
error!("Error during bidirectional copy: {}", e);
}
warn!("Bidirectional copy stopped");
}
_ = rx_clone_clone.changed() => {
info!("Received close signal");
}
}
warn!("Bidirectional copy stopped");
}
_ = rx_clone.changed() => {
info!("Received close signal");
}
let _ = remote_stream.shutdown().await;
});
}
if rx_clone.has_changed()? {
ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?;
break;
}
}
ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?;
remote_stream.shutdown().await?;
drop(listener);
info!("Stream closed");
Ok::<(), Error>(())
Expand All @@ -90,6 +98,7 @@ impl SshTunnel {
Ok(())
}

#[allow(unused)]
pub fn is_connected(&self) -> bool {
self.is_connected
}
Expand All @@ -103,4 +112,39 @@ impl Handler for IHandler {
async fn check_server_key(&mut self, _: &key::PublicKey) -> Result<bool, Self::Error> {
Ok(true)
}
}

#[cfg(test)]
mod test {
use anyhow::Result;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::time::{Duration, Instant};
use tokio::net::{TcpListener, TcpStream};

#[tokio::test]
async fn tcp_listener() -> Result<()> {
let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
let addr = listener.local_addr()?;

tokio::spawn(async move {
let now = Instant::now();
loop {
if let Ok((mut stream, _)) = listener.accept().await {
println!("{:?} {:?}", now.elapsed(), stream);
} else {
println!("No connection");
}
}
});
tokio::time::sleep(Duration::from_secs(1)).await;
let x = TcpStream::connect(addr).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
let x = TcpStream::connect(addr).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
let x = TcpStream::connect(addr).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
let x = TcpStream::connect(addr).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(())
}
}

0 comments on commit 40e70a7

Please sign in to comment.