1
1
import itertools
2
+ import logging
2
3
import math
3
4
from collections .abc import Iterator
4
5
from dataclasses import dataclass
23
24
24
25
from .types import CubedArrayProxy , MemoryModeller , PrimitiveOperation
25
26
27
+ logger = logging .getLogger (__name__ )
28
+
29
+
26
30
sym_counter = 0
27
31
28
32
@@ -352,6 +356,7 @@ def can_fuse_primitive_ops(
352
356
353
357
354
358
def can_fuse_multiple_primitive_ops (
359
+ name : str ,
355
360
primitive_op : PrimitiveOperation ,
356
361
predecessor_primitive_ops : List [PrimitiveOperation ],
357
362
* ,
@@ -362,27 +367,67 @@ def can_fuse_multiple_primitive_ops(
362
367
):
363
368
# If the peak projected memory for running all the predecessor ops in
364
369
# order is larger than allowed_mem then we can't fuse.
365
- if peak_projected_mem (predecessor_primitive_ops ) > primitive_op .allowed_mem :
370
+ peak_projected = peak_projected_mem (predecessor_primitive_ops )
371
+ if peak_projected > primitive_op .allowed_mem :
372
+ logger .debug (
373
+ "can't fuse %s since peak projected memory for predecessor ops (%s) is greater than allowed (%s)" ,
374
+ name ,
375
+ peak_projected ,
376
+ primitive_op .allowed_mem ,
377
+ )
366
378
return False
367
379
# If the number of input blocks for each input is not uniform, then we
368
380
# can't fuse. (This should never happen since all operations are
369
381
# currently uniform, and fused operations are too if fuse is applied in
370
382
# topological order.)
371
383
num_input_blocks = primitive_op .pipeline .config .num_input_blocks
372
384
if not all (num_input_blocks [0 ] == n for n in num_input_blocks ):
385
+ logger .debug (
386
+ "can't fuse %s since number of input blocks for each input is not uniform: %s" ,
387
+ name ,
388
+ num_input_blocks ,
389
+ )
373
390
return False
374
391
if max_total_num_input_blocks is None :
375
392
# If max total input blocks not specified, then only fuse if num
376
393
# tasks of predecessor ops match.
377
- return all (
394
+ ret = all (
378
395
primitive_op .num_tasks == p .num_tasks for p in predecessor_primitive_ops
379
396
)
397
+ if ret :
398
+ logger .debug (
399
+ "can fuse %s since num tasks of predecessor ops match" , name
400
+ )
401
+ else :
402
+ logger .debug (
403
+ "can't fuse %s since num tasks of predecessor ops do not match" ,
404
+ name ,
405
+ )
406
+ return ret
380
407
else :
381
408
total_num_input_blocks = 0
382
409
for ni , p in zip (num_input_blocks , predecessor_primitive_ops ):
383
410
for nj in p .pipeline .config .num_input_blocks :
384
411
total_num_input_blocks += ni * nj
385
- return total_num_input_blocks <= max_total_num_input_blocks
412
+ ret = total_num_input_blocks <= max_total_num_input_blocks
413
+ if ret :
414
+ logger .debug (
415
+ "can fuse %s since total number of input blocks (%s) does not exceed max (%s)" ,
416
+ name ,
417
+ total_num_input_blocks ,
418
+ max_total_num_input_blocks ,
419
+ )
420
+ else :
421
+ logger .debug (
422
+ "can't fuse %s since total number of input blocks (%s) exceeds max (%s)" ,
423
+ name ,
424
+ total_num_input_blocks ,
425
+ max_total_num_input_blocks ,
426
+ )
427
+ return ret
428
+ logger .debug (
429
+ "can't fuse %s since primitive op and predecessors are not all candidates" , name
430
+ )
386
431
return False
387
432
388
433
0 commit comments