Skip to content

Commit

Permalink
global phase
Browse files Browse the repository at this point in the history
  • Loading branch information
arulandu committed Jan 20, 2025
1 parent f6d4d77 commit 11eeb84
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions src/pyqasm/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def _draw_mpl(module: Qasm3Module, idle_wires=True) -> plt.Figure:
depths = dict()
for k in line_nums.keys():
depths[k] = -1

global_phase = sum([Qasm3ExprEvaluator.evaluate_expression(s.argument)[0] for s in statements if isinstance(s, ast.QuantumPhase)])
statements = [s for s in statements if not isinstance(s, ast.QuantumPhase)]

moments = []
for statement in statements:
Expand All @@ -109,12 +112,6 @@ def _draw_mpl(module: Qasm3Module, idle_wires=True) -> plt.Figure:
depth = 1 + max(depths[qubit_key], depths[target_key])
for k in [qubit_key, target_key]:
depths[k] = depth
elif isinstance(statement, ast.QuantumPhase):
qubits = [_identifier_to_key(q) for q in statement.qubits]
if len(qubits) > 0:
depth = 1 + max([depths[q] for q in qubits])
for q in qubits:
depths[q] = depth
elif isinstance(statement, ast.QuantumBarrier):
qubits = [_identifier_to_key(q) for q in statement.qubits]
depth = 1 + max([depths[q] for q in qubits])
Expand Down Expand Up @@ -175,6 +172,7 @@ def _draw_mpl(module: Qasm3Module, idle_wires=True) -> plt.Figure:
ax = axs[sidx]
x = 0
if sidx == 0:
if global_phase != 0: _mpl_draw_global_phase(global_phase, ax, x)
for k in module._qubit_registers.keys():
for i in range(module._qubit_registers[k]):
if (k, i) in line_nums:
Expand Down Expand Up @@ -213,11 +211,11 @@ def _identifier_to_key(identifier: ast.Identifier | ast.IndexedIdentifier) -> tu
def _mpl_line_to_y(line_num: int) -> float:
return line_num * (GATE_BOX_HEIGHT + LINE_SPACING)

def _mpl_draw_global_phase(global_phase: float, ax: plt.Axes, x: float):
ax.text(x, -0.75, f"Global Phase: {global_phase:.3f}", ha="left", va="center")

def _mpl_draw_qubit_label(qubit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{qubit[0]}[{qubit[1]}]", ha="right", va="center")


def _mpl_draw_creg_label(creg: str, size: int, line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{creg[0]}", ha="right", va="center")

Expand Down Expand Up @@ -290,9 +288,6 @@ def _mpl_draw_statement(
_mpl_draw_measurement(
line_nums[qubit_key], line_nums[(target_key[0], -1)], target_key[1], ax, x
)
elif isinstance(statement, ast.QuantumPhase):
# TODO: draw gphase
pass
elif isinstance(statement, ast.QuantumBarrier):
lines = [line_nums[_identifier_to_key(q)] for q in statement.qubits]
_mpl_draw_barrier(lines, ax, x)
Expand Down

0 comments on commit 11eeb84

Please sign in to comment.