Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Cirq Transforms for Gate Decomposition (#93) #184

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions qbraid_qir/cirq/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,68 @@

"""
import itertools
from typing import Iterable
from typing import Iterable, List, Sequence, Type, Union

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult

from .exceptions import CirqConversionError
from .opsets import map_cirq_op_to_pyqir_callable


class QirTargetGateSet(cirq.TwoQubitCompilationTargetGateset):
def __init__(
self,
*,
atol: float = 1e-8,
allow_partial_czs: bool = False,
additional_gates: Sequence[
Union[Type["cirq.Gate"], "cirq.Gate", "cirq.GateFamily"]
] = (),
preserve_moment_structure: bool = True,
) -> None:
super().__init__(
cirq.IdentityGate,
cirq.HPowGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.ZPowGate,
cirq.SWAP,
cirq.CNOT,
cirq.CZ,
cirq.TOFFOLI,
cirq.ResetChannel,
*additional_gates,
name="QirTargetGateset",
preserve_moment_structure=preserve_moment_structure,
)
self.allow_partial_czs = allow_partial_czs
self.atol = atol

@property
def postprocess_transformers(self) -> List["cirq.TRANSFORMER"]:
return []

def _decompose_single_qubit_operation(
self, op: "cirq.Operation", moment_idx: int
) -> DecomposeResult:
qubit = op.qubits[0]
mat = cirq.unitary(op)
for gate in cirq.single_qubit_matrix_to_gates(mat, self.atol):
yield gate(qubit)

def _decompose_two_qubit_operation(self, op: "cirq.Operation", _) -> "cirq.OP_TREE":
if not cirq.has_unitary(op):
return NotImplemented
return cirq.two_qubit_matrix_to_cz_operations(
op.qubits[0],
op.qubits[1],
cirq.unitary(op),
allow_partial_czs=self.allow_partial_czs,
atol=self.atol,
)


def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
"""Decomposes a single Cirq gate operation into a sequence of operations
that are directly supported by PyQIR.
Expand All @@ -36,12 +90,10 @@ def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
_ = map_cirq_op_to_pyqir_callable(operation)
return [operation]
except CirqConversionError:
pass
new_ops = cirq.decompose_once(operation, flatten=True, default=[operation])
if len(new_ops) == 1 and new_ops[0] == operation:
raise CirqConversionError("Couldn't convert circuit to QIR gate set.")
return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops)))

new_ops = cirq.decompose_once(operation, flatten=True, default=[operation])
if len(new_ops) == 1 and new_ops[0] == operation:
raise CirqConversionError("Couldn't convert circuit to QIR gate set.")
return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops)))

def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand All @@ -53,21 +105,10 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
Returns:
cirq.Circuit: A new circuit with unsupported gates decomposed.
"""
new_circuit = cirq.Circuit()
for moment in circuit:
new_ops = []
for operation in moment:
if isinstance(operation, cirq.GateOperation):
decomposed_ops = list(_decompose_gate_op(operation))
new_ops.extend(decomposed_ops)
elif isinstance(operation, cirq.ClassicallyControlledOperation):
new_ops.append(operation)
else:
new_ops.append(operation)

new_circuit.append(new_ops)
return new_circuit

circuit = cirq.optimize_for_target_gateset(circuit, gateset=QirTargetGateSet(), ignore_failures=True, max_num_passes=1)

return circuit

def preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand Down
17 changes: 13 additions & 4 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from abc import ABCMeta, abstractmethod

import numpy as np
import cirq
import pyqir
import pyqir._native
Expand Down Expand Up @@ -108,6 +109,13 @@ def handle_measurement(pyqir_func):
for qubit, result in zip(qubits, results):
self._measured_qubits[pyqir.qubit_id(qubit)] = True
pyqir_func(self._builder, qubit, result)

def get_rot_gate_angle(operation: cirq.Operation):
if isinstance(operation.gate, (cirq.ops.XPowGate, cirq.ops.YPowGate, cirq.ops.ZPowGate)):
angle = operation.gate.exponent * np.pi
else:
angle = operation.gate._rads
return angle

# dealing with conditional gates
if isinstance(operation, cirq.ClassicallyControlledOperation):
Expand All @@ -121,9 +129,10 @@ def handle_measurement(pyqir_func):

# pylint: disable=unnecessary-lambda-assignment
if op_str in ["Rx", "Ry", "Rz"]:
angle = get_rot_gate_angle(operation._sub_operation)
pyqir_func = lambda: temp_pyqir_func(
self._builder,
operation._sub_operation.gate._rads, # type: ignore[union-attr]
angle, # type: ignore[union-attr]
*qubits,
)
else:
Expand All @@ -144,11 +153,11 @@ def _branch(conds, pyqir_func):
_branch(conditions, pyqir_func)
else:
pyqir_func, op_str = map_cirq_op_to_pyqir_callable(operation)

if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr]
elif op_str in ["Rx", "Ry", "Rz"]:
angle = get_rot_gate_angle(operation)
pyqir_func(self._builder, angle, *qubits) # type: ignore[union-attr]
else:
pyqir_func(self._builder, *qubits)

Expand Down
63 changes: 59 additions & 4 deletions tests/cirq_qir/test_cirq_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,61 @@

# pylint: disable=redefined-outer-name

def _match_global_phase(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Matches the global phase of two numpy arrays.

This function aligns the global phases of two matrices by applying a phase shift based
on the position of the largest entry in one matrix. It computes and adjusts for the
phase difference at this position, resulting in two phase-aligned matrices. The output,
(a', b'), indicates that a' == b' is equivalent to a == b * exp(i * t) for some phase t.

Args:
a (np.ndarray): The first input matrix.
b (np.ndarray): The second input matrix.

Returns:
tuple[np.ndarray, np.ndarray]: A tuple of the two matrices `(a', b')`, adjusted for
global phase. If shapes of `a` and `b` do not match or
either is empty, returns copies of the original matrices.
"""
if a.shape != b.shape or a.size == 0:
return np.copy(a), np.copy(b)

k = max(np.ndindex(*a.shape), key=lambda t: abs(b[t]))

def dephase(v):
r = np.real(v)
i = np.imag(v)

if i == 0:
return -1 if r < 0 else 1
if r == 0:
return 1j if i < 0 else -1j

return np.exp(-1j * np.arctan2(i, r))

return a * dephase(a[k]), b * dephase(b[k])


def _assert_allclose_up_to_global_phase(a: np.ndarray, b: np.ndarray, atol: float, **kwargs) -> None:
"""
Checks if two numpy arrays are equal up to a global phase, within
a specified tolerance, i.e. if a ~= b * exp(i t) for some t.

Args:
a (np.ndarray): The first input array.
b (np.ndarray): The second input array.
atol (float): The absolute error tolerance.
**kwargs: Additional keyword arguments to pass to `np.testing.assert_allclose`.

Raises:
AssertionError: The matrices aren't nearly equal up to global phase.

"""
a, b = _match_global_phase(a, b)
np.testing.assert_allclose(actual=a, desired=b, atol=atol, **kwargs)


@pytest.fixture
def gridqubit_circuit():
Expand All @@ -40,17 +95,17 @@ def test_convert_gridqubits_to_linequbits(gridqubit_circuit):
linequbit_circuit = preprocess_circuit(gridqubit_circuit)
for qubit in linequbit_circuit.all_qubits():
assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit"
assert np.allclose(
linequbit_circuit.unitary(), gridqubit_circuit.unitary()
_assert_allclose_up_to_global_phase(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qbraid is already a test dependency so you can just use qbraid.interface.assert_allclose_up_to_global_phase instead of copying over those same funcs

linequbit_circuit.unitary(), gridqubit_circuit.unitary(), atol=1e-6
), "Circuits are not equal"


def test_convert_namedqubits_to_linequbits(namedqubit_circuit):
linequbit_circuit = preprocess_circuit(namedqubit_circuit)
for qubit in linequbit_circuit.all_qubits():
assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit"
assert np.allclose(
linequbit_circuit.unitary(), namedqubit_circuit.unitary()
_assert_allclose_up_to_global_phase(
linequbit_circuit.unitary(), namedqubit_circuit.unitary(), atol=1e-6
), "Circuits are not equal"


Expand Down
Loading