Skip to content

Commit

Permalink
Merge branch 'oxidize-dag' into oxidize-dag-layers
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss authored Jul 10, 2024
2 parents 6253c89 + d478257 commit 3f2a10a
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 64 deletions.
6 changes: 6 additions & 0 deletions crates/circuit/src/circuit_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ impl PackedInstruction {
py_op,
}
}

pub fn is_parameterized(&self) -> bool {
self.params
.iter()
.any(|x| matches!(x, Param::ParameterExpression(_)))
}
}

/// A single instruction in a :class:`.QuantumCircuit`, comprised of the :attr:`operation` and
Expand Down
258 changes: 194 additions & 64 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use rustworkx_core::traversal::{
descendants as core_descendants,
};
use std::borrow::Borrow;
use std::collections::VecDeque;
use std::convert::Infallible;
use std::f64::consts::PI;
use std::ffi::c_double;
Expand Down Expand Up @@ -2599,9 +2600,9 @@ def _format(operand):
Ok(tup.into_any().iter().unwrap().unbind())
}

/// Iterator for edge values and source and dest node
/// Iterator for edge values with source and destination node.
///
/// This works by returning the output edges from the specified nodes. If
/// This works by returning the outgoing edges from the specified nodes. If
/// no nodes are specified all edges from the graph are returned.
///
/// Args:
Expand All @@ -2610,19 +2611,47 @@ def _format(operand):
/// all edges are returned from the graph.
///
/// Yield:
/// edge: the edge in the same format as out_edges the tuple
/// (source node, destination node, edge data)
fn edges(&self, nodes: Option<Bound<PyAny>>) -> Py<PyIterator> {
// if nodes is None:
// nodes = self._multi_graph.nodes()
//
// elif isinstance(nodes, (DAGOpNode, DAGInNode, DAGOutNode)):
// nodes = [nodes]
// for node in nodes:
// raw_nodes = self._multi_graph.out_edges(node._node_id)
// for source, dest, edge in raw_nodes:
// yield (self._multi_graph[source], self._multi_graph[dest], edge)
todo!()
/// edge: the edge as a tuple with the format
/// (source node, destination node, edge wire)
fn edges(&self, nodes: Option<Bound<PyAny>>, py: Python) -> PyResult<Py<PyIterator>> {
let get_node_index = |obj: &Bound<PyAny>| -> PyResult<NodeIndex> {
Ok(obj.downcast::<DAGNode>()?.borrow().node.unwrap())
};

let actual_nodes: Vec<_> = match nodes {
None => self.dag.node_indices().collect(),
Some(nodes) => {
let mut out = Vec::new();
if let Ok(node) = get_node_index(&nodes) {
out.push(node);
} else {
for node in nodes.iter()? {
out.push(get_node_index(&node?)?);
}
}
out
}
};

let mut edges = Vec::new();
for node in actual_nodes {
for edge in self.dag.edges_directed(node, Outgoing) {
edges.push((
self.get_node(py, edge.source())?,
self.get_node(py, edge.target())?,
match edge.weight() {
Wire::Qubit(qubit) => self.qubits.get(*qubit).unwrap(),
Wire::Clbit(clbit) => self.clbits.get(*clbit).unwrap(),
},
))
}
}

Ok(PyTuple::new_bound(py, edges)
.into_any()
.iter()
.unwrap()
.unbind())
}

/// Get the list of "op" nodes in the dag.
Expand Down Expand Up @@ -3091,61 +3120,85 @@ def _format(operand):
/// in the circuit's basis.
///
/// Nodes must have only one successor to continue the run.
fn collect_runs(&self, namelist: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
// def filter_fn(node):
// return (
// isinstance(node, DAGOpNode)
// and node.op.name in namelist
// and getattr(node.op, "condition", None) is None
// )
//
// group_list = rx.collect_runs(self._multi_graph, filter_fn)
// return {tuple(x) for x in group_list}
todo!()
#[pyo3(name = "collect_runs")]
fn py_collect_runs(&self, py: Python, namelist: &Bound<PyList>) -> PyResult<Py<PySet>> {
let mut name_list_set = HashSet::with_capacity(namelist.len());
for name in namelist.iter() {
name_list_set.insert(name.extract::<String>()?);
}
match self.collect_runs(name_list_set) {
Some(runs) => {
let run_iter = runs.map(|node_indices| {
PyTuple::new_bound(
py,
node_indices
.into_iter()
.map(|node_index| self.get_node(py, node_index).unwrap()),
)
.unbind()
});
let out_set = PySet::empty_bound(py)?;
for run_tuple in run_iter {
out_set.add(run_tuple)?;
}
Ok(out_set.unbind())
}
None => Err(PyRuntimeError::new_err(
"Invalid DAGCircuit, cycle encountered",
)),
}
}

/// Return a set of non-conditional runs of 1q "op" nodes.
fn collect_1q_runs(&self) -> PyResult<Py<PyList>> {
// def filter_fn(node):
// return (
// isinstance(node, DAGOpNode)
// and len(node.qargs) == 1
// and len(node.cargs) == 0
// and isinstance(node.op, Gate)
// and hasattr(node.op, "__array__")
// and getattr(node.op, "condition", None) is None
// and not node.op.is_parameterized()
// )
//
// return rx.collect_runs(self._multi_graph, filter_fn)
todo!()
#[pyo3(name = "collect_1q_runs")]
fn py_collect_1q_runs(&self, py: Python) -> PyResult<Py<PyList>> {
match self.collect_1q_runs() {
Some(runs) => {
let runs_iter = runs.map(|node_indices| {
PyList::new_bound(
py,
node_indices
.into_iter()
.map(|node_index| self.get_node(py, node_index).unwrap()),
)
.unbind()
});
let out_list = PyList::empty_bound(py);
for run_list in runs_iter {
out_list.append(run_list)?;
}
Ok(out_list.unbind())
}
None => Err(PyRuntimeError::new_err(
"Invalid DAGCircuit, cycle encountered",
)),
}
}

/// Return a set of non-conditional runs of 2q "op" nodes.
fn collect_2q_runs(&self) -> PyResult<Py<PyList>> {
// to_qid = {}
// for i, qubit in enumerate(self.qubits):
// to_qid[qubit] = i
//
// def filter_fn(node):
// if isinstance(node, DAGOpNode):
// return (
// isinstance(node.op, Gate)
// and len(node.qargs) <= 2
// and not getattr(node.op, "condition", None)
// and not node.op.is_parameterized()
// )
// else:
// return None
//
// def color_fn(edge):
// if isinstance(edge, Qubit):
// return to_qid[edge]
// else:
// return None
//
// return rx.collect_bicolor_runs(self._multi_graph, filter_fn, color_fn)
todo!()
#[pyo3(name = "collect_2q_runs")]
fn py_collect_2q_runs(&self, py: Python) -> PyResult<Py<PyList>> {
match self.collect_2q_runs() {
Some(runs) => {
let runs_iter = runs.into_iter().map(|node_indices| {
PyList::new_bound(
py,
node_indices
.into_iter()
.map(|node_index| self.get_node(py, node_index).unwrap()),
)
.unbind()
});
let out_list = PyList::empty_bound(py);
for run_list in runs_iter {
out_list.append(run_list)?;
}
Ok(out_list.unbind())
}
None => Err(PyRuntimeError::new_err(
"Invalid DAGCircuit, cycle encountered",
)),
}
}

/// Iterator for nodes that affect a given wire.
Expand Down Expand Up @@ -3368,6 +3421,83 @@ def _format(operand):
}

impl DAGCircuit {
/// Return an iterator of gate runs with non-conditional op nodes of given names
pub fn collect_runs(
&self,
namelist: HashSet<String>,
) -> Option<impl Iterator<Item = Vec<NodeIndex>> + '_> {
let filter_fn = move |node_index: NodeIndex| -> Result<bool, Infallible> {
let node = &self.dag[node_index];
match node {
NodeType::Operation(inst) => Ok(namelist.contains(inst.op.name())
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}),
_ => Ok(false),
}
};
rustworkx_core::dag_algo::collect_runs(&self.dag, filter_fn)
.map(|node_iter| node_iter.map(|x| x.unwrap()))
}

/// Return a set of non-conditional runs of 1q "op" nodes.
pub fn collect_1q_runs(&self) -> Option<impl Iterator<Item = Vec<NodeIndex>> + '_> {
let filter_fn = move |node_index: NodeIndex| -> Result<bool, Infallible> {
let node = &self.dag[node_index];
match node {
NodeType::Operation(inst) => Ok(inst.op.num_qubits() == 1
&& inst.op.num_clbits() == 0
&& inst.op.matrix(&inst.params).is_some()
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}),
_ => Ok(false),
}
};
rustworkx_core::dag_algo::collect_runs(&self.dag, filter_fn)
.map(|node_iter| node_iter.map(|x| x.unwrap()))
}

/// Return a set of non-conditional runs of 2q "op" nodes.
pub fn collect_2q_runs(&self) -> Option<Vec<Vec<NodeIndex>>> {
let filter_fn = move |node_index: NodeIndex| -> Result<Option<bool>, Infallible> {
let node = &self.dag[node_index];
match node {
NodeType::Operation(inst) => match &inst.op {
OperationType::Standard(gate) => Ok(Some(
gate.num_qubits() <= 2
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}
&& !inst.is_parameterized(),
)),
OperationType::Gate(gate) => Ok(Some(
gate.num_qubits() <= 2
&& match &inst.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}
&& !inst.is_parameterized(),
)),
_ => Ok(Some(false)),
},
_ => Ok(None),
}
};

let color_fn = move |edge_index: EdgeIndex| -> Result<Option<usize>, Infallible> {
let wire = self.dag.edge_weight(edge_index).unwrap();
match wire {
Wire::Qubit(index) => Ok(Some(index.0 as usize)),
Wire::Clbit(_) => Ok(None),
}
};
rustworkx_core::dag_algo::collect_bicolor_runs(&self.dag, filter_fn, color_fn).unwrap()
}

fn increment_op(&mut self, op: String) {
match self.op_names.entry(op) {
hash_map::Entry::Occupied(mut o) => {
Expand Down

0 comments on commit 3f2a10a

Please sign in to comment.