From b58bf6020feddcef5f33593269016241ebd4f5b7 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez Date: Sun, 25 Aug 2024 13:27:17 -0400 Subject: [PATCH] Fix: Keep track of Vars for add_from_iter - Remove `from_iter` --- crates/circuit/src/dag_circuit.rs | 55 ++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index 457a7852ec95..3ed5da718183 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -6243,14 +6243,15 @@ impl DAGCircuit { // Create HashSets to keep track of each bit/var's last node let mut qubit_last_nodes: HashMap = HashMap::default(); let mut clbit_last_nodes: HashMap = HashMap::default(); - // TODO: Keep track of vars + // TODO: Refactor once Vars are in rust + // Dict [ Var: (int, VarWeight)] + let vars_last_nodes: Bound = PyDict::new_bound(py); // Store new nodes to return let mut new_nodes = vec![]; for instr in iter { let op_name = instr.op.name(); - // TODO: Use _vars - let (all_cbits, _vars): (Vec, Option>) = { + let (all_cbits, vars): (Vec, Option>) = { // Check if the clbits are already included if self.may_have_additional_wires(py, &instr) { let mut clbits: HashSet = @@ -6296,7 +6297,6 @@ impl DAGCircuit { }; qubit_last_nodes .entry(*qubit) - .and_modify(|val| *val = (new_node, qubit_last_node.1.clone())) .or_insert((new_node, qubit_last_node.1.clone())); nodes_to_connect.insert(qubit_last_node); } @@ -6318,38 +6318,63 @@ impl DAGCircuit { }; clbit_last_nodes .entry(clbit) - .and_modify(|val| *val = (new_node, clbit_last_node.1.clone())) .or_insert((new_node, clbit_last_node.1.clone())); nodes_to_connect.insert(clbit_last_node); } - // TODO: Check all the vars in this instruction. + // If available, check all the vars in this instruction + if let Some(vars_available) = vars { + for var in vars_available { + let var_last_node = if vars_last_nodes.contains(&var)? { + let (node, wire): (usize, PyObject) = + vars_last_nodes.get_item(&var)?.unwrap().extract()?; + (NodeIndex::new(node), Wire::Var(wire)) + } else { + let output_node = self.var_output_map.get(py, &var).unwrap(); + let (edge_id, predecessor_node) = self + .dag + .edges_directed(output_node, Incoming) + .next() + .map(|edge| (edge.id(), (edge.source(), edge.weight().clone()))) + .unwrap(); + self.dag.remove_edge(edge_id); + predecessor_node + }; + if let Wire::Var(var) = &var_last_node.1 { + vars_last_nodes.set_item(var, (new_node.index(), var))? + } + nodes_to_connect.insert(var_last_node); + } + } + + // Add all of the new edges for (node, wire) in nodes_to_connect { self.dag.add_edge(node, new_node, wire); } } - // Add the output_nodes back + // Add the output_nodes back to qargs for (qubit, (node, wire)) in qubit_last_nodes { let output_node = self.qubit_io_map[qubit.0 as usize][1]; self.dag.add_edge(node, output_node, wire); } + // Add the output_nodes back to cargs for (clbit, (node, wire)) in clbit_last_nodes { let output_node = self.clbit_io_map[clbit.0 as usize][1]; self.dag.add_edge(node, output_node, wire); } - Ok(new_nodes) - } + // Add the output_nodes back to vars + for item in vars_last_nodes.items() { + let (var, (node, wire)): (PyObject, (usize, PyObject)) = item.extract()?; + let output_node = self.var_output_map.get(py, &var).unwrap(); + self.dag + .add_edge(NodeIndex::new(node), output_node, Wire::Var(wire)); + } - /// Creates an instance of DAGCircuit from an iterator over `PackedInstruction`. - fn from_iter(_py: Python, _iter: I) -> PyResult - where - I: IntoIterator, - { - todo!() + Ok(new_nodes) } }