19
19
20
20
import gt4py .eve as eve
21
21
import gt4py .next as gtx
22
- from gt4py .eve import Coerced , NodeTranslator
22
+ from gt4py .eve import Coerced , NodeTranslator , PreserveLocationVisitor
23
23
from gt4py .eve .traits import SymbolTableTrait
24
24
from gt4py .eve .utils import UIDGenerator
25
25
from gt4py .next import common
@@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
267
267
stencil = stencil ,
268
268
output = im .ref (tmp_sym .id ),
269
269
inputs = [closure_param_arg_mapping [param .id ] for param in lift_expr .args ], # type: ignore[attr-defined]
270
+ location = current_closure .location ,
270
271
)
271
272
)
272
273
@@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
294
295
output = current_closure .output ,
295
296
inputs = current_closure .inputs
296
297
+ [ir .SymRef (id = sym .id ) for sym in extracted_lifts .keys ()],
298
+ location = current_closure .location ,
297
299
)
298
300
)
299
301
else :
@@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
307
309
+ [ir .Sym (id = tmp .id ) for tmp in tmps ]
308
310
+ [ir .Sym (id = AUTO_DOMAIN .fun .id )], # type: ignore[attr-defined] # value is a global constant
309
311
closures = list (reversed (closures )),
312
+ location = node .location ,
310
313
),
311
314
params = node .params ,
312
315
tmps = [Temporary (id = tmp .id ) for tmp in tmps ],
@@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari
333
336
function_definitions = node .fencil .function_definitions ,
334
337
params = [p for p in node .fencil .params if p .id not in unused_tmps ],
335
338
closures = closures ,
339
+ location = node .fencil .location ,
336
340
),
337
341
params = node .params ,
338
342
tmps = [tmp for tmp in node .tmps if tmp .id not in unused_tmps ],
@@ -456,6 +460,7 @@ def update_domains(
456
460
stencil = closure .stencil ,
457
461
output = closure .output ,
458
462
inputs = closure .inputs ,
463
+ location = closure .location ,
459
464
)
460
465
else :
461
466
domain = closure .domain
@@ -521,6 +526,7 @@ def update_domains(
521
526
function_definitions = node .fencil .function_definitions ,
522
527
params = node .fencil .params [:- 1 ], # remove `_gtmp_auto_domain` param again
523
528
closures = list (reversed (closures )),
529
+ location = node .fencil .location ,
524
530
),
525
531
params = node .params ,
526
532
tmps = node .tmps ,
@@ -580,7 +586,7 @@ def convert_type(dtype):
580
586
# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be
581
587
# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore
582
588
# and hence also not extract as a temporary.
583
- class CreateGlobalTmps (NodeTranslator ):
589
+ class CreateGlobalTmps (PreserveLocationVisitor , NodeTranslator ):
584
590
"""Main entry point for introducing global temporaries.
585
591
586
592
Transforms an existing iterator IR fencil into a fencil with global temporaries.
0 commit comments