Skip to content

Commit 3160adc

Browse files
committed
Add debug log statements to optimizer to see why nodes are fused
1 parent 7056a72 commit 3160adc

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

cubed/core/optimization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
import networkx as nx
24

35
from cubed.primitive.blockwise import (
@@ -7,6 +9,8 @@
79
fuse_multiple,
810
)
911

12+
logger = logging.getLogger(__name__)
13+
1014

1115
def simple_optimize_dag(dag):
1216
"""Apply map blocks fusion."""
@@ -108,16 +112,20 @@ def can_fuse_predecessors(
108112

109113
# if node itself can't be fused then there is nothing to fuse
110114
if not is_fusable(nodes[name]):
115+
logger.debug("can't fuse %s since it is not fusable", name)
111116
return False
112117

113118
# if no predecessor ops can be fused then there is nothing to fuse
114119
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
120+
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
115121
return False
116122

117123
# if node is in never_fuse or always_fuse list then it overrides logic below
118124
if never_fuse is not None and name in never_fuse:
125+
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
119126
return False
120127
if always_fuse is not None and name in always_fuse:
128+
logger.debug("can fuse %s since it is in 'always_fuse'", name)
121129
return True
122130

123131
# if there is more than a single predecessor op, and the total number of source arrays to
@@ -128,6 +136,12 @@ def can_fuse_predecessors(
128136
for pre in predecessor_ops(dag, name)
129137
)
130138
if total_source_arrays > max_total_source_arrays:
139+
logger.debug(
140+
"can't fuse %s since total number of source arrays (%s) exceeds max (%s)",
141+
name,
142+
total_source_arrays,
143+
max_total_source_arrays,
144+
)
131145
return False
132146

133147
predecessor_primitive_ops = [
@@ -136,6 +150,7 @@ def can_fuse_predecessors(
136150
if is_fusable(nodes[pre])
137151
]
138152
return can_fuse_multiple_primitive_ops(
153+
name,
139154
nodes[name]["primitive_op"],
140155
predecessor_primitive_ops,
141156
max_total_num_input_blocks=max_total_num_input_blocks,
@@ -219,6 +234,8 @@ def multiple_inputs_optimize_dag(
219234
):
220235
"""Fuse multiple inputs."""
221236
for name in list(nx.topological_sort(dag)):
237+
if name.startswith("array-"):
238+
continue
222239
dag = fuse_predecessors(
223240
dag,
224241
name,

cubed/primitive/blockwise.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import logging
23
import math
34
from collections.abc import Iterator
45
from dataclasses import dataclass
@@ -23,6 +24,9 @@
2324

2425
from .types import CubedArrayProxy, MemoryModeller, PrimitiveOperation
2526

27+
logger = logging.getLogger(__name__)
28+
29+
2630
sym_counter = 0
2731

2832

@@ -352,6 +356,7 @@ def can_fuse_primitive_ops(
352356

353357

354358
def can_fuse_multiple_primitive_ops(
359+
name: str,
355360
primitive_op: PrimitiveOperation,
356361
predecessor_primitive_ops: List[PrimitiveOperation],
357362
*,
@@ -362,27 +367,67 @@ def can_fuse_multiple_primitive_ops(
362367
):
363368
# If the peak projected memory for running all the predecessor ops in
364369
# order is larger than allowed_mem then we can't fuse.
365-
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
370+
peak_projected = peak_projected_mem(predecessor_primitive_ops)
371+
if peak_projected > primitive_op.allowed_mem:
372+
logger.debug(
373+
"can't fuse %s since peak projected memory for predecessor ops (%s) is greater than allowed (%s)",
374+
name,
375+
peak_projected,
376+
primitive_op.allowed_mem,
377+
)
366378
return False
367379
# If the number of input blocks for each input is not uniform, then we
368380
# can't fuse. (This should never happen since all operations are
369381
# currently uniform, and fused operations are too if fuse is applied in
370382
# topological order.)
371383
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
372384
if not all(num_input_blocks[0] == n for n in num_input_blocks):
385+
logger.debug(
386+
"can't fuse %s since number of input blocks for each input is not uniform: %s",
387+
name,
388+
num_input_blocks,
389+
)
373390
return False
374391
if max_total_num_input_blocks is None:
375392
# If max total input blocks not specified, then only fuse if num
376393
# tasks of predecessor ops match.
377-
return all(
394+
ret = all(
378395
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
379396
)
397+
if ret:
398+
logger.debug(
399+
"can fuse %s since num tasks of predecessor ops match", name
400+
)
401+
else:
402+
logger.debug(
403+
"can't fuse %s since num tasks of predecessor ops do not match",
404+
name,
405+
)
406+
return ret
380407
else:
381408
total_num_input_blocks = 0
382409
for ni, p in zip(num_input_blocks, predecessor_primitive_ops):
383410
for nj in p.pipeline.config.num_input_blocks:
384411
total_num_input_blocks += ni * nj
385-
return total_num_input_blocks <= max_total_num_input_blocks
412+
ret = total_num_input_blocks <= max_total_num_input_blocks
413+
if ret:
414+
logger.debug(
415+
"can fuse %s since total number of input blocks (%s) does not exceed max (%s)",
416+
name,
417+
total_num_input_blocks,
418+
max_total_num_input_blocks,
419+
)
420+
else:
421+
logger.debug(
422+
"can't fuse %s since total number of input blocks (%s) exceeds max (%s)",
423+
name,
424+
total_num_input_blocks,
425+
max_total_num_input_blocks,
426+
)
427+
return ret
428+
logger.debug(
429+
"can't fuse %s since primitive op and predecessors are not all candidates", name
430+
)
386431
return False
387432

388433

0 commit comments

Comments
 (0)