Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rust gates for ConsolidateBlocks #12704

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use pyo3::Python;

Expand All @@ -20,32 +22,84 @@ use numpy::ndarray::{aview2, Array2, ArrayView2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use smallvec::SmallVec;

use qiskit_circuit::bit_data::BitData;
use qiskit_circuit::circuit_instruction::{operation_type_to_py, CircuitInstruction};
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::gate_matrix::ONE_QUBIT_IDENTITY;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::{Operation, OperationType};

use crate::QiskitError;

fn get_matrix_from_inst<'py>(
py: Python<'py>,
inst: &'py CircuitInstruction,
) -> PyResult<Array2<Complex64>> {
match inst.operation.matrix(&inst.params) {
Some(mat) => Ok(mat),
None => match inst.operation {
OperationType::Standard(_) => Err(QiskitError::new_err(
"Parameterized gates can't be consolidated",
)),
OperationType::Gate(_) => Ok(QI_OPERATOR
.get_bound(py)
.call1((operation_type_to_py(py, inst)?,))?
.getattr(intern!(py, "data"))?
.extract::<PyReadonlyArray2<Complex64>>()?
.as_array()
.to_owned()),
_ => unreachable!("Only called for unitary ops"),
},
}
}

/// Return the matrix Operator resulting from a block of Instructions.
#[pyfunction]
#[pyo3(text_signature = "(op_list, /")]
pub fn blocks_to_matrix(
py: Python,
op_list: Vec<(PyReadonlyArray2<Complex64>, SmallVec<[u8; 2]>)>,
op_list: Vec<PyRef<DAGOpNode>>,
block_index_map_dict: &Bound<PyDict>,
) -> PyResult<Py<PyArray2<Complex64>>> {
// Build a BitData in block_index_map_dict order. block_index_map_dict is a dict of bits to
// indices mapping the order of the qargs in the block. There should only be 2 entries since
// there are only 2 qargs here (e.g. `{Qubit(): 0, Qubit(): 1}`) so we need to ensure that
// we added the qubits to bit data in the correct index order.
let mut index_map: Vec<PyObject> = (0..block_index_map_dict.len()).map(|_| py.None()).collect();
for bit_tuple in block_index_map_dict.items() {
let (bit, index): (PyObject, usize) = bit_tuple.extract()?;
index_map[index] = bit;
}
let mut bit_map: BitData<u32> = BitData::new(py, "qargs".to_string());
for bit in index_map {
bit_map.add(py, bit.bind(py), true)?;
}
let identity = aview2(&ONE_QUBIT_IDENTITY);
let input_matrix = op_list[0].0.as_array();
let mut matrix: Array2<Complex64> = match op_list[0].1.as_slice() {
let first_node = &op_list[0];
let input_matrix = get_matrix_from_inst(py, &first_node.instruction)?;
let mut matrix: Array2<Complex64> = match bit_map
.map_bits(first_node.instruction.qubits.bind(py).iter())?
.collect::<Vec<_>>()
.as_slice()
{
[0] => kron(&identity, &input_matrix),
[1] => kron(&input_matrix, &identity),
[0, 1] => input_matrix.to_owned(),
[1, 0] => change_basis(input_matrix),
[0, 1] => input_matrix,
[1, 0] => change_basis(input_matrix.view()),
[] => Array2::eye(4),
_ => unreachable!(),
};
for (op_matrix, q_list) in op_list.into_iter().skip(1) {
let op_matrix = op_matrix.as_array();
for node in op_list.into_iter().skip(1) {
let op_matrix = get_matrix_from_inst(py, &node.instruction)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 91 would convey more information as [..] => unreachable!()

let q_list = bit_map
.map_bits(node.instruction.qubits.bind(py).iter())?
.map(|x| x as u8)
.collect::<SmallVec<[u8; 2]>>();

let result = match q_list.as_slice() {
[0] => Some(kron(&identity, &op_matrix)),
[1] => Some(kron(&op_matrix, &identity)),
[1, 0] => Some(change_basis(op_matrix)),
[1, 0] => Some(change_basis(op_matrix.view())),
[] => Some(Array2::eye(4)),
_ => None,
};
Expand All @@ -71,8 +125,42 @@ pub fn change_basis(matrix: ArrayView2<Complex64>) -> Array2<Complex64> {
trans_matrix
}

#[pyfunction]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 50 could be [..] => None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two suggestions, doing [..] => None, change lines not yet touched by this PR. Whether to do it here depends on whether in general you favor cleaning these things up in a PR or doing them in a separate PR. An argument against is that this is not strictly part of the point of the PR. An argument for is that a PR to do these things would not be a high priority, so it might never be done. So I favor making the change. OTOH, the arguments for leaving these as is are not unreasonable.

pub fn collect_2q_blocks_filter(node: &Bound<PyAny>) -> Option<bool> {
match node.downcast::<DAGOpNode>() {
Ok(bound_node) => {
let node = bound_node.borrow();
match &node.instruction.operation {
OperationType::Standard(gate) => Some(
gate.num_qubits() <= 2
&& node
.instruction
.extra_attrs
.as_ref()
.and_then(|attrs| attrs.condition.as_ref())
.is_none()
&& !node.is_parameterized(),
),
OperationType::Gate(gate) => Some(
gate.num_qubits() <= 2
&& node
.instruction
.extra_attrs
.as_ref()
.and_then(|attrs| attrs.condition.as_ref())
.is_none()
&& !node.is_parameterized(),
),
_ => Some(false),
}
}
Err(_) => None,
}
}

#[pymodule]
pub fn convert_2q_block_matrix(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(blocks_to_matrix))?;
m.add_wrapped(wrap_pyfunction!(collect_2q_blocks_filter))?;
Ok(())
}
8 changes: 6 additions & 2 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl PartialEq for BitAsKey {
impl Eq for BitAsKey {}

#[derive(Clone, Debug)]
pub(crate) struct BitData<T> {
pub struct BitData<T> {
/// The public field name (i.e. `qubits` or `clbits`).
description: String,
/// Registered Python bits.
Expand All @@ -81,7 +81,7 @@ pub(crate) struct BitData<T> {
cached: Py<PyList>,
}

pub(crate) struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);
pub struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);

impl<'py> From<BitNotFoundError<'py>> for PyErr {
fn from(error: BitNotFoundError) -> Self {
Expand Down Expand Up @@ -111,6 +111,10 @@ where
self.bits.len()
}

pub fn is_empty(&self) -> bool {
self.bits.is_empty()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<PyObject> {
Expand Down
1 change: 1 addition & 0 deletions crates/circuit/src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub static SINGLETON_CONTROLLED_GATE: ImportOnceCell =
pub static CONTROLLED_GATE: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit", "ControlledGate");
pub static DEEPCOPY: ImportOnceCell = ImportOnceCell::new("copy", "deepcopy");
pub static QI_OPERATOR: ImportOnceCell = ImportOnceCell::new("qiskit.quantum_info", "Operator");
pub static WARNINGS_WARN: ImportOnceCell = ImportOnceCell::new("warnings", "warn");

/// A mapping from the enum variant in crate::operations::StandardGate to the python
Expand Down
2 changes: 1 addition & 1 deletion crates/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

pub mod bit_data;
pub mod circuit_data;
pub mod circuit_instruction;
pub mod dag_node;
Expand All @@ -20,7 +21,6 @@ pub mod parameter_table;
pub mod slice;
pub mod util;

mod bit_data;
mod interner;

use pyo3::prelude::*;
Expand Down
24 changes: 5 additions & 19 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from qiskit.circuit.bit import Bit
from qiskit.pulse import Schedule
from qiskit._accelerate.euler_one_qubit_decomposer import collect_1q_runs_filter
from qiskit._accelerate.convert_2q_block_matrix import collect_2q_blocks_filter

BitLocations = namedtuple("BitLocations", ("index", "registers"))
# The allowable arguments to :meth:`DAGCircuit.copy_empty_like`'s ``vars_mode``.
Expand Down Expand Up @@ -1348,9 +1349,9 @@ def replace_block_with_op(
for nd in node_block:
block_qargs |= set(nd.qargs)
block_cargs |= set(nd.cargs)
if (condition := getattr(nd.op, "condition", None)) is not None:
if (condition := getattr(nd, "condition", None)) is not None:
block_cargs.update(condition_resources(condition).clbits)
elif isinstance(nd.op, SwitchCaseOp):
elif nd.name in CONTROL_FLOW_OP_NAMES and isinstance(nd.op, SwitchCaseOp):
if isinstance(nd.op.target, Clbit):
block_cargs.add(nd.op.target)
elif isinstance(nd.op.target, ClassicalRegister):
Expand Down Expand Up @@ -2158,28 +2159,13 @@ def collect_1q_runs(self) -> list[list[DAGOpNode]]:
def collect_2q_runs(self):
"""Return a set of non-conditional runs of 2q "op" nodes."""

to_qid = {}
for i, qubit in enumerate(self.qubits):
to_qid[qubit] = i

def filter_fn(node):
if isinstance(node, DAGOpNode):
return (
isinstance(node.op, Gate)
and len(node.qargs) <= 2
and not getattr(node.op, "condition", None)
and not node.op.is_parameterized()
)
else:
return None

def color_fn(edge):
if isinstance(edge, Qubit):
return to_qid[edge]
return self.find_bit(edge).index
else:
return None

return rx.collect_bicolor_runs(self._multi_graph, filter_fn, color_fn)
return rx.collect_bicolor_runs(self._multi_graph, collect_2q_blocks_filter, color_fn)

def nodes_on_wire(self, wire, only_ops=False):
"""
Expand Down
19 changes: 11 additions & 8 deletions qiskit/transpiler/passes/optimization/consolidate_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from qiskit.circuit.library.generalized_gates.unitary import UnitaryGate
from qiskit.circuit.library.standard_gates import CXGate
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.circuit.controlflow import ControlFlowOp
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.passes.synthesis import unitary_synthesis
from qiskit.transpiler.passes.utils import _block_to_matrix
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit._accelerate.convert_2q_block_matrix import blocks_to_matrix

from .collect_1q_runs import Collect1qRuns
from .collect_2q_blocks import Collect2qBlocks

Expand Down Expand Up @@ -105,14 +106,14 @@ def run(self, dag):
block_cargs = set()
for nd in block:
block_qargs |= set(nd.qargs)
if isinstance(nd, DAGOpNode) and getattr(nd.op, "condition", None):
block_cargs |= set(getattr(nd.op, "condition", None)[0])
if isinstance(nd, DAGOpNode) and getattr(nd, "condition", None):
block_cargs |= set(getattr(nd, "condition", None)[0])
all_block_gates.add(nd)
block_index_map = self._block_qargs_to_indices(dag, block_qargs)
for nd in block:
if nd.op.name == basis_gate_name:
if nd.name == basis_gate_name:
basis_count += 1
if self._check_not_in_basis(dag, nd.op.name, nd.qargs):
if self._check_not_in_basis(dag, nd.name, nd.qargs):
outside_basis = True
if len(block_qargs) > 2:
q = QuantumRegister(len(block_qargs))
Expand All @@ -124,7 +125,7 @@ def run(self, dag):
qc.append(nd.op, [q[block_index_map[i]] for i in nd.qargs])
unitary = UnitaryGate(Operator(qc), check_input=False)
else:
matrix = _block_to_matrix(block, block_index_map)
matrix = blocks_to_matrix(block, block_index_map)
unitary = UnitaryGate(matrix, check_input=False)

max_2q_depth = 20 # If depth > 20, there will be 1q gates to consolidate.
Expand Down Expand Up @@ -192,7 +193,9 @@ def _handle_control_flow_ops(self, dag):
pass_manager.append(Collect2qBlocks())

pass_manager.append(self)
for node in dag.op_nodes(ControlFlowOp):
for node in dag.op_nodes():
if node.name not in CONTROL_FLOW_OP_NAMES:
continue
node.op = node.op.replace_blocks(pass_manager.run(block) for block in node.op.blocks)
return dag

Expand Down
1 change: 0 additions & 1 deletion qiskit/transpiler/passes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@

# Utility functions
from . import control_flow
from .block_to_matrix import _block_to_matrix
47 changes: 0 additions & 47 deletions qiskit/transpiler/passes/utils/block_to_matrix.py

This file was deleted.

Loading