9
9
import zarr
10
10
11
11
from cubed .core .optimization import simple_optimize_dag
12
+ from cubed .primitive .blockwise import BlockwiseSpec
12
13
from cubed .primitive .types import PrimitiveOperation
13
14
from cubed .runtime .pipeline import visit_nodes
14
15
from cubed .runtime .types import ComputeEndEvent , ComputeStartEvent , CubedPipeline
@@ -72,6 +73,9 @@ def _new(
72
73
frame = inspect .currentframe ().f_back # go back one in the stack
73
74
stack_summaries = extract_stack_summaries (frame , limit = 10 )
74
75
76
+ first_cubed_i = min (i for i , s in enumerate (stack_summaries ) if s .is_cubed ())
77
+ first_cubed_summary = stack_summaries [first_cubed_i ]
78
+
75
79
op_name_unique = gensym ()
76
80
77
81
if primitive_op is None :
@@ -82,6 +86,7 @@ def _new(
82
86
op_name = op_name ,
83
87
type = "op" ,
84
88
stack_summaries = stack_summaries ,
89
+ op_display_name = f"{ op_name_unique } \n { first_cubed_summary .name } " ,
85
90
hidden = hidden ,
86
91
)
87
92
# array (when multiple outputs are supported there could be more than one)
@@ -101,6 +106,7 @@ def _new(
101
106
op_name = op_name ,
102
107
type = "op" ,
103
108
stack_summaries = stack_summaries ,
109
+ op_display_name = f"{ op_name_unique } \n { first_cubed_summary .name } " ,
104
110
hidden = hidden ,
105
111
primitive_op = primitive_op ,
106
112
pipeline = primitive_op .pipeline ,
@@ -160,6 +166,7 @@ def _create_lazy_zarr_arrays(self, dag):
160
166
name = name ,
161
167
op_name = op_name ,
162
168
type = "op" ,
169
+ op_display_name = name ,
163
170
primitive_op = primitive_op ,
164
171
pipeline = primitive_op .pipeline ,
165
172
)
@@ -307,19 +314,17 @@ def visualize(
307
314
tooltip = f"name: { n } \n "
308
315
node_type = d .get ("type" , None )
309
316
if node_type == "op" :
317
+ label = d ["op_display_name" ]
310
318
op_name = d ["op_name" ]
311
319
if op_name == "blockwise" :
312
320
d ["style" ] = '"rounded,filled"'
313
321
d ["fillcolor" ] = "#dcbeff"
314
- op_name_summary = "(bw)"
315
322
elif op_name == "rechunk" :
316
323
d ["style" ] = '"rounded,filled"'
317
324
d ["fillcolor" ] = "#aaffc3"
318
- op_name_summary = "(rc)"
319
325
else :
320
326
# creation function
321
327
d ["style" ] = "rounded"
322
- op_name_summary = ""
323
328
tooltip += f"op: { op_name } "
324
329
325
330
num_tasks = None
@@ -337,7 +342,7 @@ def visualize(
337
342
# remove pipeline attribute since it is a long string that causes graphviz to fail
338
343
if "pipeline" in d :
339
344
pipeline = d ["pipeline" ]
340
- if pipeline .config is not None :
345
+ if isinstance ( pipeline .config , BlockwiseSpec ) :
341
346
tooltip += (
342
347
f"\n num input blocks: { pipeline .config .num_input_blocks } "
343
348
)
@@ -350,11 +355,8 @@ def visualize(
350
355
first_cubed_i = min (
351
356
i for i , s in enumerate (stack_summaries ) if s .is_cubed ()
352
357
)
353
- first_cubed_summary = stack_summaries [first_cubed_i ]
354
358
caller_summary = stack_summaries [first_cubed_i - 1 ]
355
359
356
- label = f"{ first_cubed_summary .name } { op_name_summary } "
357
-
358
360
calls = " -> " .join (
359
361
[
360
362
s .name
@@ -384,7 +386,7 @@ def visualize(
384
386
nbytes = memory_repr (target .nbytes )
385
387
if n in array_display_names :
386
388
var_name = array_display_names [n ]
387
- label = f"{ n } ( { var_name } ) "
389
+ label = f"{ n } \n { var_name } "
388
390
tooltip += f"variable: { var_name } \n "
389
391
tooltip += f"shape: { target .shape } \n "
390
392
tooltip += f"chunks: { target .chunks } \n "
0 commit comments