Skip to content

Commit

Permalink
feat: rewrite tracing (#267)
Browse files Browse the repository at this point in the history
Optionally store metadata about the rewrites applied to a circuit.

If the `rewrite-tracing` feature is disabled, this has no effect on the
runtime.

If the feature is enabled, circuits must opt-in to the tracing by
calling `RewriteTracer::enable_rewrite_tracing`.

Closes #216
  • Loading branch information
aborgna-q authored Dec 14, 2023
1 parent 8b50e90 commit f04e815
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 22 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ Version 2 of the TKET compiler.
## Features

- `pyo3`
This optional feature enables some python bindings via pyo3. See the `tket2-py` folder for more.
Enables some python bindings via pyo3. See the `tket2-py` folder for more.

- `portmatching`
This enables pattern matching using the `portmatching` crate.
Enables pattern matching using the `portmatching` crate.

- `rewrite-tracing`
Adds opt-in tracking of the rewrites applied to a circuit.

## Developing TKET2

Expand Down
2 changes: 1 addition & 1 deletion badger-optimiser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license-file = { workspace = true }
[dependencies]
clap = { version = "4.4.2", features = ["derive"] }
serde_json = "1.0"
tket2 = { workspace = true, features = ["portmatching"] }
tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] }
quantinuum-hugr = { workspace = true }
itertools = { workspace = true }
tket-json-rs = { workspace = true }
Expand Down
12 changes: 11 additions & 1 deletion badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file};
use tket2::optimiser::badger::log::BadgerLogger;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser};
use tket2::rewrite::trace::RewriteTracer;

#[cfg(feature = "peak_alloc")]
#[global_allocator]
Expand Down Expand Up @@ -104,6 +105,12 @@ struct CmdLineArgs {
help = "The priority queue size. Defaults to 100."
)]
queue_size: usize,
/// Trace each rewrite applied to the circuit.
#[arg(
long = "rewrite-tracing",
help = "Trace each rewrite applied to the circuit. Prints statistics for the best circuit at the end of the optimisation."
)]
rewrite_tracing: bool,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -129,7 +136,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let badger_logger = BadgerLogger::new(circ_candidates_csv);

let circ = load_tk1_json_file(input_path)?;
let mut circ = load_tk1_json_file(input_path)?;
if opts.rewrite_tracing {
circ.enable_rewrite_tracing();
}

println!("Loading optimiser...");
let Ok(optimiser) = load_optimiser(ecc_path) else {
Expand Down
10 changes: 9 additions & 1 deletion tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ name = "tket2"
path = "src/lib.rs"

[features]
# Enables some python bindings
pyo3 = ["dep:pyo3"]

# Enables search and replace optimisation passes using the `portmatching` crate.
portmatching = ["dep:portmatching", "dep:rmp-serde"]

# Stores a trace of the applied rewrites
rewrite-tracing = []

default = []

[dependencies]
lazy_static = "1.4.0"
cgmath = "0.18.0"
Expand All @@ -44,7 +52,7 @@ strum_macros = "0.25.2"
strum = "0.25.0"
fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"
delegate = "0.11.0"
csv = { version = "1.2.2" }
chrono = { version = "0.4.30" }
bytemuck = "1.14.0"
Expand Down
19 changes: 13 additions & 6 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::badger::worker::BadgerWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::trace::RewriteTracer;
use crate::rewrite::Rewriter;
use crate::Circuit;

Expand Down Expand Up @@ -158,7 +159,8 @@ where

let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(circ);
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let hash = circ.circuit_hash().unwrap();
Expand All @@ -181,7 +183,8 @@ where
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost.clone();
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
last_best_time = Instant::now();
}
circ_cnt += 1;
Expand Down Expand Up @@ -297,7 +300,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
if let Some(t) = opt.progress_timeout {
progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t));
}
Expand Down Expand Up @@ -337,7 +341,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
}
}
PriorityChannelLog::CircuitCount {
Expand Down Expand Up @@ -381,7 +386,8 @@ where
let mut chunks =
CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op));

logger.log_best(circ_cost.clone());
let num_rewrites = circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(circ_cost.clone(), num_rewrites);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
Expand Down Expand Up @@ -420,7 +426,8 @@ where
let best_circ = chunks.reassemble()?;
let best_circ_cost = self.cost(&best_circ);
if best_circ_cost.clone() < circ_cost {
logger.log_best(best_circ_cost.clone());
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(best_circ_cost.clone(), num_rewrites);
}

logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false);
Expand Down
13 changes: 11 additions & 2 deletions tket2/src/optimiser/badger/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,17 @@ impl<'w> BadgerLogger<'w> {

/// Log a new best candidate
#[inline]
pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) {
self.log(format!("new best of size {:?}", best_cost));
pub fn log_best<C: Debug + serde::Serialize>(
&mut self,
best_cost: C,
num_rewrites: Option<usize>,
) {
match num_rewrites {
Some(rs) => self.log(format!(
"new best of size {best_cost:?} after {rs} rewrites"
)),
None => self.log(format!("new best of size {:?}", best_cost)),
}
if let Some(csv_writer) = self.circ_candidates_csv.as_mut() {
csv_writer.serialize(BestCircSer::new(best_cost)).unwrap();
csv_writer.flush().unwrap();
Expand Down
20 changes: 14 additions & 6 deletions tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#[cfg(feature = "portmatching")]
pub mod ecc_rewriter;
pub mod strategy;
pub mod trace;

use bytemuck::TransparentWrapper;
#[cfg(feature = "portmatching")]
pub use ecc_rewriter::ECCRewriter;

use delegate::delegate;
use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph};
use hugr::Node;
Expand All @@ -19,6 +19,8 @@ use hugr::{

use crate::circuit::Circuit;

use self::trace::RewriteTracer;

/// A subcircuit of a circuit.
#[derive(Debug, Clone, From, Into)]
#[repr(transparent)]
Expand Down Expand Up @@ -107,11 +109,17 @@ impl CircuitRewrite {
self.0.invalidation_set()
}

delegate! {
to self.0 {
/// Apply the rewrite rule to a circuit.
pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError>;
}
/// Apply the rewrite rule to a circuit.
#[inline]
pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
circ.add_rewrite_trace(&self);
self.0.apply(circ)
}

/// Apply the rewrite rule to a circuit, without registering it in the rewrite trace.
#[inline]
pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
self.0.apply(circ)
}
}

Expand Down
42 changes: 39 additions & 3 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use itertools::Itertools;
use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, MajorMinorCost};
use crate::Circuit;

use super::trace::{RewriteTrace, RewriteTracer};
use super::CircuitRewrite;

/// Rewriting strategies for circuit optimisation.
Expand Down Expand Up @@ -219,6 +220,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
let mut composed_rewrite_count = 0;
for (rewrite, delta) in &rewrites[i..] {
if !changed_nodes.is_empty()
&& rewrite
Expand All @@ -230,11 +232,15 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
changed_nodes.extend(rewrite.invalidation_set());
cost_delta += delta.clone();

composed_rewrite_count += 1;

rewrite
.clone()
.apply(&mut curr_circ)
.apply_notrace(&mut curr_circ)
.expect("Could not perform rewrite in exhaustive greedy strategy");
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
Expand Down Expand Up @@ -462,6 +468,7 @@ mod tests {
use hugr::{Hugr, Node};
use itertools::Itertools;

use crate::rewrite::trace::REWRITE_TRACING_ENABLED;
use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
Expand Down Expand Up @@ -494,9 +501,16 @@ mod tests {

#[test]
fn test_greedy_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();

assert_eq!(circ.rewrite_trace(), None);
circ.enable_rewrite_tracing();
match REWRITE_TRACING_ENABLED {
true => assert_eq!(circ.rewrite_trace(), Some(vec![])),
false => assert_eq!(circ.rewrite_trace(), None),
}

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
Expand All @@ -508,12 +522,17 @@ mod tests {
let rewritten = strategy.apply_rewrites(rws, &circ);
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
}
}

#[test]
fn test_exhaustive_default_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
circ.enable_rewrite_tracing();

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
Expand All @@ -527,6 +546,23 @@ mod tests {
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
}

#[test]
Expand Down
Loading

0 comments on commit f04e815

Please sign in to comment.