Skip to content

Commit 4491736

Browse files
committed
Show both operation ID and name (e.g. "op-005 add") in plan visualization and progress bars.
Change plan visualization to always show array or op ID at top.
1 parent 41aac1b commit 4491736

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

cubed/core/plan.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import zarr
1010

1111
from cubed.core.optimization import simple_optimize_dag
12+
from cubed.primitive.blockwise import BlockwiseSpec
1213
from cubed.primitive.types import PrimitiveOperation
1314
from cubed.runtime.pipeline import visit_nodes
1415
from cubed.runtime.types import ComputeEndEvent, ComputeStartEvent, CubedPipeline
@@ -72,6 +73,9 @@ def _new(
7273
frame = inspect.currentframe().f_back # go back one in the stack
7374
stack_summaries = extract_stack_summaries(frame, limit=10)
7475

76+
first_cubed_i = min(i for i, s in enumerate(stack_summaries) if s.is_cubed())
77+
first_cubed_summary = stack_summaries[first_cubed_i]
78+
7579
op_name_unique = gensym()
7680

7781
if primitive_op is None:
@@ -82,6 +86,7 @@ def _new(
8286
op_name=op_name,
8387
type="op",
8488
stack_summaries=stack_summaries,
89+
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
8590
hidden=hidden,
8691
)
8792
# array (when multiple outputs are supported there could be more than one)
@@ -101,6 +106,7 @@ def _new(
101106
op_name=op_name,
102107
type="op",
103108
stack_summaries=stack_summaries,
109+
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
104110
hidden=hidden,
105111
primitive_op=primitive_op,
106112
pipeline=primitive_op.pipeline,
@@ -160,6 +166,7 @@ def _create_lazy_zarr_arrays(self, dag):
160166
name=name,
161167
op_name=op_name,
162168
type="op",
169+
op_display_name=name,
163170
primitive_op=primitive_op,
164171
pipeline=primitive_op.pipeline,
165172
)
@@ -307,19 +314,17 @@ def visualize(
307314
tooltip = f"name: {n}\n"
308315
node_type = d.get("type", None)
309316
if node_type == "op":
317+
label = d["op_display_name"]
310318
op_name = d["op_name"]
311319
if op_name == "blockwise":
312320
d["style"] = '"rounded,filled"'
313321
d["fillcolor"] = "#dcbeff"
314-
op_name_summary = "(bw)"
315322
elif op_name == "rechunk":
316323
d["style"] = '"rounded,filled"'
317324
d["fillcolor"] = "#aaffc3"
318-
op_name_summary = "(rc)"
319325
else:
320326
# creation function
321327
d["style"] = "rounded"
322-
op_name_summary = ""
323328
tooltip += f"op: {op_name}"
324329

325330
num_tasks = None
@@ -337,7 +342,7 @@ def visualize(
337342
# remove pipeline attribute since it is a long string that causes graphviz to fail
338343
if "pipeline" in d:
339344
pipeline = d["pipeline"]
340-
if pipeline.config is not None:
345+
if isinstance(pipeline.config, BlockwiseSpec):
341346
tooltip += (
342347
f"\nnum input blocks: {pipeline.config.num_input_blocks}"
343348
)
@@ -350,11 +355,8 @@ def visualize(
350355
first_cubed_i = min(
351356
i for i, s in enumerate(stack_summaries) if s.is_cubed()
352357
)
353-
first_cubed_summary = stack_summaries[first_cubed_i]
354358
caller_summary = stack_summaries[first_cubed_i - 1]
355359

356-
label = f"{first_cubed_summary.name} {op_name_summary}"
357-
358360
calls = " -> ".join(
359361
[
360362
s.name
@@ -384,7 +386,7 @@ def visualize(
384386
nbytes = memory_repr(target.nbytes)
385387
if n in array_display_names:
386388
var_name = array_display_names[n]
387-
label = f"{n} ({var_name})"
389+
label = f"{n}\n{var_name}"
388390
tooltip += f"variable: {var_name}\n"
389391
tooltip += f"shape: {target.shape}\n"
390392
tooltip += f"chunks: {target.chunks}\n"

cubed/extensions/rich.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def on_compute_start(self, event):
4141
progress_tasks = {}
4242
for name, node in visit_nodes(event.dag, event.resume):
4343
num_tasks = node["primitive_op"].num_tasks
44-
progress_task = progress.add_task(f"{name}", start=False, total=num_tasks)
44+
op_display_name = node["op_display_name"].replace("\n", " ")
45+
progress_task = progress.add_task(
46+
f"{op_display_name}", start=False, total=num_tasks
47+
)
4548
progress_tasks[name] = progress_task
4649

4750
self.logger_aware_progress = logger_aware_progress

cubed/extensions/tqdm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ def on_compute_start(self, event):
2121
i = 0
2222
for name, node in visit_nodes(event.dag, event.resume):
2323
num_tasks = node["primitive_op"].num_tasks
24+
op_display_name = node["op_display_name"].replace("\n", " ")
2425
self.pbars[name] = tqdm(
25-
*self.args, desc=name, total=num_tasks, position=i, **self.kwargs
26+
*self.args,
27+
desc=op_display_name,
28+
total=num_tasks,
29+
position=i,
30+
**self.kwargs,
2631
)
2732
i = i + 1
2833

cubed/tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def test_visualize(tmp_path):
457457
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float64, chunks=(2, 2))
458458
b = cubed.random.random((3, 3), chunks=(2, 2))
459459
c = xp.add(a, b)
460-
d = c * 2
460+
d = c.rechunk((3, 1))
461461
e = c * 3
462462

463463
f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))

0 commit comments

Comments
 (0)