@@ -46,6 +46,7 @@ def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
46
46
47
47
def collect_specs_from_graph_module (
48
48
graph_module : torch .fx .GraphModule ,
49
+ graph_signature : ExportGraphSignature ,
49
50
alloc_graph_input : bool ,
50
51
alloc_graph_output : bool ,
51
52
) -> Iterable [TensorSpec ]:
@@ -56,6 +57,7 @@ def collect_specs_from_graph_module(
56
57
# Collect the specs from all the nodes in the graph module, and return it
57
58
return collect_specs_from_nodes (
58
59
graph_module .graph .nodes ,
60
+ graph_signature ,
59
61
ignore_graph_input = not alloc_graph_input ,
60
62
ignore_graph_output = not alloc_graph_output ,
61
63
)
@@ -107,7 +109,7 @@ def memory_available(spec: TensorSpec) -> bool:
107
109
# Iterate over all the specs in sorted order
108
110
for spec in sorted (
109
111
collect_specs_from_graph_module (
110
- graph_module , alloc_graph_input , alloc_graph_output
112
+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
111
113
),
112
114
key = lambda spec : spec .allocated_memory ,
113
115
reverse = True ,
@@ -182,7 +184,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
182
184
# Iterate over all the specs in sorted order
183
185
for spec in sorted (
184
186
collect_specs_from_graph_module (
185
- graph_module , alloc_graph_input , alloc_graph_output
187
+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
186
188
),
187
189
key = lambda spec : spec .allocated_memory ,
188
190
reverse = True ,
@@ -250,6 +252,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
250
252
251
253
def find_peak_memory_usages_per_memory (
252
254
graph_module : torch .fx .GraphModule ,
255
+ graph_signature : ExportGraphSignature ,
253
256
alloc_graph_input : bool ,
254
257
alloc_graph_output : bool ,
255
258
mem_constraints : Optional [MemConstraints ] = None ,
@@ -265,7 +268,7 @@ def find_peak_memory_usages_per_memory(
265
268
266
269
# go through all nodes in the graph, collect memory usage per spec.mem_id
267
270
for spec in collect_specs_from_graph_module (
268
- graph_module , alloc_graph_input , alloc_graph_output
271
+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
269
272
):
270
273
if mem_constraints is not None and mem_constraints .skipped_spec (spec ):
271
274
continue
@@ -288,6 +291,7 @@ def find_peak_memory_usages_per_memory(
288
291
289
292
def find_peak_memory_usage (
290
293
graph_module : torch .fx .GraphModule ,
294
+ graph_signature : ExportGraphSignature ,
291
295
alloc_graph_input : bool ,
292
296
alloc_graph_output : bool ,
293
297
mem_constraints : Optional [MemConstraints ] = None ,
@@ -303,7 +307,7 @@ def find_peak_memory_usage(
303
307
304
308
# Iterate over all the node specs
305
309
for spec in collect_specs_from_graph_module (
306
- graph_module , alloc_graph_input , alloc_graph_output
310
+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
307
311
):
308
312
if spec .lifetime [0 ] is None or (
309
313
mem_constraints is not None and mem_constraints .skipped_spec (spec )
@@ -358,6 +362,7 @@ def print_memory_planning_info(
358
362
# Get the peak memory usages per memory space
359
363
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory (
360
364
executorch_prog .exported_program ().graph_module ,
365
+ executorch_prog .exported_program ().graph_signature ,
361
366
alloc_graph_input ,
362
367
alloc_graph_output ,
363
368
mem_constraints ,
@@ -393,6 +398,7 @@ def print_memory_planning_info(
393
398
# Get the total peak memory usage across all memory spaces
394
399
total_peak_memory_usage = find_peak_memory_usage (
395
400
executorch_prog .exported_program ().graph_module ,
401
+ executorch_prog .exported_program ().graph_signature ,
396
402
alloc_graph_input ,
397
403
alloc_graph_output ,
398
404
mem_constraints ,
@@ -453,7 +459,17 @@ def _init_mem_algos(self) -> None:
453
459
greedy_by_size_for_offset_calculation_with_hierarchy ,
454
460
]
455
461
456
- def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
462
+ def __call__ (
463
+ self ,
464
+ graph_module : torch .fx .GraphModule ,
465
+ ) -> PassResult :
466
+ return self .run (graph_module )
467
+
468
+ def run (
469
+ self ,
470
+ graph_module : torch .fx .GraphModule ,
471
+ graph_signature : Optional [ExportGraphSignature ] = None ,
472
+ ) -> PassResult :
457
473
mem_constraints = MemConstraints (
458
474
opt_level = self .opt_level ,
459
475
alloc_graph_input = self .alloc_graph_input ,
@@ -475,6 +491,6 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
475
491
alloc_graph_output = self .alloc_graph_output ,
476
492
alignment = self .mem_alignment ,
477
493
)
478
- mem_planning (graph_module )
494
+ mem_planning . run (graph_module , graph_signature )
479
495
480
496
return PassResult (graph_module , True )
0 commit comments