19
19
import gt4py .eve as eve
20
20
from gt4py .next import Dimension , DimensionKind , type_inference as next_typing
21
21
from gt4py .next .common import NeighborTable
22
- from gt4py .next .iterator import ir as itir , type_inference as itir_typing
23
- from gt4py .next .iterator .ir import Expr , FunCall , Literal , SymRef
22
+ from gt4py .next .iterator import (
23
+ ir as itir ,
24
+ transforms as itir_transforms ,
25
+ type_inference as itir_typing ,
26
+ )
27
+ from gt4py .next .iterator .ir import Expr , FunCall , Literal , Sym , SymRef
24
28
from gt4py .next .type_system import type_specifications as ts , type_translation
25
29
26
30
from .itir_to_tasklet import (
36
40
from .utility import (
37
41
add_mapped_nested_sdfg ,
38
42
as_dace_type ,
43
+ as_scalar_type ,
39
44
connectivity_identifier ,
40
45
create_memlet_at ,
41
46
create_memlet_full ,
44
49
flatten_list ,
45
50
get_sorted_dims ,
46
51
map_nested_sdfg_symbols ,
52
+ new_array_symbols ,
47
53
unique_name ,
48
54
unique_var_name ,
49
55
)
@@ -154,12 +160,14 @@ def __init__(
154
160
self ,
155
161
param_types : list [ts .TypeSpec ],
156
162
offset_provider : dict [str , NeighborTable ],
163
+ tmps : list [itir_transforms .global_tmps .Temporary ],
157
164
column_axis : Optional [Dimension ] = None ,
158
165
):
159
166
self .param_types = param_types
160
167
self .column_axis = column_axis
161
168
self .offset_provider = offset_provider
162
169
self .storage_types = {}
170
+ self .tmps = tmps
163
171
164
172
def add_storage (
165
173
self ,
@@ -189,6 +197,70 @@ def add_storage(
189
197
raise NotImplementedError ()
190
198
self .storage_types [name ] = type_
191
199
200
+ def add_storage_for_temporaries (
201
+ self , node_params : list [Sym ], defs_state : dace .SDFGState , program_sdfg : dace .SDFG
202
+ ) -> dict [str , str ]:
203
+ symbol_map : dict [str , TaskletExpr ] = {}
204
+ # The shape of temporary arrays might be defined based on scalar values passed as program arguments.
205
+ # Here we collect these values in a symbol map.
206
+ tmp_ids = set (tmp .id for tmp in self .tmps )
207
+ for sym in node_params :
208
+ if sym .id not in tmp_ids and sym .kind != "Iterator" :
209
+ name_ = str (sym .id )
210
+ type_ = self .storage_types [name_ ]
211
+ assert isinstance (type_ , ts .ScalarType )
212
+ symbol_map [name_ ] = SymbolExpr (name_ , as_dace_type (type_ ))
213
+
214
+ tmp_symbols : dict [str , str ] = {}
215
+ for tmp in self .tmps :
216
+ tmp_name = str (tmp .id )
217
+
218
+ # We visit the domain of the temporary field, passing the set of available symbols.
219
+ assert isinstance (tmp .domain , itir .FunCall )
220
+ self .node_types .update (itir_typing .infer_all (tmp .domain ))
221
+ domain_ctx = Context (program_sdfg , defs_state , symbol_map )
222
+ tmp_domain = self ._visit_domain (tmp .domain , domain_ctx )
223
+
224
+ # We build the FieldType for this temporary array.
225
+ dims : list [Dimension ] = []
226
+ for dim , _ in tmp_domain :
227
+ dims .append (
228
+ Dimension (
229
+ value = dim ,
230
+ kind = (
231
+ DimensionKind .VERTICAL
232
+ if self .column_axis is not None and self .column_axis .value == dim
233
+ else DimensionKind .HORIZONTAL
234
+ ),
235
+ )
236
+ )
237
+ assert isinstance (tmp .dtype , str )
238
+ type_ = ts .FieldType (dims = dims , dtype = as_scalar_type (tmp .dtype ))
239
+ self .storage_types [tmp_name ] = type_
240
+
241
+ # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now.
242
+ # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics
243
+ # to assign optimal stride values.
244
+ tmp_shape , _ = new_array_symbols (tmp_name , len (dims ))
245
+ tmp_offset = [
246
+ dace .symbol (unique_name (f"{ tmp_name } _offset{ i } " )) for i in range (len (dims ))
247
+ ]
248
+ _ , tmp_array = program_sdfg .add_array (
249
+ tmp_name , tmp_shape , as_dace_type (type_ .dtype ), offset = tmp_offset , transient = True
250
+ )
251
+
252
+ # Loop through all dimensions to visit the symbolic expressions for array shape and offset.
253
+ # These expressions are later mapped to interstate symbols.
254
+ for (_ , (begin , end )), offset_sym , shape_sym in zip (
255
+ tmp_domain ,
256
+ tmp_array .offset ,
257
+ tmp_array .shape ,
258
+ ):
259
+ tmp_symbols [str (offset_sym )] = f"0 - { begin .value } "
260
+ tmp_symbols [str (shape_sym )] = f"{ end .value } - { begin .value } "
261
+
262
+ return tmp_symbols
263
+
192
264
def get_output_nodes (
193
265
self , closure : itir .StencilClosure , sdfg : dace .SDFG , state : dace .SDFGState
194
266
) -> dict [str , dace .nodes .AccessNode ]:
@@ -204,7 +276,7 @@ def get_output_nodes(
204
276
def visit_FencilDefinition (self , node : itir .FencilDefinition ):
205
277
program_sdfg = dace .SDFG (name = node .id )
206
278
program_sdfg .debuginfo = dace_debuginfo (node )
207
- last_state = program_sdfg .add_state ("program_entry" , True )
279
+ entry_state = program_sdfg .add_state ("program_entry" , is_start_block = True )
208
280
self .node_types = itir_typing .infer_all (node )
209
281
210
282
# Filter neighbor tables from offset providers.
@@ -214,6 +286,20 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
214
286
for param , type_ in zip (node .params , self .param_types ):
215
287
self .add_storage (program_sdfg , str (param .id ), type_ , neighbor_tables )
216
288
289
+ if self .tmps :
290
+ tmp_symbols = self .add_storage_for_temporaries (node .params , entry_state , program_sdfg )
291
+ # on the first interstate edge define symbols for shape and offsets of temporary arrays
292
+ last_state = program_sdfg .add_state ("init_symbols_for_temporaries" )
293
+ program_sdfg .add_edge (
294
+ entry_state ,
295
+ last_state ,
296
+ dace .InterstateEdge (
297
+ assignments = tmp_symbols ,
298
+ ),
299
+ )
300
+ else :
301
+ last_state = entry_state
302
+
217
303
# Add connectivities as SDFG storages.
218
304
for offset , offset_provider in neighbor_tables .items ():
219
305
scalar_kind = type_translation .get_scalar_kind (offset_provider .table .dtype )
0 commit comments