diff --git a/Cargo.lock b/Cargo.lock index 2ac5f2334..9fe67645c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5781,11 +5781,13 @@ dependencies = [ "console", "copy_dir", "docker-compose-runner", + "futures", "scylla", "serde", "strum 0.25.0", "time", "tokio", + "tokio-stream", ] [[package]] diff --git a/windsock/Cargo.toml b/windsock/Cargo.toml index 7686c1f2b..fffbcdeaf 100644 --- a/windsock/Cargo.toml +++ b/windsock/Cargo.toml @@ -13,10 +13,12 @@ bincode.workspace = true clap.workspace = true console = "0.15.5" copy_dir = "0.1.2" +futures.workspace = true serde = { workspace = true, features = ["derive"] } strum = { version = "0.25.0", features = ["derive"] } time = { version = "0.3.25", features = ["serde"] } tokio.workspace = true +tokio-stream = "0.1.14" [dev-dependencies] scylla = { version = "0.10.0", features = ["ssl"] } diff --git a/windsock/src/bench.rs b/windsock/src/bench.rs index 733d7eb05..a8523992a 100644 --- a/windsock/src/bench.rs +++ b/windsock/src/bench.rs @@ -3,6 +3,7 @@ use crate::report::{report_builder, Report, ReportArchive}; use crate::tables::ReportColumn; use anyhow::Result; use async_trait::async_trait; +use futures::StreamExt; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -10,6 +11,8 @@ use std::path::PathBuf; use std::time::{Duration, Instant}; use tokio::sync::mpsc::UnboundedSender; use tokio::task::JoinHandle; +use tokio_stream::wrappers::IntervalStream; +use tokio_stream::Stream; pub struct BenchState { bench: Box, @@ -284,45 +287,71 @@ pub trait BenchTask: Clone + Send + Sync + 'static { reporter: UnboundedSender, operations_per_second: Option, ) -> Vec> { - let mut tasks = vec![]; - // 100 is a generally nice amount of tasks to have, but if we have more tasks than OPS the throughput is very unstable - let task_count = operations_per_second.map(|x| x.min(100)).unwrap_or(100); - - let allocated_time_per_op = operations_per_second - .map(|ops| (Duration::from_secs(1) * task_count as u32) / ops as u32); - for i in 0..task_count { - let task = self.clone(); - let reporter = reporter.clone(); - tasks.push(tokio::spawn(async move { - // spread load out over a second - tokio::time::sleep(Duration::from_nanos((1_000_000_000 / task_count) * i)).await; - - let mut interval = allocated_time_per_op.map(tokio::time::interval); - - loop { - if let Some(interval) = &mut interval { - interval.tick().await; - } - - let operation_start = Instant::now(); - let report = match task.run_one_operation().await { - Ok(()) => Report::QueryCompletedIn(operation_start.elapsed()), - Err(message) => Report::QueryErrored { - completed_in: operation_start.elapsed(), - message, - }, - }; - if reporter.send(report).is_err() { - // The benchmark has completed and the reporter no longer wants to receive reports so just shutdown - return; - } - } - })); + match operations_per_second { + Some(ops) => { + spawn_tasks_inner( + self.clone(), + interval_stream(ops), + reporter, + operations_per_second, + ) + .await + } + None => { + spawn_tasks_inner( + self.clone(), + futures::stream::repeat_with(|| ()), + reporter, + operations_per_second, + ) + .await + } } + } +} - // sleep until all tasks have started running - tokio::time::sleep(Duration::from_secs(1)).await; +async fn spawn_tasks_inner( + task: T, + stream: impl Stream + std::marker::Unpin + Send + 'static, + reporter: UnboundedSender, + operations_per_second: Option, +) -> Vec> { + // TODO: remove vec + vec![tokio::spawn(async move { + // 100 is a generally nice amount of tasks to have, but if we have more tasks than OPS the throughput is very unstable + let task_count = operations_per_second.map(|x| x.min(500)).unwrap_or(500); + + let mut result_stream = stream + .map(|_| async { + let start = Instant::now(); + let task = task.clone(); + let result = tokio::task::spawn(async move { task.run_one_operation().await }) + .await + .unwrap(); + (result, start.elapsed()) + }) + .buffer_unordered(task_count as usize); + + while let Some((res, elapsed)) = result_stream.next().await { + let report = match res { + Ok(()) => Report::QueryCompletedIn(elapsed), + Err(message) => Report::QueryErrored { + completed_in: elapsed, + message, + }, + }; + if reporter.send(report).is_err() { + // The benchmark has completed and the reporter no longer wants to receive reports so just shutdown + break; + } + } + })] +} - tasks - } +/// Create a stream that emits the configured events per second +fn interval_stream(events_per_second: u64) -> IntervalStream { + let mut interval = + tokio::time::interval(Duration::from_nanos(1_000_000_000 / events_per_second)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + IntervalStream::new(interval) }