Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Check for rewrite composition in badger #255

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test_files/nam_6_3.rwr
Binary file not shown.
34 changes: 22 additions & 12 deletions tket2/src/circuit/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::ops::{LeafOp, OpName, OpType};
use hugr::{HugrView, Node};
use petgraph::visit::{self as pg, Walker};
use thiserror::Error;

use super::Circuit;

Expand All @@ -22,29 +23,30 @@ pub trait CircuitHash<'circ>: HugrView {
///
/// Adapted from Quartz (Apache 2.0)
/// <https://github.com/quantum-compiler/quartz/blob/2e13eb7ffb3c5c5fe96cf5b4246f4fd7512e111e/src/quartz/tasograph/tasograph.cpp#L410>
fn circuit_hash(&'circ self) -> u64;
fn circuit_hash(&'circ self) -> Result<u64, HashError>;
}

impl<'circ, T> CircuitHash<'circ> for T
where
T: HugrView,
{
fn circuit_hash(&'circ self) -> u64 {
fn circuit_hash(&'circ self) -> Result<u64, HashError> {
let mut node_hashes = HashState::default();

for node in pg::Topo::new(&self.as_petgraph())
.iter(&self.as_petgraph())
.filter(|&n| n != self.root())
{
let hash = hash_node(self, node, &mut node_hashes);
let hash = hash_node(self, node, &mut node_hashes)?;
if node_hashes.set_hash(node, hash).is_some() {
panic!("Hash already set for node {node}");
}
}

// If the output node has no hash, the topological sort failed due to a cycle.
node_hashes
.node_hash(self.output())
.expect("Output hash has not been set")
.ok_or(HashError::CyclicCircuit)
}
}

Expand Down Expand Up @@ -95,14 +97,14 @@ fn hashable_op(op: &OpType) -> impl Hash {
/// # Panics
/// - If the command is a container node, or if it is a parametric CustomOp.
/// - If the hash of any of its predecessors has not been set.
fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 {
fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> Result<u64, HashError> {
let op = circ.get_optype(node);
let mut hasher = FxHasher64::default();

// Hash the node children
if circ.children(node).count() > 0 {
let container: SiblingGraph = SiblingGraph::try_new(circ, node).unwrap();
container.circuit_hash().hash(&mut hasher);
container.circuit_hash()?.hash(&mut hasher);
}

// Hash the node operation
Expand All @@ -121,7 +123,15 @@ fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 {
.fold(0, |total, hash| hash ^ total);
input_hash.hash(&mut hasher);
}
hasher.finish()
Ok(hasher.finish())
}

/// Errors that can occur while hashing a hugr.
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum HashError {
/// The circuit contains a cycle.
#[error("The circuit contains a cycle.")]
CyclicCircuit,
}

#[cfg(test)]
Expand All @@ -144,7 +154,7 @@ mod test {
Ok(())
})
.unwrap();
let hash1 = circ1.circuit_hash();
let hash1 = circ1.circuit_hash().unwrap();

// A circuit built in a different order should have the same hash
let circ2 = build_simple_circuit(2, |circ| {
Expand All @@ -154,7 +164,7 @@ mod test {
Ok(())
})
.unwrap();
let hash2 = circ2.circuit_hash();
let hash2 = circ2.circuit_hash().unwrap();

assert_eq!(hash1, hash2);

Expand All @@ -166,7 +176,7 @@ mod test {
Ok(())
})
.unwrap();
let hash3 = circ3.circuit_hash();
let hash3 = circ3.circuit_hash().unwrap();

assert_ne!(hash1, hash3);
}
Expand All @@ -176,7 +186,7 @@ mod test {
let c_str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#;
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Hugr = ser.decode().unwrap();
circ.circuit_hash();
circ.circuit_hash().unwrap();
}

#[test]
Expand All @@ -188,7 +198,7 @@ mod test {
for c_str in [c_str1, c_str2] {
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Hugr = ser.decode().unwrap();
all_hashes.push(circ.circuit_hash());
all_hashes.push(circ.circuit_hash().unwrap());
}
assert_ne!(all_hashes[0], all_hashes[1]);
}
Expand Down
69 changes: 65 additions & 4 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,19 @@ where
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let hash = circ.circuit_hash().unwrap();
let mut seen_hashes = FxHashSet::default();
seen_hashes.insert(circ.circuit_hash());
seen_hashes.insert(hash);

// The priority queue of circuits to be processed (this should not get big)
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let cost = (cost_fn)(circ);

let mut pq = HugrPQ::new(cost_fn, queue_size);
pq.push(circ.clone());
pq.push_unchecked(circ.clone(), hash, cost);

let mut circ_cnt = 0;
let mut timeout_flag = false;
Expand All @@ -169,7 +172,13 @@ where
continue;
}

let new_circ_hash = new_circ.circuit_hash();
let Ok(new_circ_hash) = new_circ.circuit_hash() else {
// The composed rewrites produced a loop.
//
// See [https://github.com/CQCL/tket2/discussions/242]
continue;
};

if !seen_hashes.insert(new_circ_hash) {
// Ignore this circuit: we've already seen it
continue;
Expand Down Expand Up @@ -218,7 +227,7 @@ where
};
let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), queue_size);

let initial_circ_hash = circ.circuit_hash();
let initial_circ_hash = circ.circuit_hash().unwrap();
let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(&best_circ);

Expand Down Expand Up @@ -436,6 +445,7 @@ mod tests {
};
use rstest::{fixture, rstest};

use crate::json::load_tk1_json_str;
use crate::{extension::REGISTRY, Circuit, T2Op};

use super::{BadgerOptimiser, DefaultBadgerOptimiser};
Expand Down Expand Up @@ -466,11 +476,45 @@ mod tests {
h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap()
}

/// This hugr corresponds to the qasm circuit:
///
/// ```skip
/// OPENQASM 2.0;
/// include "qelib1.inc";
///
/// qreg q[5];
/// cx q[4],q[1];
/// cx q[3],q[4];
/// cx q[1],q[2];
/// cx q[4],q[0];
/// u3(0.5*pi,0.0*pi,0.5*pi) q[1];
/// cx q[0],q[2];
/// cx q[3],q[1];
/// cx q[0],q[2];
/// ```
const NON_COMPOSABLE: &str = r#"{"phase":"0.0","commands":[{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[1]],["q",[2]]]},{"op":{"type":"U3","params":["0.5","0","0.5"],"signature":["Q"]},"args":[["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[4]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[0]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[1]]]}],"qubits":[["q",[0]],["q",[1]],["q",[2]],["q",[3]],["q",[4]]],"bits":[],"implicit_permutation":[[["q",[0]],["q",[0]]],[["q",[1]],["q",[1]]],[["q",[2]],["q",[2]]],[["q",[3]],["q",[3]]],[["q",[4]],["q",[4]]]]}"#;

/// A Hugr that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches.
#[fixture]
fn non_composable_rw_hugr() -> Hugr {
load_tk1_json_str(NON_COMPOSABLE).unwrap()
}

/// A badger optimiser using a reduced set of rewrite rules.
#[fixture]
fn badger_opt() -> DefaultBadgerOptimiser {
BadgerOptimiser::default_with_eccs_json_file("../test_files/small_eccs.json").unwrap()
}

/// A badger optimiser using the complete nam_6_3 rewrite set.
///
/// NOTE: This takes a few seconds to load.
/// Use [`badger_opt`] if possible.
#[fixture]
fn badger_opt_full() -> DefaultBadgerOptimiser {
BadgerOptimiser::default_with_rewriter_binary("../test_files/nam_6_3.rwr").unwrap()
}

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) {
let opt_rz = badger_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false, 4);
Expand All @@ -483,4 +527,21 @@ mod tests {
let mut opt_rz = badger_opt.optimise(&rz_rz, Some(0), 2.try_into().unwrap(), false, 4);
opt_rz.update_validate(&REGISTRY).unwrap();
}

#[rstest]
#[ignore = "Loading the ECC set is really slow (~5 seconds)"]
fn non_composable_rewrites(
non_composable_rw_hugr: Hugr,
badger_opt_full: DefaultBadgerOptimiser,
) {
let mut opt = badger_opt_full.optimise(
&non_composable_rw_hugr,
Some(0),
1.try_into().unwrap(),
false,
10,
);
// No rewrites applied.
opt.update_validate(&REGISTRY).unwrap();
}
}
2 changes: 1 addition & 1 deletion tket2/src/optimiser/badger/hugr_pqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl<P: Ord, C> HugrPQ<P, C> {
where
C: Fn(&Hugr) -> P,
{
let hash = hugr.circuit_hash();
let hash = hugr.circuit_hash().unwrap();
let cost = (self.cost_fn)(&hugr);
self.push_unchecked(hugr, hash, cost);
}
Expand Down
8 changes: 7 additions & 1 deletion tket2/src/optimiser/badger/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ where
return None;
}

let hash = c.circuit_hash();
let Ok(hash) = c.circuit_hash() else {
// The composed rewrites were not valid.
//
// See [https://github.com/CQCL/tket2/discussions/242]
return None;
};

Some(Work {
cost: new_cost,
hash,
Expand Down