Skip to content

Commit

Permalink
Suppoting thread_limit in parallel backoff scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
eytans committed Jul 15, 2024
1 parent 591f234 commit 370ff6d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde_json = { version = "1.0.81", optional = true }
saturating = "0.1.0"
rayon = { version = "1.10.0", optional = true }
crossbeam = { version = "0.8.4", optional = true, features = ["crossbeam-channel"] }
num_cpus = "1.16.0"

[dev-dependencies]
ordered-float = "3.0.0"
Expand Down
35 changes: 22 additions & 13 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,11 +996,12 @@ where L: Language + Send + Sync,
#[cfg(feature = "parallel")]
pub struct ParallelBackoffScheduler {
scheduler: BackoffScheduler,
thread_limit: usize,
}

impl Default for ParallelBackoffScheduler {
fn default() -> Self {
ParallelBackoffScheduler {scheduler: Default::default()}
ParallelBackoffScheduler {scheduler: Default::default(), thread_limit: num_cpus::get()}
}
}

Expand All @@ -1011,6 +1012,11 @@ impl ParallelBackoffScheduler {
self.scheduler.rule_stats(*name);
}
}

fn with_thread_limit(mut self, thread_limit: usize) -> Self {
self.thread_limit = thread_limit;
self
}
}

impl<L, N> RewriteScheduler<L, N> for ParallelBackoffScheduler where
Expand Down Expand Up @@ -1041,18 +1047,21 @@ impl<L, N> RewriteScheduler<L, N> for ParallelBackoffScheduler where
let stats = self.scheduler.stats.remove(&rw.name).unwrap();
(*rw, stats)
}).collect::<Vec<_>>();
let res = with_stats.par_iter_mut().enumerate().try_for_each(|(i, (rw, stats))| {
debug!("Searching rw {}", rw.name);
let results = BackoffScheduler::search_with_stats(iteration, egraph, rw, stats);
if results.len() > 0 {
channel.0.send((i, results)).expect("Channel should be big enough for all messages");
}
let elapsed = start_time.elapsed();
if elapsed > time_limit {
Err(StopReason::TimeLimit(elapsed.as_secs_f64()))
} else {
Ok(())
}
let pool = rayon::ThreadPoolBuilder::new().num_threads(self.thread_limit).build().unwrap();
let res = pool.install(|| {
with_stats.par_iter_mut().enumerate().try_for_each(|(i, (rw, stats))| {
debug!("Searching rw {}", rw.name);
let results = BackoffScheduler::search_with_stats(iteration, egraph, rw, stats);
if results.len() > 0 {
channel.0.send((i, results)).expect("Channel should be big enough for all messages");
}
let elapsed = start_time.elapsed();
if elapsed > time_limit {
Err(StopReason::TimeLimit(elapsed.as_secs_f64()))
} else {
Ok(())
}
})
});
drop(channel.0);
debug!("Finished searching rewrites in parallel. Collecting results");
Expand Down

0 comments on commit 370ff6d

Please sign in to comment.