From aafd73cf8b3ee1c55633da749cf1637ed67a343b Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:12:54 +0000 Subject: [PATCH] feat: Add monomorphization and constant folding to QSystemPass (#730) Closes #729 Note that constant folding is disabled by default as it currently does not work on modules. --------- Co-authored-by: Seyon Sivarajah <seyon.sivarajah@cambridgequantum.com> --- tket2-hseries/src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++---- uv.lock | 2 +- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs index 0a8be413..5808afde 100644 --- a/tket2-hseries/src/lib.rs +++ b/tket2-hseries/src/lib.rs @@ -1,12 +1,16 @@ //! Provides a preparation and validation workflow for Hugrs targeting //! Quantinuum H-series quantum computers. +use std::mem; + use derive_more::{Display, Error, From}; use hugr::{ algorithms::{ - force_order, + const_fold::{ConstFoldError, ConstantFoldPass}, + force_order, monomorphize, remove_polyfuncs, validation::{ValidatePassError, ValidationLevel}, }, - hugr::{hugrmut::HugrMut, HugrError}, + hugr::HugrError, + Hugr, HugrView, }; use tket2::Tk2Op; @@ -26,9 +30,21 @@ pub mod lazify_measure; /// Returns an error if this cannot be done. /// /// To construct a `QSystemPass` use [Default::default]. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy)] pub struct QSystemPass { validation_level: ValidationLevel, + constant_fold: bool, + monomorphize: bool, +} + +impl Default for QSystemPass { + fn default() -> Self { + Self { + validation_level: ValidationLevel::default(), + constant_fold: false, + monomorphize: true, + } + } } #[derive(Error, Debug, Display, From)] @@ -43,12 +59,28 @@ pub enum QSystemPassError { ForceOrderError(HugrError), /// An error from the component [LowerTket2ToQSystemPass] pass. LowerTk2Error(LowerTk2Error), + /// An error from the component [ConstantFoldPass] pass. + ConstantFoldError(ConstFoldError), } impl QSystemPass { - /// Run `QSystemPass` on the given [HugrMut]. `registry` is used for + /// Run `QSystemPass` on the given [Hugr]. `registry` is used for /// validation, if enabled. - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), QSystemPassError> { + pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> { + if self.monomorphize { + self.validation_level.run_validated_pass(hugr, |hugr, _| { + let mut owned_hugr = Hugr::default(); + mem::swap(&mut owned_hugr, hugr); + owned_hugr = remove_polyfuncs(monomorphize(owned_hugr)); + mem::swap(&mut owned_hugr, hugr); + + Ok::<_, QSystemPassError>(()) + })?; + } + + if self.constant_fold { + self.constant_fold().run(hugr)?; + } self.lower_tk2().run(hugr)?; self.lazify_measure().run(hugr)?; self.validation_level.run_validated_pass(hugr, |hugr, _| { @@ -77,11 +109,29 @@ impl QSystemPass { LazifyMeasurePass::default().with_validation_level(self.validation_level) } + fn constant_fold(&self) -> ConstantFoldPass { + ConstantFoldPass::default().validation_level(self.validation_level) + } + /// Returns a new `QSystemPass` with the given [ValidationLevel]. pub fn with_validation_level(mut self, level: ValidationLevel) -> Self { self.validation_level = level; self } + + /// Returns a new `QSystemPass` with constant folding enabled according to + /// `constant_fold`. + pub fn with_constant_fold(mut self, constant_fold: bool) -> Self { + self.constant_fold = constant_fold; + self + } + + /// Returns a new `QSystemPass` with monomorphization enabled according to + /// `monomorphize`. + pub fn with_monormophize(mut self, monomorphize: bool) -> Self { + self.monomorphize = monomorphize; + self + } } #[cfg(test)] diff --git a/uv.lock b/uv.lock index e4ac503a..8c2b94fb 100644 --- a/uv.lock +++ b/uv.lock @@ -823,7 +823,7 @@ wheels = [ [[package]] name = "tket2" -version = "0.5.1" +version = "0.6.0" source = { editable = "tket2-py" } dependencies = [ { name = "hugr" },