From e773d7a0d8f0e256749ee376587a2de8f58609d8 Mon Sep 17 00:00:00 2001 From: Ben Linsay Date: Tue, 29 Jun 2021 00:26:48 -0400 Subject: [PATCH] Use a threadpool to handle incoming requests Adds a pool of worker threads to handle incoming requests. Uses a bounded channel to pass incoming requests to each worker. Closes #24 --- Cargo.lock | 13 ++++++ Cargo.toml | 2 + src/main.rs | 123 ++++++++++++++++++++++++++++++++++++++++------------ 3 files changed, 110 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2569f98..f710f6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "anyhow" version = "1.0.32" @@ -385,12 +387,14 @@ dependencies = [ "anyhow", "atoi", "criterion", + "crossbeam-channel", "nix", "num-derive", "num-traits", "slog", "slog-async", "slog-term", + "threadpool", ] [[package]] @@ -735,6 +739,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "threadpool" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" +dependencies = [ + "num_cpus", +] + [[package]] name = "time" version = "0.1.44" diff --git a/Cargo.toml b/Cargo.toml index e7ba67c..8041aea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,11 @@ atoi = "^0.3" slog = "^2.5" slog-async = "^2.5" slog-term = "^2.6" +crossbeam-channel = "^0.4" nix = "^0.18" num-derive = "^0.2" num-traits = "^0.2" +threadpool = "^1.8" [dev-dependencies] criterion = "^0.3" diff --git a/src/main.rs b/src/main.rs index e2e472d..637e1b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,15 +49,95 @@ use std::io::ErrorKind; use std::os::unix::fs::PermissionsExt; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::Path; -use std::thread; +use std::time::Duration; use anyhow::{Context, Result}; +use crossbeam_channel as channel; use slog::{debug, error, o, Drain}; +use threadpool::ThreadPool; mod ffi; mod handlers; mod protocol; +fn main() -> Result<()> { + const SOCKET_PATH: &str = "/var/run/nscd/socket"; + const N_WORKERS: usize = 256; + const HANDOFF_TIMEOUT: Duration = Duration::from_secs(3); + + ffi::disable_internal_nscd(); + + let decorator = slog_term::TermDecorator::new().build(); + let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + + let logger = slog::Logger::root(drain, slog::o!()); + + let (pool, handle) = worker_pool(logger.clone(), N_WORKERS); + + let path = Path::new(SOCKET_PATH); + std::fs::create_dir_all(path.parent().expect("socket path has no parent"))?; + std::fs::remove_file(path).ok(); + let listener = UnixListener::bind(path).context("could not bind to socket")?; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o777))?; + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + // if something goes wrong and it's multiple seconds until we + // get a response, kill the process. + // + // the timeout here is set such that nss will fall back to system + // libc before this timeout is hit - clients will already be + // giving up and going elsewhere so crashing the process should + // not make a bad situation worse. + match handle.send_timeout(stream, HANDOFF_TIMEOUT) { + Err(channel::SendTimeoutError::Timeout(_)) => { + anyhow::bail!("timed out waiting for an available worker: exiting") + } + Err(channel::SendTimeoutError::Disconnected(_)) => { + anyhow::bail!("aborting: worker channel is disconnected") + } + _ => { /*ok!*/ } + } + } + Err(err) => { + error!(logger, "error accepting connection"; "err" => %err); + break; + } + } + } + + // drop the worker handle so that the worker pool shuts down. every worker + // task should break and the process should exit. + std::mem::drop(handle); + pool.join(); + + Ok(()) +} + +fn worker_pool(log: slog::Logger, n_workers: usize) -> (ThreadPool, channel::Sender) { + let pool = ThreadPool::new(n_workers); + let (tx, rx) = channel::bounded(0); + + // TODO: figure out how to name the worker threads in each worker + // TODO: report actively working threads + for _ in 0..n_workers { + let log = log.clone(); + let rx = rx.clone(); + + pool.execute(move || loop { + let log = log.clone(); + match rx.recv() { + Ok(stream) => handle_stream(log, stream), + Err(channel::RecvError) => break, + } + }); + } + + (pool, tx) +} + /// Handle a new socket connection, reading the request and sending the response. fn handle_stream(log: slog::Logger, mut stream: UnixStream) { debug!(log, "accepted connection"; "stream" => ?stream); @@ -101,35 +181,22 @@ fn handle_stream(log: slog::Logger, mut stream: UnixStream) { } } -const SOCKET_PATH: &str = "/var/run/nscd/socket"; - -fn main() -> Result<()> { - ffi::disable_internal_nscd(); - - let decorator = slog_term::TermDecorator::new().build(); - let drain = slog_term::FullFormat::new(decorator).build().fuse(); - let drain = slog_async::Async::new(drain).build().fuse(); +#[cfg(test)] +mod pool_test { + use super::*; - let logger = slog::Logger::root(drain, slog::o!()); + #[test] + fn worker_shutdown() { + let logger = { + let decorator = slog_term::PlainDecorator::new(slog_term::TestStdoutWriter); + let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + slog::Logger::root(drain, slog::o!()) + }; - let path = Path::new(SOCKET_PATH); - std::fs::create_dir_all(path.parent().expect("socket path has no parent"))?; - std::fs::remove_file(path).ok(); - let listener = UnixListener::bind(path).context("could not bind to socket")?; - std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o777))?; + let (pool, handle) = worker_pool(logger, 123); - for stream in listener.incoming() { - match stream { - Ok(stream) => { - let thread_logger = logger.clone(); - thread::spawn(move || handle_stream(thread_logger, stream)); - } - Err(err) => { - error!(logger, "error accepting connection"; "err" => %err); - break; - } - } + std::mem::drop(handle); + pool.join(); } - - Ok(()) }