Skip to content

Commit

Permalink
Fix: Rebalance PackedInstruction to avoid having dead references.
Browse files Browse the repository at this point in the history
- Make all attributes of the `PackedInstruction` private to allow for control of the underlying cached `PyObject` reference stored in Python. This ensures that no dead references are left as soon as the gate is modified from the Rust side.
- Reimplement `Clone` trait for `PackedInstruction`s to clear the cached gate whenever we decide make calls to `clone()`.
- Re-organize rest of the code to adapt to the new changes.
  • Loading branch information
raynelfss committed Dec 8, 2024
1 parent 582070d commit de6bd6f
Show file tree
Hide file tree
Showing 24 changed files with 676 additions and 608 deletions.
4 changes: 2 additions & 2 deletions crates/accelerate/src/barrier_before_final_measurement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ pub fn barrier_before_final_measurements(
let NodeType::Operation(ref inst) = dag.dag()[*node] else {
unreachable!();
};
if !FINAL_OP_NAMES.contains(&inst.op.name()) {
if !FINAL_OP_NAMES.contains(&inst.op().name()) {
return false;
}
let is_final_op = dag.bfs_successors(*node).all(|(_, child_successors)| {
!child_successors.iter().any(|suc| match dag.dag()[*suc] {
NodeType::Operation(ref suc_inst) => {
!FINAL_OP_NAMES.contains(&suc_inst.op.name())
!FINAL_OP_NAMES.contains(&suc_inst.op().name())
}
_ => false,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub(crate) fn basis_search(
let mut cost_tot = 0;
let borrowed_cost = opt_cost_map.borrow();
for instruction in edge_data.rule.circuit.0.iter() {
let instruction_op = instruction.op.view();
let instruction_op = instruction.op().view();
cost_tot += borrowed_cost[&(
instruction_op.name().to_string(),
instruction_op.num_qubits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(super) fn compose_transforms<'a>(
.filter_map(|node| {
if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) {
if (gate_name.as_str(), *gate_num_qubits)
== (op.op.name(), op.op.num_qubits())
== (op.op().name(), op.op().num_qubits())
{
Some((
node,
Expand Down Expand Up @@ -144,11 +144,11 @@ fn get_gates_num_params(
for node in dag.op_nodes(true) {
if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) {
example_gates.insert(
(op.op.name().to_string(), op.op.num_qubits()),
(op.op().name().to_string(), op.op().num_qubits()),
op.params_view().len(),
);
if op.op.control_flow() {
let blocks = op.op.blocks();
if op.op().control_flow() {
let blocks = op.op().blocks();
for block in blocks {
get_gates_num_params_circuit(&block, example_gates)?;
}
Expand All @@ -168,11 +168,11 @@ fn get_gates_num_params_circuit(
) -> PyResult<()> {
for inst in circuit.iter() {
example_gates.insert(
(inst.op.name().to_string(), inst.op.num_qubits()),
(inst.op().name().to_string(), inst.op().num_qubits()),
inst.params_view().len(),
);
if inst.op.control_flow() {
let blocks = inst.op.blocks();
if inst.op().control_flow() {
let blocks = inst.op().blocks();
for block in blocks {
get_gates_num_params_circuit(&block, example_gates)?;
}
Expand Down
184 changes: 82 additions & 102 deletions crates/accelerate/src/basis/basis_translator/mod.rs

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions crates/accelerate/src/check_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ fn recurse<'py>(
};
for node in dag.op_nodes(false) {
if let NodeType::Operation(inst) = &dag.dag()[node] {
let qubits = dag.get_qargs(inst.qubits);
if inst.op.control_flow() {
if let OperationRef::Instruction(py_inst) = inst.op.view() {
let qubits = dag.get_qargs(inst.qubits());
if inst.op().control_flow() {
if let OperationRef::Instruction(py_inst) = inst.op().view() {
let raw_blocks = py_inst.instruction.getattr(py, "blocks")?;
let circuit_to_dag = CIRCUIT_TO_DAG.get_bound(py);
for raw_block in raw_blocks.bind(py).iter().unwrap() {
Expand Down Expand Up @@ -71,7 +71,7 @@ fn recurse<'py>(
&& !check_qubits(qubits)
{
return Ok(Some((
inst.op.name().to_string(),
inst.op().name().to_string(),
[qubits[0].0, qubits[1].0],
)));
}
Expand Down
16 changes: 8 additions & 8 deletions crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,25 @@ pub(crate) fn analyze_commutations_inner(
if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) =
(&dag.dag()[current_gate_idx], &dag.dag()[*prev_gate_idx])
{
let op1 = packed_inst0.op.view();
let op2 = packed_inst1.op.view();
let op1 = packed_inst0.op().view();
let op2 = packed_inst1.op().view();
let params1 = packed_inst0.params_view();
let params2 = packed_inst1.params_view();
let qargs1 = dag.get_qargs(packed_inst0.qubits);
let qargs2 = dag.get_qargs(packed_inst1.qubits);
let cargs1 = dag.get_cargs(packed_inst0.clbits);
let cargs2 = dag.get_cargs(packed_inst1.clbits);
let qargs1 = dag.get_qargs(packed_inst0.qubits());
let qargs2 = dag.get_qargs(packed_inst1.qubits());
let cargs1 = dag.get_cargs(packed_inst0.clbits());
let cargs2 = dag.get_cargs(packed_inst1.clbits());

all_commute = commutation_checker.commute_inner(
py,
&op1,
params1,
&packed_inst0.extra_attrs,
packed_inst0.extra_attrs(),
qargs1,
cargs1,
&op2,
params2,
&packed_inst1.extra_attrs,
packed_inst1.extra_attrs(),
qargs2,
cargs2,
MAX_NUM_QUBITS,
Expand Down
14 changes: 7 additions & 7 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ pub(crate) fn cancel_commutations(
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
let num_qargs = dag.get_qargs(instr.qubits).len();
let num_qargs = dag.get_qargs(instr.qubits()).len();
// no support for cancellation of parameterized gates
if instr.is_parameterized() {
continue;
}
if let Some(op_gate) = instr.op.try_standard_gate() {
if let Some(op_gate) = instr.op().try_standard_gate() {
if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
Expand Down Expand Up @@ -158,8 +158,8 @@ pub(crate) fn cancel_commutations(
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2 && dag.get_qargs(instr.qubits)[0] == wire {
let second_qarg = dag.get_qargs(instr.qubits)[1];
if num_qargs == 2 && dag.get_qargs(instr.qubits())[0] == wire {
let second_qarg = dag.get_qargs(instr.qubits())[1];
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
Expand Down Expand Up @@ -202,14 +202,14 @@ pub(crate) fn cancel_commutations(
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};
let node_op_name = node_op.op.name();
let node_op_name = node_op.op().name();

let node_angle = if ROTATION_GATES.contains(&node_op_name) {
match node_op.params_view().first() {
Some(Param::Float(f)) => Ok(*f),
_ => return Err(QiskitError::new_err(format!(
"Rotational gate with parameter expression encountered in cancellation {:?}",
node_op.op
node_op.op()
)))
}
} else if HALF_TURNS.contains(&node_op_name) {
Expand All @@ -227,7 +227,7 @@ pub(crate) fn cancel_commutations(
total_angle += node_angle?;

let Param::Float(new_phase) = node_op
.op
.op()
.definition(node_op.params_view())
.unwrap()
.global_phase()
Expand Down
25 changes: 14 additions & 11 deletions crates/accelerate/src/consolidate_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ pub(crate) fn consolidate_blocks(
if !is_supported(
target,
basis_gates.as_ref(),
inst.op.name(),
dag.get_qargs(inst.qubits),
inst.op().name(),
dag.get_qargs(inst.qubits()),
) {
all_block_gates.insert(inst_node);
let matrix = match get_matrix_from_inst(py, inst) {
Expand All @@ -124,16 +124,16 @@ pub(crate) fn consolidate_blocks(
let mut outside_basis = false;
for node in &block {
let inst = dag.dag()[*node].unwrap_operation();
block_qargs.extend(dag.get_qargs(inst.qubits));
block_qargs.extend(dag.get_qargs(inst.qubits()));
all_block_gates.insert(*node);
if inst.op.name() == basis_gate_name {
if inst.op().name() == basis_gate_name {
basis_count += 1;
}
if !is_supported(
target,
basis_gates.as_ref(),
inst.op.name(),
dag.get_qargs(inst.qubits),
inst.op().name(),
dag.get_qargs(inst.qubits()),
) {
outside_basis = true;
}
Expand All @@ -154,9 +154,12 @@ pub(crate) fn consolidate_blocks(
let inst = dag.dag()[*node].unwrap_operation();

Ok((
inst.op.clone(),
inst.params_view().iter().cloned().collect(),
dag.get_qargs(inst.qubits)
inst.op().clone(),
inst.params_view()
.iter()
.map(|param| param.clone_ref(py))
.collect(),
dag.get_qargs(inst.qubits())
.iter()
.map(|x| Qubit::new(block_index_map[x]))
.collect(),
Expand Down Expand Up @@ -243,13 +246,13 @@ pub(crate) fn consolidate_blocks(
}
let first_inst_node = run[0];
let first_inst = dag.dag()[first_inst_node].unwrap_operation();
let first_qubits = dag.get_qargs(first_inst.qubits);
let first_qubits = dag.get_qargs(first_inst.qubits());

if run.len() == 1
&& !is_supported(
target,
basis_gates.as_ref(),
first_inst.op.name(),
first_inst.op().name(),
first_qubits,
)
{
Expand Down
8 changes: 4 additions & 4 deletions crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ use crate::QiskitError;

#[inline]
pub fn get_matrix_from_inst(py: Python, inst: &PackedInstruction) -> PyResult<Array2<Complex64>> {
if let Some(mat) = inst.op.matrix(inst.params_view()) {
if let Some(mat) = inst.op().matrix(inst.params_view()) {
Ok(mat)
} else if inst.op.try_standard_gate().is_some() {
} else if inst.op().try_standard_gate().is_some() {
Err(QiskitError::new_err(
"Parameterized gates can't be consolidated",
))
} else if let OperationRef::Gate(gate) = inst.op.view() {
} else if let OperationRef::Gate(gate) = inst.op().view() {
Ok(QI_OPERATOR
.get_bound(py)
.call1((gate.gate.clone_ref(py),))?
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn blocks_to_matrix(
let inst = dag.dag()[*node].unwrap_operation();
let op_matrix = get_matrix_from_inst(py, inst)?;
match dag
.get_qargs(inst.qubits)
.get_qargs(inst.qubits())
.iter()
.map(map_bits)
.collect::<Vec<_>>()
Expand Down
20 changes: 10 additions & 10 deletions crates/accelerate/src/elide_permutations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ fn run(py: Python, dag: &mut DAGCircuit) -> PyResult<Option<(DAGCircuit, Vec<usi
let mut new_dag = dag.copy_empty_like(py, "alike")?;
for node_index in dag.topological_op_nodes()? {
if let NodeType::Operation(inst) = &dag.dag()[node_index] {
match (inst.op.name(), inst.condition()) {
match (inst.op().name(), inst.condition()) {
("swap", None) => {
let qargs = dag.get_qargs(inst.qubits);
let qargs = dag.get_qargs(inst.qubits());
let index0 = qargs[0].index();
let index1 = qargs[1].index();
mapping.swap(index0, index1);
}
("permutation", None) => {
if let Param::Obj(ref pyobj) = inst.params.as_ref().unwrap()[0] {
if let Param::Obj(ref pyobj) = inst.params_view()[0] {
let pyarray: PyReadonlyArray1<i32> = pyobj.extract(py)?;
let pattern = pyarray.as_array();

let qindices: Vec<usize> = dag
.get_qargs(inst.qubits)
.get_qargs(inst.qubits())
.iter()
.map(|q| q.index())
.collect();
Expand All @@ -75,22 +75,22 @@ fn run(py: Python, dag: &mut DAGCircuit) -> PyResult<Option<(DAGCircuit, Vec<usi
}
_ => {
// General instruction
let qargs = dag.get_qargs(inst.qubits);
let cargs = dag.get_cargs(inst.clbits);
let qargs = dag.get_qargs(inst.qubits());
let cargs = dag.get_cargs(inst.clbits());
let mapped_qargs: Vec<Qubit> = qargs
.iter()
.map(|q| Qubit::new(mapping[q.index()]))
.collect();

new_dag.apply_operation_back(
py,
inst.op.clone(),
inst.op().clone(),
&mapped_qargs,
cargs,
inst.params.as_deref().cloned(),
inst.extra_attrs.clone(),
(!inst.params_view().is_empty()).then_some(inst.params_view().into()),
inst.extra_attrs().clone(),
#[cfg(feature = "cache_pygates")]
inst.py_op.get().map(|x| x.clone_ref(py)),
None,
)?;
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ impl EquivalenceLibrary {
equivalent_circuit
.0
.iter()
.map(|inst| Key::from_operation(&inst.op)),
.map(|inst| Key::from_operation(inst.op())),
);
let edges = Vec::from_iter(sources.iter().map(|source| {
(
Expand Down
8 changes: 4 additions & 4 deletions crates/accelerate/src/euler_one_qubit_decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
None => raw_run.len() as f64,
};
let qubit: PhysicalQubit = if let NodeType::Operation(inst) = &dag.dag()[raw_run[0]] {
PhysicalQubit::new(dag.get_qargs(inst.qubits)[0].0)
PhysicalQubit::new(dag.get_qargs(inst.qubits())[0].0)
} else {
unreachable!("nodes in runs will always be op nodes")
};
Expand Down Expand Up @@ -1178,9 +1178,9 @@ pub(crate) fn optimize_1q_gates_decomposition(
let node = &dag.dag()[*node_index];
if let NodeType::Operation(inst) = node {
if let Some(target) = target {
error *= compute_error_term_from_target(inst.op.name(), target, qubit);
error *= compute_error_term_from_target(inst.op().name(), target, qubit);
}
inst.op.matrix(inst.params_view()).unwrap()
inst.op().matrix(inst.params_view()).unwrap()
} else {
unreachable!("Can only have op nodes here")
}
Expand Down Expand Up @@ -1219,7 +1219,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
if let Some(basis) = basis_gates {
for node in &raw_run {
if let NodeType::Operation(inst) = &dag.dag()[*node] {
if !basis.contains(inst.op.name()) {
if !basis.contains(inst.op().name()) {
outside_basis = true;
break;
}
Expand Down
Loading

0 comments on commit de6bd6f

Please sign in to comment.