diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index 514991d0c3ff..60291eca209e 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -2810,53 +2810,36 @@ def _format(operand): /// Returns set of the ancestors of a node as DAGOpNodes and DAGInNodes. #[pyo3(name="ancestors")] fn py_ancestors(&self, py: Python, node: &DAGNode) -> PyResult> { - if let Some(node_index) = node.node { - let ancestors = core_ancestors(&self.dag, node_index) - .filter_map(|node| self.get_node(py, node).ok()); - let ancestor_set = PySet::empty_bound(py)?; - for ancestor in ancestors { - ancestor_set.add(ancestor)?; - } - return Ok(ancestor_set.unbind()); - } else { - unreachable!("The provided {:?} is not properly initialized.", node) - } + let ancestors: PyResult> = self.ancestors(node.node.unwrap()) + .map(|node| self.get_node(py, node)).collect(); + Ok(PySet::new_bound(py, &ancestors?)?.unbind()) } /// Returns set of the descendants of a node as DAGOpNodes and DAGOutNodes. #[pyo3(name="descendants")] fn py_descendants(&self, py: Python, node: &DAGNode) -> PyResult> { - let descendants = self.descendants(node) - .filter_map(|node| self.get_node(py, node).ok()); - let descendant_set = PySet::empty_bound(py)?; - for descendant in descendants { - descendant_set.add(descendant)?; - } - return Ok(descendant_set.unbind()); + let descendants: PyResult> = self.descendants(node.node.unwrap()) + .map(|node| self.get_node(py, node)).collect(); + Ok(PySet::new_bound(py, &descendants?)?.unbind()) } /// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node /// and [DAGNode] is its successors in BFS order. #[pyo3(name="bfs_successors")] - fn py_bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult> { - // return iter(rx.bfs_successors(self._multi_graph, node._node_id)) - let successor_index = self.bfs_successors(node).filter_map(|(node, nodes)| { - match ( - self.get_node(py, node).ok(), - nodes - .iter() - .filter_map(|sub_node| self.get_node(py, *sub_node).ok()) - .collect::>(), - ) { - (Some(node), nodes) => Some((node, nodes)), - _ => None, - } - }); - let successor_list = PyList::empty_bound(py); - for successors in successor_index { - successor_list.append(successors)?; - } - return Ok(successor_list.unbind()); + fn py_bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult> { + let successor_index: PyResult)>> = self.bfs_successors(node.node.unwrap()).map(|(node, nodes)| -> PyResult<(PyObject, Vec)> { + Ok(( + self.get_node(py, node)?, + nodes + .iter() + .map(|sub_node| self.get_node(py, *sub_node)) + .collect::>>()? + )) + }).collect(); + Ok(PyList::new_bound(py, successor_index?) + .into_any() + .iter()? + .unbind()) } @@ -3611,32 +3594,19 @@ impl DAGCircuit { } /// Returns an iterator of the ancestors indices of a node. - pub fn ancestors<'a>(&'a self, node: &DAGNode) -> impl Iterator + 'a { - if let Some(node_index) = node.node { - core_ancestors(&self.dag, node_index) - } else { - unreachable!("The provided {:?} is not properly initialized.", node) - } + pub fn ancestors<'a>(&'a self, node: NodeIndex) -> impl Iterator + 'a { + core_ancestors(&self.dag, node) } /// Returns an iterator of the descendants of a node as DAGOpNodes and DAGOutNodes. - pub fn descendants<'a>(&'a self, node: &'a DAGNode) -> impl Iterator + 'a { - if let Some(node_index) = node.node { - core_descendants(&self.dag, node_index) - } else { - unreachable!("The provided {:?} is not properly initialized.", node) - } + pub fn descendants<'a>(&'a self, node: NodeIndex) -> impl Iterator + 'a { + core_descendants(&self.dag, node) } /// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node /// and [DAGNode] is its successors in BFS order. - pub fn bfs_successors<'a>(&'a self, node: &'a DAGNode) -> impl Iterator)> + 'a { - // return iter(rx.bfs_successors(self._multi_graph, node._node_id)) - if let Some(node_index) = node.node { - core_bfs_successors(&self.dag, node_index) - } else { - unreachable!("The provided {:?} is not properly initialized.", node) - } + pub fn bfs_successors<'a>(&'a self, node: NodeIndex) -> impl Iterator)> + 'a { + core_bfs_successors(&self.dag, node) } fn unpack_into(&self, py: Python, id: NodeIndex, weight: &NodeType) -> PyResult> {