Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fast hashing for StaticSizeCircuit #555

Open
wants to merge 1 commit into
base: feat/StaticSizeCircuit
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod units;
use std::iter::Sum;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
pub use hash::{CircuitHash, HashError};
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use itertools::Either::{Left, Right};

Expand Down
3 changes: 2 additions & 1 deletion tket2/src/static_circ.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! A 2d array-like representation of simple quantum circuits.

mod hash;
mod match_op;
mod rewrite;

use std::{collections::BTreeMap, fmt, rc::Rc};

Expand Down Expand Up @@ -36,7 +38,6 @@ pub struct OpLocation {

impl StaticSizeCircuit {
/// Returns the number of qubits in the circuit.
#[allow(unused)]
pub fn qubit_count(&self) -> usize {
self.qubit_ops.len()
}
Expand Down
168 changes: 168 additions & 0 deletions tket2/src/static_circ/hash.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use std::{
hash::{Hash, Hasher},
ops::Range,
};

use cgmath::num_traits::{WrappingAdd, WrappingShl};

use crate::circuit::{CircuitHash, HashError};

use super::{
rewrite::{OpInterval, StaticRewrite},
MatchOp, StaticQubitIndex, StaticSizeCircuit,
};

pub struct UpdatableHash {
cum_hash: Vec<Vec<u64>>,
}

impl UpdatableHash {
pub fn with_static(circuit: &StaticSizeCircuit) -> Self {
let num_qubits = circuit.qubit_count();
let mut cum_hash = Vec::with_capacity(num_qubits);

for row in circuit.qubit_ops.iter() {
let mut prev_hash = 0;
let mut row_hash = Vec::with_capacity(row.len());
for op in row.iter() {
let hash = Self::hash_op(op);
let combined_hash = prev_hash.wrapping_shl(5).wrapping_add(&hash);
row_hash.push(combined_hash);
prev_hash = combined_hash;
}
cum_hash.push(row_hash);
}

Self { cum_hash }
}

/// Compute the hash of the circuit that results from applying the given rewrite.
pub fn hash_rewrite<F>(&self, circuit: &StaticSizeCircuit, rewrite: &StaticRewrite<F>) -> u64
where
F: Fn(StaticQubitIndex) -> StaticQubitIndex,
{
let new_hash = Self::with_static(&rewrite.replacement);
hash_iter((0..circuit.qubit_count()).map(|i| {
if let Some(interval) = rewrite.subcircuit.op_indices.get(&StaticQubitIndex(i)) {
splice(&self.cum_hash[i], interval, &new_hash.cum_hash[i])
} else {
*self.cum_hash[i].last().unwrap()
}
}))
}

fn hash_op(op: &MatchOp) -> u64 {
let mut hasher = fxhash::FxHasher::default();
op.hash(&mut hasher);
hasher.finish()
}
}

/// Compute the hash that results from replacing the ops in the range [start, end)
/// with the new ops (given by `new_cum_hashes`).
fn splice(cum_hashes: &[u64], interval: &OpInterval, new_cum_hashes: &[u64]) -> u64 {
let Range { start, end } = interval.0;
let mut hash = 0;
if start > 0 {
hash = hash.wrapping_add(&cum_hashes[start - 1]);
}
if !new_cum_hashes.is_empty() {
hash = hash.wrapping_shl(5 * (new_cum_hashes.len() as u32));
hash = hash.wrapping_add(new_cum_hashes[new_cum_hashes.len() - 1]);
}
if end < cum_hashes.len() {
hash = hash.wrapping_shl(5 * (cum_hashes.len() - end) as u32);
hash = hash.wrapping_add(hash_delta(cum_hashes, end..cum_hashes.len()));
}
hash
}

/// The hash "contribution" that comes from within the range [start, end).
fn hash_delta(cum_hashes: &[u64], Range { start, end }: Range<usize>) -> u64 {
if start >= end {
return 0;
}
let end_hash = if end > 0 { cum_hashes[end - 1] } else { 0 };
let start_hash = if start > 0 { cum_hashes[start - 1] } else { 0 };
let start_hash_shifted = start_hash.wrapping_shl(5 * (end - start) as u32);
end_hash.wrapping_sub(start_hash_shifted)
}

fn hash_iter<T: Hash>(iter: impl Iterator<Item = T>) -> u64 {
let mut hasher = fxhash::FxHasher::default();
for item in iter {
item.hash(&mut hasher);
}
hasher.finish()
}
impl CircuitHash for StaticSizeCircuit {
fn circuit_hash(&self) -> Result<u64, HashError> {
let hash_updater = UpdatableHash::with_static(self);
Ok(hash_iter(
hash_updater
.cum_hash
.iter()
.map(|row| row.last().unwrap_or(&0)),
))
}
}

#[cfg(test)]
mod tests {
use crate::{static_circ::rewrite::StaticSubcircuit, utils::build_simple_circuit, Tk2Op};

use super::*;

#[test]
fn test_rewrite_circuit() {
// Create initial circuit
let circuit = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [1])?;
Ok(())
})
.unwrap();

let initial_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap();

// Create subcircuit to be replaced
let subcircuit = StaticSubcircuit {
op_indices: vec![
(StaticQubitIndex(0), OpInterval(0..2)),
(StaticQubitIndex(1), OpInterval(0..1)),
]
.into_iter()
.collect(),
};

let circuit = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [0])?;
Ok(())
})
.unwrap();

let replacement_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap();

// Define qubit mapping
let qubit_map = |qb: StaticQubitIndex| qb;

let rewrite = StaticRewrite {
subcircuit,
replacement: replacement_circuit,
qubit_map,
};

// Perform rewrite
let rewritten_circuit = initial_circuit.apply_rewrite(&rewrite).unwrap();

// Assert the hash of the rewritten circuit matches the spliced hash
let hash_updater = UpdatableHash::with_static(&initial_circuit);
let rewritten_hash = hash_updater.hash_rewrite(&initial_circuit, &rewrite);
let expected_hash = rewritten_circuit.circuit_hash().unwrap();
assert_eq!(rewritten_hash, expected_hash);
}
}
196 changes: 196 additions & 0 deletions tket2/src/static_circ/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use std::{collections::BTreeMap, ops::Range, rc::Rc};

use derive_more::{From, Into};
use thiserror::Error;

use super::{OpLocation, StaticQubitIndex, StaticSizeCircuit};

/// An interval of operation indices.
#[derive(Debug, Clone, PartialEq, Eq, From, Into)]
pub(super) struct OpInterval(pub Range<usize>);

/// A subcircuit of a static circuit.
#[derive(Debug, Clone, PartialEq, Eq, From, Into)]
pub struct StaticSubcircuit {
/// Maps qubit indices to the intervals of operations on that qubit.
pub(super) op_indices: BTreeMap<StaticQubitIndex, OpInterval>,
}

impl StaticSubcircuit {
/// The subcircuit before `self`.
fn before(&self, circuit: &StaticSizeCircuit) -> Self {
let mut op_indices = BTreeMap::new();
for qb in circuit.qubits_iter() {
if let Some(interval) = self.op_indices.get(&qb) {
let start = interval.0.start;
op_indices.insert(qb, OpInterval(0..start));
} else {
op_indices.insert(qb, OpInterval(0..circuit.qubit_ops(qb).len()));
}
}
StaticSubcircuit { op_indices }
}

/// The subcircuit after `self`.
fn after(&self, circuit: &StaticSizeCircuit) -> Self {
let op_indices = self
.op_indices
.iter()
.map(|(&qb, interval)| (qb, OpInterval(interval.0.end..circuit.qubit_ops(qb).len())))
.collect();
StaticSubcircuit { op_indices }
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
#[error("invalid subcircuit")]
pub struct InvalidSubcircuitError;

impl StaticSizeCircuit {
fn subcircuit(&self, subcircuit: &StaticSubcircuit) -> Result<Self, InvalidSubcircuitError> {
let Self {
mut qubit_ops,
mut op_locations,
} = self.clone();
for (qb, interval) in subcircuit.op_indices.iter() {
for op in qubit_ops[qb.0].drain(interval.0.end..) {
op_locations.remove(&Rc::as_ptr(&op));
}
for op in qubit_ops[qb.0].drain(..interval.0.start) {
op_locations.remove(&Rc::as_ptr(&op));
}
}
let ret = Self {
qubit_ops,
op_locations,
};
ret.check_valid()?;
Ok(ret)
}

fn append(
&mut self,
other: &StaticSizeCircuit,
qubit_map: impl Fn(StaticQubitIndex) -> StaticQubitIndex,
) {
for (qb, ops) in other.qubit_ops.iter().enumerate() {
let new_qb = qubit_map(StaticQubitIndex(qb));
for op in ops.iter() {
let op_idx = self.qubit_ops[new_qb.0].len();
self.qubit_ops[new_qb.0].push(op.clone());
self.op_locations
.entry(Rc::as_ptr(op))
.or_default()
.push(OpLocation {
qubit: new_qb,
op_idx,
});
}
}
}

fn check_valid(&self) -> Result<(), InvalidSubcircuitError> {
for op in self.all_ops_iter() {
if self.op_locations.get(&Rc::as_ptr(op)).is_none() {
return Err(InvalidSubcircuitError);
}
}
Ok(())
}
}

/// A rewrite that applies on a static circuit.
pub struct StaticRewrite<F> {
/// The subcircuit to be replaced.
pub subcircuit: StaticSubcircuit,
/// The replacement circuit.
pub replacement: StaticSizeCircuit,
/// The qubit map.
pub qubit_map: F,
}

impl StaticSizeCircuit {
/// Rewrite a subcircuit in the circuit with a replacement circuit.
pub fn apply_rewrite<F>(
&self,
rewrite: &StaticRewrite<F>,
) -> Result<StaticSizeCircuit, InvalidSubcircuitError>
where
F: Fn(StaticQubitIndex) -> StaticQubitIndex,
{
let mut new_circ = self.subcircuit(&rewrite.subcircuit.before(self))?;
new_circ.append(&rewrite.replacement, &rewrite.qubit_map);
let after = self.subcircuit(&rewrite.subcircuit.after(self))?;
new_circ.append(&after, |qb| qb);
Ok(new_circ)
}
}

#[cfg(test)]
mod tests {
use crate::{utils::build_simple_circuit, Tk2Op};

use super::*;

#[test]
fn test_rewrite_circuit() {
// Create initial circuit
let circuit = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [1])?;
Ok(())
})
.unwrap();

let initial_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap();

// Create subcircuit to be replaced
let subcircuit = StaticSubcircuit {
op_indices: vec![
(StaticQubitIndex(0), OpInterval(0..2)),
(StaticQubitIndex(1), OpInterval(0..1)),
]
.into_iter()
.collect(),
};

let circuit = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [0])?;
Ok(())
})
.unwrap();

let replacement_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap();

// Define qubit mapping
let qubit_map = |qb: StaticQubitIndex| qb;

let rewrite = StaticRewrite {
subcircuit,
replacement: replacement_circuit,
qubit_map,
};

// Perform rewrite
let rewritten_circuit = initial_circuit.apply_rewrite(&rewrite).unwrap();

// Expected circuit after rewrite
let circuit = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::H, [1])?;
Ok(())
})
.unwrap();
let expected_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap();

// Assert the rewritten circuit matches the expected circuit
assert_eq!(rewritten_circuit, expected_circuit);
}
}