Skip to content

Commit

Permalink
dialects: (stim) Add parser for .stim formatted strings (#3122)
Browse files Browse the repository at this point in the history
Stim circuits are fomatted as detailed on this page:

https://github.com/quantumlib/Stim/blob/main/doc/file_format_stim_circuit.md

This pr adds a parser that can parse this format into stim dialect
operations, and adds tests for the parsing of existing operations

---------

Co-authored-by: Emilien Bauer <[email protected]>
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2024
1 parent bd9aaf7 commit c017c10
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 12 deletions.
62 changes: 61 additions & 1 deletion tests/dialects/stim/test_stim_printer_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest

from xdsl.dialects import stim
from xdsl.dialects.stim.ops import QubitAttr, QubitCoordsOp, QubitMappingAttr
from xdsl.dialects.stim.ops import (
QubitAttr,
QubitCoordsOp,
QubitMappingAttr,
)
from xdsl.dialects.stim.stim_parser import StimParseError, StimParser
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.dialects.test import TestOp
from xdsl.ir import Block, Region
Expand All @@ -20,6 +25,19 @@ def check_stim_print(program: StimPrintable, expected_stim: str):
assert expected_stim == res_io.getvalue()


def check_stim_roundtrip(program: str):
"""Check that the given program roundtrips exactly (including whitespaces)."""
stim_parser = StimParser(program)
stim_circuit = stim_parser.parse_circuit()

check_stim_print(stim_circuit, program)


################################################################################
# Test operations stim_print() #
################################################################################


def test_empty_circuit():
empty_block = Block()
empty_region = Region(empty_block)
Expand Down Expand Up @@ -60,3 +78,45 @@ def test_print_stim_qubit_coord_op():
qubit_annotation = QubitCoordsOp(qubit_coord)
expected_stim = "QUBIT_COORDS(0, 0) 0"
check_stim_print(qubit_annotation, expected_stim)


################################################################################
# Test stim parser and printer #
################################################################################


@pytest.mark.parametrize(
"program",
[(""), ("\n"), ("#hi"), ("# hi \n" "#hi\n")],
)
def test_stim_roundtrip_empty_circuit(program: str):
stim_parser = StimParser(program)
stim_circuit = stim_parser.parse_circuit()
check_stim_print(stim_circuit, "")


@pytest.mark.parametrize(
"program",
[
("QUBIT_COORDS() 0\n"),
("QUBIT_COORDS(0, 0) 0\n"),
("QUBIT_COORDS(0, 2) 1\n"),
("QUBIT_COORDS(0, 0) 0\n" "QUBIT_COORDS(1, 2) 2\n"),
],
)
def test_stim_roundtrip_qubit_coord_op(program: str):
check_stim_roundtrip(program)


def test_no_spaces_before_target():
with pytest.raises(StimParseError, match="Targets must be separated by spacing."):
program = "QUBIT_COORDS(1, 1)1"
parser = StimParser(program)
parser.parse_circuit()


def test_no_targets():
program = "QUBIT_COORDS(1, 1)"
with pytest.raises(StimParseError, match="Expected at least one target"):
parser = StimParser(program)
parser.parse_circuit()
24 changes: 15 additions & 9 deletions xdsl/dialects/stim/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from io import StringIO

from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.ir import ParametrizedAttribute, Region, TypeAttribute
from xdsl.irdl import (
Expand Down Expand Up @@ -63,26 +63,33 @@ class QubitMappingAttr(StimPrintable, ParametrizedAttribute):

name = "stim.qubit_coord"

coords: ParameterDef[ArrayAttr[IntAttr]]
coords: ParameterDef[ArrayAttr[FloatData | IntAttr]]
qubit_name: ParameterDef[QubitAttr]

def __init__(
self, coords: list[int] | ArrayAttr[IntAttr], qubit_name: int | QubitAttr
self,
coords: list[float] | ArrayAttr[FloatData | IntAttr],
qubit_name: int | QubitAttr,
) -> None:
if not isinstance(qubit_name, QubitAttr):
qubit_name = QubitAttr(qubit_name)
if not isinstance(coords, ArrayAttr):
coords = ArrayAttr(IntAttr(c) for c in coords)
coords = ArrayAttr(
(IntAttr(int(arg))) if (type(arg) is int) else (FloatData(arg))
for arg in coords
)
super().__init__(parameters=[coords, qubit_name])

@classmethod
def parse_parameters(
cls, parser: AttrParser
) -> tuple[ArrayAttr[IntAttr], QubitAttr]:
) -> tuple[ArrayAttr[FloatData | IntAttr], QubitAttr]:
parser.parse_punctuation("<")
coords = parser.parse_comma_separated_list(
delimiter=parser.Delimiter.PAREN,
parse=lambda: IntAttr(parser.parse_integer(allow_boolean=False)),
parse=lambda: IntAttr(x)
if type(x := parser.parse_number(allow_boolean=False)) is int
else FloatData(x),
)
parser.parse_punctuation(",")
qubit = parser.parse_attribute()
Expand Down Expand Up @@ -128,9 +135,8 @@ def verify(self, verify_nested_ops: bool = True) -> None:

def print_stim(self, printer: StimPrinter):
for op in self.body.block.ops:
if not isinstance(op, StimPrintable):
raise ValueError(f"Cannot print in stim format: {op}")
op.print_stim(printer)
printer.print_op(op)
printer.print_string("\n")
printer.print_string("")

def stim(self) -> str:
Expand Down
Loading

0 comments on commit c017c10

Please sign in to comment.