Skip to content

Commit

Permalink
Use rust gates for ConsolidateBlocks
Browse files Browse the repository at this point in the history
This commit moves to use rust gates for the ConsolidateBlocks transpiler
pass. Instead of generating the unitary matrices for the gates in a 2q
block Python side and passing that list to a rust function this commit
switches to passing a list of DAGOpNodes to the rust and then generating
the matrices inside the rust function directly. This is similar to what
was done in Qiskit#12650 for Optimize1qGatesDecomposition. Besides being faster
to get the matrix for standard gates, it also reduces the eager
construction of Python gate objects which was a significant source of
overhead after Qiskit#12459. To that end this builds on the thread of work in
the two PRs Qiskit#12692 and Qiskit#12701 which changed the access patterns for
other passes to minimize eager gate object construction.
  • Loading branch information
mtreinish committed Jul 2, 2024
1 parent fa774b3 commit e449357
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 73 deletions.
71 changes: 63 additions & 8 deletions crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use pyo3::Python;

Expand All @@ -20,32 +22,85 @@ use numpy::ndarray::{aview2, Array2, ArrayView2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use smallvec::SmallVec;

use qiskit_circuit::bit_data::BitData;
use qiskit_circuit::circuit_instruction::{operation_type_to_py, CircuitInstruction};
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::gate_matrix::ONE_QUBIT_IDENTITY;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::{Operation, OperationType};

use crate::QiskitError;

fn get_matrix_from_inst<'py>(
py: Python<'py>,
inst: &'py CircuitInstruction,
) -> PyResult<Array2<Complex64>> {
match inst.operation.matrix(&inst.params) {
Some(mat) => Ok(mat),
None => match inst.operation {
OperationType::Standard(_) => Err(QiskitError::new_err(
"Parameterized gates can't be consolidated",
)),
OperationType::Gate(_) => Ok(QI_OPERATOR
.get_bound(py)
.call1((operation_type_to_py(py, inst)?,))?
.getattr(intern!(py, "data"))?
.extract::<PyReadonlyArray2<Complex64>>()?
.as_array()
.to_owned()),
_ => unreachable!("Only called for unitary ops"),
},
}
}

/// Return the matrix Operator resulting from a block of Instructions.
#[pyfunction]
#[pyo3(text_signature = "(op_list, /")]
pub fn blocks_to_matrix(
py: Python,
op_list: Vec<(PyReadonlyArray2<Complex64>, SmallVec<[u8; 2]>)>,
op_list: Vec<PyRef<DAGOpNode>>,
block_index_map_dict: &Bound<PyDict>,
) -> PyResult<Py<PyArray2<Complex64>>> {
// Build a BitData in block_index_map_dict order. block_index_map_dict is a dict of bits to
// indices mapping the order of the qargs in the block. There should only be 2 entries since
// there are only 2 qargs here (e.g. `{Qubit(): 0, Qubit(): 1}`) so we need to ensure that
// we added the qubits to bit data in the correct index order.
let mut index_map: Vec<PyObject> = (0..block_index_map_dict.len()).map(|_| py.None()).collect();
for bit_tuple in block_index_map_dict.items() {
let (bit, index): (PyObject, usize) = bit_tuple.extract()?;
index_map[index] = bit;
}
let mut bit_map: BitData<u32> = BitData::new(py, "qargs".to_string());
for bit in index_map {
bit_map.add(py, bit.bind(py), true)?;
}
let identity = aview2(&ONE_QUBIT_IDENTITY);
let input_matrix = op_list[0].0.as_array();
let mut matrix: Array2<Complex64> = match op_list[0].1.as_slice() {
let first_node = &op_list[0];
let input_matrix = get_matrix_from_inst(py, &first_node.instruction)?;
let mut matrix: Array2<Complex64> = match bit_map
.map_bits(first_node.instruction.qubits.bind(py).iter())?
.map(|x| x as u8)
.collect::<SmallVec<[u8; 2]>>()
.as_slice()
{
[0] => kron(&identity, &input_matrix),
[1] => kron(&input_matrix, &identity),
[0, 1] => input_matrix.to_owned(),
[1, 0] => change_basis(input_matrix),
[0, 1] => input_matrix,
[1, 0] => change_basis(input_matrix.view()),
[] => Array2::eye(4),
_ => unreachable!(),
};
for (op_matrix, q_list) in op_list.into_iter().skip(1) {
let op_matrix = op_matrix.as_array();
for node in op_list.into_iter().skip(1) {
let op_matrix = get_matrix_from_inst(py, &node.instruction)?;
let q_list = bit_map
.map_bits(node.instruction.qubits.bind(py).iter())?
.map(|x| x as u8)
.collect::<SmallVec<[u8; 2]>>();

let result = match q_list.as_slice() {
[0] => Some(kron(&identity, &op_matrix)),
[1] => Some(kron(&op_matrix, &identity)),
[1, 0] => Some(change_basis(op_matrix)),
[1, 0] => Some(change_basis(op_matrix.view())),
[] => Some(Array2::eye(4)),
_ => None,
};
Expand Down
8 changes: 6 additions & 2 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl PartialEq for BitAsKey {
impl Eq for BitAsKey {}

#[derive(Clone, Debug)]
pub(crate) struct BitData<T> {
pub struct BitData<T> {
/// The public field name (i.e. `qubits` or `clbits`).
description: String,
/// Registered Python bits.
Expand All @@ -81,7 +81,7 @@ pub(crate) struct BitData<T> {
cached: Py<PyList>,
}

pub(crate) struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);
pub struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);

impl<'py> From<BitNotFoundError<'py>> for PyErr {
fn from(error: BitNotFoundError) -> Self {
Expand Down Expand Up @@ -111,6 +111,10 @@ where
self.bits.len()
}

pub fn is_empty(&self) -> bool {
self.bits.is_empty()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<PyObject> {
Expand Down
1 change: 1 addition & 0 deletions crates/circuit/src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub static SINGLETON_GATE: ImportOnceCell =
pub static SINGLETON_CONTROLLED_GATE: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.singleton", "SingletonControlledGate");
pub static DEEPCOPY: ImportOnceCell = ImportOnceCell::new("copy", "deepcopy");
pub static QI_OPERATOR: ImportOnceCell = ImportOnceCell::new("qiskit.quantum_info", "Operator");

pub static WARNINGS_WARN: ImportOnceCell = ImportOnceCell::new("warnings", "warn");

Expand Down
2 changes: 1 addition & 1 deletion crates/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

pub mod bit_data;
pub mod circuit_data;
pub mod circuit_instruction;
pub mod dag_node;
Expand All @@ -20,7 +21,6 @@ pub mod parameter_table;
pub mod slice;
pub mod util;

mod bit_data;
mod interner;

use pyo3::prelude::*;
Expand Down
13 changes: 7 additions & 6 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from qiskit.dagcircuit.dagnode import DAGNode, DAGOpNode, DAGInNode, DAGOutNode
from qiskit.circuit.bit import Bit
from qiskit.pulse import Schedule
from qiskit._accelerate.circuit import StandardGate, PyGate

BitLocations = namedtuple("BitLocations", ("index", "registers"))
# The allowable arguments to :meth:`DAGCircuit.copy_empty_like`'s ``vars_mode``.
Expand Down Expand Up @@ -1347,9 +1348,9 @@ def replace_block_with_op(
for nd in node_block:
block_qargs |= set(nd.qargs)
block_cargs |= set(nd.cargs)
if (condition := getattr(nd.op, "condition", None)) is not None:
if (condition := getattr(nd, "condition", None)) is not None:
block_cargs.update(condition_resources(condition).clbits)
elif isinstance(nd.op, SwitchCaseOp):
elif nd.name in CONTROL_FLOW_OP_NAMES and isinstance(nd.op, SwitchCaseOp):
if isinstance(nd.op.target, Clbit):
block_cargs.add(nd.op.target)
elif isinstance(nd.op.target, ClassicalRegister):
Expand Down Expand Up @@ -1382,7 +1383,7 @@ def replace_block_with_op(
self._increment_op(op)

for nd in node_block:
self._decrement_op(nd.op)
self._decrement_op(nd)

return new_node

Expand Down Expand Up @@ -2176,10 +2177,10 @@ def collect_2q_runs(self):
def filter_fn(node):
if isinstance(node, DAGOpNode):
return (
isinstance(node.op, Gate)
isinstance(node._raw_op, (StandardGate, PyGate))
and len(node.qargs) <= 2
and not getattr(node.op, "condition", None)
and not node.op.is_parameterized()
and not getattr(node, "condition", None)
and not node.is_parameterized()
)
else:
return None
Expand Down
19 changes: 11 additions & 8 deletions qiskit/transpiler/passes/optimization/consolidate_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from qiskit.circuit.library.generalized_gates.unitary import UnitaryGate
from qiskit.circuit.library.standard_gates import CXGate
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.circuit.controlflow import ControlFlowOp
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.passes.synthesis import unitary_synthesis
from qiskit.transpiler.passes.utils import _block_to_matrix
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit._accelerate.convert_2q_block_matrix import blocks_to_matrix

from .collect_1q_runs import Collect1qRuns
from .collect_2q_blocks import Collect2qBlocks

Expand Down Expand Up @@ -105,14 +106,14 @@ def run(self, dag):
block_cargs = set()
for nd in block:
block_qargs |= set(nd.qargs)
if isinstance(nd, DAGOpNode) and getattr(nd.op, "condition", None):
block_cargs |= set(getattr(nd.op, "condition", None)[0])
if isinstance(nd, DAGOpNode) and getattr(nd, "condition", None):
block_cargs |= set(getattr(nd, "condition", None)[0])
all_block_gates.add(nd)
block_index_map = self._block_qargs_to_indices(dag, block_qargs)
for nd in block:
if nd.op.name == basis_gate_name:
if nd.name == basis_gate_name:
basis_count += 1
if self._check_not_in_basis(dag, nd.op.name, nd.qargs):
if self._check_not_in_basis(dag, nd.name, nd.qargs):
outside_basis = True
if len(block_qargs) > 2:
q = QuantumRegister(len(block_qargs))
Expand All @@ -124,7 +125,7 @@ def run(self, dag):
qc.append(nd.op, [q[block_index_map[i]] for i in nd.qargs])
unitary = UnitaryGate(Operator(qc), check_input=False)
else:
matrix = _block_to_matrix(block, block_index_map)
matrix = blocks_to_matrix(block, block_index_map)
unitary = UnitaryGate(matrix, check_input=False)

max_2q_depth = 20 # If depth > 20, there will be 1q gates to consolidate.
Expand Down Expand Up @@ -192,7 +193,9 @@ def _handle_control_flow_ops(self, dag):
pass_manager.append(Collect2qBlocks())

pass_manager.append(self)
for node in dag.op_nodes(ControlFlowOp):
for node in dag.op_nodes():
if node.name not in CONTROL_FLOW_OP_NAMES:
continue
node.op = node.op.replace_blocks(pass_manager.run(block) for block in node.op.blocks)
return dag

Expand Down
1 change: 0 additions & 1 deletion qiskit/transpiler/passes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@

# Utility functions
from . import control_flow
from .block_to_matrix import _block_to_matrix
47 changes: 0 additions & 47 deletions qiskit/transpiler/passes/utils/block_to_matrix.py

This file was deleted.

0 comments on commit e449357

Please sign in to comment.