Skip to content

Commit

Permalink
feat: TasoPass (#185)
Browse files Browse the repository at this point in the history
In hindsight this should have been part of #180.

However, note that the Python tests currently fail: I get a segmentation
fault, no idea why!

drive by: renamed `PyDefaultTasoOptimiser` -> `PyTasoOptimiser`.

---------

Co-authored-by: Agustin Borgna <[email protected]>
  • Loading branch information
lmondada and aborgna-q authored Oct 31, 2023
1 parent 94ed1f7 commit f6440ab
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 18 deletions.
File renamed without changes.
40 changes: 40 additions & 0 deletions pyrs/pyrs/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path
from typing import Optional
import pkg_resources

from pytket import Circuit
from pytket.passes import CustomPass
from pyrs.pyrs import passes, optimiser


def taso_pass(
rewriter: Optional[Path] = None,
max_threads: Optional[int] = None,
timeout: Optional[int] = None,
log_dir: Optional[Path] = None,
rebase: Optional[bool] = None,
) -> CustomPass:
"""Construct a TASO pass.
The Taso optimiser requires a pre-compiled rewriter produced by the
`compile-rewriter <https://github.com/CQCL/tket2/tree/main/taso-optimiser>`_
utility. If `rewriter` is not specified, a default one will be used.
The arguments `max_threads`, `timeout`, `log_dir` and `rebase` are optional
and will be passed on to the TASO optimiser if provided."""
if rewriter is None:
rewriter = pkg_resources.resource_filename("pyrs", "data/nam_6_3.rwr")
opt = optimiser.TasoOptimiser.load_precompiled(rewriter)

def apply(circuit: Circuit) -> Circuit:
"""Apply TASO optimisation to the circuit."""
return passes.taso_optimise(
circuit,
opt,
max_threads,
timeout,
log_dir,
rebase,
)

return CustomPass(apply)
8 changes: 4 additions & 4 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::circuit::update_hugr;
/// The circuit optimisation module.
pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
m.add_class::<PyDefaultTasoOptimiser>()?;
m.add_class::<PyTasoOptimiser>()?;

parent.add_submodule(m)
}
Expand All @@ -22,10 +22,10 @@ pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
/// Currently only exposes loading from an ECC file using the constructor
/// and optimising using default logging settings.
#[pyclass(name = "TasoOptimiser")]
pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser);
pub struct PyTasoOptimiser(DefaultTasoOptimiser);

#[pymethods]
impl PyDefaultTasoOptimiser {
impl PyTasoOptimiser {
/// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> Self {
Expand Down Expand Up @@ -80,7 +80,7 @@ impl PyDefaultTasoOptimiser {
}
}

impl PyDefaultTasoOptimiser {
impl PyTasoOptimiser {
/// The Python optimise method, but on Hugrs.
pub(super) fn optimise(
&self,
Expand Down
13 changes: 3 additions & 10 deletions pyrs/src/pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tket_json_rs::circuit_json::SerialCircuit;

use crate::{
circuit::{try_update_hugr, try_with_hugr},
optimiser::PyDefaultTasoOptimiser,
optimiser::PyTasoOptimiser,
};

#[pyfunction]
Expand Down Expand Up @@ -44,7 +44,7 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> {
/// TASO optimisation pass.
///
/// HyperTKET's best attempt at optimising a circuit using circuit rewriting
/// and TASO.
/// and the given TASO optimiser.
///
/// By default, the input circuit will be rebased to Nam, i.e. CX + Rz + H before
/// optimising. This can be deactivated by setting `rebase` to `false`, in which
Expand All @@ -55,22 +55,18 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> {
/// 15min respectively.
///
/// Log files will be written to the directory `log_dir` if specified.
///
/// This requires a `nam_6_3.rwr` file in the current directory. The location
/// can alternatively be specified using the `rewriter_dir` argument.
#[pyfunction]
fn taso_optimise(
circ: PyObject,
optimiser: &PyTasoOptimiser,
max_threads: Option<NonZeroUsize>,
rewriter_dir: Option<PathBuf>,
timeout: Option<u64>,
log_dir: Option<PathBuf>,
rebase: Option<bool>,
) -> PyResult<PyObject> {
// Default parameter values
let rebase = rebase.unwrap_or(true);
let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap());
let rewrite_dir = rewriter_dir.unwrap_or(PathBuf::from("."));
let timeout = timeout.unwrap_or(30);
// Create log directory if necessary
if let Some(log_dir) = log_dir.as_ref() {
Expand All @@ -94,9 +90,6 @@ fn taso_optimise(
1 => (vec![1], vec![timeout]),
_ => unreachable!(),
};
// Load rewriter
// TODO: do not hardcode file name
let optimiser = PyDefaultTasoOptimiser::load_precompiled(rewrite_dir.join("nam_6_3.rwr"));
// Optimise
try_update_hugr(circ, |mut circ| {
let n_cx = circ
Expand Down
8 changes: 4 additions & 4 deletions pyrs/test/test_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pytket import Circuit, OpType
from pyrs.pyrs import passes
from pyrs.passes import taso_pass


def test_simple_taso_pass_no_opt():
c = Circuit(3).CCX(0, 1, 2)
c = passes.taso_optimise(c, max_threads = 1, timeout = 0)
print(c)
assert c.n_gates_of_type(OpType.CX) == 6
taso = taso_pass(max_threads = 1, timeout = 0)
taso.apply(c)
assert c.n_gates_of_type(OpType.CX) == 6
1 change: 1 addition & 0 deletions src/optimiser/taso/hugr_pqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl<P: Ord, C> HugrPQ<P, C> {
}

/// The cost function used by the queue.
#[allow(unused)]
pub fn cost_fn(&self) -> &C {
&self.cost_fn
}
Expand Down

0 comments on commit f6440ab

Please sign in to comment.