Skip to content

Commit b1d76c9

Browse files
authored
Add option to enforce alignment constraint when planning memory
Differential Revision: D68762973 Pull Request resolved: #8003
1 parent 4facb18 commit b1d76c9

File tree

3 files changed

+82
-27
lines changed

3 files changed

+82
-27
lines changed

backends/cadence/aot/compiler.py

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord(
264264
alloc_graph_output: bool = True,
265265
memory_config: Optional[MemoryConfig] = None,
266266
dump_graphs: bool = False,
267+
mem_alignment: int = 1,
267268
) -> ExecutorchProgramManager:
268269
cadence_passes = get_cadence_passes(opt_level)
269270
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
@@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord(
290291
mem_algo=mem_algo,
291292
alloc_graph_input=alloc_graph_input,
292293
alloc_graph_output=alloc_graph_output,
294+
mem_alignment=mem_alignment,
293295
)
294296

295297
# Get executorch program after Cadence specific passes

backends/cadence/aot/memory_planning.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12+
import math
1213
import typing
1314
from functools import partial
1415
from typing import Iterable, List, Optional, Tuple
@@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
3940
return memory_config.memory_sizes[exir_id - 1]
4041

4142

43+
def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
44+
return int(math.ceil(pre_aligned_offset / alignment) * alignment)
45+
46+
4247
def collect_specs_from_graph_module(
4348
graph_module: torch.fx.GraphModule,
4449
alloc_graph_input: bool,
@@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95100
return None
96101

97102
def memory_available(spec: TensorSpec) -> bool:
98-
return spec.mem_offset + spec.allocated_memory <= get_size(
99-
memory_config, spec.mem_id
100-
)
103+
return get_aligned_offset(
104+
spec.mem_offset + spec.allocated_memory, alignment
105+
) <= get_size(memory_config, spec.mem_id)
101106

102107
# Iterate over all the specs in sorted order
103108
for spec in sorted(
@@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool:
116121
continue
117122
spec.mem_offset = 0
118123
while memory_available(spec) and (overlapped := overlap(spec)):
119-
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
124+
spec.mem_offset = get_aligned_offset(
125+
overlapped.mem_offset + overlapped.allocated_memory, alignment
126+
)
120127
if memory_available(spec):
121128
allocated_buffers[spec.mem_id].append(spec)
122129
bufsizes[spec.mem_id] = max(
@@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202209
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
203210
# to the naive one, reusing 0 tensor. The paper may have a typo here.
204211
prev_offset = max(
205-
allocated_spec.mem_offset + allocated_spec.allocated_memory,
212+
get_aligned_offset(
213+
allocated_spec.mem_offset + allocated_spec.allocated_memory,
214+
alignment,
215+
),
206216
prev_offset,
207217
)
208218
if spec.mem_offset is None:
209-
if prev_offset + spec.allocated_memory > get_size(
210-
memory_config, spec.mem_id
211-
):
219+
if get_aligned_offset(
220+
prev_offset + spec.allocated_memory, alignment
221+
) > get_size(memory_config, spec.mem_id):
212222
continue
213223
else:
214224
spec.mem_offset = prev_offset
@@ -423,6 +433,7 @@ def __init__(
423433
]
424434
]
425435
] = None,
436+
mem_alignment: int = 1,
426437
) -> None:
427438
self._init_mem_algos()
428439

@@ -433,6 +444,9 @@ def __init__(
433444
self.alloc_graph_output = alloc_graph_output
434445
self.additional_constraint_gen_passes = additional_constraint_gen_passes
435446

447+
assert mem_alignment > 0, "mem_alignment must be positive"
448+
self.mem_alignment = mem_alignment
449+
436450
def _init_mem_algos(self) -> None:
437451
self.available_mem_algos = [
438452
position_based_greedy_with_hierarchy,
@@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459473
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
460474
alloc_graph_input=self.alloc_graph_input,
461475
alloc_graph_output=self.alloc_graph_output,
476+
alignment=self.mem_alignment,
462477
)
463478
mem_planning(graph_module)
464479

backends/cadence/aot/tests/test_memory_passes.py

+57-19
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from executorch.backends.cadence.aot.pass_utils import count_node
1515
from executorch.exir import memory
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.memory_planning import collect_specs_from_nodes
1718
from executorch.exir.tests.models import MultiLayerPerceptron
1819

1920

2021
class TestMemPlanningPasses(unittest.TestCase):
21-
def test_calculate_peak_memory_pass(self):
22+
def test_calculate_peak_memory_pass(self) -> None:
2223
class PeakMemoryTestModel(torch.nn.Module):
2324
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
2425
super().__init__()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
3233
x = self.linear2(x)
3334
return x
3435

35-
def calculate_aligned_num_bytes(num: int, alignment: int = 16):
36+
def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
3637
return math.ceil(num / alignment) * alignment
3738

3839
# model 1
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8687
) # Align data on a 16 byte boundary
8788
self.assertEqual(peak_usage, expected_peak_usage)
8889

89-
def test_zero_memory_pass(self):
90+
def test_zero_memory_pass(self) -> None:
9091
class ZeroMem(torch.nn.Module):
9192
def forward(self, x):
9293
return x[:, 2::3, ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188189
f"{spec=} {arg_spec=}",
189190
)
190191

191-
def verify_nop_memory_alloc(self, graph_module):
192+
def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
192193
for node in graph_module.graph.find_nodes(
193194
op="call_function", target=torch.ops.aten._cat_nop.out
194195
):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204205
):
205206
self._verify_select_nop_memory_alloc(node)
206207

207-
def test_optimize_cat_on_placeholders(self):
208+
def test_optimize_cat_on_placeholders(self) -> None:
208209
class Cat(torch.nn.Module):
209210
def forward(self, x, y):
210211
return torch.ops.aten.cat((x, y))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228229
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
229230
self.verify_nop_memory_alloc(graph_module)
230231

231-
def test_optimize_cat_outermost(self):
232+
def test_optimize_cat_outermost(self) -> None:
232233
class OptimizeCatFeasible1(torch.nn.Module):
233234
def forward(self, x, y):
234235
x1 = torch.add(x, 2.4, 3.1)
@@ -255,7 +256,7 @@ def forward(self, x, y):
255256
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
256257
self.verify_nop_memory_alloc(graph_module)
257258

258-
def test_optimize_cat_non_outermost(self):
259+
def test_optimize_cat_non_outermost(self) -> None:
259260
class OptimizeCatFeasible2(torch.nn.Module):
260261
def forward(self, x, y):
261262
x1 = torch.add(x, 2.4, 3.1)
@@ -282,7 +283,7 @@ def forward(self, x, y):
282283
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
283284
self.verify_nop_memory_alloc(graph_module)
284285

285-
def test_no_optimize_cat_non_outermost(self):
286+
def test_no_optimize_cat_non_outermost(self) -> None:
286287
class OptimizeCatInfeasible1(torch.nn.Module):
287288
def forward(self, x, y):
288289
x1 = torch.add(x, 2.4, 3.1)
@@ -308,7 +309,7 @@ def forward(self, x, y):
308309
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
309310
self.verify_nop_memory_alloc(graph_module)
310311

311-
def test_no_optimize_cat_non_outermost1(self):
312+
def test_no_optimize_cat_non_outermost1(self) -> None:
312313
class OptimizeCatInfeasible2(torch.nn.Module):
313314
def forward(self, x, y):
314315
x1 = torch.add(x, 2.4, 3.1)
@@ -335,7 +336,7 @@ def forward(self, x, y):
335336
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
336337
self.verify_nop_memory_alloc(graph_module)
337338

338-
def test_optimize_cat_with_slice(self):
339+
def test_optimize_cat_with_slice(self) -> None:
339340
class OptimizeCatSliceFeasible(torch.nn.Module):
340341
def forward(self, x):
341342
x1 = torch.add(x, 2.4, 3.1)
@@ -364,7 +365,7 @@ def forward(self, x):
364365
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
365366
self.verify_nop_memory_alloc(graph_module)
366367

367-
def test_optimize_cat_with_slice_infeasible(self):
368+
def test_optimize_cat_with_slice_infeasible(self) -> None:
368369
class OptimizeCatSliceInfeasible(torch.nn.Module):
369370
def forward(self, x, y):
370371
x1 = torch.add(x, 2.4, 3.1)
@@ -390,7 +391,7 @@ def forward(self, x, y):
390391
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
391392
self.verify_nop_memory_alloc(graph_module)
392393

393-
def test_optimize_slice_Tensor(self):
394+
def test_optimize_slice_Tensor(self) -> None:
394395
class SliceTensor(torch.nn.Module):
395396
def forward(self, x, y, z):
396397
x1 = torch.add(x, 2.4, 3.1)
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452453
)
453454
self.verify_nop_memory_alloc(graph_module)
454455

455-
def test_optimize_select_Tensor(self):
456+
def test_optimize_select_Tensor(self) -> None:
456457
class SelectTensor(torch.nn.Module):
457458
def forward(self, x, y, z):
458459
x1 = torch.add(x, 2.4, 3.1)
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519520

520521
# TODO: Test fails due to memory planning
521522
@unittest.expectedFailure
522-
def test_optimize_cat_with_param(self):
523+
def test_optimize_cat_with_param(self) -> None:
523524
class CatWithPadding(torch.nn.Module):
524525
def __init__(self, padding_shape):
525526
super().__init__()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547548
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
548549
self.verify_nop_memory_alloc(graph_module)
549550

550-
def test_optimize_cat_then_slice_on_mutable_buffer(self):
551+
def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None:
551552
class CatWithPadding(torch.nn.Module):
552553
def __init__(self, padding_shape):
553554
super().__init__()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572573
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
573574
self.verify_nop_memory_alloc(graph_module)
574575

575-
def test_optimize_cat_with_view(self):
576+
def test_optimize_cat_with_view(self) -> None:
576577
class CatViewFeasible(torch.nn.Module):
577578
def forward(self, x, y):
578579
x1 = torch.add(x, 2.4, 3.1)
@@ -599,7 +600,7 @@ def forward(self, x, y):
599600
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
600601
self.verify_nop_memory_alloc(graph_module)
601602

602-
def test_no_optimize_cat_with_repeated_args(self):
603+
def test_no_optimize_cat_with_repeated_args(self) -> None:
603604
class CatViewInfeasible(torch.nn.Module):
604605
def forward(self, x):
605606
x1 = torch.add(x, 2.4, 3.1)
@@ -623,7 +624,7 @@ def forward(self, x):
623624
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
624625
self.verify_nop_memory_alloc(graph_module)
625626

626-
def test_no_optimize_cat_with_placeholder(self):
627+
def test_no_optimize_cat_with_placeholder(self) -> None:
627628
class CatViewInfeasible(torch.nn.Module):
628629
def forward(self, x, y):
629630
# Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741742
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
742743
self.verify_nop_memory_alloc(graph_module)
743744

744-
def test_view_for_unallocated_output(self):
745+
def test_view_for_unallocated_output(self) -> None:
745746
class Model(torch.nn.Module):
746747
def __init__(self, padding_shape):
747748
super().__init__()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764765
)
765766
self.assertEqual(count_node(graph_module, memory.view), 1)
766767
self.verify_nop_memory_alloc(graph_module)
768+
769+
def test_start_alignment_constraints(self) -> None:
770+
class Model(torch.nn.Module):
771+
def __init__(self):
772+
super().__init__()
773+
774+
def forward(self, x: torch.Tensor, y: torch.Tensor):
775+
add_0 = torch.add(x, y)
776+
add_1 = torch.add(x, add_0)
777+
add_2 = torch.add(add_0, add_1)
778+
add_3 = torch.add(add_1, add_2)
779+
return add_3
780+
781+
model = Model()
782+
inputs = (torch.randn(4, 17), torch.randn(4, 17))
783+
for mem_algo in range(0, 2):
784+
graph_module = (
785+
compiler.export_to_executorch_gen_etrecord(
786+
model,
787+
inputs,
788+
opt_level=1,
789+
mem_algo=mem_algo,
790+
alloc_graph_input=False,
791+
alloc_graph_output=False,
792+
mem_alignment=37,
793+
)
794+
.exported_program()
795+
.graph_module
796+
)
797+
# Assert that all memory allocations are aligned to 32B start address
798+
for spec in collect_specs_from_nodes(
799+
graph_module.graph.nodes,
800+
ignore_graph_input=True,
801+
ignore_graph_output=True,
802+
):
803+
if spec and spec.mem_offset:
804+
self.assertEqual(spec.mem_offset % 37, 0)

0 commit comments

Comments
 (0)