Skip to content

Commit f272dec

Browse files
authored
Merge pull request #229 from elnosh/use-tasktracker
fix memory buildup from JoinSet
2 parents a1ca5ac + 3a94d61 commit f272dec

File tree

6 files changed

+70
-78
lines changed

6 files changed

+70
-78
lines changed

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sim-cli/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ rand = "0.8.5"
2525
hex = {version = "0.4.3"}
2626
futures = "0.3.30"
2727
console-subscriber = { version = "0.4.0", optional = true}
28+
tokio-util = { version = "0.7.13", features = ["rt"] }
2829

2930
[features]
3031
dev = ["console-subscriber"]

sim-cli/src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::collections::HashMap;
1212
use std::path::PathBuf;
1313
use std::sync::Arc;
1414
use tokio::sync::Mutex;
15+
use tokio_util::task::TaskTracker;
1516

1617
/// The default directory where the simulation files are stored and where the results will be written to.
1718
pub const DEFAULT_DATA_DIR: &str = ".";
@@ -209,6 +210,7 @@ async fn main() -> anyhow::Result<()> {
209210
None
210211
};
211212

213+
let tasks = TaskTracker::new();
212214
let sim = Simulation::new(
213215
SimulationCfg::new(
214216
cli.total_time,
@@ -219,6 +221,7 @@ async fn main() -> anyhow::Result<()> {
219221
),
220222
clients,
221223
validated_activities,
224+
tasks,
222225
);
223226
let sim2 = sim.clone();
224227

simln-lib/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ serde_millis = "0.1.1"
3232
rand_distr = "0.4.3"
3333
mockall = "0.12.1"
3434
rand_chacha = "0.3.1"
35+
tokio-util = { version = "0.7.13", features = ["rt"] }
3536

3637
[dev-dependencies]
3738
ntest = "0.9.0"

simln-lib/src/lib.rs

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use std::{collections::HashMap, sync::Arc, time::SystemTime};
1919
use thiserror::Error;
2020
use tokio::sync::mpsc::{channel, Receiver, Sender};
2121
use tokio::sync::Mutex;
22-
use tokio::task::JoinSet;
2322
use tokio::{select, time, time::Duration};
23+
use tokio_util::task::TaskTracker;
2424
use triggered::{Listener, Trigger};
2525

2626
use self::defined_activity::DefinedPaymentActivity;
@@ -518,6 +518,9 @@ pub struct Simulation {
518518
activity: Vec<ActivityDefinition>,
519519
/// Results logger that holds the simulation statistics.
520520
results: Arc<Mutex<PaymentResultLogger>>,
521+
/// Track all tasks spawned for use in the simulation. When used in the `run` method, it will wait for
522+
/// these tasks to complete before returning.
523+
tasks: TaskTracker,
521524
/// High level triggers used to manage simulation tasks and shutdown.
522525
shutdown_trigger: Trigger,
523526
shutdown_listener: Listener,
@@ -546,13 +549,15 @@ impl Simulation {
546549
cfg: SimulationCfg,
547550
nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>>,
548551
activity: Vec<ActivityDefinition>,
552+
tasks: TaskTracker,
549553
) -> Self {
550554
let (shutdown_trigger, shutdown_listener) = triggered::trigger();
551555
Self {
552556
cfg,
553557
nodes,
554558
activity,
555559
results: Arc::new(Mutex::new(PaymentResultLogger::new())),
560+
tasks,
556561
shutdown_trigger,
557562
shutdown_listener,
558563
}
@@ -644,7 +649,19 @@ impl Simulation {
644649
Ok(())
645650
}
646651

652+
/// run until the simulation completes or we hit an error.
653+
/// Note that it will wait for the tasks in self.tasks to complete
654+
/// before returning.
647655
pub async fn run(&self) -> Result<(), SimulationError> {
656+
self.internal_run().await?;
657+
// Close our TaskTracker and wait for any background tasks
658+
// spawned during internal_run to complete.
659+
self.tasks.close();
660+
self.tasks.wait().await;
661+
Ok(())
662+
}
663+
664+
async fn internal_run(&self) -> Result<(), SimulationError> {
648665
if let Some(total_time) = self.cfg.total_time {
649666
log::info!("Running the simulation for {}s.", total_time.as_secs());
650667
} else {
@@ -659,7 +676,6 @@ impl Simulation {
659676
self.activity.len(),
660677
self.nodes.len()
661678
);
662-
let mut tasks = JoinSet::new();
663679

664680
// Before we start the simulation up, start tasks that will be responsible for gathering simulation data.
665681
// The event channels are shared across our functionality:
@@ -668,21 +684,15 @@ impl Simulation {
668684
// - Event Receiver: used by data reporting to receive events that have been simulated that need to be
669685
// tracked and recorded.
670686
let (event_sender, event_receiver) = channel(1);
671-
self.run_data_collection(event_receiver, &mut tasks);
687+
self.run_data_collection(event_receiver, &self.tasks);
672688

673689
// Get an execution kit per activity that we need to generate and spin up consumers for each source node.
674690
let activities = match self.activity_executors().await {
675691
Ok(a) => a,
676692
Err(e) => {
677693
// If we encounter an error while setting up the activity_executors,
678-
// we need to shutdown and wait for tasks to finish. We have started background tasks in the
679-
// run_data_collection function, so we should shut those down before returning.
694+
// we need to shutdown and return.
680695
self.shutdown();
681-
while let Some(res) = tasks.join_next().await {
682-
if let Err(e) = res {
683-
log::error!("Task exited with error: {e}.");
684-
}
685-
}
686696
return Err(e);
687697
},
688698
};
@@ -692,40 +702,30 @@ impl Simulation {
692702
.map(|generator| generator.source_info.pubkey)
693703
.collect(),
694704
event_sender.clone(),
695-
&mut tasks,
705+
&self.tasks,
696706
);
697707

698708
// Next, we'll spin up our actual producers that will be responsible for triggering the configured activity.
699-
// The producers will use their own JoinSet so that the simulation can be shutdown if they all finish.
700-
let mut producer_tasks = JoinSet::new();
709+
// The producers will use their own TaskTracker so that the simulation can be shutdown if they all finish.
710+
let producer_tasks = TaskTracker::new();
701711
match self
702-
.dispatch_producers(activities, consumer_channels, &mut producer_tasks)
712+
.dispatch_producers(activities, consumer_channels, &producer_tasks)
703713
.await
704714
{
705715
Ok(_) => {},
706716
Err(e) => {
707-
// If we encounter an error in dispatch_producers, we need to shutdown and wait for tasks to finish.
708-
// We have started background tasks in the run_data_collection function,
709-
// so we should shut those down before returning.
717+
// If we encounter an error in dispatch_producers, we need to shutdown and return.
710718
self.shutdown();
711-
while let Some(res) = tasks.join_next().await {
712-
if let Err(e) = res {
713-
log::error!("Task exited with error: {e}.");
714-
}
715-
}
716719
return Err(e);
717720
},
718721
}
719722

720723
// Start a task that waits for the producers to finish.
721724
// If all producers finish, then there is nothing left to do and the simulation can be shutdown.
722725
let producer_trigger = self.shutdown_trigger.clone();
723-
tasks.spawn(async move {
724-
while let Some(res) = producer_tasks.join_next().await {
725-
if let Err(e) = res {
726-
log::error!("Producer exited with error: {e}.");
727-
}
728-
}
726+
self.tasks.spawn(async move {
727+
producer_tasks.close();
728+
producer_tasks.wait().await;
729729
log::info!("All producers finished. Shutting down.");
730730
producer_trigger.trigger()
731731
});
@@ -735,7 +735,7 @@ impl Simulation {
735735
let t = self.shutdown_trigger.clone();
736736
let l = self.shutdown_listener.clone();
737737

738-
tasks.spawn(async move {
738+
self.tasks.spawn(async move {
739739
if time::timeout(total_time, l).await.is_err() {
740740
log::info!(
741741
"Simulation run for {}s. Shutting down.",
@@ -746,18 +746,7 @@ impl Simulation {
746746
});
747747
}
748748

749-
// We always want to wait for all threads to exit, so we wait for all of them to exit and track any errors
750-
// that surface. It's okay if there are multiple and one is overwritten, we just want to know whether we
751-
// exited with an error or not.
752-
let mut success = true;
753-
while let Some(res) = tasks.join_next().await {
754-
if let Err(e) = res {
755-
log::error!("Task exited with error: {e}.");
756-
success = false;
757-
}
758-
}
759-
760-
success.then_some(()).ok_or(SimulationError::TaskError)
749+
Ok(())
761750
}
762751

763752
pub fn shutdown(&self) {
@@ -777,7 +766,7 @@ impl Simulation {
777766
fn run_data_collection(
778767
&self,
779768
output_receiver: Receiver<SimulationOutput>,
780-
tasks: &mut JoinSet<()>,
769+
tasks: &TaskTracker,
781770
) {
782771
let listener = self.shutdown_listener.clone();
783772
let shutdown = self.shutdown_trigger.clone();
@@ -790,11 +779,17 @@ impl Simulation {
790779
// psr: produce simulation results
791780
let psr_listener = listener.clone();
792781
let psr_shutdown = shutdown.clone();
782+
let psr_tasks = tasks.clone();
793783
tasks.spawn(async move {
794784
log::debug!("Starting simulation results producer.");
795-
if let Err(e) =
796-
produce_simulation_results(nodes, output_receiver, results_sender, psr_listener)
797-
.await
785+
if let Err(e) = produce_simulation_results(
786+
nodes,
787+
output_receiver,
788+
results_sender,
789+
psr_listener,
790+
&psr_tasks,
791+
)
792+
.await
798793
{
799794
psr_shutdown.trigger();
800795
log::error!("Produce simulation results exited with error: {e:?}.");
@@ -939,7 +934,7 @@ impl Simulation {
939934
&self,
940935
consuming_nodes: HashSet<PublicKey>,
941936
output_sender: Sender<SimulationOutput>,
942-
tasks: &mut JoinSet<()>,
937+
tasks: &TaskTracker,
943938
) -> HashMap<PublicKey, Sender<SimulationEvent>> {
944939
let mut channels = HashMap::new();
945940

@@ -984,7 +979,7 @@ impl Simulation {
984979
&self,
985980
executors: Vec<ExecutorKit>,
986981
producer_channels: HashMap<PublicKey, Sender<SimulationEvent>>,
987-
tasks: &mut JoinSet<()>,
982+
tasks: &TaskTracker,
988983
) -> Result<(), SimulationError> {
989984
for executor in executors {
990985
let sender = producer_channels.get(&executor.source_info.pubkey).ok_or(
@@ -1350,9 +1345,8 @@ async fn produce_simulation_results(
13501345
mut output_receiver: Receiver<SimulationOutput>,
13511346
results: Sender<(Payment, PaymentResult)>,
13521347
listener: Listener,
1348+
tasks: &TaskTracker,
13531349
) -> Result<(), SimulationError> {
1354-
let mut set = tokio::task::JoinSet::new();
1355-
13561350
let result = loop {
13571351
tokio::select! {
13581352
biased;
@@ -1365,7 +1359,7 @@ async fn produce_simulation_results(
13651359
match simulation_output{
13661360
SimulationOutput::SendPaymentSuccess(payment) => {
13671361
if let Some(source_node) = nodes.get(&payment.source) {
1368-
set.spawn(track_payment_result(
1362+
tasks.spawn(track_payment_result(
13691363
source_node.clone(), results.clone(), payment, listener.clone()
13701364
));
13711365
} else {
@@ -1396,11 +1390,6 @@ async fn produce_simulation_results(
13961390
};
13971391

13981392
log::debug!("Simulation results producer exiting.");
1399-
while let Some(res) = set.join_next().await {
1400-
if let Err(e) = res {
1401-
log::error!("Simulation results producer task exited with error: {e}.");
1402-
}
1403-
}
14041393

14051394
result
14061395
}
@@ -1476,6 +1465,7 @@ mod tests {
14761465
use std::sync::Arc;
14771466
use std::time::Duration;
14781467
use tokio::sync::Mutex;
1468+
use tokio_util::task::TaskTracker;
14791469

14801470
#[test]
14811471
fn create_seeded_mut_rng() {
@@ -1619,6 +1609,7 @@ mod tests {
16191609
crate::SimulationCfg::new(Some(0), 0, 0.0, None, None),
16201610
clients,
16211611
vec![activity_definition],
1612+
TaskTracker::new(),
16221613
);
16231614
assert!(simulation.validate_activity().await.is_err());
16241615
}

0 commit comments

Comments
 (0)