Skip to content

Commit 800dca2

Browse files
committed
refactor[next]: stricter typing in ffront
1 parent 7a980f9 commit 800dca2

22 files changed

+339
-236
lines changed

pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ warn_unused_ignores = false
206206
disallow_incomplete_defs = false
207207
module = 'gt4py.next.*'
208208

209+
[[tool.mypy.overrides]]
210+
# TODO: temporarily to propagate it to all of next
211+
disallow_incomplete_defs = true
212+
module = 'gt4py.next.ffront.*'
213+
209214
[[tool.mypy.overrides]]
210215
ignore_errors = true
211216
module = 'gt4py.next.ffront.decorator'

src/gt4py/next/ffront/decorator.py

+12-25
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import typing
2525
import warnings
2626
from collections.abc import Callable
27-
from typing import Generator, Generic, TypeVar
27+
from typing import Generic, TypeVar
2828

2929
from gt4py import eve
3030
from gt4py._core import definitions as core_defs
@@ -72,25 +72,6 @@
7272
DEFAULT_BACKEND: Callable = None
7373

7474

75-
def _field_constituents_shape_and_dims(
76-
arg, arg_type: ts.FieldType | ts.ScalarType | ts.TupleType
77-
) -> Generator[tuple[tuple[int, ...], list[Dimension]]]:
78-
if isinstance(arg_type, ts.TupleType):
79-
for el, el_type in zip(arg, arg_type.types):
80-
yield from _field_constituents_shape_and_dims(el, el_type)
81-
elif isinstance(arg_type, ts.FieldType):
82-
dims = type_info.extract_dims(arg_type)
83-
if hasattr(arg, "shape"):
84-
assert len(arg.shape) == len(dims)
85-
yield (arg.shape, dims)
86-
else:
87-
yield (None, dims)
88-
elif isinstance(arg_type, ts.ScalarType):
89-
yield (None, [])
90-
else:
91-
raise ValueError("Expected 'FieldType' or 'TupleType' thereof.")
92-
93-
9475
# TODO(tehrengruber): Decide if and how programs can call other programs. As a
9576
# result Program could become a GTCallable.
9677
@dataclasses.dataclass(frozen=True)
@@ -179,7 +160,7 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> Program:
179160
def with_grid_type(self, grid_type: GridType) -> Program:
180161
return dataclasses.replace(self, grid_type=grid_type)
181162

182-
def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs:
163+
def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs:
183164
"""
184165
Bind scalar, i.e. non field, program arguments.
185166
@@ -229,7 +210,7 @@ def itir(self) -> itir.FencilDefinition:
229210
)
230211
).program
231212

232-
def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None:
213+
def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None:
233214
if self.backend is None:
234215
warnings.warn(
235216
UserWarning(
@@ -376,7 +357,9 @@ def program(
376357

377358
def program_inner(definition: types.FunctionType) -> Program:
378359
return Program.from_function(
379-
definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
360+
definition,
361+
DEFAULT_BACKEND if backend is eve.NOTHING else backend,
362+
grid_type,
380363
)
381364

382365
return program_inner if definition is None else program_inner(definition)
@@ -634,9 +617,13 @@ def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None):
634617
... ...
635618
"""
636619

637-
def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]:
620+
def field_operator_inner(
621+
definition: types.FunctionType,
622+
) -> FieldOperator[foast.FieldOperator]:
638623
return FieldOperator.from_function(
639-
definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
624+
definition,
625+
DEFAULT_BACKEND if backend is eve.NOTHING else backend,
626+
grid_type,
640627
)
641628

642629
return field_operator_inner if definition is None else field_operator_inner(definition)

src/gt4py/next/ffront/dialect_parser.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def apply(
8484
return output_ast
8585

8686
@classmethod
87-
def apply_to_function(cls, function: Callable):
87+
def apply_to_function(cls, function: Callable) -> DialectRootT:
8888
src = SourceDefinition.from_function(function)
8989
closure_vars = get_closure_vars_from_function(function)
9090
annotations = typing.get_type_hints(function)
@@ -96,7 +96,10 @@ def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST:
9696

9797
@classmethod
9898
def _postprocess_dialect_ast(
99-
cls, output_ast: DialectRootT, closure_vars: dict[str, Any], annotations: dict[str, Any]
99+
cls,
100+
output_ast: DialectRootT,
101+
closure_vars: dict[str, Any],
102+
annotations: dict[str, Any],
100103
) -> DialectRootT:
101104
return output_ast
102105

src/gt4py/next/ffront/foast_passes/closure_var_folding.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,18 @@ class ClosureVarFolding(NodeTranslator, traits.VisitorWithSymbolTableTrait):
3535

3636
@classmethod
3737
def apply(
38-
cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any]
38+
cls,
39+
node: foast.FunctionDefinition | foast.FieldOperator,
40+
closure_vars: dict[str, Any],
3941
) -> foast.FunctionDefinition:
4042
return cls(closure_vars=closure_vars).visit(node)
4143

4244
def visit_Name(
43-
self, node: foast.Name, current_closure_vars, symtable, **kwargs
45+
self,
46+
node: foast.Name,
47+
current_closure_vars: dict[str, Any],
48+
symtable: dict[str, foast.Symbol],
49+
**kwargs: Any,
4450
) -> foast.Name | foast.Constant:
4551
if node.id in symtable:
4652
definition = symtable[node.id]
@@ -50,7 +56,7 @@ def visit_Name(
5056
return foast.Constant(value=value, location=node.location)
5157
return node
5258

53-
def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant:
59+
def visit_Attribute(self, node: foast.Attribute, **kwargs: Any) -> foast.Constant:
5460
value = self.visit(node.value, **kwargs)
5561
if isinstance(value, foast.Constant):
5662
if hasattr(value.value, node.attr):
@@ -59,6 +65,6 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant:
5965
raise errors.DSLError(node.location, "Attribute access only applicable to constants.")
6066

6167
def visit_FunctionDefinition(
62-
self, node: foast.FunctionDefinition, **kwargs
68+
self, node: foast.FunctionDefinition, **kwargs: Any
6369
) -> foast.FunctionDefinition:
6470
return self.generic_visit(node, current_closure_vars=node.closure_vars, **kwargs)

src/gt4py/next/ffront/foast_passes/closure_var_type_deduction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def apply(
3939
return cls(closure_vars=closure_vars).visit(node)
4040

4141
def visit_FunctionDefinition(
42-
self, node: foast.FunctionDefinition, **kwargs
42+
self, node: foast.FunctionDefinition, **kwargs: Any
4343
) -> foast.FunctionDefinition:
4444
new_closure_vars: list[foast.Symbol] = []
4545
for sym in node.closure_vars:

src/gt4py/next/ffront/foast_passes/dead_closure_var_elimination.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from typing import Any
1616

1717
import gt4py.next.ffront.field_operator_ast as foast
18-
from gt4py.eve import NodeTranslator, traits
18+
from gt4py import eve
1919

2020

21-
class DeadClosureVarElimination(NodeTranslator, traits.VisitorWithSymbolTableTrait):
21+
class DeadClosureVarElimination(eve.NodeTranslator, eve.traits.VisitorWithSymbolTableTrait):
2222
"""Remove closure variable symbols that are not referenced in the AST."""
2323

2424
_referenced_symbols: list[foast.Symbol]
@@ -27,7 +27,9 @@ class DeadClosureVarElimination(NodeTranslator, traits.VisitorWithSymbolTableTra
2727
def apply(cls, node: foast.FunctionDefinition) -> foast.FunctionDefinition:
2828
return cls().visit(node)
2929

30-
def visit_Name(self, node: foast.Name, symtable, **kwargs: Any) -> foast.Name:
30+
def visit_Name(
31+
self, node: foast.Name, symtable: dict[str, foast.Symbol], **kwargs: Any
32+
) -> foast.Name:
3133
if node.id in symtable:
3234
self._referenced_symbols.append(symtable[node.id])
3335
return node

src/gt4py/next/ffront/foast_passes/iterable_unpack.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _unique_tuple_symbol(self, node: foast.TupleTargetAssign) -> foast.Symbol[An
5757
self.unique_tuple_symbol_id += 1
5858
return sym
5959

60-
def visit_BlockStmt(self, node: foast.BlockStmt, **kwargs) -> foast.BlockStmt:
60+
def visit_BlockStmt(self, node: foast.BlockStmt, **kwargs: Any) -> foast.BlockStmt:
6161
unrolled_stmts: list[foast.Assign | foast.BlockStmt | foast.Return] = []
6262

6363
for stmt in node.stmts:
@@ -79,7 +79,10 @@ def visit_BlockStmt(self, node: foast.BlockStmt, **kwargs) -> foast.BlockStmt:
7979
slice_indices = list(range(lower, upper))
8080
tuple_slice = [
8181
foast.Subscript(
82-
value=tuple_name, index=i, type=el_type, location=stmt.location
82+
value=tuple_name,
83+
index=i,
84+
type=el_type,
85+
location=stmt.location,
8386
)
8487
for i in slice_indices
8588
]
@@ -98,7 +101,10 @@ def visit_BlockStmt(self, node: foast.BlockStmt, **kwargs) -> foast.BlockStmt:
98101
new_assign = foast.Assign(
99102
target=subtarget,
100103
value=foast.Subscript(
101-
value=tuple_name, index=index, type=el_type, location=stmt.location
104+
value=tuple_name,
105+
index=index,
106+
type=el_type,
107+
location=stmt.location,
102108
),
103109
location=stmt.location,
104110
)

src/gt4py/next/ffront/foast_passes/type_alias_replacement.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait):
3737

3838
@classmethod
3939
def apply(
40-
cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any]
40+
cls,
41+
node: foast.FunctionDefinition | foast.FieldOperator,
42+
closure_vars: dict[str, Any],
4143
) -> tuple[foast.FunctionDefinition, dict[str, Any]]:
4244
foast_node = cls(closure_vars=closure_vars).visit(node)
4345
new_closure_vars = closure_vars.copy()
@@ -53,10 +55,12 @@ def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool:
5355
and node_id not in TYPE_BUILTIN_NAMES
5456
)
5557

56-
def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name:
58+
def visit_Name(self, node: foast.Name, **kwargs: Any) -> foast.Name:
5759
if self.is_type_alias(node.id):
5860
return foast.Name(
59-
id=self.closure_vars[node.id].__name__, location=node.location, type=node.type
61+
id=self.closure_vars[node.id].__name__,
62+
location=node.location,
63+
type=node.type,
6064
)
6165
return node
6266

@@ -79,7 +83,8 @@ def _update_closure_var_symbols(
7983
kw_only_args={},
8084
pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)],
8185
returns=cast(
82-
ts.DataType, from_type_hint(self.closure_vars[var.id])
86+
ts.DataType,
87+
from_type_hint(self.closure_vars[var.id]),
8388
),
8489
),
8590
namespace=dialect_ast_enums.Namespace.CLOSURE,
@@ -94,7 +99,7 @@ def _update_closure_var_symbols(
9499
return new_closure_vars
95100

96101
def visit_FunctionDefinition(
97-
self, node: foast.FunctionDefinition, **kwargs
102+
self, node: foast.FunctionDefinition, **kwargs: Any
98103
) -> foast.FunctionDefinition:
99104
return foast.FunctionDefinition(
100105
id=node.id,

0 commit comments

Comments
 (0)