Skip to content

Commit

Permalink
Add: implement rest of search functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Jul 21, 2024
1 parent 7fcba6e commit f753fd7
Showing 1 changed file with 47 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
use ahash::{HashMap, HashSet};
use pyo3::prelude::*;
use qiskit_circuit::{equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}, operations::Param};
use qiskit_circuit::{
equivalence::{CircuitRep, EdgeData, Equivalence, Key, NodeData},
operations::Param,
};
use rustworkx_core::{
petgraph::{graph::{EdgeIndex, NodeIndex}, stable_graph::{StableDiGraph, StableGraph}, visit::Control},
traversal::{dijkstra_search, DijkstraEvent}
petgraph::{
graph::{NodeIndex},
stable_graph::{EdgeReference, StableDiGraph, StableGraph},
visit::{Control, EdgeRef},
},
};

pub struct BasisSearchVisitor<'a>
{
pub struct BasisSearchVisitor<'a> {
graph: &'a StableDiGraph<NodeData, EdgeData>,
target_basis: HashSet<&'a str>,
target_basis: HashSet<&'a Key>,
source_gates_remain: HashSet<&'a Key>,
num_gates_remain_for_rule: HashMap<usize, usize>,
basis_transforms: Vec<(&'a str, u32, &'a [Param], &'a CircuitRep)>,
Expand All @@ -21,7 +26,7 @@ impl<'a> BasisSearchVisitor<'a> {
pub fn new(
graph: &'a StableGraph<NodeData, EdgeData>,
source_basis: HashSet<&'a Key>,
target_basis: HashSet<&'a str>,
target_basis: HashSet<&'a Key>,
) -> Self {
let mut save_index = usize::MAX;
let mut num_gates_remain_for_rule = HashMap::default();
Expand Down Expand Up @@ -52,8 +57,13 @@ impl<'a> BasisSearchVisitor<'a> {
self.opt_cost_map.insert(gate, score);
}
if let Some(rule) = self.predecessors.get(gate) {
// Logger
self.basis_transforms.push((gate.name.as_str(), gate.num_qubits, &rule.params, &rule.circuit));
// TODO: Logger
self.basis_transforms.push((
gate.name.as_str(),
gate.num_qubits,
&rule.params,
&rule.circuit,
));
}
if self.source_gates_remain.is_empty() {
self.basis_transforms.reverse();
Expand All @@ -62,7 +72,7 @@ impl<'a> BasisSearchVisitor<'a> {
Control::Continue
}

pub fn examine_edge(&self, edge: EdgeIndex) -> Control<()> {
pub fn examine_edge(&mut self, edge: EdgeReference<'a, EdgeData>) -> Control<()> {
// _, target, edata = edge
// if edata is None:
// return
Expand All @@ -74,23 +84,38 @@ impl<'a> BasisSearchVisitor<'a> {
// # this `rule`. if `target` is already in basis, it's not beneficial to use this rule.
// if self._num_gates_remain_for_rule[edata.index] > 0 or target in self.target_basis:
// raise rustworkx.visit.PruneSearch
todo!()
let (target, edata) = (edge.target(), edge.weight());

// TODO: How should I handle a null edge_weight?
self.num_gates_remain_for_rule
.entry(edata.index)
.and_modify(|val| *val -= 1)
.or_default();
let target = &self.graph[target].key;

if self.num_gates_remain_for_rule[&edata.index] > 0 || self.target_basis.contains(target) {
return Control::Prune;
}
Control::Continue
}

pub fn edge_relaxed(&self, edge: EdgeIndex) -> Contol<()> {
pub fn edge_relaxed(&mut self, edge: EdgeReference<'a, EdgeData>) -> Control<()> {
// _, target, edata = edge
// if edata is not None:
// gate = self.graph[target].key
// self._predecessors[gate] = edata.rule
todo!()
let (target, edata) = (edge.target(), edge.weight());
let gate = &self.graph[target].key;
self.predecessors.insert(gate, &edata.rule);
Control::Continue
}
/// Returns the cost of an edge.
///
///
/// This function computes the cost of this edge rule by summing
/// the costs of all gates in the rule equivalence circuit. In the
/// end, we need to subtract the cost of the source since `dijkstra`
/// will later add it.
pub fn edge_cost(&self, edge_data: EdgeData) -> u32 {
pub fn edge_cost(&self, _edge_data: EdgeData) -> u32 {
// if edge_data is None:
// # the target of the edge is a gate in the target basis,
// # so we return a default value of 1.
Expand All @@ -103,5 +128,12 @@ impl<'a> BasisSearchVisitor<'a> {

// return cost_tot - self._opt_cost_map[edge_data.source]
todo!()
// TODO: Handle None case
// let mut cost_tot = 0;
// for instruction in edge_data.rule.circuit {
// let key = Key(name=instruction.operation.name, num_qubit=instruction.num_qubits);
// cost_tot += self.opt_cost_map[key]
// }
// return cost_tot - self.opt_cost_map[edge_data.source];
}
}

0 comments on commit f753fd7

Please sign in to comment.