diff --git a/.github/workflows/rust_build.yml b/.github/workflows/rust_build.yml index 3bd52d8..d4d8c1a 100644 --- a/.github/workflows/rust_build.yml +++ b/.github/workflows/rust_build.yml @@ -36,3 +36,7 @@ jobs: run: cargo test - name: "Build in release mode" run: cargo build --release + - name: "Run CLI in debug and release mode" # tests clap, mainly + run: | + cargo run --release -- --help + cargo run -- --help diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..78d94de --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,67 @@ +use std::collections::BTreeMap; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::time::Duration; + +use concread::arcache::ARCache; +use hashbrown::HashSet; +use ldap3_proto::{LdapFilter, LdapSearchScope}; +use openssl::ssl::SslConnector; +use serde::Deserialize; +use url::Url; + +pub mod proxy; + +use crate::proxy::{CachedValue, SearchCacheKey}; + +const MEGABYTES: usize = 1048576; + +pub struct AppState { + pub tls_params: SslConnector, + pub addrs: Vec, + // Cache later here. + pub binddn_map: BTreeMap, + pub cache: ARCache, + pub cache_entry_timeout: Duration, + pub max_incoming_ber_size: Option, + pub max_proxy_ber_size: Option, + pub allow_all_bind_dns: bool, +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct DnConfig { + #[serde(default)] + pub allowed_queries: HashSet<(String, LdapSearchScope, LdapFilter)>, +} + +fn default_cache_bytes() -> usize { + 128 * MEGABYTES +} + +fn default_cache_entry_timeout() -> u64 { + 1800 +} + +#[derive(Debug, Deserialize)] +pub struct Config { + pub bind: SocketAddr, + pub tls_key: PathBuf, + pub tls_chain: PathBuf, + + #[serde(default = "default_cache_bytes")] + pub cache_bytes: usize, + #[serde(default = "default_cache_entry_timeout")] + pub cache_entry_timeout: u64, + + pub ldap_ca: PathBuf, + pub ldap_url: Url, + + pub max_incoming_ber_size: Option, + pub max_proxy_ber_size: Option, + + #[serde(default)] + pub allow_all_bind_dns: bool, + + #[serde(flatten)] + pub binddn_map: BTreeMap, +} diff --git a/src/main.rs b/src/main.rs index b3c0942..fc79efe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,20 +14,15 @@ static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; use clap::Parser; -use hashbrown::HashSet; use ldap3_proto::LdapCodec; -use ldap3_proto::{LdapFilter, LdapSearchScope}; -use serde::Deserialize; -use std::collections::BTreeMap; +use ldap_proxy::{AppState, Config}; use std::fs::File; use std::io::Read; -use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tokio::sync::broadcast; use tracing_forest::{traits::*, util::*}; -use url::Url; use openssl::ssl::{Ssl, SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; use openssl::x509::X509; @@ -35,14 +30,10 @@ use tokio::net::TcpListener; use tokio_openssl::SslStream; use tokio_util::codec::{FramedRead, FramedWrite}; -use concread::arcache::{ARCache, ARCacheBuilder}; - -mod proxy; - -use crate::proxy::{CachedValue, SearchCacheKey}; +use concread::arcache::ARCacheBuilder; +use ldap_proxy::proxy::client_process; const DEFAULT_CONFIG_PATH: &str = "/etc/kanidm/ldap-proxy"; -const MEGABYTES: usize = 1048576; #[derive(Debug, clap::Parser)] struct Opt { @@ -53,56 +44,6 @@ struct Opt { config: PathBuf, } -fn default_cache_bytes() -> usize { - 128 * MEGABYTES -} - -fn default_cache_entry_timeout() -> u64 { - 1800 -} - -#[derive(Debug, Deserialize)] -struct Config { - bind: SocketAddr, - tls_key: PathBuf, - tls_chain: PathBuf, - - #[serde(default = "default_cache_bytes")] - cache_bytes: usize, - #[serde(default = "default_cache_entry_timeout")] - cache_entry_timeout: u64, - - ldap_ca: PathBuf, - ldap_url: Url, - - max_incoming_ber_size: Option, - max_proxy_ber_size: Option, - - #[serde(default)] - allow_all_bind_dns: bool, - - #[serde(flatten)] - binddn_map: BTreeMap, -} - -#[derive(Debug, Clone, Deserialize, Default)] -struct DnConfig { - #[serde(default)] - allowed_queries: HashSet<(String, LdapSearchScope, LdapFilter)>, -} - -pub(crate) struct AppState { - pub tls_params: SslConnector, - pub addrs: Vec, - // Cache later here. - pub binddn_map: BTreeMap, - pub cache: ARCache, - pub cache_entry_timeout: Duration, - pub max_incoming_ber_size: Option, - pub max_proxy_ber_size: Option, - pub allow_all_bind_dns: bool, -} - async fn ldaps_acceptor( listener: TcpListener, tls_parms: SslAcceptor, @@ -135,7 +76,7 @@ async fn ldaps_acceptor( let r = FramedRead::new(r, LdapCodec::new(max_incoming_ber_size)); let w = FramedWrite::new(w, LdapCodec::new(max_incoming_ber_size)); let c_app_state = app_state.clone(); - tokio::spawn(proxy::client_process(r, w, client_socket_addr, c_app_state)); + tokio::spawn(client_process(r, w, client_socket_addr, c_app_state)); } Err(e) => { error!("LDAP acceptor error, continuing -> {:?}", e); diff --git a/src/proxy.rs b/src/proxy.rs index 16f0147..d143b62 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -35,14 +35,14 @@ pub struct SearchCacheKey { #[derive(Debug, Clone)] pub struct CachedValue { - valid_until: Instant, - entries: Vec<(LdapSearchResultEntry, Vec)>, - result: LdapResult, - ctrl: Vec, + pub valid_until: Instant, + pub entries: Vec<(LdapSearchResultEntry, Vec)>, + pub result: LdapResult, + pub ctrl: Vec, } impl CachedValue { - fn size(&self) -> usize { + pub fn size(&self) -> usize { std::mem::size_of::() + self.entries.iter().map(|(e, _)| e.size()).sum::() } } @@ -72,7 +72,7 @@ fn bind_operror(msgid: i32, msg: &str) -> LdapMsg { } } -pub(crate) async fn client_process( +pub async fn client_process( mut r: FramedRead, mut w: FramedWrite, client_address: SocketAddr, @@ -423,14 +423,14 @@ pub(crate) async fn client_process( } #[derive(Debug, Clone)] -enum LdapError { +pub enum LdapError { TlsError, ConnectError, Transport, InvalidProtocolState, } -struct BasicLdapClient { +pub struct BasicLdapClient { r: FramedRead, w: FramedWrite, msg_counter: i32, diff --git a/tests/test_config.toml b/tests/test_config.toml new file mode 100644 index 0000000..816f498 --- /dev/null +++ b/tests/test_config.toml @@ -0,0 +1,26 @@ + +bind = "127.0.0.1:3636" +tls_chain = "/etc/ldap-proxy/chain.pem" +tls_key = "/etc/ldap-proxy/key.pem" + +ldap_ca = "/etc/ldap-proxy/ldap-ca.pem" +ldap_url = "ldaps://ldap.example.com" + +[""] +allowed_queries = [["", "base", "(objectclass=*)"]] + +["cn=John Cena,dc=dooo,dc=do,dc=do,dc=doooooo"] +allowed_queries = [ + [ + "", + "base", + "(objectclass=*)", + ], + [ + "o=kanidm", + "subtree", + "(objectclass=*)", + ], +] + +["cn=Administrator"] diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..a66097f --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,37 @@ +// use ldap_proxy::proxy::BasicLdapClient; + +use ldap3_proto::proto::LdapResult; +use ldap_proxy::proxy::CachedValue; +use ldap_proxy::Config; +use std::time::{Duration, Instant}; + +#[test] +fn hello_world() { + assert_eq!(2 + 2, 4); +} + +#[test] +fn test_config_load() { + assert!(toml::from_str::("").is_err()); + + assert!(toml::from_str::(include_str!("test_config.toml")).is_ok()); + let config = toml::from_str::(include_str!("test_config.toml")).unwrap(); + + assert_eq!(config.ldap_ca.to_str(), Some("/etc/ldap-proxy/ldap-ca.pem")); +} + +#[test] +fn test_cachedvalue() { + let cv = CachedValue { + valid_until: Instant::now() + Duration::from_secs(60), + entries: Vec::with_capacity(5), + result: LdapResult { + code: ldap3_proto::LdapResultCode::Busy, + matcheddn: "dn=doo".to_string(), + message: "ohno".to_string(), + referral: Vec::with_capacity(5), + }, + ctrl: Vec::with_capacity(5), + }; + assert_eq!(cv.size(), 144); +}