diff --git a/pyproject.toml b/pyproject.toml index 2f4bf5c9..20a35970 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ authors = [] # TODO maintainers = [] # TODO include = ["pyproject.toml"] license = "Apache-2.0" +license_file = "LICENCE" readme = "README.md" packages = [{ include = "tket2-py" }] diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index af848e0d..a410d31b 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -13,19 +13,18 @@ use tket2::json::TKETDecode; use tket2::rewrite::CircuitRewrite; use tket_json_rs::circuit_json::SerialCircuit; -pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, T2Circuit}; +pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit}; /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_circuit")?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(validate_hugr, m)?)?; m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?; - m.add_function(wrap_pyfunction!(to_hugr, m)?)?; m.add("HugrError", py.get_type::())?; m.add("BuildError", py.get_type::())?; @@ -47,20 +46,14 @@ pub fn module(py: Python) -> PyResult<&PyModule> { /// Run the validation checks on a circuit. #[pyfunction] -pub fn validate_hugr(c: Py) -> PyResult<()> { - try_with_hugr(c, |hugr| hugr.validate(®ISTRY)) +pub fn validate_hugr(c: &PyAny) -> PyResult<()> { + try_with_hugr(c, |hugr, _| hugr.validate(®ISTRY)) } /// Return a Graphviz DOT string representation of the circuit. #[pyfunction] -pub fn to_hugr_dot(c: Py) -> PyResult { - with_hugr(c, |hugr| hugr.dot_string()) -} - -/// Downcast a python object to a [`Hugr`]. -#[pyfunction] -pub fn to_hugr(c: Py) -> PyResult { - with_hugr(c, |hugr| hugr.into()) +pub fn to_hugr_dot(c: &PyAny) -> PyResult { + with_hugr(c, |hugr, _| hugr.dot_string()) } /// A [`hugr::Node`] wrapper for Python. diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 33d183ff..47993ccd 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -1,9 +1,11 @@ //! Utilities for calling Hugr functions on generic python objects. +use pyo3::exceptions::PyAttributeError; use pyo3::{prelude::*, PyTypeInfo}; use derive_more::From; use hugr::{Hugr, HugrView}; +use serde::Serialize; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; @@ -14,78 +16,149 @@ use crate::pattern::rewrite::PyCircuitRewrite; /// A manager for tket 2 operations on a tket 1 Circuit. #[pyclass] #[derive(Clone, Debug, PartialEq, From)] -pub struct T2Circuit { +pub struct Tk2Circuit { /// Rust representation of the circuit. pub hugr: Hugr, } #[pymethods] -impl T2Circuit { +impl Tk2Circuit { + /// Convert a tket1 circuit to a [`Tk2Circuit`]. #[new] - fn from_circuit(circ: PyObject) -> PyResult { + pub fn from_tket1(circ: &PyAny) -> PyResult { Ok(Self { - hugr: with_hugr(circ, |hugr| hugr)?, + hugr: with_hugr(circ, |hugr, _| hugr)?, }) } - fn finish(&self) -> PyResult { - SerialCircuit::encode(&self.hugr)?.to_tket1_with_gil() + /// Convert the [`Tk2Circuit`] to a tket1 circuit. + pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + SerialCircuit::encode(&self.hugr)?.to_tket1(py) } - fn apply_match(&mut self, rw: PyCircuitRewrite) { + /// Apply a rewrite on the circuit. + pub fn apply_match(&mut self, rw: PyCircuitRewrite) { rw.rewrite.apply(&mut self.hugr).expect("Apply error."); } + + /// Encode the circuit as a HUGR json string. + // + // TODO: Bind a messagepack encoder/decoder too. + pub fn to_hugr_json(&self) -> PyResult { + Ok(serde_json::to_string(&self.hugr).unwrap()) + } + + /// Decode a HUGR json string to a circuit. + #[staticmethod] + pub fn from_hugr_json(json: &str) -> PyResult { + let hugr = serde_json::from_str(json) + .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; + Ok(Tk2Circuit { hugr }) + } + + /// Encode the circuit as a tket1 json string. + /// + /// FIXME: Currently the encoded circuit cannot be loaded back due to + /// [https://github.com/CQCL/hugr/issues/683] + pub fn to_tket1_json(&self) -> PyResult { + Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr)?).unwrap()) + } + + /// Decode a tket1 json string to a circuit. + #[staticmethod] + pub fn from_tket1_json(json: &str) -> PyResult { + let tk1: SerialCircuit = serde_json::from_str(json) + .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; + Ok(Tk2Circuit { + hugr: tk1.decode()?, + }) + } } -impl T2Circuit { - /// Tries to extract a T2Circuit from a python object. +impl Tk2Circuit { + /// Tries to extract a Tk2Circuit from a python object. /// - /// Returns an error if the py object is not a T2Circuit. - pub fn try_extract(circ: Py) -> PyResult { - Python::with_gil(|py| circ.as_ref(py).extract::()) + /// Returns an error if the py object is not a Tk2Circuit. + pub fn try_extract(circ: &PyAny) -> PyResult { + circ.extract::() + } +} + +/// A flag to indicate the encoding of a circuit. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum CircuitType { + /// A `pytket` `Circuit`. + Tket1, + /// A tket2 `Tk2Circuit`, represented as a HUGR. + Tket2, +} + +impl CircuitType { + /// Converts a `Hugr` into the format indicated by the flag. + pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<&PyAny> { + match self { + CircuitType::Tket1 => SerialCircuit::encode(&hugr)?.to_tket1(py), + CircuitType::Tket2 => Ok(Py::new(py, Tk2Circuit { hugr })?.into_ref(py)), + } } } -/// Apply a fallible function expecting a hugr on a pytket circuit. -pub fn try_with_hugr(circ: Py, f: F) -> PyResult +/// Apply a fallible function expecting a hugr on a python circuit. +/// +/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects. +pub fn try_with_hugr(circ: &PyAny, f: F) -> PyResult where E: Into, - F: FnOnce(Hugr) -> Result, + F: FnOnce(Hugr, CircuitType) -> Result, { - let hugr = Python::with_gil(|py| -> PyResult { - let circ = circ.as_ref(py); - match T2Circuit::extract(circ) { - // hugr circuit - Ok(t2circ) => Ok(t2circ.hugr), - // tket1 circuit - Err(_) => Ok(SerialCircuit::from_tket1(circ)?.decode()?), - } - })?; - (f)(hugr).map_err(|e| e.into()) + let (hugr, typ) = match Tk2Circuit::extract(circ) { + // hugr circuit + Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2), + // tket1 circuit + Err(_) => ( + SerialCircuit::from_tket1(circ)?.decode()?, + CircuitType::Tket1, + ), + }; + (f)(hugr, typ).map_err(|e| e.into()) } -/// Apply a function expecting a hugr on a pytket circuit. -pub fn with_hugr(circ: Py, f: F) -> PyResult +/// Apply a function expecting a hugr on a python circuit. +/// +/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects. +pub fn with_hugr(circ: &PyAny, f: F) -> PyResult where - F: FnOnce(Hugr) -> T, + F: FnOnce(Hugr, CircuitType) -> T, { - try_with_hugr(circ, |hugr| Ok::((f)(hugr))) + try_with_hugr(circ, |hugr, typ| Ok::((f)(hugr, typ))) } -/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. -pub fn try_update_hugr(circ: Py, f: F) -> PyResult> +/// Apply a fallible hugr-to-hugr function on a python circuit, and return the modified circuit. +/// +/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects. +/// The returned Hugr is converted to the matching python object. +pub fn try_update_hugr(circ: &PyAny, f: F) -> PyResult<&PyAny> where E: Into, - F: FnOnce(Hugr) -> Result, + F: FnOnce(Hugr, CircuitType) -> Result, { - let hugr = try_with_hugr(circ, f)?; - SerialCircuit::encode(&hugr)?.to_tket1_with_gil() + let py = circ.py(); + try_with_hugr(circ, |hugr, typ| { + let hugr = f(hugr, typ).map_err(|e| e.into())?; + typ.convert(py, hugr) + }) } -/// Apply a hugr-to-hugr function on a pytket circuit, and return the modified circuit. -pub fn update_hugr(circ: Py, f: F) -> PyResult> +/// Apply a hugr-to-hugr function on a python circuit, and return the modified circuit. +/// +/// This method supports both `pytket.Circuit` and `Tk2Circuit` python objects. +/// The returned Hugr is converted to the matching python object. +pub fn update_hugr(circ: &PyAny, f: F) -> PyResult<&PyAny> where - F: FnOnce(Hugr) -> Hugr, + F: FnOnce(Hugr, CircuitType) -> Hugr, { - let hugr = with_hugr(circ, f)?; - SerialCircuit::encode(&hugr)?.to_tket1_with_gil() + let py = circ.py(); + try_with_hugr(circ, |hugr, typ| { + let hugr = f(hugr, typ); + typ.convert(py, hugr) + }) } diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index 76d8c40e..38a8fd11 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -57,16 +57,16 @@ impl PyBadgerOptimiser { /// * `log_progress`: The path to a CSV file to log progress to. /// #[pyo3(name = "optimise")] - pub fn py_optimise( + pub fn py_optimise<'py>( &self, - circ: PyObject, + circ: &'py PyAny, timeout: Option, n_threads: Option, split_circ: Option, log_progress: Option, queue_size: Option, - ) -> PyResult { - update_hugr(circ, |circ| { + ) -> PyResult<&'py PyAny> { + update_hugr(circ, |circ, _| { self.optimise( circ, timeout, diff --git a/tket2-py/src/passes.rs b/tket2-py/src/passes.rs index c3363c4c..68aa743d 100644 --- a/tket2-py/src/passes.rs +++ b/tket2-py/src/passes.rs @@ -5,8 +5,7 @@ pub mod chunks; use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf}; use pyo3::{prelude::*, types::IntoPyDict}; -use tket2::{json::TKETDecode, op_matches, passes::apply_greedy_commutation, Circuit, T2Op}; -use tket_json_rs::circuit_json::SerialCircuit; +use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, T2Op}; use crate::{ circuit::{try_update_hugr, try_with_hugr}, @@ -30,35 +29,33 @@ pub fn module(py: Python) -> PyResult<&PyModule> { } #[pyfunction] -fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> { - try_with_hugr(py_c, |mut h| { +fn greedy_depth_reduce(circ: &PyAny) -> PyResult<(&PyAny, u32)> { + let py = circ.py(); + try_with_hugr(circ, |mut h, typ| { let n_moves = apply_greedy_commutation(&mut h)?; - let py_c = SerialCircuit::encode(&h)?.to_tket1_with_gil()?; - PyResult::Ok((py_c, n_moves)) + let circ = typ.convert(py, h)?; + PyResult::Ok((circ, n_moves)) }) } /// Rebase a circuit to the Nam gate set (CX, Rz, H) using TKET1. /// -/// Acquires the python GIL to call TKET's `auto_rebase_pass`. -/// /// Equivalent to running the following code: /// ```python /// from pytket.passes.auto_rebase import auto_rebase_pass /// from pytket import OpType /// auto_rebase_pass({OpType.CX, OpType.Rz, OpType.H}).apply(circ)" // ``` -fn rebase_nam(circ: &PyObject) -> PyResult<()> { - Python::with_gil(|py| { - let auto_rebase = py - .import("pytket.passes.auto_rebase")? - .getattr("auto_rebase_pass")?; - let optype = py.import("pytket")?.getattr("OpType")?; - let locals = [("OpType", &optype)].into_py_dict(py); - let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?; - let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?; - rebase_pass.call1((circ,)).map(|_| ()) - }) +fn rebase_nam(circ: &PyAny) -> PyResult<()> { + let py = circ.py(); + let auto_rebase = py + .import("pytket.passes.auto_rebase")? + .getattr("auto_rebase_pass")?; + let optype = py.import("pytket")?.getattr("OpType")?; + let locals = [("OpType", &optype)].into_py_dict(py); + let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?; + let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?; + rebase_pass.call1((circ,)).map(|_| ()) } /// Badger optimisation pass. @@ -76,14 +73,14 @@ fn rebase_nam(circ: &PyObject) -> PyResult<()> { /// /// Log files will be written to the directory `log_dir` if specified. #[pyfunction] -fn badger_optimise( - circ: PyObject, +fn badger_optimise<'py>( + circ: &'py PyAny, optimiser: &PyBadgerOptimiser, max_threads: Option, timeout: Option, log_dir: Option, rebase: Option, -) -> PyResult { +) -> PyResult<&'py PyAny> { // Default parameter values let rebase = rebase.unwrap_or(true); let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap()); @@ -94,7 +91,7 @@ fn badger_optimise( } // Rebase circuit if rebase { - rebase_nam(&circ)?; + rebase_nam(circ)?; } // Logic to choose how to split the circuit let badger_splits = |n_threads: NonZeroUsize| match n_threads.get() { @@ -111,7 +108,7 @@ fn badger_optimise( _ => unreachable!(), }; // Optimise - try_update_hugr(circ, |mut circ| { + try_update_hugr(circ, |mut circ, _| { let n_cx = circ .commands() .filter(|c| op_matches(c.optype(), T2Op::CX)) diff --git a/tket2-py/src/passes/chunks.rs b/tket2-py/src/passes/chunks.rs index 4642aa42..b1a71828 100644 --- a/tket2-py/src/passes/chunks.rs +++ b/tket2-py/src/passes/chunks.rs @@ -3,21 +3,19 @@ use derive_more::From; use pyo3::exceptions::PyAttributeError; use pyo3::prelude::*; -use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; use tket2::Circuit; -use tket_json_rs::circuit_json::SerialCircuit; -use crate::circuit::{with_hugr, T2Circuit}; +use crate::circuit::convert::CircuitType; +use crate::circuit::{try_with_hugr, with_hugr}; /// Split a circuit into chunks of a given size. #[pyfunction] -pub fn chunks(c: Py, max_chunk_size: usize) -> PyResult { - with_hugr(c, |hugr| { - // TODO: Detect if the circuit is in tket1 format or T2Circuit. - let is_tket1 = true; +pub fn chunks(c: &PyAny, max_chunk_size: usize) -> PyResult { + with_hugr(c, |hugr, typ| { + // TODO: Detect if the circuit is in tket1 format or Tk2Circuit. let chunks = CircuitChunks::split(&hugr, max_chunk_size); - (chunks, is_tket1).into() + (chunks, typ).into() }) } @@ -32,38 +30,36 @@ pub fn chunks(c: Py, max_chunk_size: usize) -> PyResult pub struct PyCircuitChunks { /// Rust representation of the circuit chunks. pub chunks: CircuitChunks, - /// Whether to reassemble the circuit in the tket1 format. - pub in_tket1: bool, + /// Whether to reassemble the circuit in the tket1 or tket2 format. + pub original_type: CircuitType, } #[pymethods] impl PyCircuitChunks { /// Reassemble the chunks into a circuit. - fn reassemble(&self) -> PyResult> { + fn reassemble<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { let hugr = self.clone().chunks.reassemble()?; - Python::with_gil(|py| match self.in_tket1 { - true => Ok(SerialCircuit::encode(&hugr)?.to_tket1(py)?.into_py(py)), - false => Ok(T2Circuit { hugr }.into_py(py)), - }) + self.original_type.convert(py, hugr) } /// Returns clones of the split circuits. - fn circuits(&self) -> PyResult>> { + fn circuits<'py>(&self, py: Python<'py>) -> PyResult> { self.chunks .iter() - .map(|hugr| SerialCircuit::encode(hugr)?.to_tket1_with_gil()) + .map(|hugr| self.original_type.convert(py, hugr.clone())) .collect() } /// Replaces a chunk's circuit with an updated version. - fn update_circuit(&mut self, index: usize, new_circ: Py) -> PyResult<()> { - let hugr = SerialCircuit::from_tket1_with_gil(new_circ)?.decode()?; - if hugr.circuit_signature() != self.chunks[index].circuit_signature() { - return Err(PyAttributeError::new_err( - "The new circuit has a different signature.", - )); - } - self.chunks[index] = hugr; - Ok(()) + fn update_circuit(&mut self, index: usize, new_circ: &PyAny) -> PyResult<()> { + try_with_hugr(new_circ, |hugr, _| { + if hugr.circuit_signature() != self.chunks[index].circuit_signature() { + return Err(PyAttributeError::new_err( + "The new circuit has a different signature.", + )); + } + self.chunks[index] = hugr; + Ok(()) + }) } } diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index ed3899e1..322b5220 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -3,7 +3,7 @@ pub mod portmatching; pub mod rewrite; -use crate::circuit::{to_hugr, T2Circuit}; +use crate::circuit::Tk2Circuit; use hugr::Hugr; use pyo3::prelude::*; @@ -45,9 +45,9 @@ pub struct Rule(pub [Hugr; 2]); #[pymethods] impl Rule { #[new] - fn new_rule(l: PyObject, r: PyObject) -> PyResult { - let l = to_hugr(l)?; - let r = to_hugr(r)?; + fn new_rule(l: &PyAny, r: &PyAny) -> PyResult { + let l = Tk2Circuit::from_tket1(l)?; + let r = Tk2Circuit::from_tket1(r)?; Ok(Rule([l.hugr, r.hugr])) } @@ -71,7 +71,7 @@ impl RuleMatcher { Ok(Self { matcher, rights }) } - pub fn find_match(&self, target: &T2Circuit) -> PyResult> { + pub fn find_match(&self, target: &Tk2Circuit) -> PyResult> { let h = &target.hugr; if let Some(p_match) = self.matcher.find_matches_iter(h).next() { let r = self.rights.get(p_match.pattern_id().0).unwrap().clone(); diff --git a/tket2-py/src/pattern/portmatching.rs b/tket2-py/src/pattern/portmatching.rs index c73e0548..8623bac3 100644 --- a/tket2-py/src/pattern/portmatching.rs +++ b/tket2-py/src/pattern/portmatching.rs @@ -29,8 +29,8 @@ pub struct PyCircuitPattern { impl PyCircuitPattern { /// Construct a pattern from a TKET1 circuit #[new] - pub fn from_circuit(circ: Py) -> PyResult { - let pattern = try_with_hugr(circ, |circ| CircuitPattern::try_from_circuit(&circ))?; + pub fn from_circuit(circ: &PyAny) -> PyResult { + let pattern = try_with_hugr(circ, |circ, _| CircuitPattern::try_from_circuit(&circ))?; Ok(pattern.into()) } @@ -79,8 +79,8 @@ impl PyPatternMatcher { } /// Find all convex matches in a circuit. - pub fn find_matches(&self, circ: PyObject) -> PyResult> { - with_hugr(circ, |circ| { + pub fn find_matches(&self, circ: &PyAny) -> PyResult> { + with_hugr(circ, |circ, _| { self.matcher .find_matches(&circ) .into_iter() diff --git a/tket2-py/test/test_bindings.py b/tket2-py/test/test_bindings.py index 9385b6d4..28aa8c0c 100644 --- a/tket2-py/test/test_bindings.py +++ b/tket2-py/test/test_bindings.py @@ -3,10 +3,26 @@ from tket2 import passes from tket2.passes import greedy_depth_reduce -from tket2.circuit import T2Circuit +from tket2.circuit import Tk2Circuit, to_hugr_dot from tket2.pattern import Rule, RuleMatcher +def test_conversion(): + tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + tk1_dot = to_hugr_dot(tk1) + + tk2 = Tk2Circuit(tk1) + tk2_dot = to_hugr_dot(tk2) + + assert type(tk2) == Tk2Circuit + assert tk1_dot == tk2_dot + + tk1_back = tk2.to_tket1() + + assert tk1_back == tk1 + assert type(tk1_back) == Circuit + + @dataclass class DepthOptimisePass: def apply(self, circ: Circuit) -> Circuit: @@ -35,10 +51,17 @@ def test_chunks(): c2 = chunks.reassemble() assert c2.depth() == 3 + assert type(c2) == Circuit + + # Split and reassemble, with a tket2 circuit + tk2_chunks = passes.chunks(Tk2Circuit(c2), 2) + tk2 = tk2_chunks.reassemble() + + assert type(tk2) == Tk2Circuit def test_cx_rule(): - c = T2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) + c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2)) matcher = RuleMatcher([rule]) @@ -47,13 +70,13 @@ def test_cx_rule(): c.apply_match(mtch) - out = c.finish() + out = c.to_tket1() assert out == Circuit(4).CX(0, 2) def test_multiple_rules(): - circ = T2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) + circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0)) rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1)) @@ -66,5 +89,5 @@ def test_multiple_rules(): assert match_count == 3 - out = circ.finish() + out = circ.to_tket1() assert out == Circuit(3).CX(0, 1).X(0)