From aafd73cf8b3ee1c55633da749cf1637ed67a343b Mon Sep 17 00:00:00 2001
From: Douglas Wilson <141026920+doug-q@users.noreply.github.com>
Date: Tue, 17 Dec 2024 16:12:54 +0000
Subject: [PATCH] feat: Add monomorphization and constant folding to
 QSystemPass (#730)

Closes #729

Note that constant folding is disabled by default as it currently does
not work on modules.

---------

Co-authored-by: Seyon Sivarajah <seyon.sivarajah@cambridgequantum.com>
---
 tket2-hseries/src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++----
 uv.lock                  |  2 +-
 2 files changed, 56 insertions(+), 6 deletions(-)

diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs
index 0a8be413..5808afde 100644
--- a/tket2-hseries/src/lib.rs
+++ b/tket2-hseries/src/lib.rs
@@ -1,12 +1,16 @@
 //! Provides a preparation and validation workflow for Hugrs targeting
 //! Quantinuum H-series quantum computers.
+use std::mem;
+
 use derive_more::{Display, Error, From};
 use hugr::{
     algorithms::{
-        force_order,
+        const_fold::{ConstFoldError, ConstantFoldPass},
+        force_order, monomorphize, remove_polyfuncs,
         validation::{ValidatePassError, ValidationLevel},
     },
-    hugr::{hugrmut::HugrMut, HugrError},
+    hugr::HugrError,
+    Hugr, HugrView,
 };
 use tket2::Tk2Op;
 
@@ -26,9 +30,21 @@ pub mod lazify_measure;
 /// Returns an error if this cannot be done.
 ///
 /// To construct a `QSystemPass` use [Default::default].
-#[derive(Debug, Clone, Copy, Default)]
+#[derive(Debug, Clone, Copy)]
 pub struct QSystemPass {
     validation_level: ValidationLevel,
+    constant_fold: bool,
+    monomorphize: bool,
+}
+
+impl Default for QSystemPass {
+    fn default() -> Self {
+        Self {
+            validation_level: ValidationLevel::default(),
+            constant_fold: false,
+            monomorphize: true,
+        }
+    }
 }
 
 #[derive(Error, Debug, Display, From)]
@@ -43,12 +59,28 @@ pub enum QSystemPassError {
     ForceOrderError(HugrError),
     /// An error from the component [LowerTket2ToQSystemPass] pass.
     LowerTk2Error(LowerTk2Error),
+    /// An error from the component [ConstantFoldPass] pass.
+    ConstantFoldError(ConstFoldError),
 }
 
 impl QSystemPass {
-    /// Run `QSystemPass` on the given [HugrMut]. `registry` is used for
+    /// Run `QSystemPass` on the given [Hugr]. `registry` is used for
     /// validation, if enabled.
-    pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), QSystemPassError> {
+    pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
+        if self.monomorphize {
+            self.validation_level.run_validated_pass(hugr, |hugr, _| {
+                let mut owned_hugr = Hugr::default();
+                mem::swap(&mut owned_hugr, hugr);
+                owned_hugr = remove_polyfuncs(monomorphize(owned_hugr));
+                mem::swap(&mut owned_hugr, hugr);
+
+                Ok::<_, QSystemPassError>(())
+            })?;
+        }
+
+        if self.constant_fold {
+            self.constant_fold().run(hugr)?;
+        }
         self.lower_tk2().run(hugr)?;
         self.lazify_measure().run(hugr)?;
         self.validation_level.run_validated_pass(hugr, |hugr, _| {
@@ -77,11 +109,29 @@ impl QSystemPass {
         LazifyMeasurePass::default().with_validation_level(self.validation_level)
     }
 
+    fn constant_fold(&self) -> ConstantFoldPass {
+        ConstantFoldPass::default().validation_level(self.validation_level)
+    }
+
     /// Returns a new `QSystemPass` with the given [ValidationLevel].
     pub fn with_validation_level(mut self, level: ValidationLevel) -> Self {
         self.validation_level = level;
         self
     }
+
+    /// Returns a new `QSystemPass` with constant folding enabled according to
+    /// `constant_fold`.
+    pub fn with_constant_fold(mut self, constant_fold: bool) -> Self {
+        self.constant_fold = constant_fold;
+        self
+    }
+
+    /// Returns a new `QSystemPass` with monomorphization enabled according to
+    /// `monomorphize`.
+    pub fn with_monormophize(mut self, monomorphize: bool) -> Self {
+        self.monomorphize = monomorphize;
+        self
+    }
 }
 
 #[cfg(test)]
diff --git a/uv.lock b/uv.lock
index e4ac503a..8c2b94fb 100644
--- a/uv.lock
+++ b/uv.lock
@@ -823,7 +823,7 @@ wheels = [
 
 [[package]]
 name = "tket2"
-version = "0.5.1"
+version = "0.6.0"
 source = { editable = "tket2-py" }
 dependencies = [
     { name = "hugr" },