Skip to content

Commit

Permalink
Remove: add_equiv, set_entry from rust-native methods.
Browse files Browse the repository at this point in the history
- Add `node_index` Rust native method.
- Use python set comparison for `Param` check.
  • Loading branch information
raynelfss committed Jul 3, 2024
1 parent 9d39cf6 commit e9e3921
Showing 1 changed file with 105 additions and 162 deletions.
267 changes: 105 additions & 162 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ impl EquivalenceLibrary {
/// base (Optional[EquivalenceLibrary]): Base equivalence library to
/// be referenced if an entry is not found in this library.
#[new]
#[pyo3(signature= (base=None))]
fn new(base: Option<&EquivalenceLibrary>) -> Self {
if let Some(base) = base {
Self {
Expand Down Expand Up @@ -405,16 +406,51 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// equivalent_circuit (QuantumCircuit): A circuit equivalently
/// implementing the given Gate.
#[pyo3(name = "add_equivalence")]
fn py_add_equivalence(
fn add_equivalence(
&mut self,
py: Python,
gate: GateOper,
equivalent_circuit: CircuitRep,
mut equivalent_circuit: CircuitRep,
) -> PyResult<()> {
match self._add_equiv_native(gate, equivalent_circuit) {
Ok(_) => Ok(()),
Err(e) => Err(CircuitError::new_err(e.message)),
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(py, &gate.params, equivalent_circuit.parameters())?;

let key: Key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let equiv = Equivalence {
params: gate.params,
circuit: equivalent_circuit.clone(),
};

let target = self.set_default_node(key);
if let Some(node) = self._graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> =
HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key {
name: inst.operation.name().to_string(),
num_qubits: inst.operation.num_qubits(),
}));
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
target,
EdgeData {
index: self.rule_id,
num_gates: sources.len(),
rule: equiv.clone(),
source: source.clone(),
},
)
}));
for edge in edges {
self._graph.add_edge(edge.0, edge.1, edge.2);
}
self.rule_id += 1;
self.graph = None;
Ok(())
}

/// Check if a library contains any decompositions for gate.
Expand All @@ -441,12 +477,40 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
#[pyo3(name = "set_entry")]
fn py_set_entry(&mut self, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
match self._set_entry_native(gate, entry) {
Ok(_) => Ok(()),
Err(e) => Err(CircuitError::new_err(e.message)),
fn set_entry(
&mut self,
py: Python,
gate: GateOper,
mut entry: Vec<CircuitRep>,
) -> PyResult<()> {
for equiv in entry.iter_mut() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(py, &gate.params, equiv.parameters())?;
}

let key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let node_index = self.set_default_node(key);

if let Some(graph_ind) = self._graph.node_weight_mut(node_index) {
graph_ind.equivs.clear();
}

let edges: Vec<EdgeIndex> = self
._graph
.edges_directed(node_index, rustworkx_core::petgraph::Direction::Incoming)
.map(|x| x.id())
.collect();
for edge in edges {
self._graph.remove_edge(edge);
}
for equiv in entry {
self.add_equivalence(py, gate.clone(), equiv.clone())?
}
self.graph = None;
Ok(())
}

/// Gets the set of QuantumCircuits circuits from the library which
Expand All @@ -466,8 +530,7 @@ impl EquivalenceLibrary {
/// the library, from earliest to latest, from top to base. The
/// ordering of the StandardEquivalenceLibrary will not generally be
/// consistent across Qiskit versions.
#[pyo3(text_signature = "(gate, /,)")]
pub fn get_entry(&self, py: Python, gate: GateOper) -> PyResult<Py<PyList>> {
fn get_entry(&self, py: Python, gate: GateOper) -> PyResult<Py<PyList>> {
let key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
Expand All @@ -486,7 +549,7 @@ impl EquivalenceLibrary {
}

#[getter]
fn get_graph(&mut self, py: Python<'_>) -> PyResult<PyObject> {
fn get_graph(&mut self, py: Python) -> PyResult<PyObject> {
if let Some(graph) = &self.graph {
Ok(graph.clone_ref(py))
} else {
Expand All @@ -509,19 +572,20 @@ impl EquivalenceLibrary {
}

#[pyo3(name = "keys")]
pub fn py_keys(slf: PyRef<Self>) -> PyResult<Bound<PySet>> {
fn py_keys(slf: PyRef<Self>) -> PyResult<Bound<PySet>> {
let py_set = PySet::empty_bound(slf.py())?;
for key in slf.keys() {
py_set.add(key.clone().into_py(slf.py()))?;
}
Ok(py_set)
}

fn node_index(&self, key: Key) -> usize {
self.key_to_node_index[&key].index()
#[pyo3(name = "node_index")]
fn py_node_index(&self, key: &Key) -> usize {
self.node_index(key).index()
}

fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<'_, PyDict>> {
fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<PyDict>> {
let ret = PyDict::new_bound(slf.py());
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: Bound<PyDict> = PyDict::new_bound(slf.py());
Expand Down Expand Up @@ -549,7 +613,7 @@ impl EquivalenceLibrary {
Ok(ret)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<'_, PyDict>) -> PyResult<()> {
fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<PyDict>) -> PyResult<()> {
slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?;
let graph_nodes_ref: Bound<PyAny> = state.get_item("graph_nodes")?.unwrap();
let graph_nodes: &Bound<PyList> = graph_nodes_ref.downcast()?;
Expand Down Expand Up @@ -604,7 +668,7 @@ impl EquivalenceLibrary {
}

/// Create a new node if key not found
fn set_default_node(&mut self, key: Key) -> NodeIndex {
pub fn set_default_node(&mut self, key: Key) -> NodeIndex {
if let Some(value) = self.key_to_node_index.get(&key) {
*value
} else {
Expand All @@ -617,166 +681,45 @@ impl EquivalenceLibrary {
}
}

/// Rust native equivalent to `EquivalenceLibrary.add_equivalence()`
///
/// Add a new equivalence to the library. Future queries for the Gate
/// will include the given circuit, in addition to all existing equivalences
/// (including those from base).
///
/// Parameterized Gates (those including `qiskit.circuit.Parameters` in their
/// `Gate.params`) can be marked equivalent to parameterized circuits,
/// provided the parameters match.
/// Retrieve the `NodeIndex` that represents a `Key`
///
/// # Arguments:
/// * `operation`: A Gate instance.
/// * `params`: A list of the gate's parameters.
/// * `equivalent_circuit`: A circuit equivalently implementing the given Gate.
pub fn add_equivalence(
&mut self,
operation: &OperationType,
params: &[Param],
equivalent_circuit: CircuitRep,
) -> Result<(), EquivalenceError> {
let gate = GateOper {
operation: operation.clone(),
params: params.to_vec().into(),
};
self._add_equiv_native(gate, equivalent_circuit)
}

fn _add_equiv_native(
&mut self,
gate: GateOper,
mut equivalent_circuit: CircuitRep,
) -> Result<(), EquivalenceError> {
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(&gate.params, equivalent_circuit.parameters())?;

let key: Key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let equiv = Equivalence {
params: gate.params,
circuit: equivalent_circuit.clone(),
};

let target = self.set_default_node(key);
if let Some(node) = self._graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> =
HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key {
name: inst.operation.name().to_string(),
num_qubits: inst.operation.num_qubits(),
}));
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
target,
EdgeData {
index: self.rule_id,
num_gates: sources.len(),
rule: equiv.clone(),
source: source.clone(),
},
)
}));
for edge in edges {
self._graph.add_edge(edge.0, edge.1, edge.2);
}
self.rule_id += 1;
self.graph = None;
Ok(())
}

/// Rust native equivalent to `EquivalenceLibrary.set_entry()`
/// Set the equivalence record for a Gate. Future queries for the Gate
/// will return only the circuits provided.
/// * `key`: The `Key` to look for.
///
/// Parameterized Gates (those including `qiskit.circuit.Parameters` in their
/// `Gate.params`) can be marked equivalent to parameterized circuits,
/// provided the parameters match.
///
/// # Arguments:
/// * `operation`: A Gate instance.
/// * `params`: A list of the gate's parameters.
/// * `entry` : A list of QuantumCircuits, each equivalently implementing the given Gate.
pub fn set_entry(
&mut self,
operation: &OperationType,
params: &[Param],
entry: Vec<CircuitRep>,
) -> Result<(), EquivalenceError> {
let gate = GateOper {
operation: operation.clone(),
params: params.to_vec().into(),
};
self._set_entry_native(gate, entry)
}

fn _set_entry_native(
&mut self,
gate: GateOper,
mut entry: Vec<CircuitRep>,
) -> Result<(), EquivalenceError> {
for equiv in entry.iter_mut() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(&gate.params, equiv.parameters())?;
}

let key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let node_index = self.set_default_node(key);

if let Some(graph_ind) = self._graph.node_weight_mut(node_index) {
graph_ind.equivs.clear();
}

let edges: Vec<EdgeIndex> = self
._graph
.edges_directed(node_index, rustworkx_core::petgraph::Direction::Incoming)
.map(|x| x.id())
.collect();
for edge in edges {
self._graph.remove_edge(edge);
}
for equiv in entry {
self._add_equiv_native(gate.clone(), equiv.clone())?
}
self.graph = None;
Ok(())
/// # Returns:
/// `NodeIndex`
pub fn node_index(&self, key: &Key) -> NodeIndex {
self.key_to_node_index[key]
}
}

fn raise_if_param_mismatch(
py: Python,
gate_params: &[Param],
circuit_parameters: &[Param],
) -> Result<(), EquivalenceError> {
let gate_params = gate_params
.iter()
.filter(|param| matches!(param, Param::ParameterExpression(_)))
.collect_vec();
if gate_params.len() == circuit_parameters.len()
&& gate_params.iter().any(|x| !circuit_parameters.contains(x))
{
return Err(EquivalenceError::new_err(format!(
) -> PyResult<()> {
let gate_params_obj = PySet::new_bound(
py,
gate_params
.iter()
.filter(|param| matches!(param, Param::ParameterExpression(_))),
)?;
if !gate_params_obj.eq(PySet::new_bound(py, circuit_parameters)?)? {
return Err(CircuitError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different parameters. Gate params: {:#?}. \
Circuit params: {:#?}.",
of different parameters. Gate params: {:?}. \
Circuit params: {:?}.",
gate_params, circuit_parameters
)));
}
Ok(())
}

fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> Result<(), EquivalenceError> {
fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> PyResult<()> {
if gate.operation.num_qubits() != circuit.num_qubits
|| gate.operation.num_clbits() != circuit.num_clbits
{
return Err(EquivalenceError::new_err(format!(
return Err(CircuitError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different shapes. Gate: {} qubits and {} clbits. \
Circuit: {} qubits and {} clbits.",
Expand Down Expand Up @@ -833,7 +776,7 @@ impl Display for EquivalenceError {
}
}

fn to_pygraph<N, E>(py: Python<'_>, pet_graph: &StableDiGraph<N, E>) -> PyResult<PyObject>
fn to_pygraph<N, E>(py: Python, pet_graph: &StableDiGraph<N, E>) -> PyResult<PyObject>
where
N: IntoPy<PyObject> + Clone,
E: IntoPy<PyObject> + Clone,
Expand Down

0 comments on commit e9e3921

Please sign in to comment.