Skip to content

Commit ebd6c66

Browse files
committed
edits with Enrique
1 parent 4303bff commit ebd6c66

29 files changed

+64
-57
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ repos:
3838
- id: debug-statements
3939

4040
- repo: https://github.com/astral-sh/ruff-pre-commit
41-
# Ruff version.
4241
rev: v0.2.0
4342
hooks:
4443
# Run the linter.
44+
# TODO: include tests here
4545
- id: ruff
4646
files: ^src/
4747
args: [--fix]

pyproject.toml

+2-10
Original file line numberDiff line numberDiff line change
@@ -396,16 +396,7 @@ preview = true
396396
# NPY: NumPy-specific rules
397397
# RUF: Ruff-specific rules
398398
ignore = [
399-
'B008', # Do not perform function calls in argument defaults
400-
'B905', # B905 `zip()` without an explicit `strict=` parameter
401-
'E501', # Line too long (using Bugbear's B950 warning)
402-
'E701', # Multiple statements on one line, see https://github.com/psf/black/issues/3887
403-
'RUF012', # Mutable class attributes should be annotated with `typing.ClassVar`
404-
'RUF003', # Comment contains ambiguous character
405-
'RUF015', # Prefer `next(iter(new_def_type.pos_or_kw_args.values()))` over single element slice
406-
'B028', # No explicit `stacklevel` keyword argument found
407-
'RUF001', # String contains ambiguous character
408-
'RUF009' # Do not perform function call in dataclass defaults
399+
'E501' # Line too long
409400
]
410401
ignore-init-module-imports = true
411402
select = ['E', 'F', 'I', 'B', 'A', 'T10', 'ERA', 'NPY', 'RUF']
@@ -460,6 +451,7 @@ split-on-trailing-comma = false
460451
max-complexity = 15
461452

462453
[tool.ruff.lint.per-file-ignores]
454+
"src/gt4py/cartesian/*" = ["RUF012"]
463455
'src/gt4py/eve/extended_typing.py' = ['F401', 'F405']
464456
'src/gt4py/next/__init__.py' = ['F401']
465457

src/gt4py/cartesian/backend/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None:
271271
unknown_options = set(options.backend_opts.keys()) - set(self.options.keys())
272272
if unknown_options:
273273
warnings.warn(
274-
f"Unknown options '{unknown_options}' for backend '{self.name}'", RuntimeWarning
274+
f"Unknown options '{unknown_options}' for backend '{self.name}'",
275+
RuntimeWarning,
276+
stacklevel=2,
275277
)
276278

277279
def make_module(

src/gt4py/cartesian/backend/dace_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _postprocess_dace_code(code_objects, is_gpu, builder):
516516
break
517517
for i, line in enumerate(lines):
518518
if "#include <dace/dace.h>" in line:
519-
cuda_code = [co.clean_code for co in code_objects if co.title == "CUDA"][0]
519+
cuda_code = next(co.clean_code for co in code_objects if co.title == "CUDA")
520520
lines = lines[0:i] + cuda_code.split("\n") + lines[i + 1 :]
521521
break
522522

src/gt4py/cartesian/frontend/gtscript_frontend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def visit_If(self, node: ast.If):
601601
and node.test.func.id == "__INLINED"
602602
and len(node.test.args) == 1
603603
):
604-
warnings.warn(
604+
warnings.warn( # noqa: B028 -> no-explicit-stacklevel
605605
f"stencil {self.stencil_name}, line {node.lineno}, column {node.col_offset}: compile-time if condition via __INLINED deprecated",
606606
category=DeprecationWarning,
607607
)

src/gt4py/cartesian/gtc/passes/gtir_definitive_assignment_analysis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,6 @@ def check(gtir_stencil_expr: gtir.Stencil) -> gtir.Stencil:
7474
"""Execute definitive assignment analysis and warn on errors."""
7575
invalid_accesses = analyze(gtir_stencil_expr)
7676
for invalid_access in invalid_accesses:
77-
warnings.warn(f"`{invalid_access.name}` may be uninitialized.")
77+
warnings.warn(f"`{invalid_access.name}` may be uninitialized.", stacklevel=2)
7878

7979
return gtir_stencil_expr

src/gt4py/cartesian/gtc/passes/oir_optimizations/vertical_loop_merging.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def _mergeable(a: oir.VerticalLoop, b: oir.VerticalLoop) -> bool:
3838
def _merge(a: oir.VerticalLoop, b: oir.VerticalLoop) -> oir.VerticalLoop:
3939
sections = a.sections + b.sections
4040
if a.caches or b.caches:
41-
warnings.warn("AdjacentLoopMerging pass removed previously declared caches")
41+
warnings.warn(
42+
"AdjacentLoopMerging pass removed previously declared caches", stacklevel=2
43+
)
4244
return oir.VerticalLoop(
4345
loop_order=a.loop_order,
4446
sections=sections,

src/gt4py/cartesian/stencil_object.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ def _validate_args( # Function is too complex
420420
warnings.warn(
421421
f"The layout of the field '{name}' is not recommended for this backend."
422422
f"This may lead to performance degradation. Please consider using the"
423-
f"provided allocators in `gt4py.storage`."
423+
f"provided allocators in `gt4py.storage`.",
424+
stacklevel=2,
424425
)
425426

426427
field_dtype = self.field_info[name].dtype

src/gt4py/cartesian/utils/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def shashed_id(*args, length=10, hash_algorithm=None):
197197
return shash(*args, hash_algorithm=hash_algorithm)[:length]
198198

199199

200-
def classmethod_to_function(class_method, instance=None, owner=type(None), remove_cls_arg=False):
200+
def classmethod_to_function(class_method, instance=None, owner=None, remove_cls_arg=False):
201201
if remove_cls_arg:
202202
return functools.partial(class_method.__get__(instance, owner), None)
203203
else:

src/gt4py/cartesian/utils/meta.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import operator
2121
import platform
2222
import textwrap
23-
from typing import Callable, Dict, List, Tuple, Type
23+
from typing import Callable, Dict, Final, List, Tuple, Type
2424

2525
from packaging import version
2626

@@ -260,7 +260,7 @@ def generic_visit(self, node, **kwargs):
260260

261261

262262
class ASTEvaluator(ASTPass):
263-
AST_OP_TO_OP: Dict[Type, Callable] = {
263+
AST_OP_TO_OP: Final[Dict[Type, Callable]] = {
264264
# Arithmetic operations
265265
ast.UAdd: operator.pos,
266266
ast.USub: operator.neg,

src/gt4py/eve/datamodels/core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def get_fields(model: Union[DataModel, Type[DataModel]]) -> utils.FrozenNamespac
660660
>>> fields(Model) # doctest:+ELLIPSIS
661661
FrozenNamespace(...name=Attribute(name='name', default=NOTHING, ...
662662
663-
""" # doctest conventions confuse RST validator
663+
"""
664664
if not is_datamodel(model):
665665
raise TypeError(f"Invalid datamodel instance or class: '{model}'.")
666666
if not isinstance(model, type):
@@ -825,7 +825,8 @@ def concretize(
825825
RuntimeWarning(
826826
f"Existing '{class_name}' symbol in module '{module}' contains a reference"
827827
"to a different object."
828-
)
828+
),
829+
stacklevel=2,
829830
)
830831

831832
return concrete_cls

src/gt4py/eve/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def reset_sequence(self, start: int = 1, *, warn_unsafe: Optional[bool] = None)
695695
if warn_unsafe is None:
696696
warn_unsafe = self.warn_unsafe
697697
if warn_unsafe and start < next(self._counter):
698-
warnings.warn("Unsafe reset of UIDGenerator ({self})")
698+
warnings.warn("Unsafe reset of UIDGenerator ({self})", stacklevel=2)
699699
self._counter = itertools.count(start)
700700

701701
return self

src/gt4py/next/constructors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
@eve.utils.with_fluid_partial
3030
def empty(
3131
domain: common.DomainLike,
32-
dtype: core_defs.DTypeLike = core_defs.Float64DType(()),
32+
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 -> function-call-in-default-argument
3333
*,
3434
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
3535
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
@@ -96,7 +96,7 @@ def empty(
9696
@eve.utils.with_fluid_partial
9797
def zeros(
9898
domain: common.DomainLike,
99-
dtype: core_defs.DTypeLike = core_defs.Float64DType(()),
99+
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 -> function-call-in-default-argument
100100
*,
101101
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
102102
allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None,
@@ -128,7 +128,7 @@ def zeros(
128128
@eve.utils.with_fluid_partial
129129
def ones(
130130
domain: common.DomainLike,
131-
dtype: core_defs.DTypeLike = core_defs.Float64DType(()),
131+
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 -> function-call-in-default-argument
132132
*,
133133
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
134134
allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None,

src/gt4py/next/ffront/decorator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No
298298
warnings.warn(
299299
UserWarning(
300300
f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend."
301-
)
301+
),
302+
stacklevel=2,
302303
)
303304
with next_embedded.context.new_context(offset_provider=offset_provider) as ctx:
304305
ctx.run(self.definition, *rewritten_args, **kwargs)

src/gt4py/next/ffront/foast_passes/type_deduction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -319,15 +319,15 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp
319319
)
320320
new_definition = self.visit(node.definition, **kwargs)
321321
new_def_type = new_definition.type
322-
carry_type = list(new_def_type.pos_or_kw_args.values())[0]
322+
carry_type = next(iter(new_def_type.pos_or_kw_args.values()))
323323
if new_init.type != new_def_type.returns:
324324
raise errors.DSLError(
325325
node.location,
326326
f"Argument 'init' to scan operator '{node.id}' must have same type as its return: "
327327
f"expected '{new_def_type.returns}', got '{new_init.type}'.",
328328
)
329329
elif new_init.type != carry_type:
330-
carry_arg_name = list(new_def_type.pos_or_kw_args.keys())[0]
330+
carry_arg_name = next(iter(new_def_type.pos_or_kw_args.keys()))
331331
raise errors.DSLError(
332332
node.location,
333333
f"Argument 'init' to scan operator '{node.id}' must have same type as '{carry_arg_name}' argument: "

src/gt4py/next/ffront/foast_to_itir.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio
114114
# (this is the only place in the lowering were a variable is captured in a lifted lambda)
115115
lowering_utils.to_tuples_of_iterator(
116116
im.promote_to_const_iterator(func_definition.params[0].id),
117-
[*node.type.definition.pos_or_kw_args.values()][0],
117+
[*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 -> unnecessary-iterable-allocation-for-first-element
118118
),
119119
)(
120120
# the function itself returns a tuple of iterators, deref element-wise

src/gt4py/next/iterator/embedded.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ def _shift_field_indices(
10131013
def np_as_located_field(
10141014
*axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None
10151015
) -> Callable[[np.ndarray], common.Field]:
1016-
warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning)
1016+
warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning) # noqa: B028
10171017

10181018
origin = origin or {}
10191019

src/gt4py/next/iterator/pretty_parser.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
| "if" prec1 "then" prec1 "else" prec1 -> ifthenelse
4343
4444
?prec2: prec3
45-
| prec2 "∨" prec3 -> bool_or
45+
| prec2 "∨" prec3 -> bool_or
4646
4747
?prec3: prec4
4848
| prec3 "∧" prec4 -> bool_and
@@ -85,7 +85,7 @@
8585
8686
%import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS)
8787
%ignore WS
88-
"""
88+
""" # noqa: RUF001
8989

9090

9191
@lark_visitors.v_args(inline=True)

src/gt4py/next/iterator/pretty_printer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
BINARY_OPS: Final = {
3232
"plus": "+",
3333
"minus": "-",
34-
"multiplies": "×",
34+
"multiplies": "×", # noqa: RUF001
3535
"divides": "/",
3636
"eq": "==",
3737
"less": "<",
3838
"greater": ">",
3939
"and_": "∧",
40-
"or_": "∨",
40+
"or_": "∨", # noqa: RUF001
4141
}
4242

4343
# replacements for builtin unary operations
@@ -195,11 +195,11 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]:
195195
res = self._hmerge(dim, [": ["], start, [", "], end, [")"])
196196
return self._prec_parens(res, prec, PRECEDENCE["__call__"])
197197
if fun_name == "cartesian_domain" and len(node.args) >= 1:
198-
# cartesian_domain(x, y, ...) → c{ x × y × ... }
198+
# cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003
199199
args = self.visit(node.args, prec=PRECEDENCE["__call__"])
200200
return self._hmerge(["c⟨ "], *self._hinterleave(args, ", "), [" ⟩"])
201201
if fun_name == "unstructured_domain" and len(node.args) >= 1:
202-
# unstructured_domain(x, y, ...) → u{ x × y × ... }
202+
# unstructured_domain(x, y, ...) → u{ x × y × ... } # noqa: RUF003
203203
args = self.visit(node.args, prec=PRECEDENCE["__call__"])
204204
return self._hmerge(["u⟨ "], *self._hinterleave(args, ", "), [" ⟩"])
205205
if fun_name == "if_" and len(node.args) == 3:

src/gt4py/next/iterator/tracing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import dataclasses
1616
import inspect
1717
import typing
18-
from typing import List
18+
from typing import ClassVar, List
1919

2020
from gt4py._core import definitions as core_defs
2121
from gt4py.eve import Node
@@ -208,8 +208,8 @@ def __bool__(self):
208208

209209

210210
class TracerContext:
211-
fundefs: List[FunctionDefinition] = []
212-
closures: List[StencilClosure] = []
211+
fundefs: ClassVar[List[FunctionDefinition]] = []
212+
closures: ClassVar[List[StencilClosure]] = []
213213

214214
@classmethod
215215
def add_fundef(cls, fun):

src/gt4py/next/iterator/transforms/collapse_tuple.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
14+
from __future__ import annotations
15+
1416
import dataclasses
1517
import enum
1618
import functools
@@ -110,12 +112,12 @@ class Flag(enum.Flag):
110112
INLINE_TRIVIAL_LET = enum.auto()
111113

112114
@classmethod
113-
def all(self): # shadowing a python builtin
115+
def all(self) -> CollapseTuple.Flag:
114116
return functools.reduce(operator.or_, self.__members__.values())
115117

116118
ignore_tuple_size: bool
117119
use_global_type_inference: bool
118-
flags: Flag = Flag.all()
120+
flags: Flag = Flag.all() # noqa: RUF009
119121

120122
PRESERVED_ANNEX_ATTRS = ("type",)
121123

src/gt4py/next/iterator/transforms/simple_inline_heuristic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def is_eligible_for_inlining(node: ir.FunCall, is_scan_pass_context: bool) -> bo
3838

3939
assert isinstance(node.fun, ir.FunCall) # for mypy
4040
(stencil,) = node.fun.args
41-
# Dont inline scans, i.e. exclude `↑(scan(...))(...)`
41+
# Don't inline scans, i.e. exclude `↑(scan(...))(...)`
4242
if isinstance(stencil, ir.FunCall) and stencil.fun == ir.SymRef(id="scan"):
4343
return False
4444

45-
# Dont inline the first lifted function call within a scan, e.g. if the node given here
45+
# Don't inline the first lifted function call within a scan, e.g. if the node given here
4646
# is `↑(f)(args...)` and appears in a scan pass `scan(λ(acc, args...) → acc + ·↑(f)(args...))`
4747
# it should not be inlined.
4848
return not is_scan_pass_context

src/gt4py/next/iterator/type_inference.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _default_constraints():
387387
DOMAIN_DTYPE = Primitive(name="domain")
388388
OFFSET_TAG_DTYPE = Primitive(name="offset_tag")
389389

390-
# Some helpers to define the builtins types
390+
# Some helpers to define the builtins' types
391391
T0 = TypeVar.fresh()
392392
T1 = TypeVar.fresh()
393393
T2 = TypeVar.fresh()
@@ -558,16 +558,16 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints):
558558
current_loc_out = current_loc_in
559559
for arg in shift_args:
560560
if not isinstance(arg, ir.OffsetLiteral):
561-
# probably some dynamically computed offset, thus we assume its a number not an axis and just ignore it (see comment below)
561+
# probably some dynamically computed offset, thus we assume it's a number not an axis and just ignore it (see comment below)
562562
continue
563563
offset = arg.value
564564
if isinstance(offset, int):
565-
continue # ignore application of (partial) shifts
565+
continue # ignore 'application' of (partial) shifts
566566
else:
567567
assert isinstance(offset, str)
568568
axis = offset_provider[offset]
569569
if isinstance(axis, gtx.Dimension):
570-
continue # Cartesian shifts dont change the location type
570+
continue # Cartesian shifts don't change the location type
571571
elif isinstance(axis, Connectivity):
572572
assert (
573573
axis.origin_axis.kind

src/gt4py/next/program_processors/codegens/gtfn/codegen.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414

15-
from typing import Any, Collection, Union
15+
from typing import Any, Collection, Final, Union
1616

1717
from gt4py.eve import codegen
1818
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
@@ -22,12 +22,12 @@
2222

2323

2424
class GTFNCodegen(codegen.TemplatedGenerator):
25-
_grid_type_str = {
25+
_grid_type_str: Final = {
2626
common.GridType.CARTESIAN: "cartesian",
2727
common.GridType.UNSTRUCTURED: "unstructured",
2828
}
2929

30-
_builtins_mapping = {
30+
_builtins_mapping: Final = {
3131
"abs": "std::abs",
3232
"sin": "std::sin",
3333
"cos": "std::cos",

src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def _preprocess_program(
193193
if runtime_lift_mode and runtime_lift_mode != self.lift_mode:
194194
warnings.warn(
195195
f"GTFN Backend was configured for LiftMode `{self.lift_mode!s}`, but "
196-
f"overriden to be {runtime_lift_mode!s} at runtime."
196+
f"overriden to be {runtime_lift_mode!s} at runtime.",
197+
stacklevel=2,
197198
)
198199

199200
if not self.enable_itir_transforms:

0 commit comments

Comments
 (0)