Skip to content

Commit

Permalink
Reimplementing _repr_pretty_ with tree/plot printer. (#870)
Browse files Browse the repository at this point in the history
* re=implementing  for lattice

* Add truncation for large children list.

* simply if statement.

* adding test.

* fixing test.

* actually fixing test.
  • Loading branch information
weinbe58 authored Jan 17, 2024
1 parent dd356ab commit a455915
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 27 deletions.
36 changes: 21 additions & 15 deletions src/bloqade/ir/location/bravais.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def cell_vectors(self) -> List[List[float]]:
def cell_atoms(self) -> List[List[float]]:
return [[0, 0]]

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass
class Square(BoundedBravais):
Expand Down Expand Up @@ -249,15 +252,15 @@ def __init__(
def shape(self) -> Tuple[int, ...]:
return (self.L1, self.L2)

def __repr__(self):
return super().__repr__()

def cell_vectors(self) -> List[List[float]]:
return [[1, 0], [0, 1]]

def cell_atoms(self) -> List[List[float]]:
return [[0, 0]]

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass(init=False)
class Rectangular(BoundedBravais):
Expand Down Expand Up @@ -378,6 +381,9 @@ def cell_vectors(self) -> List[List[float]]:
def cell_atoms(self) -> List[List[float]]:
return [[0, 0]]

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass
class Honeycomb(BoundedBravais):
Expand Down Expand Up @@ -428,15 +434,15 @@ def __init__(
def shape(self) -> Tuple[int, ...]:
return (self.L1, self.L2)

def __repr__(self):
return super().__repr__()

def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]

def cell_atoms(self) -> List[List[float]]:
return [[0.0, 0.0], [1 / 2, 1 / (2 * np.sqrt(3))]]

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass
class Triangular(BoundedBravais):
Expand Down Expand Up @@ -484,15 +490,15 @@ def __init__(
def shape(self) -> Tuple[int, ...]:
return (self.L1, self.L2)

def __repr__(self):
return super().__repr__()

def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]

def cell_atoms(self) -> List[List[float]]:
return [[0.0, 0.0]]

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass
class Lieb(BoundedBravais):
Expand Down Expand Up @@ -535,9 +541,6 @@ def __init__(
self.L2 = L2
self.lattice_spacing = cast(lattice_spacing)

def __repr__(self):
return super().__repr__()

def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [0.0, 1.0]]

Expand All @@ -548,6 +551,9 @@ def cell_atoms(self) -> List[List[float]]:
def shape(self) -> Tuple[int, ...]:
return (self.L1, self.L2)

def _repr_pretty_(self, p, _):
p.text(str(self))


@dataclass
class Kagome(BoundedBravais):
Expand Down Expand Up @@ -596,11 +602,11 @@ def __init__(
def shape(self) -> Tuple[int, ...]:
return (self.L1, self.L2)

def __repr__(self):
return super().__repr__()

def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]

def cell_atoms(self) -> List[List[float]]:
return [[0.0, 0.0], [1 / 2, 0], [1 / 4, np.sqrt(3) / 4]]

def _repr_pretty_(self, p, _):
p.text(str(self))
9 changes: 6 additions & 3 deletions src/bloqade/ir/location/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@ def is_literal(x):
ph.print(self)
return ph.get_value()

def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def print_node(self) -> str:
return "AtomArrangement"

Expand Down Expand Up @@ -518,6 +515,9 @@ def n_dims(self):
def __str__(self):
return "ParallelRegister:\n" + self.atom_arrangement.__str__()

def _repr_pretty_(self, p, _):
p.text(str(self))

def _compile_to_list(
self, __capabilities: Optional[QuEraCapabilities] = None, **assignments
):
Expand Down Expand Up @@ -626,6 +626,9 @@ def __init__(

super().__init__()

def _repr_pretty_(self, p, _):
p.text(str(self))

@property
def n_atoms(self):
return self.__n_atoms
Expand Down
31 changes: 24 additions & 7 deletions src/bloqade/ir/tree_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,33 @@ def print(self, node, cycle=None):
else:
children = list(children)

while not len(children) == 0:
list_depth = 0
trunc_list_print = False
printed_list_trunc = False

trunc_list_depth = len(children) > 2 * MAX_TREE_DEPTH

while children:
child_prefix = self.state.prefix
if this_print_annotation:
annotation, child = children.pop(0)
else:
child = children.pop(0)
annotation = None

list_depth += 1
trunc_list_print = trunc_list_depth and (
list_depth > MAX_TREE_DEPTH and len(children) >= MAX_TREE_DEPTH
)

if trunc_list_print:
if not printed_list_trunc:
self.p.text(self.charset.trunc)
self.p.text("\n")
printed_list_trunc = True

continue

self.p.text(self.state.prefix)

if len(children) == 0:
Expand All @@ -211,12 +230,10 @@ def print(self, node, cycle=None):
len(self.charset.skip) + len(self.charset.dash) + 1
)

if self.state.depth > 0 and self.state.last:
is_last_leaf_child = True
elif self.state.depth == 0:
is_last_leaf_child = True
else:
is_last_leaf_child = False
is_last_leaf_child = (
self.state.depth > 0 and self.state.last or self.state.depth == 0
)

else:
self.p.text(self.charset.mid)
child_prefix += self.charset.skip + " " * (len(self.charset.dash) + 1)
Expand Down
112 changes: 110 additions & 2 deletions tests/test_lattice_pprint.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from bloqade import start, cast
from bloqade.ir.location import Square, Rectangular
from bloqade.ir.location import Square, Rectangular, Chain
import random
import numpy as np
import os
import bloqade.ir.tree_print as trp


trp.color_enabled = False

trp.MAX_TREE_DEPTH = 10

PROJECT_RELATIVE_PPRINT_TESTS_OUTPUT_PATH = os.path.join(
os.getcwd(), "tests/data/expected_pprint_output"
Expand Down Expand Up @@ -58,6 +58,114 @@ def test_list_of_locations_pprint():
)


def test_list_trun():
geo = Chain(20).add_position((-1, "a"))

assert str(geo) == (
"AtomArrangement\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 0\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 1\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 2\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 3\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 4\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 5\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 6\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 7\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 8\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 9\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"⋮\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 11\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 12\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 13\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 14\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 15\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 16\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 17\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 18\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"├─ Location: filled\n"
"│ ├─ x\n"
"│ │ ⇒ Literal: 19\n"
"│ └─ y\n"
"│ ⇒ Literal: 0\n"
"└─ Location: filled\n "
"├─ x\n │ ⇒ Literal: -1"
"\n └─ y"
"\n ⇒ Variable: a"
)


def test_square_pprint():
# full
square_pprint_output_path = os.path.join(
Expand Down

0 comments on commit a455915

Please sign in to comment.