Skip to content

Commit d5cfa7d

Browse files
authored
feat[next][dace]: Add more debug info to DaCe (#1384)
* Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors
1 parent 8bd5a41 commit d5cfa7d

29 files changed

+288
-120
lines changed

src/gt4py/eve/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@
5858
field,
5959
frozenmodel,
6060
)
61-
from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait
61+
from .traits import (
62+
PreserveLocationVisitor,
63+
SymbolTableTrait,
64+
ValidatedSymbolTableTrait,
65+
VisitorWithSymbolTableTrait,
66+
)
6267
from .trees import (
6368
bfs_walk_items,
6469
bfs_walk_values,
@@ -113,6 +118,7 @@
113118
"SymbolTableTrait",
114119
"ValidatedSymbolTableTrait",
115120
"VisitorWithSymbolTableTrait",
121+
"PreserveLocationVisitor",
116122
# trees
117123
"bfs_walk_items",
118124
"bfs_walk_values",

src/gt4py/eve/traits.py

+8
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
172172
kwargs["symtable"] = kwargs["symtable"].parents
173173

174174
return result
175+
176+
177+
class PreserveLocationVisitor(visitors.NodeVisitor):
178+
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
179+
result = super().visit(node, **kwargs)
180+
if hasattr(node, "location") and hasattr(result, "location"):
181+
result.location = node.location
182+
return result

src/gt4py/next/ffront/foast_to_itir.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import dataclasses
1616
from typing import Any, Callable, Optional
1717

18-
from gt4py.eve import NodeTranslator
18+
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
1919
from gt4py.eve.utils import UIDGenerator
2020
from gt4py.next.ffront import (
2121
dialect_ast_enums,
@@ -39,7 +39,7 @@ def promote_to_list(
3939

4040

4141
@dataclasses.dataclass
42-
class FieldOperatorLowering(NodeTranslator):
42+
class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
4343
"""
4444
Lower FieldOperator AST (FOAST) to Iterator IR (ITIR).
4545
@@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator):
6161
<class 'gt4py.next.iterator.ir.FunctionDefinition'>
6262
>>> lowered.id
6363
SymbolName('fieldop')
64-
>>> lowered.params
64+
>>> lowered.params # doctest: +ELLIPSIS
6565
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
6666
"""
6767

@@ -142,7 +142,7 @@ def visit_IfStmt(
142142
self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs
143143
) -> itir.Expr:
144144
# the lowered if call doesn't need to be lifted as the condition can only originate
145-
# from a scalar value (and not a field)
145+
# from a scalar value (and not a field)
146146
assert (
147147
isinstance(node.condition.type, ts.ScalarType)
148148
and node.condition.type.kind == ts.ScalarKind.BOOL

src/gt4py/next/ffront/past_to_itir.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def _flatten_tuple_expr(
4040
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")
4141

4242

43-
class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator):
43+
class ProgramLowering(
44+
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
45+
):
4446
"""
4547
Lower Program AST (PAST) to Iterator IR (ITIR).
4648
@@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure:
151153
stencil=itir.SymRef(id=node.func.id),
152154
inputs=[*lowered_args, *lowered_kwargs.values()],
153155
output=output,
156+
location=node.location,
154157
)
155158

156159
def _visit_slice_bound(
@@ -175,17 +178,22 @@ def _visit_slice_bound(
175178
lowered_bound = self.visit(slice_bound, **kwargs)
176179
else:
177180
raise AssertionError("Expected 'None' or 'past.Constant'.")
181+
if slice_bound:
182+
lowered_bound.location = slice_bound.location
178183
return lowered_bound
179184

180185
def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr:
181186
if isinstance(node, past.Name):
182-
return itir.SymRef(id=node.id)
187+
return itir.SymRef(id=node.id, location=node.location)
183188
elif isinstance(node, past.Subscript):
184-
return self._construct_itir_out_arg(node.value)
189+
itir_node = self._construct_itir_out_arg(node.value)
190+
itir_node.location = node.location
191+
return itir_node
185192
elif isinstance(node, past.TupleExpr):
186193
return itir.FunCall(
187194
fun=itir.SymRef(id="make_tuple"),
188195
args=[self._construct_itir_out_arg(el) for el in node.elts],
196+
location=node.location,
189197
)
190198
else:
191199
raise ValueError(
@@ -247,7 +255,11 @@ def _construct_itir_domain_arg(
247255
else:
248256
raise AssertionError()
249257

250-
return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args)
258+
return itir.FunCall(
259+
fun=itir.SymRef(id=domain_builtin),
260+
args=domain_args,
261+
location=(node_domain or out_field).location,
262+
)
251263

252264
def _construct_itir_initialized_domain_arg(
253265
self,

src/gt4py/next/iterator/ir.py

+3
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
import gt4py.eve as eve
1919
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
20+
from gt4py.eve.concepts import SourceLocation
2021
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
2122
from gt4py.eve.utils import noninstantiable
2223

2324

2425
@noninstantiable
2526
class Node(eve.Node):
27+
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)
28+
2629
def __str__(self) -> str:
2730
from gt4py.next.iterator.pretty_printer import pformat
2831

src/gt4py/next/iterator/transforms/collapse_list_get.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from gt4py.next.iterator import ir
1717

1818

19-
class CollapseListGet(eve.NodeTranslator):
19+
class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator):
2020
"""Simplifies expressions containing `list_get`.
2121
2222
Examples

src/gt4py/next/iterator/transforms/collapse_tuple.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t
4848

4949

5050
@dataclass(frozen=True)
51-
class CollapseTuple(eve.NodeTranslator):
51+
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
5252
"""
5353
Simplifies `make_tuple`, `tuple_get` calls.
5454
@@ -88,13 +88,6 @@ def apply(
8888
node_types,
8989
).visit(node)
9090

91-
return cls(
92-
ignore_tuple_size,
93-
collapse_make_tuple_tuple_get,
94-
collapse_tuple_get_make_tuple,
95-
use_global_type_inference,
96-
).visit(node)
97-
9891
def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
9992
if (
10093
self.collapse_make_tuple_tuple_get

src/gt4py/next/iterator/transforms/constant_folding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414

15-
from gt4py.eve import NodeTranslator
15+
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
1616
from gt4py.next.iterator import embedded, ir
1717
from gt4py.next.iterator.ir_utils import ir_makers as im
1818

1919

20-
class ConstantFolding(NodeTranslator):
20+
class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
2121
@classmethod
2222
def apply(cls, node: ir.Node) -> ir.Node:
2323
return cls().visit(node)

src/gt4py/next/iterator/transforms/cse.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@
1717
import operator
1818
import typing
1919

20-
from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
20+
from gt4py.eve import (
21+
NodeTranslator,
22+
NodeVisitor,
23+
PreserveLocationVisitor,
24+
SymbolTableTrait,
25+
VisitorWithSymbolTableTrait,
26+
)
2127
from gt4py.eve.utils import UIDGenerator
2228
from gt4py.next.iterator import ir
2329
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
2430

2531

2632
@dataclasses.dataclass
27-
class _NodeReplacer(NodeTranslator):
33+
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
2834
PRESERVED_ANNEX_ATTRS = ("type",)
2935

3036
expr_map: dict[int, ir.SymRef]
@@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool:
7278

7379

7480
@dataclasses.dataclass
75-
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
81+
class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor):
7682
@dataclasses.dataclass
7783
class SubexpressionData:
7884
#: A list of node ids with equal hash and a set of collected child subexpression ids
@@ -341,7 +347,7 @@ def extract_subexpression(
341347

342348

343349
@dataclasses.dataclass(frozen=True)
344-
class CommonSubexpressionElimination(NodeTranslator):
350+
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
345351
"""
346352
Perform common subexpression elimination.
347353

src/gt4py/next/iterator/transforms/eta_reduction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414

15-
from gt4py.eve import NodeTranslator
15+
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
1616
from gt4py.next.iterator import ir
1717

1818

19-
class EtaReduction(NodeTranslator):
19+
class EtaReduction(PreserveLocationVisitor, NodeTranslator):
2020
"""Eta reduction: simplifies `λ(args...) → f(args...)` to `f`."""
2121

2222
def visit_Lambda(self, node: ir.Lambda) -> ir.Node:

src/gt4py/next/iterator/transforms/fuse_maps.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:
3838

3939

4040
@dataclasses.dataclass(frozen=True)
41-
class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator):
41+
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
4242
"""
4343
Fuses nested `map_`s.
4444
@@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:
6666
return ir.Lambda(
6767
params=params,
6868
expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]),
69+
location=fun.location,
6970
)
7071

7172
def visit_FunCall(self, node: ir.FunCall, **kwargs):
@@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
99100
ir.FunCall(
100101
fun=inner_op,
101102
args=[ir.SymRef(id=param.id) for param in inner_op.params],
103+
location=node.location,
102104
)
103105
)
104106
)

src/gt4py/next/iterator/transforms/global_tmps.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import gt4py.eve as eve
2121
import gt4py.next as gtx
22-
from gt4py.eve import Coerced, NodeTranslator
22+
from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor
2323
from gt4py.eve.traits import SymbolTableTrait
2424
from gt4py.eve.utils import UIDGenerator
2525
from gt4py.next import common
@@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
267267
stencil=stencil,
268268
output=im.ref(tmp_sym.id),
269269
inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined]
270+
location=current_closure.location,
270271
)
271272
)
272273

@@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
294295
output=current_closure.output,
295296
inputs=current_closure.inputs
296297
+ [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()],
298+
location=current_closure.location,
297299
)
298300
)
299301
else:
@@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
307309
+ [ir.Sym(id=tmp.id) for tmp in tmps]
308310
+ [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant
309311
closures=list(reversed(closures)),
312+
location=node.location,
310313
),
311314
params=node.params,
312315
tmps=[Temporary(id=tmp.id) for tmp in tmps],
@@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari
333336
function_definitions=node.fencil.function_definitions,
334337
params=[p for p in node.fencil.params if p.id not in unused_tmps],
335338
closures=closures,
339+
location=node.fencil.location,
336340
),
337341
params=node.params,
338342
tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps],
@@ -456,6 +460,7 @@ def update_domains(
456460
stencil=closure.stencil,
457461
output=closure.output,
458462
inputs=closure.inputs,
463+
location=closure.location,
459464
)
460465
else:
461466
domain = closure.domain
@@ -521,6 +526,7 @@ def update_domains(
521526
function_definitions=node.fencil.function_definitions,
522527
params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again
523528
closures=list(reversed(closures)),
529+
location=node.fencil.location,
524530
),
525531
params=node.params,
526532
tmps=node.tmps,
@@ -580,7 +586,7 @@ def convert_type(dtype):
580586
# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be
581587
# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore
582588
# and hence also not extract as a temporary.
583-
class CreateGlobalTmps(NodeTranslator):
589+
class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator):
584590
"""Main entry point for introducing global temporaries.
585591
586592
Transforms an existing iterator IR fencil into a fencil with global temporaries.

src/gt4py/next/iterator/transforms/inline_fundefs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
from typing import Any, Dict, Set
1616

17-
from gt4py.eve import NOTHING, NodeTranslator
17+
from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor
1818
from gt4py.next.iterator import ir
1919

2020

21-
class InlineFundefs(NodeTranslator):
21+
class InlineFundefs(PreserveLocationVisitor, NodeTranslator):
2222
def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]):
2323
if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition):
2424
return ir.Lambda(
@@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition):
3131
return self.generic_visit(node, symtable=node.annex.symtable)
3232

3333

34-
class PruneUnreferencedFundefs(NodeTranslator):
34+
class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator):
3535
def visit_FunctionDefinition(
3636
self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool
3737
):

src/gt4py/next/iterator/transforms/inline_into_scan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall:
5353
return inlined
5454

5555

56-
class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator):
56+
class InlineIntoScan(
57+
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
58+
):
5759
"""
5860
Inline non-SymRef arguments into the scan.
5961
@@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
100102
new_scan = ir.FunCall(
101103
fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]]
102104
)
103-
result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
104-
return result
105+
return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
105106
return self.generic_visit(node, **kwargs)

0 commit comments

Comments
 (0)