diff --git a/src/main.rs b/src/main.rs index ad6e57b..71733ab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -157,11 +157,16 @@ async fn main() { settings::EndpointType::Http2(endpoint_info) => { crate::http_server::start_h2_server(listen_end_point, app.clone(), endpoint_info); } - settings::EndpointType::Tcp { remote_addr, debug } => { + settings::EndpointType::Tcp { + remote_addr, + debug, + whitelisted_ip, + } => { crate::tcp_port_forward::start_tcp( app.clone(), listen_end_point, remote_addr, + whitelisted_ip, debug, ); } diff --git a/src/settings/end_point_settings.rs b/src/settings/end_point_settings.rs index 7fde3c4..41355b7 100644 --- a/src/settings/end_point_settings.rs +++ b/src/settings/end_point_settings.rs @@ -2,7 +2,10 @@ use std::collections::HashMap; use serde::*; -use crate::http_proxy_pass::{HttpType, ProxyPassEndpointInfo}; +use crate::{ + http_proxy_pass::{HttpType, ProxyPassEndpointInfo}, + types::WhiteListedIpList, +}; use super::{ EndpointType, GoogleAuthSettings, LocationSettings, ModifyHttpHeadersSettings, @@ -144,9 +147,13 @@ impl EndpointSettings { } }, super::ProxyPassTo::Tcp(remote_addr) => { + let mut whitelisted_ip = WhiteListedIpList::new(); + whitelisted_ip.apply(self.whitelisted_ip.as_deref()); + return Ok(EndpointType::Tcp { remote_addr, debug: self.get_debug(), + whitelisted_ip, }); } } diff --git a/src/settings/end_point_type.rs b/src/settings/end_point_type.rs index 901951c..5f51d61 100644 --- a/src/settings/end_point_type.rs +++ b/src/settings/end_point_type.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use my_ssh::SshCredentials; -use crate::http_proxy_pass::ProxyPassEndpointInfo; +use crate::{http_proxy_pass::ProxyPassEndpointInfo, types::WhiteListedIpList}; use super::{RemoteHost, SslCertificateId}; @@ -22,6 +22,7 @@ pub enum EndpointType { Tcp { remote_addr: std::net::SocketAddr, debug: bool, + whitelisted_ip: WhiteListedIpList, }, TcpOverSsh { ssh_credentials: Arc, diff --git a/src/tcp_port_forward/start_tcp.rs b/src/tcp_port_forward/start_tcp.rs index f886cbc..452c866 100644 --- a/src/tcp_port_forward/start_tcp.rs +++ b/src/tcp_port_forward/start_tcp.rs @@ -3,21 +3,29 @@ use std::sync::Arc; use rust_extensions::date_time::AtomicDateTimeAsMicroseconds; use tokio::{io::AsyncWriteExt, net::TcpStream, sync::Mutex}; -use crate::app::AppContext; +use crate::{app::AppContext, types::WhiteListedIpList}; pub fn start_tcp( app: Arc, listen_addr: std::net::SocketAddr, remote_addr: std::net::SocketAddr, + whitelisted_ip: WhiteListedIpList, debug: bool, ) { - tokio::spawn(tcp_server_accept_loop(app, listen_addr, remote_addr, debug)); + tokio::spawn(tcp_server_accept_loop( + app, + listen_addr, + remote_addr, + whitelisted_ip, + debug, + )); } async fn tcp_server_accept_loop( app: Arc, listen_addr: std::net::SocketAddr, remote_addr: std::net::SocketAddr, + whitelisted_ip: WhiteListedIpList, debug: bool, ) { let listener = tokio::net::TcpListener::bind(listen_addr).await; @@ -37,6 +45,18 @@ async fn tcp_server_accept_loop( loop { let (mut server_stream, socket_addr) = listener.accept().await.unwrap(); + if !whitelisted_ip.is_whitelisted(&socket_addr.ip()) { + if debug { + println!( + "Incoming connection from {} is not whitelisted. Closing it", + socket_addr + ); + } + + let _ = server_stream.shutdown().await; + continue; + } + let remote_tcp_connection_result = tokio::time::timeout( app.connection_settings.remote_connect_timeout, TcpStream::connect(remote_addr),