diff --git a/pyrs/nam_6_3.rwr b/pyrs/pyrs/data/nam_6_3.rwr similarity index 100% rename from pyrs/nam_6_3.rwr rename to pyrs/pyrs/data/nam_6_3.rwr diff --git a/pyrs/pyrs/passes.py b/pyrs/pyrs/passes.py new file mode 100644 index 00000000..05b6c89f --- /dev/null +++ b/pyrs/pyrs/passes.py @@ -0,0 +1,40 @@ +from pathlib import Path +from typing import Optional +import pkg_resources + +from pytket import Circuit +from pytket.passes import CustomPass +from pyrs.pyrs import passes, optimiser + + +def taso_pass( + rewriter: Optional[Path] = None, + max_threads: Optional[int] = None, + timeout: Optional[int] = None, + log_dir: Optional[Path] = None, + rebase: Optional[bool] = None, +) -> CustomPass: + """Construct a TASO pass. + + The Taso optimiser requires a pre-compiled rewriter produced by the + `compile-rewriter `_ + utility. If `rewriter` is not specified, a default one will be used. + + The arguments `max_threads`, `timeout`, `log_dir` and `rebase` are optional + and will be passed on to the TASO optimiser if provided.""" + if rewriter is None: + rewriter = pkg_resources.resource_filename("pyrs", "data/nam_6_3.rwr") + opt = optimiser.TasoOptimiser.load_precompiled(rewriter) + + def apply(circuit: Circuit) -> Circuit: + """Apply TASO optimisation to the circuit.""" + return passes.taso_optimise( + circuit, + opt, + max_threads, + timeout, + log_dir, + rebase, + ) + + return CustomPass(apply) diff --git a/pyrs/src/optimiser.rs b/pyrs/src/optimiser.rs index 6e3f1350..af31ec52 100644 --- a/pyrs/src/optimiser.rs +++ b/pyrs/src/optimiser.rs @@ -12,7 +12,7 @@ use crate::circuit::update_hugr; /// The circuit optimisation module. pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { let m = PyModule::new(py, "optimiser")?; - m.add_class::()?; + m.add_class::()?; parent.add_submodule(m) } @@ -22,10 +22,10 @@ pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { /// Currently only exposes loading from an ECC file using the constructor /// and optimising using default logging settings. #[pyclass(name = "TasoOptimiser")] -pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); +pub struct PyTasoOptimiser(DefaultTasoOptimiser); #[pymethods] -impl PyDefaultTasoOptimiser { +impl PyTasoOptimiser { /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. #[staticmethod] pub fn load_precompiled(path: PathBuf) -> Self { @@ -80,7 +80,7 @@ impl PyDefaultTasoOptimiser { } } -impl PyDefaultTasoOptimiser { +impl PyTasoOptimiser { /// The Python optimise method, but on Hugrs. pub(super) fn optimise( &self, diff --git a/pyrs/src/pass.rs b/pyrs/src/pass.rs index 7b1e3a13..7cbc00a6 100644 --- a/pyrs/src/pass.rs +++ b/pyrs/src/pass.rs @@ -6,7 +6,7 @@ use tket_json_rs::circuit_json::SerialCircuit; use crate::{ circuit::{try_update_hugr, try_with_hugr}, - optimiser::PyDefaultTasoOptimiser, + optimiser::PyTasoOptimiser, }; #[pyfunction] @@ -44,7 +44,7 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> { /// TASO optimisation pass. /// /// HyperTKET's best attempt at optimising a circuit using circuit rewriting -/// and TASO. +/// and the given TASO optimiser. /// /// By default, the input circuit will be rebased to Nam, i.e. CX + Rz + H before /// optimising. This can be deactivated by setting `rebase` to `false`, in which @@ -55,14 +55,11 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> { /// 15min respectively. /// /// Log files will be written to the directory `log_dir` if specified. -/// -/// This requires a `nam_6_3.rwr` file in the current directory. The location -/// can alternatively be specified using the `rewriter_dir` argument. #[pyfunction] fn taso_optimise( circ: PyObject, + optimiser: &PyTasoOptimiser, max_threads: Option, - rewriter_dir: Option, timeout: Option, log_dir: Option, rebase: Option, @@ -70,7 +67,6 @@ fn taso_optimise( // Default parameter values let rebase = rebase.unwrap_or(true); let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap()); - let rewrite_dir = rewriter_dir.unwrap_or(PathBuf::from(".")); let timeout = timeout.unwrap_or(30); // Create log directory if necessary if let Some(log_dir) = log_dir.as_ref() { @@ -94,9 +90,6 @@ fn taso_optimise( 1 => (vec![1], vec![timeout]), _ => unreachable!(), }; - // Load rewriter - // TODO: do not hardcode file name - let optimiser = PyDefaultTasoOptimiser::load_precompiled(rewrite_dir.join("nam_6_3.rwr")); // Optimise try_update_hugr(circ, |mut circ| { let n_cx = circ diff --git a/pyrs/test/test_pass.py b/pyrs/test/test_pass.py index 35cc3319..3ad701d8 100644 --- a/pyrs/test/test_pass.py +++ b/pyrs/test/test_pass.py @@ -1,9 +1,9 @@ from pytket import Circuit, OpType -from pyrs.pyrs import passes +from pyrs.passes import taso_pass def test_simple_taso_pass_no_opt(): c = Circuit(3).CCX(0, 1, 2) - c = passes.taso_optimise(c, max_threads = 1, timeout = 0) - print(c) - assert c.n_gates_of_type(OpType.CX) == 6 \ No newline at end of file + taso = taso_pass(max_threads = 1, timeout = 0) + taso.apply(c) + assert c.n_gates_of_type(OpType.CX) == 6 diff --git a/src/optimiser/taso/hugr_pqueue.rs b/src/optimiser/taso/hugr_pqueue.rs index 26ef4937..5738f67c 100644 --- a/src/optimiser/taso/hugr_pqueue.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -105,6 +105,7 @@ impl HugrPQ { } /// The cost function used by the queue. + #[allow(unused)] pub fn cost_fn(&self) -> &C { &self.cost_fn }