From c710541df9b8f8d7feff0c084cdf88d489fbc151 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 20 Aug 2024 15:05:35 +0200 Subject: [PATCH] feat: Fast hashing for StaticSizeCircuit --- tket2/src/circuit.rs | 2 +- tket2/src/static_circ.rs | 3 +- tket2/src/static_circ/hash.rs | 168 ++++++++++++++++++++++++++ tket2/src/static_circ/rewrite.rs | 196 +++++++++++++++++++++++++++++++ 4 files changed, 367 insertions(+), 2 deletions(-) create mode 100644 tket2/src/static_circ/hash.rs create mode 100644 tket2/src/static_circ/rewrite.rs diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 117f90b8..f126db38 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -9,7 +9,7 @@ pub mod units; use std::iter::Sum; pub use command::{Command, CommandIterator}; -pub use hash::CircuitHash; +pub use hash::{CircuitHash, HashError}; use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}; use itertools::Either::{Left, Right}; diff --git a/tket2/src/static_circ.rs b/tket2/src/static_circ.rs index 0bcc11ec..1048139d 100644 --- a/tket2/src/static_circ.rs +++ b/tket2/src/static_circ.rs @@ -1,6 +1,8 @@ //! A 2d array-like representation of simple quantum circuits. +mod hash; mod match_op; +mod rewrite; use std::{collections::BTreeMap, fmt, rc::Rc}; @@ -36,7 +38,6 @@ pub struct OpLocation { impl StaticSizeCircuit { /// Returns the number of qubits in the circuit. - #[allow(unused)] pub fn qubit_count(&self) -> usize { self.qubit_ops.len() } diff --git a/tket2/src/static_circ/hash.rs b/tket2/src/static_circ/hash.rs new file mode 100644 index 00000000..622525b7 --- /dev/null +++ b/tket2/src/static_circ/hash.rs @@ -0,0 +1,168 @@ +use std::{ + hash::{Hash, Hasher}, + ops::Range, +}; + +use cgmath::num_traits::{WrappingAdd, WrappingShl}; + +use crate::circuit::{CircuitHash, HashError}; + +use super::{ + rewrite::{OpInterval, StaticRewrite}, + MatchOp, StaticQubitIndex, StaticSizeCircuit, +}; + +pub struct UpdatableHash { + cum_hash: Vec>, +} + +impl UpdatableHash { + pub fn with_static(circuit: &StaticSizeCircuit) -> Self { + let num_qubits = circuit.qubit_count(); + let mut cum_hash = Vec::with_capacity(num_qubits); + + for row in circuit.qubit_ops.iter() { + let mut prev_hash = 0; + let mut row_hash = Vec::with_capacity(row.len()); + for op in row.iter() { + let hash = Self::hash_op(op); + let combined_hash = prev_hash.wrapping_shl(5).wrapping_add(&hash); + row_hash.push(combined_hash); + prev_hash = combined_hash; + } + cum_hash.push(row_hash); + } + + Self { cum_hash } + } + + /// Compute the hash of the circuit that results from applying the given rewrite. + pub fn hash_rewrite(&self, circuit: &StaticSizeCircuit, rewrite: &StaticRewrite) -> u64 + where + F: Fn(StaticQubitIndex) -> StaticQubitIndex, + { + let new_hash = Self::with_static(&rewrite.replacement); + hash_iter((0..circuit.qubit_count()).map(|i| { + if let Some(interval) = rewrite.subcircuit.op_indices.get(&StaticQubitIndex(i)) { + splice(&self.cum_hash[i], interval, &new_hash.cum_hash[i]) + } else { + *self.cum_hash[i].last().unwrap() + } + })) + } + + fn hash_op(op: &MatchOp) -> u64 { + let mut hasher = fxhash::FxHasher::default(); + op.hash(&mut hasher); + hasher.finish() + } +} + +/// Compute the hash that results from replacing the ops in the range [start, end) +/// with the new ops (given by `new_cum_hashes`). +fn splice(cum_hashes: &[u64], interval: &OpInterval, new_cum_hashes: &[u64]) -> u64 { + let Range { start, end } = interval.0; + let mut hash = 0; + if start > 0 { + hash = hash.wrapping_add(&cum_hashes[start - 1]); + } + if !new_cum_hashes.is_empty() { + hash = hash.wrapping_shl(5 * (new_cum_hashes.len() as u32)); + hash = hash.wrapping_add(new_cum_hashes[new_cum_hashes.len() - 1]); + } + if end < cum_hashes.len() { + hash = hash.wrapping_shl(5 * (cum_hashes.len() - end) as u32); + hash = hash.wrapping_add(hash_delta(cum_hashes, end..cum_hashes.len())); + } + hash +} + +/// The hash "contribution" that comes from within the range [start, end). +fn hash_delta(cum_hashes: &[u64], Range { start, end }: Range) -> u64 { + if start >= end { + return 0; + } + let end_hash = if end > 0 { cum_hashes[end - 1] } else { 0 }; + let start_hash = if start > 0 { cum_hashes[start - 1] } else { 0 }; + let start_hash_shifted = start_hash.wrapping_shl(5 * (end - start) as u32); + end_hash.wrapping_sub(start_hash_shifted) +} + +fn hash_iter(iter: impl Iterator) -> u64 { + let mut hasher = fxhash::FxHasher::default(); + for item in iter { + item.hash(&mut hasher); + } + hasher.finish() +} +impl CircuitHash for StaticSizeCircuit { + fn circuit_hash(&self) -> Result { + let hash_updater = UpdatableHash::with_static(self); + Ok(hash_iter( + hash_updater + .cum_hash + .iter() + .map(|row| row.last().unwrap_or(&0)), + )) + } +} + +#[cfg(test)] +mod tests { + use crate::{static_circ::rewrite::StaticSubcircuit, utils::build_simple_circuit, Tk2Op}; + + use super::*; + + #[test] + fn test_rewrite_circuit() { + // Create initial circuit + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [1])?; + Ok(()) + }) + .unwrap(); + + let initial_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Create subcircuit to be replaced + let subcircuit = StaticSubcircuit { + op_indices: vec![ + (StaticQubitIndex(0), OpInterval(0..2)), + (StaticQubitIndex(1), OpInterval(0..1)), + ] + .into_iter() + .collect(), + }; + + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [0])?; + Ok(()) + }) + .unwrap(); + + let replacement_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Define qubit mapping + let qubit_map = |qb: StaticQubitIndex| qb; + + let rewrite = StaticRewrite { + subcircuit, + replacement: replacement_circuit, + qubit_map, + }; + + // Perform rewrite + let rewritten_circuit = initial_circuit.apply_rewrite(&rewrite).unwrap(); + + // Assert the hash of the rewritten circuit matches the spliced hash + let hash_updater = UpdatableHash::with_static(&initial_circuit); + let rewritten_hash = hash_updater.hash_rewrite(&initial_circuit, &rewrite); + let expected_hash = rewritten_circuit.circuit_hash().unwrap(); + assert_eq!(rewritten_hash, expected_hash); + } +} diff --git a/tket2/src/static_circ/rewrite.rs b/tket2/src/static_circ/rewrite.rs new file mode 100644 index 00000000..0bf39b61 --- /dev/null +++ b/tket2/src/static_circ/rewrite.rs @@ -0,0 +1,196 @@ +use std::{collections::BTreeMap, ops::Range, rc::Rc}; + +use derive_more::{From, Into}; +use thiserror::Error; + +use super::{OpLocation, StaticQubitIndex, StaticSizeCircuit}; + +/// An interval of operation indices. +#[derive(Debug, Clone, PartialEq, Eq, From, Into)] +pub(super) struct OpInterval(pub Range); + +/// A subcircuit of a static circuit. +#[derive(Debug, Clone, PartialEq, Eq, From, Into)] +pub struct StaticSubcircuit { + /// Maps qubit indices to the intervals of operations on that qubit. + pub(super) op_indices: BTreeMap, +} + +impl StaticSubcircuit { + /// The subcircuit before `self`. + fn before(&self, circuit: &StaticSizeCircuit) -> Self { + let mut op_indices = BTreeMap::new(); + for qb in circuit.qubits_iter() { + if let Some(interval) = self.op_indices.get(&qb) { + let start = interval.0.start; + op_indices.insert(qb, OpInterval(0..start)); + } else { + op_indices.insert(qb, OpInterval(0..circuit.qubit_ops(qb).len())); + } + } + StaticSubcircuit { op_indices } + } + + /// The subcircuit after `self`. + fn after(&self, circuit: &StaticSizeCircuit) -> Self { + let op_indices = self + .op_indices + .iter() + .map(|(&qb, interval)| (qb, OpInterval(interval.0.end..circuit.qubit_ops(qb).len()))) + .collect(); + StaticSubcircuit { op_indices } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)] +#[error("invalid subcircuit")] +pub struct InvalidSubcircuitError; + +impl StaticSizeCircuit { + fn subcircuit(&self, subcircuit: &StaticSubcircuit) -> Result { + let Self { + mut qubit_ops, + mut op_locations, + } = self.clone(); + for (qb, interval) in subcircuit.op_indices.iter() { + for op in qubit_ops[qb.0].drain(interval.0.end..) { + op_locations.remove(&Rc::as_ptr(&op)); + } + for op in qubit_ops[qb.0].drain(..interval.0.start) { + op_locations.remove(&Rc::as_ptr(&op)); + } + } + let ret = Self { + qubit_ops, + op_locations, + }; + ret.check_valid()?; + Ok(ret) + } + + fn append( + &mut self, + other: &StaticSizeCircuit, + qubit_map: impl Fn(StaticQubitIndex) -> StaticQubitIndex, + ) { + for (qb, ops) in other.qubit_ops.iter().enumerate() { + let new_qb = qubit_map(StaticQubitIndex(qb)); + for op in ops.iter() { + let op_idx = self.qubit_ops[new_qb.0].len(); + self.qubit_ops[new_qb.0].push(op.clone()); + self.op_locations + .entry(Rc::as_ptr(op)) + .or_default() + .push(OpLocation { + qubit: new_qb, + op_idx, + }); + } + } + } + + fn check_valid(&self) -> Result<(), InvalidSubcircuitError> { + for op in self.all_ops_iter() { + if self.op_locations.get(&Rc::as_ptr(op)).is_none() { + return Err(InvalidSubcircuitError); + } + } + Ok(()) + } +} + +/// A rewrite that applies on a static circuit. +pub struct StaticRewrite { + /// The subcircuit to be replaced. + pub subcircuit: StaticSubcircuit, + /// The replacement circuit. + pub replacement: StaticSizeCircuit, + /// The qubit map. + pub qubit_map: F, +} + +impl StaticSizeCircuit { + /// Rewrite a subcircuit in the circuit with a replacement circuit. + pub fn apply_rewrite( + &self, + rewrite: &StaticRewrite, + ) -> Result + where + F: Fn(StaticQubitIndex) -> StaticQubitIndex, + { + let mut new_circ = self.subcircuit(&rewrite.subcircuit.before(self))?; + new_circ.append(&rewrite.replacement, &rewrite.qubit_map); + let after = self.subcircuit(&rewrite.subcircuit.after(self))?; + new_circ.append(&after, |qb| qb); + Ok(new_circ) + } +} + +#[cfg(test)] +mod tests { + use crate::{utils::build_simple_circuit, Tk2Op}; + + use super::*; + + #[test] + fn test_rewrite_circuit() { + // Create initial circuit + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [1])?; + Ok(()) + }) + .unwrap(); + + let initial_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Create subcircuit to be replaced + let subcircuit = StaticSubcircuit { + op_indices: vec![ + (StaticQubitIndex(0), OpInterval(0..2)), + (StaticQubitIndex(1), OpInterval(0..1)), + ] + .into_iter() + .collect(), + }; + + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [0])?; + Ok(()) + }) + .unwrap(); + + let replacement_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Define qubit mapping + let qubit_map = |qb: StaticQubitIndex| qb; + + let rewrite = StaticRewrite { + subcircuit, + replacement: replacement_circuit, + qubit_map, + }; + + // Perform rewrite + let rewritten_circuit = initial_circuit.apply_rewrite(&rewrite).unwrap(); + + // Expected circuit after rewrite + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [1])?; + Ok(()) + }) + .unwrap(); + let expected_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Assert the rewritten circuit matches the expected circuit + assert_eq!(rewritten_circuit, expected_circuit); + } +}