Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moar (ok some) tests #23

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/rust_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ jobs:
run: cargo test
- name: "Build in release mode"
run: cargo build --release
- name: "Run CLI in debug and release mode" # tests clap, mainly
run: |
cargo run --release -- --help
cargo run -- --help
67 changes: 67 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::time::Duration;

use concread::arcache::ARCache;
use hashbrown::HashSet;
use ldap3_proto::{LdapFilter, LdapSearchScope};
use openssl::ssl::SslConnector;
use serde::Deserialize;
use url::Url;

pub mod proxy;

use crate::proxy::{CachedValue, SearchCacheKey};

const MEGABYTES: usize = 1048576;

pub struct AppState {
pub tls_params: SslConnector,
pub addrs: Vec<SocketAddr>,
// Cache later here.
pub binddn_map: BTreeMap<String, DnConfig>,
pub cache: ARCache<SearchCacheKey, CachedValue>,
pub cache_entry_timeout: Duration,
pub max_incoming_ber_size: Option<usize>,
pub max_proxy_ber_size: Option<usize>,
pub allow_all_bind_dns: bool,
}

#[derive(Debug, Clone, Deserialize, Default)]
pub struct DnConfig {
#[serde(default)]
pub allowed_queries: HashSet<(String, LdapSearchScope, LdapFilter)>,
}

fn default_cache_bytes() -> usize {
128 * MEGABYTES
}

fn default_cache_entry_timeout() -> u64 {
1800
}

#[derive(Debug, Deserialize)]
pub struct Config {
pub bind: SocketAddr,
pub tls_key: PathBuf,
pub tls_chain: PathBuf,

#[serde(default = "default_cache_bytes")]
pub cache_bytes: usize,
#[serde(default = "default_cache_entry_timeout")]
pub cache_entry_timeout: u64,

pub ldap_ca: PathBuf,
pub ldap_url: Url,

pub max_incoming_ber_size: Option<usize>,
pub max_proxy_ber_size: Option<usize>,

#[serde(default)]
pub allow_all_bind_dns: bool,

#[serde(flatten)]
pub binddn_map: BTreeMap<String, DnConfig>,
}
67 changes: 4 additions & 63 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,26 @@
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;

use clap::Parser;
use hashbrown::HashSet;
use ldap3_proto::LdapCodec;
use ldap3_proto::{LdapFilter, LdapSearchScope};
use serde::Deserialize;
use std::collections::BTreeMap;
use ldap_proxy::{AppState, Config};
use std::fs::File;
use std::io::Read;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tracing_forest::{traits::*, util::*};
use url::Url;

use openssl::ssl::{Ssl, SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use openssl::x509::X509;
use tokio::net::TcpListener;
use tokio_openssl::SslStream;
use tokio_util::codec::{FramedRead, FramedWrite};

use concread::arcache::{ARCache, ARCacheBuilder};

mod proxy;

use crate::proxy::{CachedValue, SearchCacheKey};
use concread::arcache::ARCacheBuilder;
use ldap_proxy::proxy::client_process;

const DEFAULT_CONFIG_PATH: &str = "/etc/kanidm/ldap-proxy";
const MEGABYTES: usize = 1048576;

#[derive(Debug, clap::Parser)]
struct Opt {
Expand All @@ -53,56 +44,6 @@ struct Opt {
config: PathBuf,
}

fn default_cache_bytes() -> usize {
128 * MEGABYTES
}

fn default_cache_entry_timeout() -> u64 {
1800
}

#[derive(Debug, Deserialize)]
struct Config {
bind: SocketAddr,
tls_key: PathBuf,
tls_chain: PathBuf,

#[serde(default = "default_cache_bytes")]
cache_bytes: usize,
#[serde(default = "default_cache_entry_timeout")]
cache_entry_timeout: u64,

ldap_ca: PathBuf,
ldap_url: Url,

max_incoming_ber_size: Option<usize>,
max_proxy_ber_size: Option<usize>,

#[serde(default)]
allow_all_bind_dns: bool,

#[serde(flatten)]
binddn_map: BTreeMap<String, DnConfig>,
}

#[derive(Debug, Clone, Deserialize, Default)]
struct DnConfig {
#[serde(default)]
allowed_queries: HashSet<(String, LdapSearchScope, LdapFilter)>,
}

pub(crate) struct AppState {
pub tls_params: SslConnector,
pub addrs: Vec<SocketAddr>,
// Cache later here.
pub binddn_map: BTreeMap<String, DnConfig>,
pub cache: ARCache<SearchCacheKey, CachedValue>,
pub cache_entry_timeout: Duration,
pub max_incoming_ber_size: Option<usize>,
pub max_proxy_ber_size: Option<usize>,
pub allow_all_bind_dns: bool,
}

async fn ldaps_acceptor(
listener: TcpListener,
tls_parms: SslAcceptor,
Expand Down Expand Up @@ -135,7 +76,7 @@ async fn ldaps_acceptor(
let r = FramedRead::new(r, LdapCodec::new(max_incoming_ber_size));
let w = FramedWrite::new(w, LdapCodec::new(max_incoming_ber_size));
let c_app_state = app_state.clone();
tokio::spawn(proxy::client_process(r, w, client_socket_addr, c_app_state));
tokio::spawn(client_process(r, w, client_socket_addr, c_app_state));
}
Err(e) => {
error!("LDAP acceptor error, continuing -> {:?}", e);
Expand Down
16 changes: 8 additions & 8 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ pub struct SearchCacheKey {

#[derive(Debug, Clone)]
pub struct CachedValue {
valid_until: Instant,
entries: Vec<(LdapSearchResultEntry, Vec<LdapControl>)>,
result: LdapResult,
ctrl: Vec<LdapControl>,
pub valid_until: Instant,
pub entries: Vec<(LdapSearchResultEntry, Vec<LdapControl>)>,
pub result: LdapResult,
pub ctrl: Vec<LdapControl>,
}

impl CachedValue {
fn size(&self) -> usize {
pub fn size(&self) -> usize {
std::mem::size_of::<Self>() + self.entries.iter().map(|(e, _)| e.size()).sum::<usize>()
}
}
Expand Down Expand Up @@ -72,7 +72,7 @@ fn bind_operror(msgid: i32, msg: &str) -> LdapMsg {
}
}

pub(crate) async fn client_process<W: AsyncWrite + Unpin, R: AsyncRead + Unpin>(
pub async fn client_process<W: AsyncWrite + Unpin, R: AsyncRead + Unpin>(
mut r: FramedRead<R, LdapCodec>,
mut w: FramedWrite<W, LdapCodec>,
client_address: SocketAddr,
Expand Down Expand Up @@ -423,14 +423,14 @@ pub(crate) async fn client_process<W: AsyncWrite + Unpin, R: AsyncRead + Unpin>(
}

#[derive(Debug, Clone)]
enum LdapError {
pub enum LdapError {
TlsError,
ConnectError,
Transport,
InvalidProtocolState,
}

struct BasicLdapClient {
pub struct BasicLdapClient {
r: FramedRead<CR, LdapCodec>,
w: FramedWrite<CW, LdapCodec>,
msg_counter: i32,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

bind = "127.0.0.1:3636"
tls_chain = "/etc/ldap-proxy/chain.pem"
tls_key = "/etc/ldap-proxy/key.pem"

ldap_ca = "/etc/ldap-proxy/ldap-ca.pem"
ldap_url = "ldaps://ldap.example.com"

[""]
allowed_queries = [["", "base", "(objectclass=*)"]]

["cn=John Cena,dc=dooo,dc=do,dc=do,dc=doooooo"]
allowed_queries = [
[
"",
"base",
"(objectclass=*)",
],
[
"o=kanidm",
"subtree",
"(objectclass=*)",
],
]

["cn=Administrator"]
37 changes: 37 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// use ldap_proxy::proxy::BasicLdapClient;

use ldap3_proto::proto::LdapResult;
use ldap_proxy::proxy::CachedValue;
use ldap_proxy::Config;
use std::time::{Duration, Instant};

#[test]
fn hello_world() {
assert_eq!(2 + 2, 4);
}

#[test]
fn test_config_load() {
assert!(toml::from_str::<Config>("").is_err());

assert!(toml::from_str::<Config>(include_str!("test_config.toml")).is_ok());
let config = toml::from_str::<Config>(include_str!("test_config.toml")).unwrap();

assert_eq!(config.ldap_ca.to_str(), Some("/etc/ldap-proxy/ldap-ca.pem"));
}

#[test]
fn test_cachedvalue() {
let cv = CachedValue {
valid_until: Instant::now() + Duration::from_secs(60),
entries: Vec::with_capacity(5),
result: LdapResult {
code: ldap3_proto::LdapResultCode::Busy,
matcheddn: "dn=doo".to_string(),
message: "ohno".to_string(),
referral: Vec::with_capacity(5),
},
ctrl: Vec::with_capacity(5),
};
assert_eq!(cv.size(), 144);
}