From 370ff6ddd3453506ec4ebc993ef561c6dc80da0f Mon Sep 17 00:00:00 2001 From: Eytan Singher Date: Mon, 15 Jul 2024 22:55:18 +0300 Subject: [PATCH] Suppoting thread_limit in parallel backoff scheduler --- Cargo.toml | 1 + src/run.rs | 35 ++++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0e0f73e3..bc207e6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde_json = { version = "1.0.81", optional = true } saturating = "0.1.0" rayon = { version = "1.10.0", optional = true } crossbeam = { version = "0.8.4", optional = true, features = ["crossbeam-channel"] } +num_cpus = "1.16.0" [dev-dependencies] ordered-float = "3.0.0" diff --git a/src/run.rs b/src/run.rs index a7762fee..d403c244 100644 --- a/src/run.rs +++ b/src/run.rs @@ -996,11 +996,12 @@ where L: Language + Send + Sync, #[cfg(feature = "parallel")] pub struct ParallelBackoffScheduler { scheduler: BackoffScheduler, + thread_limit: usize, } impl Default for ParallelBackoffScheduler { fn default() -> Self { - ParallelBackoffScheduler {scheduler: Default::default()} + ParallelBackoffScheduler {scheduler: Default::default(), thread_limit: num_cpus::get()} } } @@ -1011,6 +1012,11 @@ impl ParallelBackoffScheduler { self.scheduler.rule_stats(*name); } } + + fn with_thread_limit(mut self, thread_limit: usize) -> Self { + self.thread_limit = thread_limit; + self + } } impl RewriteScheduler for ParallelBackoffScheduler where @@ -1041,18 +1047,21 @@ impl RewriteScheduler for ParallelBackoffScheduler where let stats = self.scheduler.stats.remove(&rw.name).unwrap(); (*rw, stats) }).collect::>(); - let res = with_stats.par_iter_mut().enumerate().try_for_each(|(i, (rw, stats))| { - debug!("Searching rw {}", rw.name); - let results = BackoffScheduler::search_with_stats(iteration, egraph, rw, stats); - if results.len() > 0 { - channel.0.send((i, results)).expect("Channel should be big enough for all messages"); - } - let elapsed = start_time.elapsed(); - if elapsed > time_limit { - Err(StopReason::TimeLimit(elapsed.as_secs_f64())) - } else { - Ok(()) - } + let pool = rayon::ThreadPoolBuilder::new().num_threads(self.thread_limit).build().unwrap(); + let res = pool.install(|| { + with_stats.par_iter_mut().enumerate().try_for_each(|(i, (rw, stats))| { + debug!("Searching rw {}", rw.name); + let results = BackoffScheduler::search_with_stats(iteration, egraph, rw, stats); + if results.len() > 0 { + channel.0.send((i, results)).expect("Channel should be big enough for all messages"); + } + let elapsed = start_time.elapsed(); + if elapsed > time_limit { + Err(StopReason::TimeLimit(elapsed.as_secs_f64())) + } else { + Ok(()) + } + }) }); drop(channel.0); debug!("Finished searching rewrites in parallel. Collecting results");