Skip to content

Commit

Permalink
Support for externally linked gates (#182)
Browse files Browse the repository at this point in the history
* preliminary support for externally linked gates

* trigger error if modifier present

* add docstring

* format code

* add tests + fix linter issues

* remove print

* Remove comment

Co-authored-by: Harshit Gupta <[email protected]>

* remove emtpy line after accepting suggested change

* update pyqasm dependency

* use pyqasm.loads instead of pyqasm.load

* ignore global phases for now

* fix linter issue

---------

Co-authored-by: Tobias Schmale <[email protected]>
Co-authored-by: Harshit Gupta <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent 0259886 commit 176f0db
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Discord = "https://discord.gg/TPBU2sa8Et"

[project.optional-dependencies]
cirq = ["cirq-core>=1.3.0,<1.5.0"]
qasm3 = ["pyqasm==0.0.3", "numpy"]
qasm3 = ["pyqasm==0.1.0a1", "numpy"]
test = ["qbraid>=0.8.3,<0.9.0", "pytest", "pytest-cov", "autoqasm>=0.1.0"]
lint = ["black", "isort", "pylint", "qbraid-cli>=0.8.7"]
docs = ["sphinx>=7.3.7,<=8.3.0", "sphinx-autodoc-typehints>=1.24,<2.6", "sphinx-rtd-theme>=2.0,<3.1", "docutils<0.22", "sphinx-copybutton"]
Expand Down
5 changes: 3 additions & 2 deletions qbraid_qir/qasm3/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def qasm3_to_qir(
elif not isinstance(program, str):
raise TypeError("Input quantum program must be of type openqasm3.ast.Program or str.")

qasm3_module = pyqasm.load(program)
qasm3_module.unroll()
external_gates: list[str] = kwargs.get("external_gates", [])

qasm3_module = pyqasm.loads(program)
qasm3_module.unroll(external_gates=external_gates)
if name is None:
name = generate_module_id()
llvm_module = qir_module(Context(), name)
Expand Down
66 changes: 64 additions & 2 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ class QasmQIRVisitor:
Args:
initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True.
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
external_gates (list[str]): List of custom gates that should not be unrolled.
Instead, these gates are marked for external linkage, as
qir-functions with the name "__quantum__qis__<GateName>__body"
"""

def __init__(
self,
initialize_runtime: bool = True,
record_output: bool = True,
external_gates: list[str] | None = None,
):
self._llvm_module: pyqir.Module
self._builder: pyqir.Builder
Expand All @@ -57,6 +61,12 @@ def __init__(
self._initialize_runtime: bool = initialize_runtime
self._record_output: bool = record_output

if external_gates is None:
external_gates = []
self._external_gates_map: dict[str, pyqir.Function | None] = {
external_gate: None for external_gate in external_gates
}

def visit_qasm3_module(self, module: QasmQIRModule) -> None:
"""
Visit a Qasm3 module.
Expand Down Expand Up @@ -319,6 +329,55 @@ def _visit_basic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
else:
qir_func(self._builder, *qubit_subset)

def _visit_external_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit an external gate operation element.
Args:
operation (qasm3_ast.QuantumGate): The gate operation to visit.
Returns:
None
Raises:
Qasm3ConversionError: If the number of qubits is invalid.
"""
logger.debug("Visiting external gate operation '%s'", str(operation))
op_name: str = operation.name.name
op_qubits = self._get_op_bits(operation)
op_qubit_count = len(op_qubits)

if len(operation.modifiers) > 0:
raise_qasm3_error(
"Modifiers on externally linked gates are not supported in pyqir",
err_type=NotImplementedError,
)

context = self._llvm_module.context
qir_function = self._external_gates_map[op_name]
if qir_function is None:
# First time seeing this external gate -> define new function
qir_function_arguments = [pyqir.Type.double(context)] * len(operation.arguments)
qir_function_arguments += [pyqir.qubit_type(context)] * op_qubit_count

qir_function = pyqir.Function(
pyqir.FunctionType(pyqir.Type.void(context), qir_function_arguments),
pyqir.Linkage.EXTERNAL,
f"__quantum__qis__{op_name}__body",
self._llvm_module,
)
self._external_gates_map[op_name] = qir_function

op_parameters = None
if len(operation.arguments) > 0: # parametric gate
op_parameters = self._get_op_parameters(operation)

if op_parameters is not None:
self._builder.call(qir_function, [*op_parameters, *op_qubits])
else:
self._builder.call(qir_function, op_qubits)

def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit a gate operation element.
Expand All @@ -328,8 +387,10 @@ def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> Non
Returns:
None
"""
# TODO: maybe needs to be extended for custom gates
self._visit_basic_gate_operation(operation)
if operation.name.name in self._external_gates_map:
self._visit_external_gate_operation(operation)
else:
self._visit_basic_gate_operation(operation)

def _get_branch_params(self, condition: Any) -> tuple[str, int, bool]:
"""
Expand Down Expand Up @@ -421,6 +482,7 @@ def visit_statement(self, statement: qasm3_ast.Statement) -> None:
qasm3_ast.QuantumBarrier: self._visit_barrier,
qasm3_ast.QuantumGate: self._visit_generic_gate_operation,
qasm3_ast.BranchingStatement: self._visit_branching_statement,
qasm3_ast.QuantumPhase: lambda x: None, # No operation
}

visitor_function = visit_map.get(type(statement))
Expand Down
2 changes: 1 addition & 1 deletion tests/cirq_qir/test_cirq_to_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_triple_qubit_gates(circuit_name, request):
check_attributes(generated_qir, 3, 3)
func = get_entry_point_body(generated_qir)
assert func[0] == initialize_call_string()
assert func[1] == generic_op_call_string(qir_op, [0, 1, 2])
assert func[1] == generic_op_call_string(qir_op, [], [0, 1, 2])
assert func[5] == return_string()
assert len(func) == 6

Expand Down
29 changes: 29 additions & 0 deletions tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tests.qir_utils import (
check_attributes,
check_custom_qasm_gate_op,
check_custom_qasm_gate_op_with_external_gates,
check_generic_gate_op,
check_single_qubit_gate_op,
check_single_qubit_rotation_op,
check_three_qubit_gate_op,
Expand Down Expand Up @@ -144,6 +146,20 @@ def test_qasm_u3_gates():
check_single_qubit_rotation_op(generated_qir, 1, [0], [0.5, 0.5, 0.5], "u3")


def test_qasm_u3_gates_external():
qasm3_string = """
OPENQASM 3;
include "stdgates.inc";
qubit[2] q1;
u3(0.5, 0.5, 0.5) q1[0];
"""
result = qasm3_to_qir(qasm3_string, external_gates=["u3"])
generated_qir = str(result).splitlines()
check_attributes(generated_qir, 2, 0)
check_generic_gate_op(generated_qir, 1, [0], ["5.000000e-01"] * 3, "u3")


def test_qasm_u2_gates():
qasm3_string = """
OPENQASM 3;
Expand Down Expand Up @@ -171,6 +187,19 @@ def test_custom_ops(test_name, request):
check_custom_qasm_gate_op(generated_qir, gate_type)


@pytest.mark.parametrize("test_name", custom_op_tests)
def test_custom_ops_with_external_gates(test_name, request):
qasm3_string = request.getfixturevalue(test_name)
gate_type = test_name.removeprefix("Fixture_")
result = qasm3_to_qir(qasm3_string, external_gates=["custom", "custom1"])

generated_qir = str(result).splitlines()
check_attributes(generated_qir, 2, 0)

# Check for custom gate definition
check_custom_qasm_gate_op_with_external_gates(generated_qir, gate_type)


def test_pow_gate_modifier():
qasm3_string = """
OPENQASM 3;
Expand Down
52 changes: 48 additions & 4 deletions tests/qir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def reset_call_string(qb: int) -> str:
return f"call void @__quantum__qis__reset__body({_qubit_string(qb)})"


def generic_op_call_string(name: str, qbs: list[int]) -> str:
args = ", ".join(_qubit_string(qb) for qb in qbs)
return f"call void @__quantum__qis__{name}__body({args})"
def generic_op_call_string(name: str, angles: list[str], qubits: list[int]) -> str:
angles = ["double " + angle for angle in angles]
qubits = [_qubit_string(q) for q in qubits]
parameters = ", ".join(angles + qubits)
return f"call void @__quantum__qis__{name}__body({parameters})"


def return_string() -> str:
Expand Down Expand Up @@ -235,6 +237,31 @@ def check_single_qubit_gate_op(
), f"Incorrect single qubit gate count: {expected_ops} expected, {op_count} actual"


def check_generic_gate_op(
qir: list[str], expected_ops: int, qubit_list: list[int], param_list: list[str], gate_name: str
):
entry_body = get_entry_point_body(qir)
op_count = 0

for line in entry_body:
gate_call_id = (
f"qis__{gate_name}" if "dg" not in gate_name else f"qis__{gate_name.removesuffix('dg')}"
)
if line.strip().startswith("call") and gate_call_id in line:
expected_line = generic_op_call_string(gate_name, param_list, qubit_list)
assert line.strip() == expected_line, (
"Incorrect single qubit gate call in qir"
+ f"Expected {expected_line}, found {line.strip()}"
)
op_count += 1

if op_count == expected_ops:
break

if op_count != expected_ops:
assert False, f"Incorrect gate count: {expected_ops} expected, {op_count} actual"


def check_two_qubit_gate_op(
qir: list[str], expected_ops: int, qubit_lists: list[int], gate_name: str
):
Expand Down Expand Up @@ -346,7 +373,7 @@ def check_three_qubit_gate_op(
for line in entry_body:
if line.strip().startswith("call") and f"qis__{gate_name}" in line:
assert line.strip() == generic_op_call_string(
gate_name, qubit_lists[q_id]
gate_name, [], qubit_lists[q_id]
), f"Incorrect three qubit gate call in qir - {line}"
op_count += 1
q_id += 1
Expand Down Expand Up @@ -427,6 +454,23 @@ def check_custom_qasm_gate_op(qir: list[str], test_type: str):
assert False, f"Unknown test type {test_type} for custom ops"


def check_custom_qasm_gate_op_with_external_gates(qir: list[str], test_type: str):
if test_type == "simple":
check_generic_gate_op(qir, 1, [0, 1], ["1.100000e+00"], "custom")
elif test_type == "nested":
check_generic_gate_op(
qir, 1, [0, 1], ["4.800000e+00", "1.000000e-01", "3.000000e-01"], "custom"
)
elif test_type == "complex":
# Only custom1 is external, custom2 and custom3 should be unrolled
check_generic_gate_op(qir, 1, [0], [], "custom1")
check_generic_gate_op(qir, 1, [0], ["1.000000e-01"], "ry")
check_generic_gate_op(qir, 1, [0], ["2.000000e-01"], "rz")
check_generic_gate_op(qir, 1, [0, 1], [], "cnot")
else:
assert False, f"Unknown test type {test_type} for custom ops"


def check_expressions(
qir: list[str], expected_ops: int, gates: list[str], expression_values, qubits: list[int]
):
Expand Down

0 comments on commit 176f0db

Please sign in to comment.