diff --git a/crates/accelerate/src/target_transpiler/gate_map.rs b/crates/accelerate/src/target_transpiler/gate_map.rs index 1be102cfe1ad..279273ce0a5f 100644 --- a/crates/accelerate/src/target_transpiler/gate_map.rs +++ b/crates/accelerate/src/target_transpiler/gate_map.rs @@ -12,16 +12,67 @@ use super::property_map::PropsMap; use hashbrown::{hash_set::IntoIter, HashSet}; -use indexmap::IndexMap; +use indexmap::{set::IntoIter as IndexSetIntoIter, IndexMap, IndexSet}; use itertools::Itertools; -use pyo3::{exceptions::PyKeyError, prelude::*, pyclass, types::PyDict}; +use pyo3::{ + exceptions::PyKeyError, + prelude::*, + pyclass, + types::{PyDict, PySet}, +}; type GateMapType = IndexMap; type GateMapIterType = IntoIter; +type GateMapKeysIter = IndexSetIntoIter; + +enum GateMapIterTypes { + Iter(GateMapIterType), + Keys(GateMapKeysIter), +} + +#[pyclass(sequence)] +pub struct GateMapKeys { + keys: IndexSet, +} + +#[pymethods] +impl GateMapKeys { + fn __iter__(slf: PyRef) -> PyResult> { + let iter = GateMapIter { + iter: GateMapIterTypes::Keys(slf.keys.clone().into_iter()), + }; + Py::new(slf.py(), iter) + } + + fn __eq__(slf: PyRef, other: Bound) -> PyResult { + for item in other.iter() { + let key = item.extract::()?; + if !(slf.keys.contains(&key)) { + return Ok(false); + } + } + Ok(true) + } + + fn __len__(slf: PyRef) -> usize { + slf.keys.len() + } + + fn __contains__(slf: PyRef, obj: String) -> PyResult { + Ok(slf.keys.contains(&obj)) + } + + fn __repr__(slf: PyRef) -> String { + let mut output = "gate_map_keys[".to_owned(); + output.push_str(slf.keys.iter().join(", ").as_str()); + output.push(']'); + output + } +} #[pyclass] pub struct GateMapIter { - iter: GateMapIterType, + iter: GateMapIterTypes, } #[pymethods] @@ -30,7 +81,10 @@ impl GateMapIter { slf } fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { - slf.iter.next() + match &mut slf.iter { + GateMapIterTypes::Iter(iter) => iter.next(), + GateMapIterTypes::Keys(iter) => iter.next(), + } } } @@ -100,18 +154,21 @@ impl GateMap { pub fn __iter__(&self, py: Python<'_>) -> PyResult> { let iter = GateMapIter { - iter: self - .map - .keys() - .cloned() - .collect::>() - .into_iter(), + iter: GateMapIterTypes::Iter( + self.map + .keys() + .cloned() + .collect::>() + .into_iter(), + ), }; Py::new(py, iter) } - pub fn keys(&self) -> HashSet { - self.map.keys().cloned().collect() + pub fn keys(&self) -> GateMapKeys { + GateMapKeys { + keys: self.map.keys().cloned().collect::>(), + } } pub fn values(&self) -> Vec { diff --git a/crates/accelerate/src/target_transpiler/mod.rs b/crates/accelerate/src/target_transpiler/mod.rs index 23b9963a15bd..18e9340761e6 100644 --- a/crates/accelerate/src/target_transpiler/mod.rs +++ b/crates/accelerate/src/target_transpiler/mod.rs @@ -35,7 +35,7 @@ use qargs::{Qargs, QargsSet}; use self::{ exceptions::{QiskitError, TranspilerError}, - gate_map::{GateMap, GateMapIter}, + gate_map::{GateMap, GateMapIter, GateMapKeys}, property_map::PropsMapKeys, qargs::QargsOrTuple, }; @@ -1962,7 +1962,7 @@ impl Target { Ok(()) } - fn keys(&self) -> HashSet { + fn keys(&self) -> GateMapKeys { self.gate_map.keys() }