Skip to content

Commit 1fd4e09

Browse files
committed
Merge branch 'main' into fix-settings-after-ruff
2 parents 8d9f253 + 0151b05 commit 1fd4e09

File tree

10 files changed

+427
-16
lines changed

10 files changed

+427
-16
lines changed

src/gt4py/cartesian/frontend/gtscript_frontend.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,25 @@ class ParsingContext(enum.Enum):
690690
COMPUTATION = 2
691691

692692

693+
_DATADIMS_INDEXER = "A"
694+
695+
696+
def _is_datadims_indexing_name(name: str):
697+
return name.endswith(f".{_DATADIMS_INDEXER}")
698+
699+
700+
def _trim_indexing_symbol(name: str):
701+
return name[: -1 * (len(_DATADIMS_INDEXER) + 1)]
702+
703+
704+
def _is_datadims_indexing_node(node):
705+
return (
706+
isinstance(node.value, ast.Attribute)
707+
and node.value.attr == _DATADIMS_INDEXER
708+
and isinstance(node.value.value, ast.Name)
709+
)
710+
711+
693712
class IRMaker(ast.NodeVisitor):
694713
def __init__(
695714
self,
@@ -1037,6 +1056,10 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:
10371056
result = nodes.VarRef(name=symbol, loc=nodes.Location.from_ast_node(node))
10381057
elif self._is_local_symbol(symbol):
10391058
raise AssertionError("Logic error")
1059+
elif _is_datadims_indexing_name(symbol):
1060+
result = nodes.FieldRef.datadims_index(
1061+
name=_trim_indexing_symbol(symbol), loc=nodes.Location.from_ast_node(node)
1062+
)
10401063
else:
10411064
raise AssertionError(f"Missing '{symbol}' symbol definition")
10421065

@@ -1145,12 +1168,16 @@ def visit_Subscript(self, node: ast.Subscript):
11451168
field_axes = self.fields[result.name].axes
11461169
if index is not None:
11471170
if len(field_axes) != len(index):
1171+
ro_field_message = ""
1172+
if len(field_axes) == 0:
1173+
ro_field_message = f"Did you mean .A{index}?"
11481174
raise GTScriptSyntaxError(
11491175
f"Incorrect offset specification detected. Found {index}, "
1150-
f"but the field has dimensions ({', '.join(field_axes)})"
1176+
f"but the field has dimensions ({', '.join(field_axes)}). "
1177+
f"{ro_field_message}"
11511178
)
11521179
result.offset = {axis: value for axis, value in zip(field_axes, index)}
1153-
elif isinstance(node.value, ast.Subscript):
1180+
elif isinstance(node.value, ast.Subscript) or _is_datadims_indexing_node(node):
11541181
result.data_index = [
11551182
(
11561183
nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32)
@@ -1601,6 +1628,11 @@ def visit_Assign(self, node: ast.Assign):
16011628
elif isinstance(t, ast.Subscript):
16021629
if isinstance(t.value, ast.Name):
16031630
name_node = t.value
1631+
elif _is_datadims_indexing_node(t):
1632+
raise GTScriptSyntaxError(
1633+
message="writing to an GlobalTable ('A' global indexation) is forbidden",
1634+
loc=nodes.Location.from_ast_node(node),
1635+
)
16041636
elif isinstance(t.value, ast.Subscript) and isinstance(t.value.value, ast.Name):
16051637
name_node = t.value.value
16061638
else:

src/gt4py/cartesian/frontend/nodes.py

+4
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,10 @@ def at_center(
354354
name=name, offset={axis: 0 for axis in axes}, data_index=data_index or [], loc=loc
355355
)
356356

357+
@classmethod
358+
def datadims_index(cls, name: str, loc=None):
359+
return cls(name=name, offset={}, data_index=[], loc=loc)
360+
357361

358362
@attribclass
359363
class Cast(Expr):

src/gt4py/cartesian/gtscript.py

+13
Original file line numberDiff line numberDiff line change
@@ -694,10 +694,23 @@ def __getitem__(self, field_spec):
694694
return _FieldDescriptor(dtype, axes, data_dims)
695695

696696

697+
class _GlobalTableDescriptorMaker(_FieldDescriptorMaker):
698+
def __getitem__(self, field_spec):
699+
if not isinstance(field_spec, collections.abc.Collection) and not len(field_spec) == 2:
700+
raise ValueError("GlobalTable is defined by a tuple (type, [axes_size..])")
701+
702+
dtype, data_dims = field_spec
703+
704+
return _FieldDescriptor(dtype, [], data_dims)
705+
706+
697707
# GTScript builtins: variable annotations
698708
Field = _FieldDescriptorMaker()
699709
"""Field descriptor."""
700710

711+
GlobalTable = _GlobalTableDescriptorMaker()
712+
"""Data array with no spatial dimension descriptor."""
713+
701714

702715
class _SequenceDescriptor:
703716
def __init__(self, dtype, length):

src/gt4py/next/otf/languages.py

+5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class Python(LanguageTag):
5757
...
5858

5959

60+
class SDFG(LanguageTag):
61+
settings_class = LanguageSettings
62+
...
63+
64+
6065
class NanobindSrcL(LanguageTag): ...
6166

6267

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# GT4Py - GridTools Framework
2+
#
3+
# Copyright (c) 2014-2023, ETH Zurich
4+
# All rights reserved.
5+
#
6+
# This file is part of the GT4Py project and the GridTools framework.
7+
# GT4Py is free software: you can redistribute it and/or modify it under
8+
# the terms of the GNU General Public License as published by the
9+
# Free Software Foundation, either version 3 of the License, or any later
10+
# version. See the LICENSE.txt file at the top-level directory of this
11+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12+
#
13+
# SPDX-License-Identifier: GPL-3.0-or-later
14+
15+
import functools
16+
17+
import factory
18+
19+
import gt4py._core.definitions as core_defs
20+
from gt4py.next import config
21+
from gt4py.next.otf import recipes, stages
22+
from gt4py.next.program_processors.runners.dace_iterator.workflow import (
23+
DaCeCompilationStepFactory,
24+
DaCeTranslationStepFactory,
25+
convert_args,
26+
)
27+
from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory
28+
29+
30+
def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource:
31+
return stages.CompilableSource(program_source=inp, binding_source=None)
32+
33+
34+
class DaCeWorkflowFactory(factory.Factory):
35+
class Meta:
36+
model = recipes.OTFCompileWorkflow
37+
38+
class Params:
39+
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
40+
cmake_build_type: config.CMakeBuildType = factory.LazyFunction(
41+
lambda: config.CMAKE_BUILD_TYPE
42+
)
43+
use_field_canonical_representation: bool = False
44+
45+
translation = factory.SubFactory(
46+
DaCeTranslationStepFactory,
47+
device_type=factory.SelfAttribute("..device_type"),
48+
use_field_canonical_representation=factory.SelfAttribute(
49+
"..use_field_canonical_representation"
50+
),
51+
)
52+
bindings = _no_bindings
53+
compilation = factory.SubFactory(
54+
DaCeCompilationStepFactory,
55+
cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME),
56+
cmake_build_type=factory.SelfAttribute("..cmake_build_type"),
57+
)
58+
decoration = factory.LazyAttribute(
59+
lambda o: functools.partial(
60+
convert_args,
61+
device=o.device_type,
62+
use_field_canonical_representation=o.use_field_canonical_representation,
63+
)
64+
)
65+
66+
67+
class DaCeBackendFactory(GTFNBackendFactory):
68+
class Params:
69+
otf_workflow = factory.SubFactory(
70+
DaCeWorkflowFactory,
71+
device_type=factory.SelfAttribute("..device_type"),
72+
use_field_canonical_representation=factory.SelfAttribute(
73+
"..use_field_canonical_representation"
74+
),
75+
)
76+
name = factory.LazyAttribute(
77+
lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}"
78+
)
79+
auto_optimize = factory.Trait(
80+
otf_workflow__translation__auto_optimize=True,
81+
name_temps="_opt",
82+
)
83+
use_field_canonical_representation: bool = False
84+
85+
86+
run_dace_cpu = DaCeBackendFactory(cached=True, auto_optimize=True)
87+
88+
run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True)

0 commit comments

Comments
 (0)