Skip to content

Commit 54b371f

Browse files
authored
Fix memory_planning API to use run()
Differential Revision: D68939461 Pull Request resolved: #8622
1 parent 5cf0106 commit 54b371f

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

backends/cadence/aot/memory_planning.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
4646

4747
def collect_specs_from_graph_module(
4848
graph_module: torch.fx.GraphModule,
49+
graph_signature: ExportGraphSignature,
4950
alloc_graph_input: bool,
5051
alloc_graph_output: bool,
5152
) -> Iterable[TensorSpec]:
@@ -56,6 +57,7 @@ def collect_specs_from_graph_module(
5657
# Collect the specs from all the nodes in the graph module, and return it
5758
return collect_specs_from_nodes(
5859
graph_module.graph.nodes,
60+
graph_signature,
5961
ignore_graph_input=not alloc_graph_input,
6062
ignore_graph_output=not alloc_graph_output,
6163
)
@@ -107,7 +109,7 @@ def memory_available(spec: TensorSpec) -> bool:
107109
# Iterate over all the specs in sorted order
108110
for spec in sorted(
109111
collect_specs_from_graph_module(
110-
graph_module, alloc_graph_input, alloc_graph_output
112+
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
111113
),
112114
key=lambda spec: spec.allocated_memory,
113115
reverse=True,
@@ -182,7 +184,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
182184
# Iterate over all the specs in sorted order
183185
for spec in sorted(
184186
collect_specs_from_graph_module(
185-
graph_module, alloc_graph_input, alloc_graph_output
187+
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
186188
),
187189
key=lambda spec: spec.allocated_memory,
188190
reverse=True,
@@ -250,6 +252,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
250252

251253
def find_peak_memory_usages_per_memory(
252254
graph_module: torch.fx.GraphModule,
255+
graph_signature: ExportGraphSignature,
253256
alloc_graph_input: bool,
254257
alloc_graph_output: bool,
255258
mem_constraints: Optional[MemConstraints] = None,
@@ -265,7 +268,7 @@ def find_peak_memory_usages_per_memory(
265268

266269
# go through all nodes in the graph, collect memory usage per spec.mem_id
267270
for spec in collect_specs_from_graph_module(
268-
graph_module, alloc_graph_input, alloc_graph_output
271+
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
269272
):
270273
if mem_constraints is not None and mem_constraints.skipped_spec(spec):
271274
continue
@@ -288,6 +291,7 @@ def find_peak_memory_usages_per_memory(
288291

289292
def find_peak_memory_usage(
290293
graph_module: torch.fx.GraphModule,
294+
graph_signature: ExportGraphSignature,
291295
alloc_graph_input: bool,
292296
alloc_graph_output: bool,
293297
mem_constraints: Optional[MemConstraints] = None,
@@ -303,7 +307,7 @@ def find_peak_memory_usage(
303307

304308
# Iterate over all the node specs
305309
for spec in collect_specs_from_graph_module(
306-
graph_module, alloc_graph_input, alloc_graph_output
310+
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
307311
):
308312
if spec.lifetime[0] is None or (
309313
mem_constraints is not None and mem_constraints.skipped_spec(spec)
@@ -358,6 +362,7 @@ def print_memory_planning_info(
358362
# Get the peak memory usages per memory space
359363
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
360364
executorch_prog.exported_program().graph_module,
365+
executorch_prog.exported_program().graph_signature,
361366
alloc_graph_input,
362367
alloc_graph_output,
363368
mem_constraints,
@@ -393,6 +398,7 @@ def print_memory_planning_info(
393398
# Get the total peak memory usage across all memory spaces
394399
total_peak_memory_usage = find_peak_memory_usage(
395400
executorch_prog.exported_program().graph_module,
401+
executorch_prog.exported_program().graph_signature,
396402
alloc_graph_input,
397403
alloc_graph_output,
398404
mem_constraints,
@@ -453,7 +459,17 @@ def _init_mem_algos(self) -> None:
453459
greedy_by_size_for_offset_calculation_with_hierarchy,
454460
]
455461

456-
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
462+
def __call__(
463+
self,
464+
graph_module: torch.fx.GraphModule,
465+
) -> PassResult:
466+
return self.run(graph_module)
467+
468+
def run(
469+
self,
470+
graph_module: torch.fx.GraphModule,
471+
graph_signature: Optional[ExportGraphSignature] = None,
472+
) -> PassResult:
457473
mem_constraints = MemConstraints(
458474
opt_level=self.opt_level,
459475
alloc_graph_input=self.alloc_graph_input,
@@ -475,6 +491,6 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
475491
alloc_graph_output=self.alloc_graph_output,
476492
alignment=self.mem_alignment,
477493
)
478-
mem_planning(graph_module)
494+
mem_planning.run(graph_module, graph_signature)
479495

480496
return PassResult(graph_module, True)

backends/cadence/aot/tests/test_memory_passes.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
4646
inputs = (torch.ones(batch_size, input_dim),)
4747
model = PeakMemoryTestModel(input_dim, hidden_dim, output_dim)
4848

49-
graph_module = (
50-
compiler.export_to_executorch_gen_etrecord(model, inputs)
51-
.exported_program()
52-
.graph_module
53-
)
49+
exported_program = compiler.export_to_executorch_gen_etrecord(
50+
model, inputs
51+
).exported_program()
5452

5553
peak_usage, _ = find_peak_memory_usage(
56-
graph_module,
54+
exported_program.graph_module,
55+
exported_program.graph_signature,
5756
mem_constraints=None,
5857
alloc_graph_input=True,
5958
alloc_graph_output=True,
@@ -73,14 +72,13 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
7372
input_dim, hidden_dim, hidden_dim, hidden_dim, output_dim
7473
)
7574

76-
graph_module = (
77-
compiler.export_to_executorch_gen_etrecord(model, inputs)
78-
.exported_program()
79-
.graph_module
80-
)
75+
exported_program = compiler.export_to_executorch_gen_etrecord(
76+
model, inputs
77+
).exported_program()
8178

8279
peak_usage, _ = find_peak_memory_usage(
83-
graph_module,
80+
exported_program.graph_module,
81+
exported_program.graph_signature,
8482
mem_constraints=None,
8583
alloc_graph_input=True,
8684
alloc_graph_output=True,
@@ -111,6 +109,7 @@ def forward(self, x):
111109
graph_module.graph.eliminate_dead_code()
112110
peak_usage, _ = find_peak_memory_usage(
113111
graph_module,
112+
executorch_prog.exported_program().graph_signature,
114113
alloc_graph_input=False,
115114
alloc_graph_output=False,
116115
mem_constraints=None,

0 commit comments

Comments
 (0)