@@ -19,8 +19,8 @@ use std::{collections::HashMap, sync::Arc, time::SystemTime};
19
19
use thiserror:: Error ;
20
20
use tokio:: sync:: mpsc:: { channel, Receiver , Sender } ;
21
21
use tokio:: sync:: Mutex ;
22
- use tokio:: task:: JoinSet ;
23
22
use tokio:: { select, time, time:: Duration } ;
23
+ use tokio_util:: task:: TaskTracker ;
24
24
use triggered:: { Listener , Trigger } ;
25
25
26
26
use self :: defined_activity:: DefinedPaymentActivity ;
@@ -518,6 +518,9 @@ pub struct Simulation {
518
518
activity : Vec < ActivityDefinition > ,
519
519
/// Results logger that holds the simulation statistics.
520
520
results : Arc < Mutex < PaymentResultLogger > > ,
521
+ /// Track all tasks spawned for use in the simulation. When used in the `run` method, it will wait for
522
+ /// these tasks to complete before returning.
523
+ tasks : TaskTracker ,
521
524
/// High level triggers used to manage simulation tasks and shutdown.
522
525
shutdown_trigger : Trigger ,
523
526
shutdown_listener : Listener ,
@@ -546,13 +549,15 @@ impl Simulation {
546
549
cfg : SimulationCfg ,
547
550
nodes : HashMap < PublicKey , Arc < Mutex < dyn LightningNode > > > ,
548
551
activity : Vec < ActivityDefinition > ,
552
+ tasks : TaskTracker ,
549
553
) -> Self {
550
554
let ( shutdown_trigger, shutdown_listener) = triggered:: trigger ( ) ;
551
555
Self {
552
556
cfg,
553
557
nodes,
554
558
activity,
555
559
results : Arc :: new ( Mutex :: new ( PaymentResultLogger :: new ( ) ) ) ,
560
+ tasks,
556
561
shutdown_trigger,
557
562
shutdown_listener,
558
563
}
@@ -644,7 +649,19 @@ impl Simulation {
644
649
Ok ( ( ) )
645
650
}
646
651
652
+ /// run until the simulation completes or we hit an error.
653
+ /// Note that it will wait for the tasks in self.tasks to complete
654
+ /// before returning.
647
655
pub async fn run ( & self ) -> Result < ( ) , SimulationError > {
656
+ self . internal_run ( ) . await ?;
657
+ // Close our TaskTracker and wait for any background tasks
658
+ // spawned during internal_run to complete.
659
+ self . tasks . close ( ) ;
660
+ self . tasks . wait ( ) . await ;
661
+ Ok ( ( ) )
662
+ }
663
+
664
+ async fn internal_run ( & self ) -> Result < ( ) , SimulationError > {
648
665
if let Some ( total_time) = self . cfg . total_time {
649
666
log:: info!( "Running the simulation for {}s." , total_time. as_secs( ) ) ;
650
667
} else {
@@ -659,7 +676,6 @@ impl Simulation {
659
676
self . activity. len( ) ,
660
677
self . nodes. len( )
661
678
) ;
662
- let mut tasks = JoinSet :: new ( ) ;
663
679
664
680
// Before we start the simulation up, start tasks that will be responsible for gathering simulation data.
665
681
// The event channels are shared across our functionality:
@@ -668,21 +684,15 @@ impl Simulation {
668
684
// - Event Receiver: used by data reporting to receive events that have been simulated that need to be
669
685
// tracked and recorded.
670
686
let ( event_sender, event_receiver) = channel ( 1 ) ;
671
- self . run_data_collection ( event_receiver, & mut tasks) ;
687
+ self . run_data_collection ( event_receiver, & self . tasks ) ;
672
688
673
689
// Get an execution kit per activity that we need to generate and spin up consumers for each source node.
674
690
let activities = match self . activity_executors ( ) . await {
675
691
Ok ( a) => a,
676
692
Err ( e) => {
677
693
// If we encounter an error while setting up the activity_executors,
678
- // we need to shutdown and wait for tasks to finish. We have started background tasks in the
679
- // run_data_collection function, so we should shut those down before returning.
694
+ // we need to shutdown and return.
680
695
self . shutdown ( ) ;
681
- while let Some ( res) = tasks. join_next ( ) . await {
682
- if let Err ( e) = res {
683
- log:: error!( "Task exited with error: {e}." ) ;
684
- }
685
- }
686
696
return Err ( e) ;
687
697
} ,
688
698
} ;
@@ -692,40 +702,30 @@ impl Simulation {
692
702
. map ( |generator| generator. source_info . pubkey )
693
703
. collect ( ) ,
694
704
event_sender. clone ( ) ,
695
- & mut tasks,
705
+ & self . tasks ,
696
706
) ;
697
707
698
708
// Next, we'll spin up our actual producers that will be responsible for triggering the configured activity.
699
- // The producers will use their own JoinSet so that the simulation can be shutdown if they all finish.
700
- let mut producer_tasks = JoinSet :: new ( ) ;
709
+ // The producers will use their own TaskTracker so that the simulation can be shutdown if they all finish.
710
+ let producer_tasks = TaskTracker :: new ( ) ;
701
711
match self
702
- . dispatch_producers ( activities, consumer_channels, & mut producer_tasks)
712
+ . dispatch_producers ( activities, consumer_channels, & producer_tasks)
703
713
. await
704
714
{
705
715
Ok ( _) => { } ,
706
716
Err ( e) => {
707
- // If we encounter an error in dispatch_producers, we need to shutdown and wait for tasks to finish.
708
- // We have started background tasks in the run_data_collection function,
709
- // so we should shut those down before returning.
717
+ // If we encounter an error in dispatch_producers, we need to shutdown and return.
710
718
self . shutdown ( ) ;
711
- while let Some ( res) = tasks. join_next ( ) . await {
712
- if let Err ( e) = res {
713
- log:: error!( "Task exited with error: {e}." ) ;
714
- }
715
- }
716
719
return Err ( e) ;
717
720
} ,
718
721
}
719
722
720
723
// Start a task that waits for the producers to finish.
721
724
// If all producers finish, then there is nothing left to do and the simulation can be shutdown.
722
725
let producer_trigger = self . shutdown_trigger . clone ( ) ;
723
- tasks. spawn ( async move {
724
- while let Some ( res) = producer_tasks. join_next ( ) . await {
725
- if let Err ( e) = res {
726
- log:: error!( "Producer exited with error: {e}." ) ;
727
- }
728
- }
726
+ self . tasks . spawn ( async move {
727
+ producer_tasks. close ( ) ;
728
+ producer_tasks. wait ( ) . await ;
729
729
log:: info!( "All producers finished. Shutting down." ) ;
730
730
producer_trigger. trigger ( )
731
731
} ) ;
@@ -735,7 +735,7 @@ impl Simulation {
735
735
let t = self . shutdown_trigger . clone ( ) ;
736
736
let l = self . shutdown_listener . clone ( ) ;
737
737
738
- tasks. spawn ( async move {
738
+ self . tasks . spawn ( async move {
739
739
if time:: timeout ( total_time, l) . await . is_err ( ) {
740
740
log:: info!(
741
741
"Simulation run for {}s. Shutting down." ,
@@ -746,18 +746,7 @@ impl Simulation {
746
746
} ) ;
747
747
}
748
748
749
- // We always want to wait for all threads to exit, so we wait for all of them to exit and track any errors
750
- // that surface. It's okay if there are multiple and one is overwritten, we just want to know whether we
751
- // exited with an error or not.
752
- let mut success = true ;
753
- while let Some ( res) = tasks. join_next ( ) . await {
754
- if let Err ( e) = res {
755
- log:: error!( "Task exited with error: {e}." ) ;
756
- success = false ;
757
- }
758
- }
759
-
760
- success. then_some ( ( ) ) . ok_or ( SimulationError :: TaskError )
749
+ Ok ( ( ) )
761
750
}
762
751
763
752
pub fn shutdown ( & self ) {
@@ -777,7 +766,7 @@ impl Simulation {
777
766
fn run_data_collection (
778
767
& self ,
779
768
output_receiver : Receiver < SimulationOutput > ,
780
- tasks : & mut JoinSet < ( ) > ,
769
+ tasks : & TaskTracker ,
781
770
) {
782
771
let listener = self . shutdown_listener . clone ( ) ;
783
772
let shutdown = self . shutdown_trigger . clone ( ) ;
@@ -790,11 +779,17 @@ impl Simulation {
790
779
// psr: produce simulation results
791
780
let psr_listener = listener. clone ( ) ;
792
781
let psr_shutdown = shutdown. clone ( ) ;
782
+ let psr_tasks = tasks. clone ( ) ;
793
783
tasks. spawn ( async move {
794
784
log:: debug!( "Starting simulation results producer." ) ;
795
- if let Err ( e) =
796
- produce_simulation_results ( nodes, output_receiver, results_sender, psr_listener)
797
- . await
785
+ if let Err ( e) = produce_simulation_results (
786
+ nodes,
787
+ output_receiver,
788
+ results_sender,
789
+ psr_listener,
790
+ & psr_tasks,
791
+ )
792
+ . await
798
793
{
799
794
psr_shutdown. trigger ( ) ;
800
795
log:: error!( "Produce simulation results exited with error: {e:?}." ) ;
@@ -939,7 +934,7 @@ impl Simulation {
939
934
& self ,
940
935
consuming_nodes : HashSet < PublicKey > ,
941
936
output_sender : Sender < SimulationOutput > ,
942
- tasks : & mut JoinSet < ( ) > ,
937
+ tasks : & TaskTracker ,
943
938
) -> HashMap < PublicKey , Sender < SimulationEvent > > {
944
939
let mut channels = HashMap :: new ( ) ;
945
940
@@ -984,7 +979,7 @@ impl Simulation {
984
979
& self ,
985
980
executors : Vec < ExecutorKit > ,
986
981
producer_channels : HashMap < PublicKey , Sender < SimulationEvent > > ,
987
- tasks : & mut JoinSet < ( ) > ,
982
+ tasks : & TaskTracker ,
988
983
) -> Result < ( ) , SimulationError > {
989
984
for executor in executors {
990
985
let sender = producer_channels. get ( & executor. source_info . pubkey ) . ok_or (
@@ -1350,9 +1345,8 @@ async fn produce_simulation_results(
1350
1345
mut output_receiver : Receiver < SimulationOutput > ,
1351
1346
results : Sender < ( Payment , PaymentResult ) > ,
1352
1347
listener : Listener ,
1348
+ tasks : & TaskTracker ,
1353
1349
) -> Result < ( ) , SimulationError > {
1354
- let mut set = tokio:: task:: JoinSet :: new ( ) ;
1355
-
1356
1350
let result = loop {
1357
1351
tokio:: select! {
1358
1352
biased;
@@ -1365,7 +1359,7 @@ async fn produce_simulation_results(
1365
1359
match simulation_output{
1366
1360
SimulationOutput :: SendPaymentSuccess ( payment) => {
1367
1361
if let Some ( source_node) = nodes. get( & payment. source) {
1368
- set . spawn( track_payment_result(
1362
+ tasks . spawn( track_payment_result(
1369
1363
source_node. clone( ) , results. clone( ) , payment, listener. clone( )
1370
1364
) ) ;
1371
1365
} else {
@@ -1396,11 +1390,6 @@ async fn produce_simulation_results(
1396
1390
} ;
1397
1391
1398
1392
log:: debug!( "Simulation results producer exiting." ) ;
1399
- while let Some ( res) = set. join_next ( ) . await {
1400
- if let Err ( e) = res {
1401
- log:: error!( "Simulation results producer task exited with error: {e}." ) ;
1402
- }
1403
- }
1404
1393
1405
1394
result
1406
1395
}
@@ -1476,6 +1465,7 @@ mod tests {
1476
1465
use std:: sync:: Arc ;
1477
1466
use std:: time:: Duration ;
1478
1467
use tokio:: sync:: Mutex ;
1468
+ use tokio_util:: task:: TaskTracker ;
1479
1469
1480
1470
#[ test]
1481
1471
fn create_seeded_mut_rng ( ) {
@@ -1619,6 +1609,7 @@ mod tests {
1619
1609
crate :: SimulationCfg :: new ( Some ( 0 ) , 0 , 0.0 , None , None ) ,
1620
1610
clients,
1621
1611
vec ! [ activity_definition] ,
1612
+ TaskTracker :: new ( ) ,
1622
1613
) ;
1623
1614
assert ! ( simulation. validate_activity( ) . await . is_err( ) ) ;
1624
1615
}
0 commit comments