From a935340aa94d563509bc05e2ef13a488e851c64a Mon Sep 17 00:00:00 2001 From: honhimW Date: Mon, 21 Oct 2024 19:24:01 +0800 Subject: [PATCH 1/7] update --- Cargo.toml | 4 ++++ examples/ssh_tunnel.rs | 48 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 examples/ssh_tunnel.rs diff --git a/Cargo.toml b/Cargo.toml index 529da2e..da726c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,10 @@ tree-sitter-highlight = "0.23" tree-sitter-json = "0.23" tree-sitter-html = "0.23" tree-sitter-ron = { path = "ratisui-tree-sitter-ron" } +russh = { version = "0.45.0" } +russh-keys = "0.45.0" +async-trait = "0.1.82" +env_logger = "0.11.5" # git crates diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs new file mode 100644 index 0000000..9e55e73 --- /dev/null +++ b/examples/ssh_tunnel.rs @@ -0,0 +1,48 @@ +use std::io::Read; +use anyhow::Result; +use russh::client::{Config, Handle, Handler, Msg}; +use russh::Channel; +use std::net::{SocketAddrV4, TcpStream}; +use std::str::FromStr; +use std::sync::Arc; +use async_trait::async_trait; +use russh_keys::key::PublicKey; +use tokio::io::AsyncReadExt; + +async fn create_ssh_tunnel() -> Result<(TcpStream, Channel)> { + let config = Config::default(); + let handler = IHandler {}; + let mut client = russh::client::connect(Arc::new(config), SocketAddrV4::from_str("xxx:xx")?, handler).await?; + let x = client.authenticate_password("guest", "123").await?; + println!("{}", x); + let channel = client.channel_open_direct_tcpip( + "xxx", + 6379, + "127.0.0.1", + 6379, + ).await?; + Ok((TcpStream::connect("127.0.0.1:6379")?, channel)) +} + +#[tokio::main] +async fn main() -> Result<()> { + let (mut stream, _channel) = create_ssh_tunnel().await?; + + let mut buffer = vec![0; 1024]; + let i = stream.read_to_end(&mut buffer)?; + + Ok(()) +} + +struct IHandler {} + +#[async_trait] +impl Handler for IHandler { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &PublicKey) -> std::result::Result { + Ok(true) + } + + +} From f0d52bbf9d099669f0ab1122660e6d9087c62a14 Mon Sep 17 00:00:00 2001 From: honhimW Date: Tue, 22 Oct 2024 19:13:47 +0800 Subject: [PATCH 2/7] update: example ssh_tunnel --- Cargo.toml | 7 +--- examples/ssh_tunnel.rs | 86 +++++++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da726c0..6d74b38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,11 +69,8 @@ tree-sitter-highlight = "0.23" tree-sitter-json = "0.23" tree-sitter-html = "0.23" tree-sitter-ron = { path = "ratisui-tree-sitter-ron" } -russh = { version = "0.45.0" } -russh-keys = "0.45.0" -async-trait = "0.1.82" -env_logger = "0.11.5" - +russh = { version = "0.45.0", default-features = false } +async-trait = "0.1.83" # git crates [build-dependencies] diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 9e55e73..52c176e 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -1,48 +1,66 @@ -use std::io::Read; use anyhow::Result; -use russh::client::{Config, Handle, Handler, Msg}; -use russh::Channel; -use std::net::{SocketAddrV4, TcpStream}; -use std::str::FromStr; -use std::sync::Arc; +use anyhow::Error; use async_trait::async_trait; -use russh_keys::key::PublicKey; -use tokio::io::AsyncReadExt; - -async fn create_ssh_tunnel() -> Result<(TcpStream, Channel)> { - let config = Config::default(); - let handler = IHandler {}; - let mut client = russh::client::connect(Arc::new(config), SocketAddrV4::from_str("xxx:xx")?, handler).await?; - let x = client.authenticate_password("guest", "123").await?; - println!("{}", x); - let channel = client.channel_open_direct_tcpip( - "xxx", - 6379, - "127.0.0.1", - 6379, - ).await?; - Ok((TcpStream::connect("127.0.0.1:6379")?, channel)) -} +use russh::client::{Config, Handler}; +use std::net::{Ipv4Addr, SocketAddrV4}; +use std::sync::Arc; +use russh::keys::key; +use tokio::net::TcpListener; -#[tokio::main] -async fn main() -> Result<()> { - let (mut stream, _channel) = create_ssh_tunnel().await?; +const SSH_HOST: &str = "10.37.1.133"; +const SSH_PORT: u16 = 22; - let mut buffer = vec![0; 1024]; - let i = stream.read_to_end(&mut buffer)?; +const REDIS_HOST: &str = "10.37.1.132"; +const REDIS_PORT: u16 = 6379; - Ok(()) -} +const LOCAL_HOST: &str = "127.0.0.1"; -struct IHandler {} +struct IHandler; #[async_trait] impl Handler for IHandler { - type Error = anyhow::Error; - - async fn check_server_key(&mut self, _: &PublicKey) -> std::result::Result { + type Error = Error; + async fn check_server_key(&mut self, _: &key::PublicKey) -> Result { Ok(true) } +} +#[tokio::main] +async fn main() -> Result<()> { + let mut client = russh::client::connect( + Arc::new(Config::default()), + format!("{SSH_HOST}:{SSH_PORT}"), + IHandler {}, + ).await?; + client.authenticate_password("guest", "123").await?; + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; + let addr = listener.local_addr()?; + + let channel = client.channel_open_direct_tcpip( + REDIS_HOST, + REDIS_PORT as u32, + LOCAL_HOST, + addr.port() as u32, + ).await?; + + let mut remote_stream = channel.into_stream(); + tokio::spawn(async move { + if let Ok((mut local_stream, _)) = listener.accept().await { + tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024).await?; + } + Ok::<(), Error>(()) + }); + + let client = redis::Client::open(format!("redis://:123456@{LOCAL_HOST}:{}/", addr.port()))?; + let mut con = client.get_connection()?; + + let pong: String = redis::cmd("PING").query(&mut con)?; + println!("Redis PING response: {}", pong); + let pong: String = redis::cmd("PING").query(&mut con)?; + println!("Redis PING response: {}", pong); + let pong: String = redis::cmd("PING").query(&mut con)?; + println!("Redis PING response: {}", pong); + + Ok(()) } From a8a9a5f9562f0e9ac5efaaf541c8daedf4ae8682 Mon Sep 17 00:00:00 2001 From: honhimW Date: Wed, 23 Oct 2024 19:21:07 +0800 Subject: [PATCH 3/7] todo: single node ssh block --- examples/common/client.rs | 44 +++++- examples/ssh_tunnel.rs | 121 ++++++++++++----- src/components/database_editor.rs | 166 ++++++++++++++++++++++- src/components/list_table.rs | 2 +- src/configuration.rs | 29 +++- src/lib.rs | 3 +- src/main.rs | 2 + src/redis_opt.rs | 215 ++++++++++++++++++++---------- src/ssh_tunnel.rs | 106 +++++++++++++++ 9 files changed, 559 insertions(+), 129 deletions(-) create mode 100644 src/ssh_tunnel.rs diff --git a/examples/common/client.rs b/examples/common/client.rs index d40153b..4823f3d 100644 --- a/examples/common/client.rs +++ b/examples/common/client.rs @@ -4,19 +4,51 @@ use redis::ConnectionAddr::Tcp; use redis::{Cmd, ConnectionInfo, ProtocolVersion, RedisConnectionInfo}; pub fn dead_pool() -> Result { + build_pool(Config { + host: "redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com".to_string(), + port: 16430, + username: Some(String::from("default")), + password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), + db: 0, + protocol: ProtocolVersion::RESP3, + }) +} + +pub fn build_pool(config: Config) -> Result { let config = deadpool_redis::Config::from_connection_info(ConnectionInfo { - addr: Tcp("redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com".to_string(), 16430), + addr: Tcp(config.host, config.port), redis: RedisConnectionInfo { - db: 0, - username: Some(String::from("default")), - password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), - protocol: ProtocolVersion::RESP3, + db: config.db as i64, + username: config.username, + password: config.password, + protocol: config.protocol, }, }); - config.create_pool(Some(Runtime::Tokio1)).context("Failed to create pool") } +pub struct Config { + pub host: String, + pub port: u16, + pub username: Option, + pub password: Option, + pub db: u8, + pub protocol: ProtocolVersion, +} + +impl Default for Config { + fn default() -> Self { + Self { + host: "127.0.0.1".to_string(), + port: 6379, + username: None, + password: None, + db: 0, + protocol: ProtocolVersion::RESP3, + } + } +} + #[macro_export] macro_rules! str_cmd { ($cmd:expr) => {{ diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 52c176e..400980e 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -1,17 +1,26 @@ -use anyhow::Result; +#[path = "common/lib.rs"] +mod common; + +use crate::common::client::build_pool; use anyhow::Error; +use anyhow::Result; use async_trait::async_trait; +use redis::cmd; use russh::client::{Config, Handler}; +use russh::keys::key; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::Arc; -use russh::keys::key; use tokio::net::TcpListener; +use ratisui::ssh_tunnel; +use ratisui::ssh_tunnel::SshTunnel; const SSH_HOST: &str = "10.37.1.133"; const SSH_PORT: u16 = 22; +const SSH_USER: &str = "guest"; +const SSH_PASSWORD: &str = "123"; -const REDIS_HOST: &str = "10.37.1.132"; -const REDIS_PORT: u16 = 6379; +const REDIS_HOST: &str = "redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com"; +const REDIS_PORT: u16 = 16430; const LOCAL_HOST: &str = "127.0.0.1"; @@ -25,42 +34,80 @@ impl Handler for IHandler { } } +// #[tokio::main] +// async fn main() -> Result<()> { +// let mut client = russh::client::connect( +// Arc::new(Config::default()), +// format!("{SSH_HOST}:{SSH_PORT}"), +// IHandler {}, +// ).await?; +// +// client.authenticate_password(SSH_USER, SSH_PASSWORD).await?; +// let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; +// let addr = listener.local_addr()?; +// +// let channel = client.channel_open_direct_tcpip( +// REDIS_HOST, +// REDIS_PORT as u32, +// LOCAL_HOST, +// addr.port() as u32, +// ).await?; +// +// let mut remote_stream = channel.into_stream(); +// tokio::spawn(async move { +// if let Ok((mut local_stream, _)) = listener.accept().await { +// tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024).await?; +// } +// Ok::<(), Error>(()) +// }); +// +// let pool = build_pool(common::client::Config { +// port: addr.port(), +// username: Some(String::from("default")), +// password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), +// ..Default::default() +// })?; +// let mut connection = pool.get().await?; +// let pong: String = cmd("PING").query_async(&mut connection).await?; +// assert!("PONG".eq_ignore_ascii_case(pong.as_str())); +// +// Ok(()) +// } + #[tokio::main] async fn main() -> Result<()> { - let mut client = russh::client::connect( - Arc::new(Config::default()), - format!("{SSH_HOST}:{SSH_PORT}"), - IHandler {}, - ).await?; - - client.authenticate_password("guest", "123").await?; - let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; - let addr = listener.local_addr()?; - - let channel = client.channel_open_direct_tcpip( - REDIS_HOST, - REDIS_PORT as u32, - LOCAL_HOST, - addr.port() as u32, - ).await?; + let mut ssh_tunnel = SshTunnel::new( + SSH_HOST.to_string(), + SSH_PORT, + SSH_USER.to_string(), + SSH_PASSWORD.to_string(), + REDIS_HOST.to_string(), + REDIS_PORT, + ); - let mut remote_stream = channel.into_stream(); - tokio::spawn(async move { - if let Ok((mut local_stream, _)) = listener.accept().await { - tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024).await?; - } - Ok::<(), Error>(()) - }); + let addr = ssh_tunnel.open_ssh_tunnel().await?; + println!("{}", addr); - let client = redis::Client::open(format!("redis://:123456@{LOCAL_HOST}:{}/", addr.port()))?; - let mut con = client.get_connection()?; - - let pong: String = redis::cmd("PING").query(&mut con)?; - println!("Redis PING response: {}", pong); - let pong: String = redis::cmd("PING").query(&mut con)?; - println!("Redis PING response: {}", pong); - let pong: String = redis::cmd("PING").query(&mut con)?; - println!("Redis PING response: {}", pong); + let pool = build_pool(common::client::Config { + host: addr.ip().to_string(), + port: addr.port(), + username: Some(String::from("default")), + password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), + ..Default::default() + })?; + let mut connection = pool.get().await?; + let pong: String = cmd("PING").query_async(&mut connection).await?; + assert!("PONG".eq_ignore_ascii_case(pong.as_str())); + drop(connection); + let mut connection = pool.get().await?; + let pong: String = cmd("PING").query_async(&mut connection).await?; + assert!("PONG".eq_ignore_ascii_case(pong.as_str())); + drop(connection); + let mut connection = pool.get().await?; + let pong: String = cmd("PING").query_async(&mut connection).await?; + assert!("PONG".eq_ignore_ascii_case(pong.as_str())); + ssh_tunnel.close().await?; + assert!(!ssh_tunnel.is_connected()); Ok(()) -} +} \ No newline at end of file diff --git a/src/components/database_editor.rs b/src/components/database_editor.rs index 1c74c92..6f8ad8e 100644 --- a/src/components/database_editor.rs +++ b/src/components/database_editor.rs @@ -1,6 +1,6 @@ use crate::app::{Listenable, Renderable}; use crate::components::servers::Data; -use crate::configuration::{Database, Protocol}; +use crate::configuration::{Database, Protocol, SshTunnel}; use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; use ratatui::layout::Constraint::{Fill, Length, Percentage}; use ratatui::layout::{Layout, Rect}; @@ -26,6 +26,11 @@ pub struct Form { use_tls: bool, db_text_area: TextArea<'static>, protocol: Protocol, + use_ssh_tunnel: bool, + ssh_host_text_area: TextArea<'static>, + ssh_port_text_area: TextArea<'static>, + ssh_username_text_area: TextArea<'static>, + ssh_password_text_area: TextArea<'static>, } #[derive(Default, Eq, PartialEq, EnumCount, EnumIter, Display)] @@ -49,6 +54,16 @@ enum Editing { Db, #[strum(serialize = "Protocol")] Protocol, + #[strum(serialize = "Use SSH Tunnel")] + UseSshTunnel, + #[strum(serialize = "SSH Host")] + SshHost, + #[strum(serialize = "SSH Port")] + SshPort, + #[strum(serialize = "SSH Username")] + SshUsername, + #[strum(serialize = "SSH Password")] + SshPassword, } fn cursor_style() -> Style { @@ -70,6 +85,11 @@ impl Default for Form { use_tls: false, db_text_area: TextArea::default(), protocol: Protocol::RESP3, + use_ssh_tunnel: false, + ssh_host_text_area: TextArea::default(), + ssh_port_text_area: TextArea::default(), + ssh_username_text_area: TextArea::default(), + ssh_password_text_area: TextArea::default(), }; form.name_text_area.set_placeholder_text("must not be blank"); form.name_text_area.set_placeholder_style(Style::default().fg(tailwind::RED.c700).dim()); @@ -78,16 +98,26 @@ impl Default for Form { form.username_text_area.set_placeholder_text(""); form.password_text_area.set_placeholder_text(""); form.db_text_area.set_placeholder_text("0"); + form.ssh_host_text_area.set_placeholder_text("127.0.0.1"); + form.ssh_port_text_area.set_placeholder_text("22"); + form.ssh_username_text_area.set_placeholder_text("root"); + form.ssh_password_text_area.set_placeholder_text(""); + form.name_text_area.set_cursor_style(Style::default()); form.host_text_area.set_cursor_style(Style::default()); form.port_text_area.set_cursor_style(Style::default()); form.username_text_area.set_cursor_style(Style::default()); form.password_text_area.set_cursor_style(Style::default()); form.db_text_area.set_cursor_style(Style::default()); + form.ssh_host_text_area.set_cursor_style(Style::default()); + form.ssh_port_text_area.set_cursor_style(Style::default()); + form.ssh_username_text_area.set_cursor_style(Style::default()); + form.ssh_password_text_area.set_cursor_style(Style::default()); form.name_text_area.insert_str(Uuid::new_v4().to_string()); form.name_text_area.select_all(); form.password_text_area.set_mask_char('•'); + form.ssh_password_text_area.set_mask_char('•'); form } @@ -107,6 +137,13 @@ impl Form { form.use_tls = data.database.use_tls; form.db_text_area.insert_str(data.db.clone()); form.protocol = data.database.protocol.clone(); + form.use_ssh_tunnel = data.database.use_ssh_tunnel; + if let Some(ref ssh_tunnel) = data.database.ssh_tunnel { + form.ssh_host_text_area.insert_str(ssh_tunnel.host.clone()); + form.ssh_port_text_area.insert_str(ssh_tunnel.port.to_string()); + form.ssh_username_text_area.insert_str(ssh_tunnel.username.clone()); + form.ssh_password_text_area.insert_str(ssh_tunnel.password.clone()); + } form } @@ -127,15 +164,29 @@ impl Form { let use_tls = self.use_tls; let db = self.db_text_area.lines().get(0).cloned().filter(|x| !x.is_empty()).unwrap_or(self.db_text_area.placeholder_text().to_string()).parse::().unwrap_or(0); let protocol = self.protocol.clone(); + let use_ssh_tunnel = self.use_ssh_tunnel; + let ssh_tunnel = if use_ssh_tunnel { + let ssh_host = self.ssh_host_text_area.lines().get(0).cloned().filter(|x| !x.is_empty()).unwrap_or(self.ssh_host_text_area.placeholder_text().to_string()); + let ssh_port = self.ssh_port_text_area.lines().get(0).cloned().filter(|x| !x.is_empty()).unwrap_or(self.ssh_port_text_area.placeholder_text().to_string()).parse::().unwrap_or(6379); + let ssh_username = self.ssh_username_text_area.lines().get(0).cloned().filter(|x| !x.is_empty()).unwrap_or(self.ssh_username_text_area.placeholder_text().to_string()); + let ssh_password = self.ssh_password_text_area.lines().get(0).cloned().filter(|x| !x.is_empty()).unwrap_or_default(); + Some(SshTunnel { + host: ssh_host, + port: ssh_port, + username: ssh_username, + password: ssh_password, + }) + } else { None }; Database { host, port, username, password, use_tls, - use_ssh_tunnel: false, db, protocol, + use_ssh_tunnel, + ssh_tunnel, } } @@ -148,6 +199,12 @@ impl Form { self.next(); } } + if !self.use_ssh_tunnel { + let editing = self.current(); + if editing == Editing::SshHost || editing == Editing::SshPort || editing == Editing::SshUsername || editing == Editing::SshPassword { + self.next(); + } + } self.change_editing(); } @@ -160,6 +217,12 @@ impl Form { self.prev(); } } + if !self.use_ssh_tunnel { + let editing = self.current(); + if editing == Editing::SshHost || editing == Editing::SshPort || editing == Editing::SshUsername || editing == Editing::SshPassword { + self.prev(); + } + } self.change_editing(); } @@ -291,12 +354,70 @@ impl Form { frame.render_widget(value, rc[1]); } + fn render_use_ssh_tunnel(&self, frame: &mut Frame, rect: Rect) { + let horizontal = Layout::horizontal([Length(18), Fill(0)]); + let rc = horizontal.split(rect); + let key = self.span(Editing::UseSshTunnel); + let value = Span::raw(if self.use_ssh_tunnel { "◄ Yes ►" } else { "◄ No ►" }).style(key.style); + frame.render_widget(key, rc[0]); + frame.render_widget(value, rc[1]); + } + + fn render_ssh_host_port(&mut self, frame: &mut Frame, rect: Rect) { + let horizontal = Layout::horizontal([Percentage(65), Percentage(35)]); + let rc = horizontal.split(rect); + let host_area = Layout::horizontal([Length(18), Fill(0)]).split(rc[0]); + let port_area = Layout::horizontal([Length(9), Fill(0)]).split(rc[1]); + { + let key = self.span(Editing::SshHost); + self.ssh_host_text_area.set_style(key.style); + let value = &self.ssh_host_text_area; + frame.render_widget(key, host_area[0]); + frame.render_widget(value, host_area[1]); + } + { + let key = self.span(Editing::SshPort); + self.ssh_port_text_area.set_style(key.style); + let value = &self.ssh_port_text_area; + frame.render_widget(key, port_area[0]); + frame.render_widget(value, port_area[1]); + } + } + + fn render_ssh_username(&mut self, frame: &mut Frame, rect: Rect) { + let horizontal = Layout::horizontal([Length(18), Fill(0)]); + let rc = horizontal.split(rect); + let key = self.span(Editing::SshUsername); + self.ssh_username_text_area.set_style(key.style); + let value = &self.ssh_username_text_area; + frame.render_widget(key, rc[0]); + frame.render_widget(value, rc[1]); + } + + fn render_ssh_password(&mut self, frame: &mut Frame, rect: Rect) { + let horizontal = Layout::horizontal([Length(18), Fill(0)]); + let rc = horizontal.split(rect); + let key = self.span(Editing::SshPassword); + self.ssh_password_text_area.set_style(key.style); + let value = &self.ssh_password_text_area; + frame.render_widget(key, rc[0]); + frame.render_widget(value, rc[1]); + } + } impl Renderable for Form { fn render_frame(&mut self, frame: &mut Frame, rect: Rect) -> anyhow::Result<()> { - let blank_length = (rect.height - 10) / 2; - let area = Layout::vertical([Length(blank_length), Length(10), Length(blank_length)]).split(rect)[1]; + let mut total_height = 9; + if self.enabled_authentication { + total_height += 2; + } + let ssh_height = 3; + if self.use_ssh_tunnel { + total_height += ssh_height; + } + let blank_length = (rect.height - total_height) / 2; + let area = Layout::vertical([Length(blank_length), Length(total_height), Length(blank_length)]).split(rect)[1]; let area = Layout::horizontal([Percentage(20), Percentage(60), Percentage(20)]).split(area)[1]; // let area = centered_rect(50, 70, rect); frame.render_widget(Clear::default(), area); @@ -306,6 +427,13 @@ impl Renderable for Form { let block_inner_area = block .inner(area); let block_inner_area = Layout::horizontal([Length(1), Fill(0), Length(1)]).split(block_inner_area)[1]; + let inner_area_vertical = Layout::vertical([Fill(0), Length(ssh_height)]).split(block_inner_area); + let base_area = if self.use_ssh_tunnel { + inner_area_vertical[0] + } else { + block_inner_area + }; + if !self.enabled_authentication { let vertical = Layout::vertical([ Length(1), // name @@ -314,14 +442,16 @@ impl Renderable for Form { Length(1), // tls Length(1), // db Length(1), // protocol + Length(1), // use ssh ]); - let rc = vertical.split(block_inner_area); + let rc = vertical.split(base_area); self.render_name(frame, rc[0]); self.render_host_port(frame, rc[1]); self.render_enabled_auth(frame, rc[2]); self.render_use_tls(frame, rc[3]); self.render_db(frame, rc[4]); self.render_protocol(frame, rc[5]); + self.render_use_ssh_tunnel(frame, rc[6]); } else { let vertical = Layout::vertical([ Length(1), // name @@ -332,8 +462,9 @@ impl Renderable for Form { Length(1), // tls Length(1), // db Length(1), // protocol + Length(1), // use ssh ]); - let rc = vertical.split(block_inner_area); + let rc = vertical.split(base_area); self.render_name(frame, rc[0]); self.render_host_port(frame, rc[1]); self.render_enabled_auth(frame, rc[2]); @@ -342,6 +473,18 @@ impl Renderable for Form { self.render_use_tls(frame, rc[5]); self.render_db(frame, rc[6]); self.render_protocol(frame, rc[7]); + self.render_use_ssh_tunnel(frame, rc[8]); + } + + if self.use_ssh_tunnel { + let rc = Layout::vertical([ + Length(1), // host + port + Length(1), // username + Length(1), // password + ]).split(inner_area_vertical[1]); + self.render_ssh_host_port(frame, rc[0]); + self.render_ssh_username(frame, rc[1]); + self.render_ssh_password(frame, rc[2]); } frame.render_widget(block, area); Ok(()) @@ -387,6 +530,10 @@ impl Listenable for Form { Editing::Username => Some(&mut self.username_text_area), Editing::Password => Some(&mut self.password_text_area), Editing::Db => Some(&mut self.db_text_area), + Editing::SshHost => Some(&mut self.ssh_host_text_area), + Editing::SshPort => Some(&mut self.ssh_port_text_area), + Editing::SshUsername => Some(&mut self.ssh_username_text_area), + Editing::SshPassword => Some(&mut self.ssh_password_text_area), _ => None, }; if let Some(text_area) = editor { @@ -411,7 +558,7 @@ impl Listenable for Form { text_area.redo(); } input => { - if editing == Editing::Port || editing == Editing::Db { + if editing == Editing::Port || editing == Editing::Db || editing == Editing::SshPort { if input.code == KeyCode::Backspace { text_area.input(input); } else { @@ -454,6 +601,9 @@ impl Listenable for Form { Ok(true) } else { match key_event.code { + KeyCode::Esc => { + return Ok(false); + }, KeyCode::Char('h') | KeyCode::Left => { match editing { Editing::EnabledAuthentication => self.enabled_authentication = !self.enabled_authentication, @@ -462,6 +612,7 @@ impl Listenable for Form { Protocol::RESP2 => Protocol::RESP3, Protocol::RESP3 => Protocol::RESP2, }, + Editing::UseSshTunnel => self.use_ssh_tunnel = !self.use_ssh_tunnel, _ => {} } } @@ -473,6 +624,7 @@ impl Listenable for Form { Protocol::RESP2 => Protocol::RESP3, Protocol::RESP3 => Protocol::RESP2, }, + Editing::UseSshTunnel => self.use_ssh_tunnel = !self.use_ssh_tunnel, _ => {} } } diff --git a/src/components/list_table.rs b/src/components/list_table.rs index 6e65cc6..3288c0e 100644 --- a/src/components/list_table.rs +++ b/src/components/list_table.rs @@ -98,7 +98,7 @@ impl ListValue { Self { state: TableState::default().with_selected(0), longest_item_lens: constraint_len_calculator(&vec), - scroll_state: ScrollbarState::new((vec.len() - 1) * ITEM_HEIGHT), + scroll_state: ScrollbarState::new((vec.len().saturating_sub(1)) * ITEM_HEIGHT), colors: TableColors::new(&tailwind::GRAY), items: vec, } diff --git a/src/configuration.rs b/src/configuration.rs index d6fe395..323e872 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -110,27 +110,48 @@ impl Databases { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Database { pub host: String, pub port: u16, pub username: Option, - #[serde(default, serialize_with = "to_base64", deserialize_with = "from_base64")] + #[serde(default, serialize_with = "to_base64_option", deserialize_with = "from_base64_option")] pub password: Option, pub use_tls: bool, pub use_ssh_tunnel: bool, pub db: u32, pub protocol: Protocol, + pub ssh_tunnel: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct SshTunnel { + pub host: String, + pub port: u16, + pub username: String, + #[serde(default, serialize_with = "to_base64", deserialize_with = "from_base64")] + pub password: String, +} + +fn to_base64(password: &String, s: S) -> Result { + s.serialize_str(BASE64_STANDARD.encode(password).as_str()) +} + +fn from_base64<'d, S: Deserializer<'d>>(deserializer: S) -> Result { + let base64 = String::deserialize(deserializer)?; + let bytes = BASE64_STANDARD.decode(base64).map_err(|_| S::Error::custom("decode base64 error"))?; + let string = String::from_utf8(bytes).map_err(|_| S::Error::custom("decode utf-8 error"))?; + Ok(string) } -fn to_base64(password: &Option, s: S) -> Result { +fn to_base64_option(password: &Option, s: S) -> Result { match password { Some(p) => s.serialize_some(&BASE64_STANDARD.encode(p)), None => s.serialize_none(), } } -fn from_base64<'d, S: Deserializer<'d>>(deserializer: S) -> Result, S::Error> { +fn from_base64_option<'d, S: Deserializer<'d>>(deserializer: S) -> Result, S::Error> { let option = Option::::deserialize(deserializer)?; match option { Some(p) => { diff --git a/src/lib.rs b/src/lib.rs index c881c80..7e11ce7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod redis_opt; pub mod configuration; pub mod utils; -pub mod bus; \ No newline at end of file +pub mod bus; +pub mod ssh_tunnel; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index a167494..daf4c20 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ mod tabs; mod components; mod utils; mod bus; +mod ssh_tunnel; use crate::app::{App, AppEvent, AppState, Listenable, Renderable}; use crate::components::fps::FpsCalculator; @@ -50,6 +51,7 @@ async fn main() -> Result<()> { if let Some(database) = db_config.databases.get(&db) { let database_clone = database.clone(); tokio::spawn(async move { + info!("{:?}", &database_clone); match switch_client(db.clone(), &database_clone) { Ok(_) => { info!("Successfully connected to default database '{db}'"); diff --git a/src/redis_opt.rs b/src/redis_opt.rs index a3bb49b..bbbd224 100644 --- a/src/redis_opt.rs +++ b/src/redis_opt.rs @@ -1,21 +1,25 @@ +use crate::bus::{publish_event, publish_msg, GlobalEvent, Message}; use crate::configuration::{to_protocol_version, Database}; +use crate::ssh_tunnel::SshTunnel; use crate::utils::split_args; -use crate::bus::{publish_event, GlobalEvent}; use anyhow::{anyhow, Context, Error, Result}; -use crossbeam_channel::{Sender}; +use crossbeam_channel::Sender; use deadpool_redis::redis::cmd; use deadpool_redis::{Pool, Runtime}; use futures::StreamExt; -use log::{info}; +use log::{debug, error, info}; use once_cell::sync::Lazy; use redis::cluster::ClusterClient; use redis::ConnectionAddr::{Tcp, TcpTls}; use redis::{AsyncCommands, AsyncIter, Client, Cmd, ConnectionAddr, ConnectionInfo, ConnectionLike, FromRedisValue, RedisConnectionInfo, ScanOptions, ToRedisArgs, Value, VerbatimFormat}; use std::collections::HashMap; use std::future::Future; -use std::sync::RwLock; +use std::ops::DerefMut; +use std::sync::{Arc, RwLock}; use std::task::Poll; use std::time::{Duration, Instant}; +use futures::future::join_all; +use tokio::join; use tokio::time::interval; #[macro_export] @@ -87,24 +91,42 @@ pub fn redis_operations() -> Option { } pub fn switch_client(name: impl Into, database: &Database) -> Result<()> { - let client = build_client(&database)?; let name = name.into(); - let mut operation = RedisOperations::new(name, database.clone(), client); - operation.initialize()?; + let database = database.clone(); + tokio::spawn(async move { + let result = async { + let (pool, tunnel) = build_pool(&database).await?; + let mut operation = RedisOperations::new(name, database.clone(), pool, tunnel)?; + operation.initialize().await?; + let result = REDIS_OPERATIONS.write(); + match result { + Ok(mut x) => { + if let Some(o) = x.deref_mut() { + o.close(); + } + *x = Some(operation); + } + Err(e) => { + return Err(anyhow!("Failed to switch client: {}", e)); + } + } + let _ = publish_event(GlobalEvent::ClientChanged); + Ok::<(), Error>(()) + }.await; - let result = REDIS_OPERATIONS.write(); - match result { - Ok(mut x) => { - *x = Some(operation); - } - Err(e) => { - return Err(anyhow!("Failed to switch client: {}", e)); + match result { + Ok(_) => { + let _ = publish_msg(Message::info("Connected".to_string())); + } + Err(e) => { + let _ = publish_msg(Message::error(format!("Failed to switch client: {}", e))); + } } - } - let _ = publish_event(GlobalEvent::ClientChanged); + }); Ok(()) } +#[allow(unused)] fn build_client(database: &Database) -> Result { let mut client = Client::open(ConnectionInfo { addr: Tcp(database.host.clone(), database.port), @@ -122,19 +144,45 @@ fn build_client(database: &Database) -> Result { } } -fn build_pool(database: &Database) -> Result { - let info = ConnectionInfo { - addr: Tcp(database.host.clone(), database.port), - redis: RedisConnectionInfo { - db: database.db as i64, - username: database.username.clone(), - password: database.password.clone(), - protocol: to_protocol_version(database.protocol.clone()), - }, +async fn build_pool(database: &Database) -> Result<(Pool, Option)> { + let mut ssh_tunnel_option = None; + let info = if database.use_ssh_tunnel && database.ssh_tunnel.is_some() { + let tunnel = database.ssh_tunnel.clone().unwrap(); + let mut ssh_tunnel = SshTunnel::new( + tunnel.host.clone(), + tunnel.port, + tunnel.username.clone(), + tunnel.password.clone(), + database.host.clone(), + database.port, + ); + let addr = ssh_tunnel.open_ssh_tunnel().await?; + info!("SSH-Tunnel listening on: {} <==> {}:{}", addr, tunnel.host, tunnel.port); + ssh_tunnel_option = Some(ssh_tunnel); + + ConnectionInfo { + addr: Tcp(addr.ip().to_string(), addr.port()), + redis: RedisConnectionInfo { + db: database.db as i64, + username: database.username.clone(), + password: database.password.clone(), + protocol: to_protocol_version(database.protocol.clone()), + }, + } + } else { + ConnectionInfo { + addr: Tcp(database.host.clone(), database.port), + redis: RedisConnectionInfo { + db: database.db as i64, + username: database.username.clone(), + password: database.password.clone(), + protocol: to_protocol_version(database.protocol.clone()), + }, + } }; let config = deadpool_redis::Config::from_connection_info(deadpool_redis::ConnectionInfo::from(info)); let pool = config.create_pool(Some(Runtime::Tokio1))?; - Ok(pool) + Ok((pool, ssh_tunnel_option)) } #[derive(Clone)] @@ -142,35 +190,54 @@ pub struct RedisOperations { #[allow(unused)] pub name: String, database: Database, - client: Client, pool: Pool, + ssh_tunnel: Option, server_info: Option, - cluster_client: Option, + is_cluster: bool, nodes: HashMap, cluster_pool: Option, } #[derive(Clone, Debug)] struct NodeClientHolder { - node_client: Client, pool: Pool, + ssh_tunnel: Option, is_master: bool, } impl RedisOperations { - fn new(name: impl Into, database: Database, client: Client) -> Self { - let info = deadpool_redis::ConnectionInfo::from(client.get_connection_info().clone()); - let config = deadpool_redis::Config::from_connection_info(info); - let pool = config.create_pool(Some(Runtime::Tokio1)).unwrap(); - Self { + fn new(name: impl Into, database: Database, pool: Pool, tunnel: Option) -> Result { + Ok(Self { name: name.into(), database, - client, pool, + ssh_tunnel: tunnel, server_info: None, - cluster_client: None, + is_cluster: false, nodes: HashMap::new(), cluster_pool: None, + }) + } + + fn close(&mut self) { + self.pool.close(); + if let Some(ref ssh_tunnel) = self.ssh_tunnel { + let mut tunnel = ssh_tunnel.clone(); + tokio::spawn(async move { + tunnel.close().await + }); + } + if let Some(ref mut cluster_pool) = self.cluster_pool { + cluster_pool.close(); + } + for (_, node_holder) in self.nodes.iter_mut() { + node_holder.pool.close(); + if let Some(ref ssh_tunnel) = node_holder.ssh_tunnel { + let mut tunnel = ssh_tunnel.clone(); + tokio::spawn(async move { + tunnel.close().await + }); + } } } @@ -190,7 +257,7 @@ impl RedisOperations { // } pub fn is_cluster(&self) -> bool { - self.cluster_client.is_some() + self.is_cluster } fn print(&self) { @@ -198,25 +265,22 @@ impl RedisOperations { info!("Cluster mode"); info!("Cluster nodes: {}", self.nodes.len()); for (s, node) in self.nodes.iter() { - info!("{s} - location: {} - master: {}", node.node_client.get_connection_info().addr, node.is_master); + info!("{s} - location: {} - master: {}", node.pool.manager().client.get_connection_info().addr, node.is_master); } } else { - info!("Standalone mode: {}", self.client.get_connection_info().addr); + info!("Standalone mode: {}", self.pool.manager().client.get_connection_info().addr); } } - fn initialize(&mut self) -> Result<()> { - let mut connection = self.client.get_connection()?; - let value = connection.req_command(&Cmd::new().arg("INFO").arg("SERVER"))?; + async fn initialize(&mut self) -> Result<()> { + let mut connection = self.pool.get().await?; + let value: Value = Cmd::new().arg("INFO").arg("SERVER").query_async(&mut connection).await?; + drop(connection); if let Value::VerbatimString { text, .. } = value { // RESP3 self.server_info = Some(text); let redis_mode = self.get_server_info("redis_mode").context("there will always contain redis_mode property")?; if redis_mode == "cluster" { - self.initialize_cluster()?; - } else { - let config = deadpool_redis::Config::from_connection_info(deadpool_redis::ConnectionInfo::from(self.client.get_connection_info().clone())); - let pool = config.create_pool(Some(Runtime::Tokio1))?; - self.pool = pool; + self.initialize_cluster().await?; } self.print(); Ok(()) @@ -225,11 +289,7 @@ impl RedisOperations { self.server_info = Some(text); let redis_mode = self.get_server_info("redis_mode").context("there will always contain redis_mode property")?; if redis_mode == "cluster" { - self.initialize_cluster()?; - } else { - let config = deadpool_redis::Config::from_connection_info(deadpool_redis::ConnectionInfo::from(self.client.get_connection_info().clone())); - let pool = config.create_pool(Some(Runtime::Tokio1))?; - self.pool = pool; + self.initialize_cluster().await?; } self.print(); Ok(()) @@ -238,9 +298,10 @@ impl RedisOperations { } } - fn initialize_cluster(&mut self) -> Result<()> { - let mut connection = self.client.get_connection()?; - let cluster_slots = connection.req_command(&Cmd::new().arg("CLUSTER").arg("SLOTS"))?; + async fn initialize_cluster(&mut self) -> Result<()> { + self.is_cluster = true; + let mut connection = self.pool.get().await?; + let cluster_slots: Value = cmd("CLUSTER").arg("SLOTS").query_async(&mut connection).await?; if let Value::Array { 0: item, .. } = cluster_slots { let mut redis_nodes: Vec<(String, u16, String)> = Vec::new(); for slot in item { @@ -272,17 +333,14 @@ impl RedisOperations { } let mut cluster_client_infos: Vec = Vec::new(); let mut node_holders: HashMap = HashMap::new(); - let connection_info = self.client.get_connection_info(); + let connection_info = self.pool.manager().client.get_connection_info(); for (host, port, _) in redis_nodes.clone() { cluster_client_infos.push(ConnectionInfo { addr: Tcp(host.clone(), port.clone()), redis: connection_info.redis.clone(), }); } - let cluster_client = ClusterClient::new(cluster_client_infos)?; - self.cluster_client = Some(cluster_client); - - let cluster_nodes = connection.req_command(&Cmd::new().arg("CLUSTER").arg("NODES"))?; + let cluster_nodes: Value = cmd("CLUSTER").arg("NODES").query_async(&mut connection).await?; let mut node_kind_map: HashMap = HashMap::new(); if let Value::VerbatimString { text, .. } = cluster_nodes { for line in text.lines() { @@ -291,18 +349,29 @@ impl RedisOperations { node_kind_map.insert(split[0].to_string(), node_kind.contains("master")); } } + let mut futures = vec![]; for (host, port, id) in redis_nodes.clone() { let mut database = Database::from(self.database.clone()); database.host = host; database.port = port; - let node_client = build_client(&database)?; - let pool = build_pool(&database)?; let is_master = node_kind_map.get(&id).unwrap_or(&false); - node_holders.insert(id, NodeClientHolder { - node_client, - pool, - is_master: *is_master, - }); + let future = async move { + if let Ok((pool, tunnel)) = build_pool(&database).await { + Ok((id, NodeClientHolder { + pool, + ssh_tunnel: tunnel, + is_master: *is_master, + })) + } else { + Err(anyhow!("Failed to initialize node")) + } + }; + futures.push(future) + } + let results = join_all(futures).await; + for result in results { + let (id, node_holder) = result?; + node_holders.insert(id, node_holder); } self.nodes = node_holders; let mut cluster_urls = vec![]; @@ -395,13 +464,13 @@ impl RedisOperations { let mut streams = vec![]; if self.is_cluster() { for (_, holder) in self.nodes.iter() { - let mut monitor = holder.node_client.get_async_monitor().await?; + let mut monitor = holder.pool.manager().client.get_async_monitor().await?; let _ = monitor.monitor().await?; let stream = monitor.into_on_message::(); streams.push(stream); } } else { - let mut monitor = self.client.get_async_monitor().await?; + let mut monitor = self.pool.manager().client.get_async_monitor().await?; let _ = monitor.monitor().await?; let stream = monitor.into_on_message::(); streams.push(stream); @@ -478,14 +547,14 @@ impl RedisOperations { if self.is_cluster() { for (_, holder) in self.nodes.iter() { if holder.is_master { - let mut pub_sub = holder.node_client.get_async_pubsub().await?; + let mut pub_sub = holder.pool.manager().client.get_async_pubsub().await?; pub_sub.subscribe(&key).await?; let stream = pub_sub.into_on_message(); streams.push(stream); } } } else { - let mut pub_sub = self.client.get_async_pubsub().await?; + let mut pub_sub = self.pool.manager().client.get_async_pubsub().await?; pub_sub.subscribe(&key).await?; let stream = pub_sub.into_on_message(); streams.push(stream); @@ -565,14 +634,14 @@ impl RedisOperations { if self.is_cluster() { for (_, holder) in self.nodes.iter() { if holder.is_master { - let mut pub_sub = holder.node_client.get_async_pubsub().await?; + let mut pub_sub = holder.pool.manager().client.get_async_pubsub().await?; pub_sub.psubscribe(&key).await?; let stream = pub_sub.into_on_message(); streams.push(stream); } } } else { - let mut pub_sub = self.client.get_async_pubsub().await?; + let mut pub_sub = self.pool.manager().client.get_async_pubsub().await?; pub_sub.psubscribe(&key).await?; let stream = pub_sub.into_on_message(); streams.push(stream); diff --git a/src/ssh_tunnel.rs b/src/ssh_tunnel.rs new file mode 100644 index 0000000..61737d2 --- /dev/null +++ b/src/ssh_tunnel.rs @@ -0,0 +1,106 @@ +use anyhow::{Error, Result}; +use async_trait::async_trait; +use log::{error, info, warn}; +use russh::client::{Config, Handler, Msg}; +use russh::keys::key; +use russh::{Channel, Disconnect}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::select; + +#[derive(Clone, Debug)] +pub struct SshTunnel { + pub host: String, + pub port: u16, + pub username: String, + pub password: String, + pub forwarding_host: String, + pub forwarding_port: u16, + tx: tokio::sync::watch::Sender, + rx: tokio::sync::watch::Receiver, + is_connected: bool, +} + +impl SshTunnel { + + pub fn new(host: String, port: u16, username: String, password: String, forwarding_host: String, forwarding_port: u16) -> Self { + let (tx, rx) = tokio::sync::watch::channel::(1); + Self { + host, + port, + username, + password, + forwarding_host, + forwarding_port, + tx, rx, + is_connected: false, + } + } + + pub async fn open_ssh_tunnel(&mut self) -> Result { + let mut ssh_client = russh::client::connect( + Arc::new(Config::default()), + format!("{}:{}", self.host, self.port), + IHandler {}, + ).await?; + + ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?; + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; + let addr = listener.local_addr()?; + + let channel = ssh_client.channel_open_direct_tcpip( + self.forwarding_host.clone(), + self.forwarding_port as u32, + Ipv4Addr::LOCALHOST.to_string(), + addr.port() as u32, + ).await?; + + let mut remote_stream = channel.into_stream(); + let mut rx_clone = self.rx.clone(); + tokio::spawn(async move { + if let Ok((mut local_stream, _)) = listener.accept().await { + select! { + result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => { + if let Err(e) = result { + error!("Error during bidirectional copy: {}", e); + } + warn!("Bidirectional copy stopped"); + } + _ = rx_clone.changed() => { + info!("Received close signal"); + } + } + } + ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?; + remote_stream.shutdown().await?; + drop(listener); + info!("Stream closed"); + Ok::<(), Error>(()) + }); + + self.is_connected = true; + Ok(addr) + } + + pub async fn close(&mut self) -> Result<()> { + self.tx.send(0)?; + self.is_connected = false; + Ok(()) + } + + pub fn is_connected(&self) -> bool { + self.is_connected + } +} + +struct IHandler; + +#[async_trait] +impl Handler for IHandler { + type Error = Error; + async fn check_server_key(&mut self, _: &key::PublicKey) -> Result { + Ok(true) + } +} \ No newline at end of file From c6f5a44388fa5d98b9a85685b18e0d2893127135 Mon Sep 17 00:00:00 2001 From: honhimW Date: Thu, 24 Oct 2024 09:55:52 +0800 Subject: [PATCH 4/7] update: example --- examples/redis_client.rs | 6 +++ examples/ssh_tunnel.rs | 2 - src/redis_opt.rs | 85 ++++++++++++++++++++++++---------------- 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/examples/redis_client.rs b/examples/redis_client.rs index e1edb83..bd28b46 100644 --- a/examples/redis_client.rs +++ b/examples/redis_client.rs @@ -12,6 +12,12 @@ async fn main() -> Result<()> { let pool = dead_pool()?; let client = &pool.manager().client; + let mut connection = pool.get().await?; + let x: Value = cmd("ping").query_async(&mut connection).await?; + dbg!(x); + let mut connection = pool.get().await?; + let x: Value = cmd("ping").query_async(&mut connection).await?; + dbg!(x); let mut connection = pool.get().await?; let x: Value = cmd("ping").query_async(&mut connection).await?; dbg!(x); diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 400980e..e3b3e0b 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -98,11 +98,9 @@ async fn main() -> Result<()> { let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; assert!("PONG".eq_ignore_ascii_case(pong.as_str())); - drop(connection); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; assert!("PONG".eq_ignore_ascii_case(pong.as_str())); - drop(connection); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; assert!("PONG".eq_ignore_ascii_case(pong.as_str())); diff --git a/src/redis_opt.rs b/src/redis_opt.rs index bbbd224..1021b39 100644 --- a/src/redis_opt.rs +++ b/src/redis_opt.rs @@ -369,13 +369,31 @@ impl RedisOperations { futures.push(future) } let results = join_all(futures).await; + let mut cluster_urls = vec![]; for result in results { let (id, node_holder) = result?; + let host; + let port; + if let Some(ref ssh_tunnel) = node_holder.ssh_tunnel { + host = ssh_tunnel.host.clone(); + port = ssh_tunnel.port; + } else { + match &node_holder.pool.manager().client.get_connection_info().addr { + Tcp(h, p) => { + host = h.clone(); + port = *p; + } + TcpTls { host: h, port: p, .. } => { + host = h.clone(); + port = *p; + } + _ => { + return Err(anyhow!("Not supported connection type")) + } + } + } node_holders.insert(id, node_holder); - } - self.nodes = node_holders; - let mut cluster_urls = vec![]; - for (host, port, _) in redis_nodes.clone() { + let addr: ConnectionAddr; if self.database.use_tls { addr = TcpTls { @@ -398,6 +416,7 @@ impl RedisOperations { }; cluster_urls.push(deadpool_redis::ConnectionInfo::from(info)) } + self.nodes = node_holders; let config = deadpool_redis::cluster::Config { urls: None, connections: Some(cluster_urls), @@ -838,6 +857,35 @@ impl RedisOperations { } } + pub async fn mem_usage(&self, key: K) -> Result { + if self.is_cluster() { + let pool = &self.cluster_pool.clone().context("should be cluster")?; + let mut connection = pool.get().await?; + let v: Value = cmd("MEMORY").arg("USAGE") + .arg(key) + .arg("SAMPLES").arg("0") + .query_async(&mut connection) + .await?; + if let Value::Int(int) = v { + Ok(int) + } else { + Ok(0) + } + } else { + let mut connection = self.pool.get().await?; + let v: Value = cmd("MEMORY").arg("USAGE") + .arg(key) + .arg("SAMPLES").arg("0") + .query_async(&mut connection) + .await?; + if let Value::Int(int) = v { + Ok(int) + } else { + Ok(0) + } + } + } + pub async fn expire(&self, key: K, seconds: i64) -> Result<()> { if self.is_cluster() { let pool = &self.cluster_pool.clone().context("should be cluster")?; @@ -964,35 +1012,6 @@ impl RedisOperations { // } // } - pub async fn mem_usage(&self, key: K) -> Result { - if self.is_cluster() { - let pool = &self.cluster_pool.clone().context("should be cluster")?; - let mut connection = pool.get().await?; - let v: Value = cmd("MEMORY").arg("USAGE") - .arg(key) - .arg("SAMPLES").arg("0") - .query_async(&mut connection) - .await?; - if let Value::Int(int) = v { - Ok(int) - } else { - Ok(0) - } - } else { - let mut connection = self.pool.get().await?; - let v: Value = cmd("MEMORY").arg("USAGE") - .arg(key) - .arg("SAMPLES").arg("0") - .query_async(&mut connection) - .await?; - if let Value::Int(int) = v { - Ok(int) - } else { - Ok(0) - } - } - } - pub async fn del(&self, key: K) -> Result<()> { if self.is_cluster() { let pool = &self.cluster_pool.clone().context("should be cluster")?; From c665a54f51f7918ec202ad278c8c6b4b98c47ecb Mon Sep 17 00:00:00 2001 From: honhimW Date: Thu, 24 Oct 2024 10:17:24 +0800 Subject: [PATCH 5/7] update --- examples/ssh_tunnel.rs | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index e3b3e0b..5b0bfc1 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -34,46 +34,6 @@ impl Handler for IHandler { } } -// #[tokio::main] -// async fn main() -> Result<()> { -// let mut client = russh::client::connect( -// Arc::new(Config::default()), -// format!("{SSH_HOST}:{SSH_PORT}"), -// IHandler {}, -// ).await?; -// -// client.authenticate_password(SSH_USER, SSH_PASSWORD).await?; -// let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; -// let addr = listener.local_addr()?; -// -// let channel = client.channel_open_direct_tcpip( -// REDIS_HOST, -// REDIS_PORT as u32, -// LOCAL_HOST, -// addr.port() as u32, -// ).await?; -// -// let mut remote_stream = channel.into_stream(); -// tokio::spawn(async move { -// if let Ok((mut local_stream, _)) = listener.accept().await { -// tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024).await?; -// } -// Ok::<(), Error>(()) -// }); -// -// let pool = build_pool(common::client::Config { -// port: addr.port(), -// username: Some(String::from("default")), -// password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), -// ..Default::default() -// })?; -// let mut connection = pool.get().await?; -// let pong: String = cmd("PING").query_async(&mut connection).await?; -// assert!("PONG".eq_ignore_ascii_case(pong.as_str())); -// -// Ok(()) -// } - #[tokio::main] async fn main() -> Result<()> { let mut ssh_tunnel = SshTunnel::new( From 6ff465bbde7e7d7223c29eb5f03391e37ed3ffec Mon Sep 17 00:00:00 2001 From: honhimW Date: Thu, 24 Oct 2024 15:59:37 +0800 Subject: [PATCH 6/7] update --- examples/ssh_tunnel.rs | 2 +- src/redis_opt.rs | 2 +- src/ssh_tunnel.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 5b0bfc1..5e33ae8 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -45,7 +45,7 @@ async fn main() -> Result<()> { REDIS_PORT, ); - let addr = ssh_tunnel.open_ssh_tunnel().await?; + let addr = ssh_tunnel.open().await?; println!("{}", addr); let pool = build_pool(common::client::Config { diff --git a/src/redis_opt.rs b/src/redis_opt.rs index 1021b39..a09f5bd 100644 --- a/src/redis_opt.rs +++ b/src/redis_opt.rs @@ -156,7 +156,7 @@ async fn build_pool(database: &Database) -> Result<(Pool, Option)> { database.host.clone(), database.port, ); - let addr = ssh_tunnel.open_ssh_tunnel().await?; + let addr = ssh_tunnel.open().await?; info!("SSH-Tunnel listening on: {} <==> {}:{}", addr, tunnel.host, tunnel.port); ssh_tunnel_option = Some(ssh_tunnel); diff --git a/src/ssh_tunnel.rs b/src/ssh_tunnel.rs index 61737d2..f2a0cc5 100644 --- a/src/ssh_tunnel.rs +++ b/src/ssh_tunnel.rs @@ -39,7 +39,7 @@ impl SshTunnel { } } - pub async fn open_ssh_tunnel(&mut self) -> Result { + pub async fn open(&mut self) -> Result { let mut ssh_client = russh::client::connect( Arc::new(Config::default()), format!("{}:{}", self.host, self.port), From 40e70a7308a37b352ba9d9088897e64db98e11cf Mon Sep 17 00:00:00 2001 From: honhimW Date: Wed, 30 Oct 2024 15:50:46 +0800 Subject: [PATCH 7/7] feat: ssh tunnel --- examples/ssh_tunnel.rs | 18 ++++---- src/redis_opt.rs | 35 ++++++--------- src/ssh_tunnel.rs | 96 ++++++++++++++++++++++++++++++------------ 3 files changed, 94 insertions(+), 55 deletions(-) diff --git a/examples/ssh_tunnel.rs b/examples/ssh_tunnel.rs index 5e33ae8..27dd709 100644 --- a/examples/ssh_tunnel.rs +++ b/examples/ssh_tunnel.rs @@ -5,14 +5,11 @@ use crate::common::client::build_pool; use anyhow::Error; use anyhow::Result; use async_trait::async_trait; +use ratisui::ssh_tunnel::SshTunnel; use redis::cmd; -use russh::client::{Config, Handler}; +use russh::client::Handler; use russh::keys::key; -use std::net::{Ipv4Addr, SocketAddrV4}; -use std::sync::Arc; -use tokio::net::TcpListener; -use ratisui::ssh_tunnel; -use ratisui::ssh_tunnel::SshTunnel; +use std::ops::Deref; const SSH_HOST: &str = "10.37.1.133"; const SSH_PORT: u16 = 22; @@ -21,6 +18,8 @@ const SSH_PASSWORD: &str = "123"; const REDIS_HOST: &str = "redis-16430.c1.asia-northeast1-1.gce.redns.redis-cloud.com"; const REDIS_PORT: u16 = 16430; +const REDIS_USER: Some(String) = Some(String::from("default")); +const REDIS_PASSWORD: Some(String) = Some(String::from("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy")); const LOCAL_HOST: &str = "127.0.0.1"; @@ -51,18 +50,21 @@ async fn main() -> Result<()> { let pool = build_pool(common::client::Config { host: addr.ip().to_string(), port: addr.port(), - username: Some(String::from("default")), - password: Some("9JRCAjglNSTc4pXWOggLT7BKljwuoSSy".to_string()), + username: REDIS_USER.deref().clone(), + password: REDIS_PASSWORD.deref().clone(), ..Default::default() })?; let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 1); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 3); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); let mut connection = pool.get().await?; let pong: String = cmd("PING").query_async(&mut connection).await?; + assert_eq!(pool.status().size, 3); assert!("PONG".eq_ignore_ascii_case(pong.as_str())); ssh_tunnel.close().await?; assert!(!ssh_tunnel.is_connected()); diff --git a/src/redis_opt.rs b/src/redis_opt.rs index a09f5bd..06aadd7 100644 --- a/src/redis_opt.rs +++ b/src/redis_opt.rs @@ -6,20 +6,18 @@ use anyhow::{anyhow, Context, Error, Result}; use crossbeam_channel::Sender; use deadpool_redis::redis::cmd; use deadpool_redis::{Pool, Runtime}; +use futures::future::join_all; use futures::StreamExt; -use log::{debug, error, info}; +use log::{info}; use once_cell::sync::Lazy; -use redis::cluster::ClusterClient; use redis::ConnectionAddr::{Tcp, TcpTls}; use redis::{AsyncCommands, AsyncIter, Client, Cmd, ConnectionAddr, ConnectionInfo, ConnectionLike, FromRedisValue, RedisConnectionInfo, ScanOptions, ToRedisArgs, Value, VerbatimFormat}; use std::collections::HashMap; use std::future::Future; use std::ops::DerefMut; -use std::sync::{Arc, RwLock}; +use std::sync::RwLock; use std::task::Poll; use std::time::{Duration, Instant}; -use futures::future::join_all; -use tokio::join; use tokio::time::interval; #[macro_export] @@ -374,22 +372,17 @@ impl RedisOperations { let (id, node_holder) = result?; let host; let port; - if let Some(ref ssh_tunnel) = node_holder.ssh_tunnel { - host = ssh_tunnel.host.clone(); - port = ssh_tunnel.port; - } else { - match &node_holder.pool.manager().client.get_connection_info().addr { - Tcp(h, p) => { - host = h.clone(); - port = *p; - } - TcpTls { host: h, port: p, .. } => { - host = h.clone(); - port = *p; - } - _ => { - return Err(anyhow!("Not supported connection type")) - } + match &node_holder.pool.manager().client.get_connection_info().addr { + Tcp(h, p) => { + host = h.clone(); + port = *p; + } + TcpTls { host: h, port: p, .. } => { + host = h.clone(); + port = *p; + } + _ => { + return Err(anyhow!("Not supported connection type")) } } node_holders.insert(id, node_holder); diff --git a/src/ssh_tunnel.rs b/src/ssh_tunnel.rs index f2a0cc5..d24a4a5 100644 --- a/src/ssh_tunnel.rs +++ b/src/ssh_tunnel.rs @@ -1,9 +1,9 @@ use anyhow::{Error, Result}; use async_trait::async_trait; use log::{error, info, warn}; -use russh::client::{Config, Handler, Msg}; +use russh::client::{Config, Handler}; use russh::keys::key; -use russh::{Channel, Disconnect}; +use russh::Disconnect; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::Arc; use tokio::io::AsyncWriteExt; @@ -24,7 +24,6 @@ pub struct SshTunnel { } impl SshTunnel { - pub fn new(host: String, port: u16, username: String, password: String, forwarding_host: String, forwarding_port: u16) -> Self { let (tx, rx) = tokio::sync::watch::channel::(1); Self { @@ -34,7 +33,8 @@ impl SshTunnel { password, forwarding_host, forwarding_port, - tx, rx, + tx, + rx, is_connected: false, } } @@ -45,36 +45,44 @@ impl SshTunnel { format!("{}:{}", self.host, self.port), IHandler {}, ).await?; - ssh_client.authenticate_password(self.username.clone(), self.password.clone()).await?; let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; let addr = listener.local_addr()?; + let forwarding_host = self.forwarding_host.clone(); + let forwarding_port = self.forwarding_port as u32; - let channel = ssh_client.channel_open_direct_tcpip( - self.forwarding_host.clone(), - self.forwarding_port as u32, - Ipv4Addr::LOCALHOST.to_string(), - addr.port() as u32, - ).await?; - - let mut remote_stream = channel.into_stream(); - let mut rx_clone = self.rx.clone(); + let rx_clone = self.rx.clone(); tokio::spawn(async move { - if let Ok((mut local_stream, _)) = listener.accept().await { - select! { - result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => { - if let Err(e) = result { - error!("Error during bidirectional copy: {}", e); + loop { + let mut rx_clone_clone = rx_clone.clone(); + if let Ok((mut local_stream, _)) = listener.accept().await { + let channel = ssh_client.channel_open_direct_tcpip( + forwarding_host.clone(), + forwarding_port, + Ipv4Addr::LOCALHOST.to_string(), + addr.port() as u32, + ).await?; + let mut remote_stream = channel.into_stream(); + tokio::spawn(async move { + select! { + result = tokio::io::copy_bidirectional_with_sizes(&mut local_stream, &mut remote_stream, 255, 8 * 1024) => { + if let Err(e) = result { + error!("Error during bidirectional copy: {}", e); + } + warn!("Bidirectional copy stopped"); + } + _ = rx_clone_clone.changed() => { + info!("Received close signal"); + } } - warn!("Bidirectional copy stopped"); - } - _ = rx_clone.changed() => { - info!("Received close signal"); - } + let _ = remote_stream.shutdown().await; + }); + } + if rx_clone.has_changed()? { + ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?; + break; } } - ssh_client.disconnect(Disconnect::ByApplication, "exit", "none").await?; - remote_stream.shutdown().await?; drop(listener); info!("Stream closed"); Ok::<(), Error>(()) @@ -90,6 +98,7 @@ impl SshTunnel { Ok(()) } + #[allow(unused)] pub fn is_connected(&self) -> bool { self.is_connected } @@ -103,4 +112,39 @@ impl Handler for IHandler { async fn check_server_key(&mut self, _: &key::PublicKey) -> Result { Ok(true) } +} + +#[cfg(test)] +mod test { + use anyhow::Result; + use std::net::{Ipv4Addr, SocketAddrV4}; + use std::time::{Duration, Instant}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tcp_listener() -> Result<()> { + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?; + let addr = listener.local_addr()?; + + tokio::spawn(async move { + let now = Instant::now(); + loop { + if let Ok((mut stream, _)) = listener.accept().await { + println!("{:?} {:?}", now.elapsed(), stream); + } else { + println!("No connection"); + } + } + }); + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + let x = TcpStream::connect(addr).await?; + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(()) + } } \ No newline at end of file