11
11
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12
12
#
13
13
# SPDX-License-Identifier: GPL-3.0-or-later
14
- from typing import Any , Optional , cast
14
+ from typing import Any , Mapping , Optional , Sequence , cast
15
15
16
16
import dace
17
17
18
18
import gt4py .eve as eve
19
19
from gt4py .next import Dimension , DimensionKind , type_inference as next_typing
20
+ from gt4py .next .common import NeighborTable
20
21
from gt4py .next .iterator import ir as itir , type_inference as itir_typing
21
- from gt4py .next .iterator .embedded import NeighborTableOffsetProvider
22
22
from gt4py .next .iterator .ir import Expr , FunCall , Literal , SymRef
23
23
from gt4py .next .type_system import type_specifications as ts , type_translation
24
24
43
43
flatten_list ,
44
44
get_sorted_dims ,
45
45
map_nested_sdfg_symbols ,
46
- new_array_symbols ,
47
46
unique_name ,
48
47
unique_var_name ,
49
48
)
50
49
51
50
52
- def get_scan_args (stencil : Expr ) -> tuple [bool , Literal ]:
51
+ def _get_scan_args (stencil : Expr ) -> tuple [bool , Literal ]:
53
52
"""
54
53
Parse stencil expression to extract the scan arguments.
55
54
@@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
68
67
return is_forward .value == "True" , init_carry
69
68
70
69
71
- def get_scan_dim (
70
+ def _get_scan_dim (
72
71
column_axis : Dimension ,
73
72
storage_types : dict [str , ts .TypeSpec ],
74
73
output : SymRef ,
@@ -93,6 +92,35 @@ def get_scan_dim(
93
92
)
94
93
95
94
95
+ def _make_array_shape_and_strides (
96
+ name : str ,
97
+ dims : Sequence [Dimension ],
98
+ neighbor_tables : Mapping [str , NeighborTable ],
99
+ sort_dims : bool ,
100
+ ) -> tuple [list [dace .symbol ], list [dace .symbol ]]:
101
+ """
102
+ Parse field dimensions and allocate symbols for array shape and strides.
103
+
104
+ For local dimensions, the size is known at compile-time and therefore
105
+ the corresponding array shape dimension is set to an integer literal value.
106
+
107
+ Returns
108
+ -------
109
+ tuple(shape, strides)
110
+ The output tuple fields are arrays of dace symbolic expressions.
111
+ """
112
+ dtype = dace .int64
113
+ sorted_dims = [dim for _ , dim in get_sorted_dims (dims )] if sort_dims else dims
114
+ shape = [
115
+ neighbor_tables [dim .value ].max_neighbors
116
+ if dim .kind == DimensionKind .LOCAL
117
+ else dace .symbol (unique_name (f"{ name } _shape{ i } " ), dtype )
118
+ for i , dim in enumerate (sorted_dims )
119
+ ]
120
+ strides = [dace .symbol (unique_name (f"{ name } _stride{ i } " ), dtype ) for i , _ in enumerate (shape )]
121
+ return shape , strides
122
+
123
+
96
124
class ItirToSDFG (eve .NodeVisitor ):
97
125
param_types : list [ts .TypeSpec ]
98
126
storage_types : dict [str , ts .TypeSpec ]
@@ -104,17 +132,27 @@ class ItirToSDFG(eve.NodeVisitor):
104
132
def __init__ (
105
133
self ,
106
134
param_types : list [ts .TypeSpec ],
107
- offset_provider : dict [str , NeighborTableOffsetProvider ],
135
+ offset_provider : dict [str , NeighborTable ],
108
136
column_axis : Optional [Dimension ] = None ,
109
137
):
110
138
self .param_types = param_types
111
139
self .column_axis = column_axis
112
140
self .offset_provider = offset_provider
113
141
self .storage_types = {}
114
142
115
- def add_storage (self , sdfg : dace .SDFG , name : str , type_ : ts .TypeSpec , has_offset : bool = True ):
143
+ def add_storage (
144
+ self ,
145
+ sdfg : dace .SDFG ,
146
+ name : str ,
147
+ type_ : ts .TypeSpec ,
148
+ neighbor_tables : Mapping [str , NeighborTable ],
149
+ has_offset : bool = True ,
150
+ sort_dimensions : bool = True ,
151
+ ):
116
152
if isinstance (type_ , ts .FieldType ):
117
- shape , strides = new_array_symbols (name , len (type_ .dims ))
153
+ shape , strides = _make_array_shape_and_strides (
154
+ name , type_ .dims , neighbor_tables , sort_dimensions
155
+ )
118
156
offset = (
119
157
[dace .symbol (unique_name (f"{ name } _offset{ i } _" )) for i in range (len (type_ .dims ))]
120
158
if has_offset
@@ -153,14 +191,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
153
191
154
192
# Add program parameters as SDFG storages.
155
193
for param , type_ in zip (node .params , self .param_types ):
156
- self .add_storage (program_sdfg , str (param .id ), type_ )
194
+ self .add_storage (program_sdfg , str (param .id ), type_ , neighbor_tables )
157
195
158
196
# Add connectivities as SDFG storages.
159
- for offset , table in neighbor_tables :
160
- scalar_kind = type_translation .get_scalar_kind (table .table .dtype )
161
- local_dim = Dimension ("ElementDim" , kind = DimensionKind .LOCAL )
162
- type_ = ts .FieldType ([table .origin_axis , local_dim ], ts .ScalarType (scalar_kind ))
163
- self .add_storage (program_sdfg , connectivity_identifier (offset ), type_ , has_offset = False )
197
+ for offset , offset_provider in neighbor_tables .items ():
198
+ scalar_kind = type_translation .get_scalar_kind (offset_provider .table .dtype )
199
+ local_dim = Dimension (offset , kind = DimensionKind .LOCAL )
200
+ type_ = ts .FieldType (
201
+ [offset_provider .origin_axis , local_dim ], ts .ScalarType (scalar_kind )
202
+ )
203
+ self .add_storage (
204
+ program_sdfg ,
205
+ connectivity_identifier (offset ),
206
+ type_ ,
207
+ neighbor_tables ,
208
+ has_offset = False ,
209
+ sort_dimensions = False ,
210
+ )
164
211
165
212
# Create a nested SDFG for all stencil closures.
166
213
for closure in node .closures :
@@ -222,7 +269,7 @@ def visit_StencilClosure(
222
269
223
270
input_names = [str (inp .id ) for inp in node .inputs ]
224
271
neighbor_tables = filter_neighbor_tables (self .offset_provider )
225
- connectivity_names = [connectivity_identifier (offset ) for offset , _ in neighbor_tables ]
272
+ connectivity_names = [connectivity_identifier (offset ) for offset in neighbor_tables . keys () ]
226
273
227
274
output_nodes = self .get_output_nodes (node , closure_sdfg , closure_state )
228
275
output_names = [k for k , _ in output_nodes .items ()]
@@ -400,11 +447,11 @@ def _visit_scan_stencil_closure(
400
447
output_name : str ,
401
448
) -> tuple [dace .SDFG , dict [str , str | dace .subsets .Subset ], int ]:
402
449
# extract scan arguments
403
- is_forward , init_carry_value = get_scan_args (node .stencil )
450
+ is_forward , init_carry_value = _get_scan_args (node .stencil )
404
451
# select the scan dimension based on program argument for column axis
405
452
assert self .column_axis
406
453
assert isinstance (node .output , SymRef )
407
- scan_dim , scan_dim_index , scan_dtype = get_scan_dim (
454
+ scan_dim , scan_dim_index , scan_dtype = _get_scan_dim (
408
455
self .column_axis ,
409
456
self .storage_types ,
410
457
node .output ,
@@ -570,7 +617,7 @@ def _visit_parallel_stencil_closure(
570
617
) -> tuple [dace .SDFG , dict [str , str | dace .subsets .Subset ], list [str ]]:
571
618
neighbor_tables = filter_neighbor_tables (self .offset_provider )
572
619
input_names = [str (inp .id ) for inp in node .inputs ]
573
- conn_names = [connectivity_identifier (offset ) for offset , _ in neighbor_tables ]
620
+ connectivity_names = [connectivity_identifier (offset ) for offset in neighbor_tables . keys () ]
574
621
575
622
# find the scan dimension, same as output dimension, and exclude it from the map domain
576
623
map_ranges = {}
@@ -583,7 +630,7 @@ def _visit_parallel_stencil_closure(
583
630
index_domain = {dim : f"i_{ dim } " for dim , _ in closure_domain }
584
631
585
632
input_arrays = [(name , self .storage_types [name ]) for name in input_names ]
586
- connectivity_arrays = [(array_table [name ], name ) for name in conn_names ]
633
+ connectivity_arrays = [(array_table [name ], name ) for name in connectivity_names ]
587
634
588
635
context , results = closure_to_tasklet_sdfg (
589
636
node ,
0 commit comments