Skip to content

Commit

Permalink
Fix: Avoid checking validity of NodeIndex
Browse files Browse the repository at this point in the history
- Make rust_native `ancestors`, `descendants` and `bfs_successors` return instances of `NodeIndex`.
  • Loading branch information
raynelfss committed Jul 2, 2024
1 parent 80defc5 commit 96c7873
Showing 1 changed file with 26 additions and 56 deletions.
82 changes: 26 additions & 56 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2810,53 +2810,36 @@ def _format(operand):
/// Returns set of the ancestors of a node as DAGOpNodes and DAGInNodes.
#[pyo3(name="ancestors")]
fn py_ancestors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
if let Some(node_index) = node.node {
let ancestors = core_ancestors(&self.dag, node_index)
.filter_map(|node| self.get_node(py, node).ok());
let ancestor_set = PySet::empty_bound(py)?;
for ancestor in ancestors {
ancestor_set.add(ancestor)?;
}
return Ok(ancestor_set.unbind());
} else {
unreachable!("The provided {:?} is not properly initialized.", node)
}
let ancestors: PyResult<Vec<PyObject>> = self.ancestors(node.node.unwrap())
.map(|node| self.get_node(py, node)).collect();
Ok(PySet::new_bound(py, &ancestors?)?.unbind())
}

/// Returns set of the descendants of a node as DAGOpNodes and DAGOutNodes.
#[pyo3(name="descendants")]
fn py_descendants(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
let descendants = self.descendants(node)
.filter_map(|node| self.get_node(py, node).ok());
let descendant_set = PySet::empty_bound(py)?;
for descendant in descendants {
descendant_set.add(descendant)?;
}
return Ok(descendant_set.unbind());
let descendants: PyResult<Vec<PyObject>> = self.descendants(node.node.unwrap())
.map(|node| self.get_node(py, node)).collect();
Ok(PySet::new_bound(py, &descendants?)?.unbind())
}

/// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node
/// and [DAGNode] is its successors in BFS order.
#[pyo3(name="bfs_successors")]
fn py_bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PyList>> {
// return iter(rx.bfs_successors(self._multi_graph, node._node_id))
let successor_index = self.bfs_successors(node).filter_map(|(node, nodes)| {
match (
self.get_node(py, node).ok(),
nodes
.iter()
.filter_map(|sub_node| self.get_node(py, *sub_node).ok())
.collect::<Vec<_>>(),
) {
(Some(node), nodes) => Some((node, nodes)),
_ => None,
}
});
let successor_list = PyList::empty_bound(py);
for successors in successor_index {
successor_list.append(successors)?;
}
return Ok(successor_list.unbind());
fn py_bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PyIterator>> {
let successor_index: PyResult<Vec<(PyObject, Vec<PyObject>)>> = self.bfs_successors(node.node.unwrap()).map(|(node, nodes)| -> PyResult<(PyObject, Vec<PyObject>)> {
Ok((
self.get_node(py, node)?,
nodes
.iter()
.map(|sub_node| self.get_node(py, *sub_node))
.collect::<PyResult<Vec<_>>>()?
))
}).collect();
Ok(PyList::new_bound(py, successor_index?)
.into_any()
.iter()?
.unbind())
}


Expand Down Expand Up @@ -3611,32 +3594,19 @@ impl DAGCircuit {
}

/// Returns an iterator of the ancestors indices of a node.
pub fn ancestors<'a>(&'a self, node: &DAGNode) -> impl Iterator<Item = NodeIndex> + 'a {
if let Some(node_index) = node.node {
core_ancestors(&self.dag, node_index)
} else {
unreachable!("The provided {:?} is not properly initialized.", node)
}
pub fn ancestors<'a>(&'a self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + 'a {
core_ancestors(&self.dag, node)
}

/// Returns an iterator of the descendants of a node as DAGOpNodes and DAGOutNodes.
pub fn descendants<'a>(&'a self, node: &'a DAGNode) -> impl Iterator<Item= NodeIndex> + 'a {
if let Some(node_index) = node.node {
core_descendants(&self.dag, node_index)
} else {
unreachable!("The provided {:?} is not properly initialized.", node)
}
pub fn descendants<'a>(&'a self, node: NodeIndex) -> impl Iterator<Item= NodeIndex> + 'a {
core_descendants(&self.dag, node)
}

/// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node
/// and [DAGNode] is its successors in BFS order.
pub fn bfs_successors<'a>(&'a self, node: &'a DAGNode) -> impl Iterator<Item = (NodeIndex, Vec<NodeIndex>)> + 'a {
// return iter(rx.bfs_successors(self._multi_graph, node._node_id))
if let Some(node_index) = node.node {
core_bfs_successors(&self.dag, node_index)
} else {
unreachable!("The provided {:?} is not properly initialized.", node)
}
pub fn bfs_successors<'a>(&'a self, node: NodeIndex) -> impl Iterator<Item = (NodeIndex, Vec<NodeIndex>)> + 'a {
core_bfs_successors(&self.dag, node)
}

fn unpack_into(&self, py: Python, id: NodeIndex, weight: &NodeType) -> PyResult<Py<PyAny>> {
Expand Down

0 comments on commit 96c7873

Please sign in to comment.