14
14
from executorch .backends .cadence .aot .pass_utils import count_node
15
15
from executorch .exir import memory
16
16
from executorch .exir .dialects ._ops import ops as exir_ops
17
+ from executorch .exir .memory_planning import collect_specs_from_nodes
17
18
from executorch .exir .tests .models import MultiLayerPerceptron
18
19
19
20
20
21
class TestMemPlanningPasses (unittest .TestCase ):
21
- def test_calculate_peak_memory_pass (self ):
22
+ def test_calculate_peak_memory_pass (self ) -> None :
22
23
class PeakMemoryTestModel (torch .nn .Module ):
23
24
def __init__ (self , input_dim : int , hidden_dim : int , output_dim : int ):
24
25
super ().__init__ ()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
32
33
x = self .linear2 (x )
33
34
return x
34
35
35
- def calculate_aligned_num_bytes (num : int , alignment : int = 16 ):
36
+ def calculate_aligned_num_bytes (num : int , alignment : int = 16 ) -> int :
36
37
return math .ceil (num / alignment ) * alignment
37
38
38
39
# model 1
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
86
87
) # Align data on a 16 byte boundary
87
88
self .assertEqual (peak_usage , expected_peak_usage )
88
89
89
- def test_zero_memory_pass (self ):
90
+ def test_zero_memory_pass (self ) -> None :
90
91
class ZeroMem (torch .nn .Module ):
91
92
def forward (self , x ):
92
93
return x [:, 2 ::3 , ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188
189
f"{ spec = } { arg_spec = } " ,
189
190
)
190
191
191
- def verify_nop_memory_alloc (self , graph_module ) :
192
+ def verify_nop_memory_alloc (self , graph_module : torch . fx . GraphModule ) -> None :
192
193
for node in graph_module .graph .find_nodes (
193
194
op = "call_function" , target = torch .ops .aten ._cat_nop .out
194
195
):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204
205
):
205
206
self ._verify_select_nop_memory_alloc (node )
206
207
207
- def test_optimize_cat_on_placeholders (self ):
208
+ def test_optimize_cat_on_placeholders (self ) -> None :
208
209
class Cat (torch .nn .Module ):
209
210
def forward (self , x , y ):
210
211
return torch .ops .aten .cat ((x , y ))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228
229
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
229
230
self .verify_nop_memory_alloc (graph_module )
230
231
231
- def test_optimize_cat_outermost (self ):
232
+ def test_optimize_cat_outermost (self ) -> None :
232
233
class OptimizeCatFeasible1 (torch .nn .Module ):
233
234
def forward (self , x , y ):
234
235
x1 = torch .add (x , 2.4 , 3.1 )
@@ -255,7 +256,7 @@ def forward(self, x, y):
255
256
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
256
257
self .verify_nop_memory_alloc (graph_module )
257
258
258
- def test_optimize_cat_non_outermost (self ):
259
+ def test_optimize_cat_non_outermost (self ) -> None :
259
260
class OptimizeCatFeasible2 (torch .nn .Module ):
260
261
def forward (self , x , y ):
261
262
x1 = torch .add (x , 2.4 , 3.1 )
@@ -282,7 +283,7 @@ def forward(self, x, y):
282
283
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
283
284
self .verify_nop_memory_alloc (graph_module )
284
285
285
- def test_no_optimize_cat_non_outermost (self ):
286
+ def test_no_optimize_cat_non_outermost (self ) -> None :
286
287
class OptimizeCatInfeasible1 (torch .nn .Module ):
287
288
def forward (self , x , y ):
288
289
x1 = torch .add (x , 2.4 , 3.1 )
@@ -308,7 +309,7 @@ def forward(self, x, y):
308
309
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
309
310
self .verify_nop_memory_alloc (graph_module )
310
311
311
- def test_no_optimize_cat_non_outermost1 (self ):
312
+ def test_no_optimize_cat_non_outermost1 (self ) -> None :
312
313
class OptimizeCatInfeasible2 (torch .nn .Module ):
313
314
def forward (self , x , y ):
314
315
x1 = torch .add (x , 2.4 , 3.1 )
@@ -335,7 +336,7 @@ def forward(self, x, y):
335
336
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
336
337
self .verify_nop_memory_alloc (graph_module )
337
338
338
- def test_optimize_cat_with_slice (self ):
339
+ def test_optimize_cat_with_slice (self ) -> None :
339
340
class OptimizeCatSliceFeasible (torch .nn .Module ):
340
341
def forward (self , x ):
341
342
x1 = torch .add (x , 2.4 , 3.1 )
@@ -364,7 +365,7 @@ def forward(self, x):
364
365
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
365
366
self .verify_nop_memory_alloc (graph_module )
366
367
367
- def test_optimize_cat_with_slice_infeasible (self ):
368
+ def test_optimize_cat_with_slice_infeasible (self ) -> None :
368
369
class OptimizeCatSliceInfeasible (torch .nn .Module ):
369
370
def forward (self , x , y ):
370
371
x1 = torch .add (x , 2.4 , 3.1 )
@@ -390,7 +391,7 @@ def forward(self, x, y):
390
391
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
391
392
self .verify_nop_memory_alloc (graph_module )
392
393
393
- def test_optimize_slice_Tensor (self ):
394
+ def test_optimize_slice_Tensor (self ) -> None :
394
395
class SliceTensor (torch .nn .Module ):
395
396
def forward (self , x , y , z ):
396
397
x1 = torch .add (x , 2.4 , 3.1 )
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452
453
)
453
454
self .verify_nop_memory_alloc (graph_module )
454
455
455
- def test_optimize_select_Tensor (self ):
456
+ def test_optimize_select_Tensor (self ) -> None :
456
457
class SelectTensor (torch .nn .Module ):
457
458
def forward (self , x , y , z ):
458
459
x1 = torch .add (x , 2.4 , 3.1 )
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519
520
520
521
# TODO: Test fails due to memory planning
521
522
@unittest .expectedFailure
522
- def test_optimize_cat_with_param (self ):
523
+ def test_optimize_cat_with_param (self ) -> None :
523
524
class CatWithPadding (torch .nn .Module ):
524
525
def __init__ (self , padding_shape ):
525
526
super ().__init__ ()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547
548
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
548
549
self .verify_nop_memory_alloc (graph_module )
549
550
550
- def test_optimize_cat_then_slice_on_mutable_buffer (self ):
551
+ def test_optimize_cat_then_slice_on_mutable_buffer (self ) -> None :
551
552
class CatWithPadding (torch .nn .Module ):
552
553
def __init__ (self , padding_shape ):
553
554
super ().__init__ ()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572
573
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
573
574
self .verify_nop_memory_alloc (graph_module )
574
575
575
- def test_optimize_cat_with_view (self ):
576
+ def test_optimize_cat_with_view (self ) -> None :
576
577
class CatViewFeasible (torch .nn .Module ):
577
578
def forward (self , x , y ):
578
579
x1 = torch .add (x , 2.4 , 3.1 )
@@ -599,7 +600,7 @@ def forward(self, x, y):
599
600
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
600
601
self .verify_nop_memory_alloc (graph_module )
601
602
602
- def test_no_optimize_cat_with_repeated_args (self ):
603
+ def test_no_optimize_cat_with_repeated_args (self ) -> None :
603
604
class CatViewInfeasible (torch .nn .Module ):
604
605
def forward (self , x ):
605
606
x1 = torch .add (x , 2.4 , 3.1 )
@@ -623,7 +624,7 @@ def forward(self, x):
623
624
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
624
625
self .verify_nop_memory_alloc (graph_module )
625
626
626
- def test_no_optimize_cat_with_placeholder (self ):
627
+ def test_no_optimize_cat_with_placeholder (self ) -> None :
627
628
class CatViewInfeasible (torch .nn .Module ):
628
629
def forward (self , x , y ):
629
630
# Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741
742
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
742
743
self .verify_nop_memory_alloc (graph_module )
743
744
744
- def test_view_for_unallocated_output (self ):
745
+ def test_view_for_unallocated_output (self ) -> None :
745
746
class Model (torch .nn .Module ):
746
747
def __init__ (self , padding_shape ):
747
748
super ().__init__ ()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764
765
)
765
766
self .assertEqual (count_node (graph_module , memory .view ), 1 )
766
767
self .verify_nop_memory_alloc (graph_module )
768
+
769
+ def test_start_alignment_constraints (self ) -> None :
770
+ class Model (torch .nn .Module ):
771
+ def __init__ (self ):
772
+ super ().__init__ ()
773
+
774
+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
775
+ add_0 = torch .add (x , y )
776
+ add_1 = torch .add (x , add_0 )
777
+ add_2 = torch .add (add_0 , add_1 )
778
+ add_3 = torch .add (add_1 , add_2 )
779
+ return add_3
780
+
781
+ model = Model ()
782
+ inputs = (torch .randn (4 , 17 ), torch .randn (4 , 17 ))
783
+ for mem_algo in range (0 , 2 ):
784
+ graph_module = (
785
+ compiler .export_to_executorch_gen_etrecord (
786
+ model ,
787
+ inputs ,
788
+ opt_level = 1 ,
789
+ mem_algo = mem_algo ,
790
+ alloc_graph_input = False ,
791
+ alloc_graph_output = False ,
792
+ mem_alignment = 37 ,
793
+ )
794
+ .exported_program ()
795
+ .graph_module
796
+ )
797
+ # Assert that all memory allocations are aligned to 32B start address
798
+ for spec in collect_specs_from_nodes (
799
+ graph_module .graph .nodes ,
800
+ ignore_graph_input = True ,
801
+ ignore_graph_output = True ,
802
+ ):
803
+ if spec and spec .mem_offset :
804
+ self .assertEqual (spec .mem_offset % 37 , 0 )
0 commit comments