diff --git a/crates/accelerate/src/basis/basis_translator/basis_search.rs b/crates/accelerate/src/basis/basis_translator/basis_search.rs index 82b8fc474b93..a1c9a6d192a3 100644 --- a/crates/accelerate/src/basis/basis_translator/basis_search.rs +++ b/crates/accelerate/src/basis/basis_translator/basis_search.rs @@ -14,14 +14,14 @@ use std::cell::RefCell; use hashbrown::{HashMap, HashSet}; use pyo3::prelude::*; + +use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}; use qiskit_circuit::operations::{Operation, Param}; use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex, StableDiGraph}; use rustworkx_core::petgraph::visit::Control; use rustworkx_core::traversal::{dijkstra_search, DijkstraEvent}; use smallvec::SmallVec; -use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}; - #[pyfunction] #[pyo3(name = "basis_search")] /// Search for a set of transformations from source_basis to target_basis. @@ -68,8 +68,9 @@ pub(crate) fn basis_search( ) -> Option { // Build the visitor attributes: let mut num_gates_remaining_for_rule: HashMap = HashMap::default(); - let predecessors: RefCell> = RefCell::new(HashMap::default()); - let opt_cost_map: RefCell> = RefCell::new(HashMap::default()); + let predecessors: RefCell> = + RefCell::new(HashMap::default()); + let opt_cost_map: RefCell> = RefCell::new(HashMap::default()); let mut basis_transforms: Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)> = vec![]; // Initialize visitor attributes: @@ -124,21 +125,26 @@ pub(crate) fn basis_search( }); // Edge cost function for Visitor - let edge_weight = - |edge: EdgeReference>| -> Result { - if edge.weight().is_none() { - return Ok(1); - } - let edge_data = edge.weight().as_ref().unwrap(); - let mut cost_tot = 0; - let borrowed_cost = opt_cost_map.borrow(); - for instruction in edge_data.rule.circuit.data.iter() { - let instruction_op = instruction.op.view(); - cost_tot += borrowed_cost[&(instruction_op.name(), instruction_op.num_qubits())]; - } - Ok(cost_tot - - borrowed_cost[&(edge_data.source.name.as_str(), edge_data.source.num_qubits)]) - }; + let edge_weight = |edge: EdgeReference>| -> Result { + if edge.weight().is_none() { + return Ok(1); + } + let edge_data = edge.weight().as_ref().unwrap(); + let mut cost_tot = 0; + let borrowed_cost = opt_cost_map.borrow(); + for instruction in edge_data.rule.circuit.data.iter() { + let instruction_op = instruction.op.view(); + cost_tot += borrowed_cost[&( + instruction_op.name().to_string(), + instruction_op.num_qubits(), + )]; + } + Ok(cost_tot + - borrowed_cost[&( + edge_data.source.name.to_string(), + edge_data.source.num_qubits, + )]) + }; let basis_transforms = match dijkstra_search( &equiv_lib.graph, @@ -148,14 +154,15 @@ pub(crate) fn basis_search( match event { DijkstraEvent::Discover(n, score) => { let gate_key = &equiv_lib.graph[n].key; - let gate = &(gate_key.name.as_str(), gate_key.num_qubits); + let gate = (gate_key.name.to_string(), gate_key.num_qubits); source_basis_remain.remove(gate_key); let mut borrowed_cost_map = opt_cost_map.borrow_mut(); - borrowed_cost_map - .entry(*gate) - .and_modify(|cost_ref| *cost_ref = score) - .or_insert(score); - if let Some(rule) = predecessors.borrow().get(gate) { + if let Some(entry) = borrowed_cost_map.get_mut(&gate) { + *entry = score; + } else { + borrowed_cost_map.insert(gate.clone(), score); + } + if let Some(rule) = predecessors.borrow().get(&gate) { // TODO: Logger basis_transforms.push(( gate_key.name.to_string(), @@ -174,7 +181,7 @@ pub(crate) fn basis_search( let gate = &equiv_lib.graph[target].key; predecessors .borrow_mut() - .entry((gate.name.as_str(), gate.num_qubits)) + .entry((gate.name.to_string(), gate.num_qubits)) .and_modify(|value| *value = edata.rule.clone()) .or_insert(edata.rule.clone()); }