Skip to content

Commit

Permalink
feat: Add monomorphization and constant folding to QSystemPass (#730)
Browse files Browse the repository at this point in the history
Closes #729

Note that constant folding is disabled by default as it currently does
not work on modules.

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
doug-q and ss2165 authored Dec 17, 2024
1 parent 95090a2 commit aafd73c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
60 changes: 55 additions & 5 deletions tket2-hseries/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
//! Provides a preparation and validation workflow for Hugrs targeting
//! Quantinuum H-series quantum computers.
use std::mem;

use derive_more::{Display, Error, From};
use hugr::{
algorithms::{
force_order,
const_fold::{ConstFoldError, ConstantFoldPass},
force_order, monomorphize, remove_polyfuncs,
validation::{ValidatePassError, ValidationLevel},
},
hugr::{hugrmut::HugrMut, HugrError},
hugr::HugrError,
Hugr, HugrView,
};
use tket2::Tk2Op;

Expand All @@ -26,9 +30,21 @@ pub mod lazify_measure;
/// Returns an error if this cannot be done.
///
/// To construct a `QSystemPass` use [Default::default].
#[derive(Debug, Clone, Copy, Default)]
#[derive(Debug, Clone, Copy)]
pub struct QSystemPass {
validation_level: ValidationLevel,
constant_fold: bool,
monomorphize: bool,
}

impl Default for QSystemPass {
fn default() -> Self {
Self {
validation_level: ValidationLevel::default(),
constant_fold: false,
monomorphize: true,
}
}
}

#[derive(Error, Debug, Display, From)]
Expand All @@ -43,12 +59,28 @@ pub enum QSystemPassError {
ForceOrderError(HugrError),
/// An error from the component [LowerTket2ToQSystemPass] pass.
LowerTk2Error(LowerTk2Error),
/// An error from the component [ConstantFoldPass] pass.
ConstantFoldError(ConstFoldError),
}

impl QSystemPass {
/// Run `QSystemPass` on the given [HugrMut]. `registry` is used for
/// Run `QSystemPass` on the given [Hugr]. `registry` is used for
/// validation, if enabled.
pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), QSystemPassError> {
pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
if self.monomorphize {
self.validation_level.run_validated_pass(hugr, |hugr, _| {
let mut owned_hugr = Hugr::default();
mem::swap(&mut owned_hugr, hugr);
owned_hugr = remove_polyfuncs(monomorphize(owned_hugr));
mem::swap(&mut owned_hugr, hugr);

Ok::<_, QSystemPassError>(())
})?;
}

if self.constant_fold {
self.constant_fold().run(hugr)?;
}
self.lower_tk2().run(hugr)?;
self.lazify_measure().run(hugr)?;
self.validation_level.run_validated_pass(hugr, |hugr, _| {
Expand Down Expand Up @@ -77,11 +109,29 @@ impl QSystemPass {
LazifyMeasurePass::default().with_validation_level(self.validation_level)
}

fn constant_fold(&self) -> ConstantFoldPass {
ConstantFoldPass::default().validation_level(self.validation_level)
}

/// Returns a new `QSystemPass` with the given [ValidationLevel].
pub fn with_validation_level(mut self, level: ValidationLevel) -> Self {
self.validation_level = level;
self
}

/// Returns a new `QSystemPass` with constant folding enabled according to
/// `constant_fold`.
pub fn with_constant_fold(mut self, constant_fold: bool) -> Self {
self.constant_fold = constant_fold;
self
}

/// Returns a new `QSystemPass` with monomorphization enabled according to
/// `monomorphize`.
pub fn with_monormophize(mut self, monomorphize: bool) -> Self {
self.monomorphize = monomorphize;
self
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit aafd73c

Please sign in to comment.