Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding some more utility functions and mock test #13

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@

import logging
from abc import ABCMeta, abstractmethod
from typing import FrozenSet
from typing import FrozenSet, List

import cirq
import pyqir.rt as rt
from pyqir import BasicBlock, Builder, Constant, IntType, PointerType, entry_point
import pyqir
from pyqir import (
BasicBlock,
Builder,
Constant,
IntType,
PointerType,
const,
entry_point,
)

from qbraid_qir.cirq.elements import CirqModule

Expand Down Expand Up @@ -73,31 +82,53 @@ def finalize(self):
self._builder.ret(None)

def record_output(self, module: CirqModule):
raise NotImplementedError

def visit_qid(self, qid: cirq.Qid):
_log.debug(f"Visiting qid '{str(qid)}'")
if isinstance(qid, cirq.LineQubit):
pass
elif isinstance(qid, cirq.GridQubit):
pass
elif isinstance(qid, cirq.NamedQubit):
pass
else:
raise ValueError(f"Qid of type {type(qid)} not supported.")
if self._record_output == False:
return

i8p = PointerType(IntType(self._module.context, 8))

# qiskit inverts the ordering of the results within each register
# but keeps the overall register ordering
# here we logically loop from n-1 to 0, decrementing in order to
# invert the register output. The second parameter is an exclusive
# range so we need to go to -1 instead of 0
logical_id_base = 0
for size in module.reg_sizes:
rt.array_record_output(
self._builder,
const(IntType(self._module.context, 64), size),
Constant.null(i8p),
)
for index in range(size - 1, -1, -1):
result_ref = pyqir.result(self._module.context, logical_id_base + index)
rt.result_record_output(self._builder, result_ref, Constant.null(i8p))
logical_id_base += size

def visit_register(self, qids: List[cirq.Qid]):
_log.debug(f"Visiting qid '{str(qids)}'")
if not isinstance(qids, list):
raise TypeError("Parameter is not a list.")

if not all(isinstance(x, cirq.Qid) for x in qids):
raise TypeError("All elements in the list must be of type cirq.Qid.")
# self._qubit_labels[qid] = len(self._qubit_labels)
self._qubit_labels.update({bit: n + len(self._qubit_labels) for n, bit in enumerate(qids)})
_log.debug(
f"Added label for qubits {qids}"
)

def process_composite_operation(self, operation: cirq.Operation):
# e.g. operation.gate.sub_gate
# e.g. operation.gate.sub_gate, this functionality might exist elsewhere.
raise NotImplementedError

def visit_operation(self, operation: cirq.Operation, qids: FrozenSet[cirq.Qid]):
raise NotImplementedError
qlabels = [self._qubit_labels.get(bit) for bit in qids]
qubits = [pyqir.qubit(self._module.context, n) for n in qlabels]
results = [pyqir.result(self._module.context, n) for n in qlabels]
# call some function that depends on qubits and results

def ir(self) -> str:
return str(self._module)

def bitcode(self) -> bytes:
return self._module.bitcode()

def _map_profile_to_capabilities(self, profile: str):
raise NotImplementedError
87 changes: 87 additions & 0 deletions tests/fixtures/basic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
"Z": "z",
}

_rotations = {"Rx": "rx", "Ry": "ry", "Rz": "rz"}

_two_qubit_gates = {"CX": "cnot", "CZ": "cz", "SWAP": "swap"}

_three_qubit_gates = {"CCX": "ccx"}

_measurements = {"measure": "mz"}


def _fixture_name(s: str) -> str:
return f"Fixture_{s}"
Expand All @@ -38,6 +46,14 @@ def _fixture_name(s: str) -> str:
def _map_gate_name(gate_name: str) -> str:
if gate_name in _one_qubit_gates:
return _one_qubit_gates[gate_name]
elif gate in _measurements:
return _measurements[gate]
elif gate in _rotations:
return _rotations[gate]
elif gate in _two_qubit_gates:
return _two_qubit_gates[gate]
elif gate in _three_qubit_gates:
return _three_qubit_gates[gate]

raise ValueError(f"Unknown Cirq gate {gate_name}")

Expand All @@ -58,4 +74,75 @@ def test_fixture():
name = _fixture_name(gate)
locals()[name] = _generate_one_qubit_fixture(gate)


def _generate_rotation_fixture(gate_name: str):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
q = cirq.NamedQubit("q")
circuit.append(getattr(cirq, gate_name)(rads=0.5)(q))
return _map_gate_name(gate_name), circuit

return test_fixture


# Generate rotation gate fixtures
for gate in _rotations.keys():
name = _fixture_name(gate)
locals()[name] = _generate_rotation_fixture(gate)


def _generate_two_qubit_fixture(gate_name: str):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
qs = cirq.LineQubit(2)
circuit.append(getattr(cirq, gate_name)(qs[0], qs[1]))
return _map_gate_name(gate_name), circuit

return test_fixture


# Generate double-qubit gate fixtures
for gate in _two_qubit_gates.keys():
name = _fixture_name(gate)
locals()[name] = _generate_two_qubit_fixture(gate)


def _generate_three_qubit_fixture(gate_name: str):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
qs = cirq.LineQubit(3)
circuit.append(getattr(cirq, gate_name)(qs[0], qs[1], qs[2]))
return _map_gate_name(gate_name), circuit

return test_fixture


# Generate three-qubit gate fixtures
for gate in _three_qubit_gates.keys():
name = _fixture_name(gate)
locals()[name] = _generate_three_qubit_fixture(gate)


def _generate_measurement_fixture(gate_name: str):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
q = cirq.NamedQubit("q")
circuit.append(getattr(cirq, gate_name)(q))
return _map_gate_name(gate_name), circuit

return test_fixture


for gate in _measurements.keys():
name = _fixture_name(gate)
locals()[name] = _generate_measurement_fixture(gate)

single_op_tests = [_fixture_name(s) for s in _one_qubit_gates]
rotation_tests = [_fixture_name(s) for s in _rotations.keys()]
double_op_tests = [_fixture_name(s) for s in _two_qubit_gates.keys()]
triple_op_tests = [_fixture_name(s) for s in _three_qubit_gates.keys()]
measurement_tests = [_fixture_name(s) for s in _measurements.keys()]
12 changes: 12 additions & 0 deletions tests/test_cirq_to_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import cirq
import pytest

from tests.fixtures.basic_gates import single_op_tests
import tests.test_utils as test_utils
from qbraid_qir.cirq.convert import cirq_to_qir, generate_module_id
from qbraid_qir.exceptions import QirConversionError

Expand All @@ -36,6 +38,16 @@ def test_cirq_to_qir_conversion_error():
with pytest.raises(QirConversionError):
cirq_to_qir(circuit)

@pytest.mark.parametrize("circuit_name", single_op_tests)
def test_single_qubit_gates(circuit_name, request):
qir_op, circuit = request.getfixturevalue(circuit_name)
generated_qir = str(cirq_to_qir(circuit)[0]).splitlines()
func = test_utils.get_entry_point_body(generated_qir)
assert func[0] == test_utils.initialize_call_string()
assert func[1] == test_utils.single_op_call_string(qir_op, 0)
assert func[2] == test_utils.return_string()
assert len(func) == 3


def test_verify_qir_bell_fixture(pyqir_bell):
"""Test that pyqir fixture generates code equal to test_qir_bell.ll file."""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

from pyqir import is_entry_point, Module, Function, Context


def _qubit_string(qubit: int) -> str:
if qubit == 0:
return "%Qubit* null"
else:
return f"%Qubit* inttoptr (i64 {qubit} to %Qubit*)"


def initialize_call_string() -> str:
return "call void @__quantum__rt__initialize(i8* null)"


def single_op_call_string(name: str, qb: int) -> str:
return f"call void @__quantum__qis__{name}__body({_qubit_string(qb)})"


def get_entry_point(mod: Module) -> Function:
func = next(filter(is_entry_point, mod.functions))
assert func is not None, "No main function found"
return func

def get_entry_point_body(qir: List[str]) -> List[str]:
joined = "\n".join(qir)
mod = Module.from_ir(Context(), joined)
func = next(filter(is_entry_point, mod.functions))
assert func is not None, "No main function found"
lines = str(func).splitlines()[2:-1]
return list(map(lambda line: line.strip(), lines))


def return_string() -> str:
return "ret void"
Loading