Skip to content

Commit a455915

Browse files
authored
Reimplementing _repr_pretty_ with tree/plot printer. (#870)
* re=implementing for lattice * Add truncation for large children list. * simply if statement. * adding test. * fixing test. * actually fixing test.
1 parent dd356ab commit a455915

File tree

4 files changed

+161
-27
lines changed

4 files changed

+161
-27
lines changed

src/bloqade/ir/location/bravais.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def cell_vectors(self) -> List[List[float]]:
205205
def cell_atoms(self) -> List[List[float]]:
206206
return [[0, 0]]
207207

208+
def _repr_pretty_(self, p, _):
209+
p.text(str(self))
210+
208211

209212
@dataclass
210213
class Square(BoundedBravais):
@@ -249,15 +252,15 @@ def __init__(
249252
def shape(self) -> Tuple[int, ...]:
250253
return (self.L1, self.L2)
251254

252-
def __repr__(self):
253-
return super().__repr__()
254-
255255
def cell_vectors(self) -> List[List[float]]:
256256
return [[1, 0], [0, 1]]
257257

258258
def cell_atoms(self) -> List[List[float]]:
259259
return [[0, 0]]
260260

261+
def _repr_pretty_(self, p, _):
262+
p.text(str(self))
263+
261264

262265
@dataclass(init=False)
263266
class Rectangular(BoundedBravais):
@@ -378,6 +381,9 @@ def cell_vectors(self) -> List[List[float]]:
378381
def cell_atoms(self) -> List[List[float]]:
379382
return [[0, 0]]
380383

384+
def _repr_pretty_(self, p, _):
385+
p.text(str(self))
386+
381387

382388
@dataclass
383389
class Honeycomb(BoundedBravais):
@@ -428,15 +434,15 @@ def __init__(
428434
def shape(self) -> Tuple[int, ...]:
429435
return (self.L1, self.L2)
430436

431-
def __repr__(self):
432-
return super().__repr__()
433-
434437
def cell_vectors(self) -> List[List[float]]:
435438
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]
436439

437440
def cell_atoms(self) -> List[List[float]]:
438441
return [[0.0, 0.0], [1 / 2, 1 / (2 * np.sqrt(3))]]
439442

443+
def _repr_pretty_(self, p, _):
444+
p.text(str(self))
445+
440446

441447
@dataclass
442448
class Triangular(BoundedBravais):
@@ -484,15 +490,15 @@ def __init__(
484490
def shape(self) -> Tuple[int, ...]:
485491
return (self.L1, self.L2)
486492

487-
def __repr__(self):
488-
return super().__repr__()
489-
490493
def cell_vectors(self) -> List[List[float]]:
491494
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]
492495

493496
def cell_atoms(self) -> List[List[float]]:
494497
return [[0.0, 0.0]]
495498

499+
def _repr_pretty_(self, p, _):
500+
p.text(str(self))
501+
496502

497503
@dataclass
498504
class Lieb(BoundedBravais):
@@ -535,9 +541,6 @@ def __init__(
535541
self.L2 = L2
536542
self.lattice_spacing = cast(lattice_spacing)
537543

538-
def __repr__(self):
539-
return super().__repr__()
540-
541544
def cell_vectors(self) -> List[List[float]]:
542545
return [[1.0, 0.0], [0.0, 1.0]]
543546

@@ -548,6 +551,9 @@ def cell_atoms(self) -> List[List[float]]:
548551
def shape(self) -> Tuple[int, ...]:
549552
return (self.L1, self.L2)
550553

554+
def _repr_pretty_(self, p, _):
555+
p.text(str(self))
556+
551557

552558
@dataclass
553559
class Kagome(BoundedBravais):
@@ -596,11 +602,11 @@ def __init__(
596602
def shape(self) -> Tuple[int, ...]:
597603
return (self.L1, self.L2)
598604

599-
def __repr__(self):
600-
return super().__repr__()
601-
602605
def cell_vectors(self) -> List[List[float]]:
603606
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]
604607

605608
def cell_atoms(self) -> List[List[float]]:
606609
return [[0.0, 0.0], [1 / 2, 0], [1 / 4, np.sqrt(3) / 4]]
610+
611+
def _repr_pretty_(self, p, _):
612+
p.text(str(self))

src/bloqade/ir/location/location.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,6 @@ def is_literal(x):
120120
ph.print(self)
121121
return ph.get_value()
122122

123-
def _repr_pretty_(self, p, cycle):
124-
Printer(p).print(self, cycle)
125-
126123
def print_node(self) -> str:
127124
return "AtomArrangement"
128125

@@ -518,6 +515,9 @@ def n_dims(self):
518515
def __str__(self):
519516
return "ParallelRegister:\n" + self.atom_arrangement.__str__()
520517

518+
def _repr_pretty_(self, p, _):
519+
p.text(str(self))
520+
521521
def _compile_to_list(
522522
self, __capabilities: Optional[QuEraCapabilities] = None, **assignments
523523
):
@@ -626,6 +626,9 @@ def __init__(
626626

627627
super().__init__()
628628

629+
def _repr_pretty_(self, p, _):
630+
p.text(str(self))
631+
629632
@property
630633
def n_atoms(self):
631634
return self.__n_atoms

src/bloqade/ir/tree_print.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,33 @@ def print(self, node, cycle=None):
195195
else:
196196
children = list(children)
197197

198-
while not len(children) == 0:
198+
list_depth = 0
199+
trunc_list_print = False
200+
printed_list_trunc = False
201+
202+
trunc_list_depth = len(children) > 2 * MAX_TREE_DEPTH
203+
204+
while children:
199205
child_prefix = self.state.prefix
200206
if this_print_annotation:
201207
annotation, child = children.pop(0)
202208
else:
203209
child = children.pop(0)
204210
annotation = None
205211

212+
list_depth += 1
213+
trunc_list_print = trunc_list_depth and (
214+
list_depth > MAX_TREE_DEPTH and len(children) >= MAX_TREE_DEPTH
215+
)
216+
217+
if trunc_list_print:
218+
if not printed_list_trunc:
219+
self.p.text(self.charset.trunc)
220+
self.p.text("\n")
221+
printed_list_trunc = True
222+
223+
continue
224+
206225
self.p.text(self.state.prefix)
207226

208227
if len(children) == 0:
@@ -211,12 +230,10 @@ def print(self, node, cycle=None):
211230
len(self.charset.skip) + len(self.charset.dash) + 1
212231
)
213232

214-
if self.state.depth > 0 and self.state.last:
215-
is_last_leaf_child = True
216-
elif self.state.depth == 0:
217-
is_last_leaf_child = True
218-
else:
219-
is_last_leaf_child = False
233+
is_last_leaf_child = (
234+
self.state.depth > 0 and self.state.last or self.state.depth == 0
235+
)
236+
220237
else:
221238
self.p.text(self.charset.mid)
222239
child_prefix += self.charset.skip + " " * (len(self.charset.dash) + 1)

tests/test_lattice_pprint.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from bloqade import start, cast
2-
from bloqade.ir.location import Square, Rectangular
2+
from bloqade.ir.location import Square, Rectangular, Chain
33
import random
44
import numpy as np
55
import os
66
import bloqade.ir.tree_print as trp
77

88

99
trp.color_enabled = False
10-
10+
trp.MAX_TREE_DEPTH = 10
1111

1212
PROJECT_RELATIVE_PPRINT_TESTS_OUTPUT_PATH = os.path.join(
1313
os.getcwd(), "tests/data/expected_pprint_output"
@@ -58,6 +58,114 @@ def test_list_of_locations_pprint():
5858
)
5959

6060

61+
def test_list_trun():
62+
geo = Chain(20).add_position((-1, "a"))
63+
64+
assert str(geo) == (
65+
"AtomArrangement\n"
66+
"├─ Location: filled\n"
67+
"│ ├─ x\n"
68+
"│ │ ⇒ Literal: 0\n"
69+
"│ └─ y\n"
70+
"│ ⇒ Literal: 0\n"
71+
"├─ Location: filled\n"
72+
"│ ├─ x\n"
73+
"│ │ ⇒ Literal: 1\n"
74+
"│ └─ y\n"
75+
"│ ⇒ Literal: 0\n"
76+
"├─ Location: filled\n"
77+
"│ ├─ x\n"
78+
"│ │ ⇒ Literal: 2\n"
79+
"│ └─ y\n"
80+
"│ ⇒ Literal: 0\n"
81+
"├─ Location: filled\n"
82+
"│ ├─ x\n"
83+
"│ │ ⇒ Literal: 3\n"
84+
"│ └─ y\n"
85+
"│ ⇒ Literal: 0\n"
86+
"├─ Location: filled\n"
87+
"│ ├─ x\n"
88+
"│ │ ⇒ Literal: 4\n"
89+
"│ └─ y\n"
90+
"│ ⇒ Literal: 0\n"
91+
"├─ Location: filled\n"
92+
"│ ├─ x\n"
93+
"│ │ ⇒ Literal: 5\n"
94+
"│ └─ y\n"
95+
"│ ⇒ Literal: 0\n"
96+
"├─ Location: filled\n"
97+
"│ ├─ x\n"
98+
"│ │ ⇒ Literal: 6\n"
99+
"│ └─ y\n"
100+
"│ ⇒ Literal: 0\n"
101+
"├─ Location: filled\n"
102+
"│ ├─ x\n"
103+
"│ │ ⇒ Literal: 7\n"
104+
"│ └─ y\n"
105+
"│ ⇒ Literal: 0\n"
106+
"├─ Location: filled\n"
107+
"│ ├─ x\n"
108+
"│ │ ⇒ Literal: 8\n"
109+
"│ └─ y\n"
110+
"│ ⇒ Literal: 0\n"
111+
"├─ Location: filled\n"
112+
"│ ├─ x\n"
113+
"│ │ ⇒ Literal: 9\n"
114+
"│ └─ y\n"
115+
"│ ⇒ Literal: 0\n"
116+
"⋮\n"
117+
"├─ Location: filled\n"
118+
"│ ├─ x\n"
119+
"│ │ ⇒ Literal: 11\n"
120+
"│ └─ y\n"
121+
"│ ⇒ Literal: 0\n"
122+
"├─ Location: filled\n"
123+
"│ ├─ x\n"
124+
"│ │ ⇒ Literal: 12\n"
125+
"│ └─ y\n"
126+
"│ ⇒ Literal: 0\n"
127+
"├─ Location: filled\n"
128+
"│ ├─ x\n"
129+
"│ │ ⇒ Literal: 13\n"
130+
"│ └─ y\n"
131+
"│ ⇒ Literal: 0\n"
132+
"├─ Location: filled\n"
133+
"│ ├─ x\n"
134+
"│ │ ⇒ Literal: 14\n"
135+
"│ └─ y\n"
136+
"│ ⇒ Literal: 0\n"
137+
"├─ Location: filled\n"
138+
"│ ├─ x\n"
139+
"│ │ ⇒ Literal: 15\n"
140+
"│ └─ y\n"
141+
"│ ⇒ Literal: 0\n"
142+
"├─ Location: filled\n"
143+
"│ ├─ x\n"
144+
"│ │ ⇒ Literal: 16\n"
145+
"│ └─ y\n"
146+
"│ ⇒ Literal: 0\n"
147+
"├─ Location: filled\n"
148+
"│ ├─ x\n"
149+
"│ │ ⇒ Literal: 17\n"
150+
"│ └─ y\n"
151+
"│ ⇒ Literal: 0\n"
152+
"├─ Location: filled\n"
153+
"│ ├─ x\n"
154+
"│ │ ⇒ Literal: 18\n"
155+
"│ └─ y\n"
156+
"│ ⇒ Literal: 0\n"
157+
"├─ Location: filled\n"
158+
"│ ├─ x\n"
159+
"│ │ ⇒ Literal: 19\n"
160+
"│ └─ y\n"
161+
"│ ⇒ Literal: 0\n"
162+
"└─ Location: filled\n "
163+
"├─ x\n │ ⇒ Literal: -1"
164+
"\n └─ y"
165+
"\n ⇒ Variable: a"
166+
)
167+
168+
61169
def test_square_pprint():
62170
# full
63171
square_pprint_output_path = os.path.join(

0 commit comments

Comments
 (0)