Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QasmModule Circuit Drawer #122

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Operating System :: Unix",
"Operating System :: MacOS",
]
dependencies = ["numpy", "openqasm3[parser]>=1.0.0,<2.0.0"]
dependencies = ["numpy", "matplotlib", "openqasm3[parser]>=1.0.0,<2.0.0"]
arulandu marked this conversation as resolved.
Show resolved Hide resolved

[project.urls]
"Source Code" = "https://github.com/qBraid/pyqasm"
Expand Down
15 changes: 15 additions & 0 deletions src/pyqasm/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,21 @@ def map_qasm_inv_op_to_callable(op_name: str):
raise ValidationError(f"Unsupported / undeclared QASM operation: {op_name}")


REV_CTRL_GATE_MAP = {
"cx": "x",
"cy": "y",
"cz": "z",
"crx": "rx",
"cry": "ry",
"crz": "rz",
"cp": "p",
"ch": "h",
"cu": "u",
"cswap": "swap",
"ccx": "cx",
}


# pylint: disable=inconsistent-return-statements
def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value):
"""Cast the variable type to the type to match, if possible.
Expand Down
4 changes: 4 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,7 @@ def accept(self, visitor):
Args:
visitor (QasmVisitor): The visitor to accept
"""

@abstractmethod
def draw(self):
"""Draw the module"""
5 changes: 5 additions & 0 deletions src/pyqasm/modules/qasm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyqasm.exceptions import ValidationError
from pyqasm.modules.base import QasmModule
from pyqasm.modules.qasm3 import Qasm3Module
from pyqasm.printer import draw


class Qasm2Module(QasmModule):
Expand Down Expand Up @@ -105,3 +106,7 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self.unrolled_ast.statements = final_stmt_list

def draw(self):
"""Draw the module"""
return draw(self.to_qasm3())
5 changes: 5 additions & 0 deletions src/pyqasm/modules/qasm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openqasm3.printer import dumps

from pyqasm.modules.base import QasmModule
from pyqasm.printer import draw


class Qasm3Module(QasmModule):
Expand Down Expand Up @@ -48,3 +49,7 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self._unrolled_ast.statements = final_stmt_list

def draw(self):
"""Draw the module"""
return draw(self)
297 changes: 297 additions & 0 deletions src/pyqasm/printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
# Copyright (C) 2024 qBraid
#
# This file is part of pyqasm
#
# Pyqasm is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for pyqasm, as per Section 15 of the GPL v3.

"""
Module with analysis functions for QASM visitor

"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Union

import matplotlib as mpl
import openqasm3.ast as ast
from matplotlib import pyplot as plt

from pyqasm.expressions import Qasm3ExprEvaluator
from pyqasm.maps import (
FIVE_QUBIT_OP_MAP,
FOUR_QUBIT_OP_MAP,
ONE_QUBIT_OP_MAP,
ONE_QUBIT_ROTATION_MAP,
REV_CTRL_GATE_MAP,
THREE_QUBIT_OP_MAP,
TWO_QUBIT_OP_MAP,
)

if TYPE_CHECKING:
from pyqasm.modules.base import Qasm3Module

DEFAULT_GATE_COLOR = "#d4b6e8"
HADAMARD_GATE_COLOR = "#f0a6a6"

GATE_BOX_WIDTH, GATE_BOX_HEIGHT = 0.6, 0.6
GATE_SPACING = 0.2
LINE_SPACING = 0.6
TEXT_MARGIN = 0.6
FRAME_PADDING = 0.2


def draw(module: Qasm3Module, output="mpl"):
if output == "mpl":
return _draw_mpl(module)
else:
raise NotImplementedError(f"{output} drawing for Qasm3Module is unsupported")


def _draw_mpl(module: Qasm3Module) -> plt.Figure:
module.unroll()
module.remove_includes()
module.remove_barriers()

n_lines = module._num_qubits + module._num_clbits
statements = module._statements

# compute line numbers per qubit + max depth
line_nums = dict()
line_num = -1
max_depth = 0

for clbit_reg in module._classical_registers.keys():
size = module._classical_registers[clbit_reg]
line_num += size
for i in range(size):
line_nums[(clbit_reg, i)] = line_num
line_num -= 1
line_num += size

for qubit_reg in module._qubit_registers.keys():
size = module._qubit_registers[qubit_reg]
line_num += size
for i in range(size):
line_nums[(qubit_reg, i)] = line_num
depth = module._qubit_depths[(qubit_reg, i)]._total_ops()
max_depth = max(max_depth, depth)
line_num -= 1
line_num += size

# compute moments
depths = dict()
for k in line_nums.keys():
depths[k] = -1

moments = []
for statement in statements:
if "Declaration" in str(type(statement)):
continue
if isinstance(statement, ast.QuantumGate):
qubits = [_identifier_to_key(q) for q in statement.qubits]
depth = 1 + max([depths[q] for q in qubits])
for q in qubits:
depths[q] = depth
elif isinstance(statement, ast.QuantumMeasurementStatement):
qubit_key = _identifier_to_key(statement.measure.qubit)
target_key = _identifier_to_key(statement.target)
depth = 1 + max(depths[qubit_key], depths[target_key])
for k in [qubit_key, target_key]:
depths[k] = depth
elif isinstance(statement, ast.QuantumBarrier):
pass
elif isinstance(statement, ast.QuantumReset):
pass
else:
raise NotImplementedError(f"Unsupported statement: {statement}")

if depth >= len(moments):
moments.append([])
moments[depth].append(statement)

width = 0
for moment in moments:
width += _mpl_get_moment_width(moment)
width += TEXT_MARGIN

fig, ax = plt.subplots(
figsize=(width, n_lines * GATE_BOX_HEIGHT + LINE_SPACING * (n_lines - 1))
)
ax.set_ylim(
-GATE_BOX_HEIGHT / 2 - FRAME_PADDING / 2,
n_lines * GATE_BOX_HEIGHT
+ LINE_SPACING * (n_lines - 1)
- GATE_BOX_HEIGHT / 2
+ FRAME_PADDING / 2,
)
ax.set_xlim(-FRAME_PADDING / 2, width)
ax.axis("off")
# ax.set_aspect('equal')
# plt.tight_layout()

x = 0
for k in module._qubit_registers.keys():
for i in range(module._qubit_registers[k]):
line_num = line_nums[(k, i)]
_mpl_draw_qubit_label((k, i), line_num, ax, x)
for k in module._classical_registers.keys():
for i in range(module._classical_registers[k]):
line_num = line_nums[(k, i)]
_mpl_draw_clbit_label((k, i), line_num, ax, x)
x += TEXT_MARGIN
x0 = x
for moment in moments:
dx = _mpl_get_moment_width(moment)
_mpl_draw_lines(dx, line_nums, ax, x)
x += dx
x = x0
for moment in moments:
dx = _mpl_get_moment_width(moment)
for statement in moment:
_mpl_draw_statement(statement, line_nums, ax, x)
x += dx

return fig


def _identifier_to_key(identifier: ast.Identifier | ast.IndexedIdentifier) -> tuple[str, int]:
if isinstance(identifier, ast.Identifier):
return identifier.name, -1
else:
return (
identifier.name.name,
Qasm3ExprEvaluator.evaluate_expression(identifier.indices[0][0])[0],
)


def _mpl_line_to_y(line_num: int) -> float:
return line_num * (GATE_BOX_HEIGHT + LINE_SPACING)


def _mpl_draw_qubit_label(qubit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{qubit[0]}[{qubit[1]}]", ha="right", va="center")


def _mpl_draw_clbit_label(clbit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{clbit[0]}[{clbit[1]}]", ha="right", va="center")


def _mpl_draw_lines(width, line_nums: dict[tuple[str, int], int], ax: plt.Axes, x: float):
for k in line_nums.keys():
y = _mpl_line_to_y(line_nums[k])
ax.hlines(
xmin=x - width / 2, xmax=x + width / 2, y=y, color="black", linestyle="-", zorder=-10
)


def _mpl_get_moment_width(moment: list[ast.QuantumStatement]) -> float:
return max([_mpl_get_statement_width(s) for s in moment])


def _mpl_get_statement_width(statement: ast.QuantumStatement) -> float:
return GATE_BOX_WIDTH + GATE_SPACING


def _mpl_draw_statement(
statement: ast.QuantumStatement, line_nums: dict[tuple[str, int], int], ax: plt.Axes, x: float
):
if isinstance(statement, ast.QuantumGate):
args = [Qasm3ExprEvaluator.evaluate_expression(arg)[0] for arg in statement.arguments]
lines = [line_nums[_identifier_to_key(q)] for q in statement.qubits]
_mpl_draw_gate(statement, args, lines, ax, x)
elif isinstance(statement, ast.QuantumMeasurementStatement):
qubit_key = _identifier_to_key(statement.measure.qubit)
target_key = _identifier_to_key(statement.target)
_mpl_draw_measurement(line_nums[qubit_key], line_nums[target_key], ax, x)
else:
raise NotImplementedError(f"Unsupported statement: {statement}")


def _mpl_draw_gate(
gate: ast.QuantumGate, args: list[Any], lines: list[int], ax: plt.Axes, x: float
):
name = gate.name.name
if name in REV_CTRL_GATE_MAP:
i = 0
while name in REV_CTRL_GATE_MAP:
name = REV_CTRL_GATE_MAP[name]
_draw_mpl_control(lines[i], lines[-1], ax, x)
i += 1
lines = lines[i:]
gate.name.name = name

if name in ONE_QUBIT_OP_MAP or name in ONE_QUBIT_ROTATION_MAP:
_draw_mpl_one_qubit_gate(gate, args, lines[0], ax, x)
elif name in TWO_QUBIT_OP_MAP:
if name == "swap":
_draw_mpl_swap(lines[0], lines[1], ax, x)
else:
raise NotImplementedError(f"Unsupported gate: {name}")
else:
raise NotImplementedError(f"Unsupported gate: {name}")


# TODO: switch to moment based system. go progressively, calculating required width for each moment, center the rest. this makes position calculations not to bad. if we overflow, start a new figure.


def _draw_mpl_one_qubit_gate(
gate: ast.QuantumGate, args: list[Any], line: int, ax: plt.Axes, x: float
):
color = DEFAULT_GATE_COLOR
if gate.name.name == "h":
color = HADAMARD_GATE_COLOR
text = gate.name.name.upper()
if len(args) > 0:
text += f"\n({', '.join([f'{a:.3f}' if isinstance(a, float) else str(a) for a in args])})"

y = _mpl_line_to_y(line)
rect = plt.Rectangle(
(x - GATE_BOX_WIDTH / 2, y - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor=color,
edgecolor="none",
)
ax.add_patch(rect)
ax.text(x, y, text, ha="center", va="center")


def _draw_mpl_control(ctrl_line: int, target_line: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(ctrl_line)
y2 = _mpl_line_to_y(target_line)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-", zorder=-1)
ax.plot(x, y1, "ko", markersize=8, markerfacecolor="black")


def _draw_mpl_swap(line1: int, line2: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(line1)
y2 = _mpl_line_to_y(line2)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-")
ax.plot(x, y1, "x", markersize=8, color="black")
ax.plot(x, y2, "x", markersize=8, color="black")


def _mpl_draw_measurement(qbit_line: int, cbit_line: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(qbit_line)
y2 = _mpl_line_to_y(cbit_line)

rect = plt.Rectangle(
(x - GATE_BOX_WIDTH / 2, y1 - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor="gray",
edgecolor="none",
)
ax.add_patch(rect)
ax.text(x, y1, "M", ha="center", va="center")
ax.vlines(
x=x - 0.025, ymin=min(y1, y2), ymax=max(y1, y2), color="gray", linestyle="-", zorder=-1
)
ax.vlines(
x=x + 0.025, ymin=min(y1, y2), ymax=max(y1, y2), color="gray", linestyle="-", zorder=-1
)
ax.plot(x, y2 + 0.1, "v", markersize=16, color="gray")
Binary file added test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading