diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7f0c8101e..a06fe75dca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -103,7 +103,7 @@ repos: # Add all type stubs from typeshed - types-all args: [--no-install-types] - exclude: | + exclude: |- (?x)^( setup.py | build/.* | @@ -120,4 +120,4 @@ repos: tests/next_tests/past_common_fixtures.py | tests/next_tests/toy_connectivity.py | tests/.* - )$ \ No newline at end of file + )$ diff --git a/pyproject.toml b/pyproject.toml index 29ed5f6bbf..0a1b92fa59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -473,4 +473,4 @@ version = {attr = 'gt4py.__about__.__version__'} 'gt4py' = ['py.typed', '*.md', '*.rst'] [tool.setuptools.packages] -find = {namespaces = false, where = ['src']} \ No newline at end of file +find = {namespaces = false, where = ['src']} diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index a550db4f2e..0b1fff1420 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -74,21 +74,24 @@ BoolScalar: TypeAlias = Union[bool_, bool] BoolT = TypeVar("BoolT", bound=BoolScalar) BOOL_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + BoolScalar.__args__, # type: ignore[attr-defined] ) IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] IntT = TypeVar("IntT", bound=IntScalar) INT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + IntScalar.__args__, # type: ignore[attr-defined] ) UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) UINT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + UnsignedIntScalar.__args__, # type: ignore[attr-defined] ) @@ -100,7 +103,8 @@ FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + FloatingScalar.__args__, # type: ignore[attr-defined] ) @@ -165,23 +169,28 @@ class DTypeKind(eve.StrEnum): @overload -def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: ... +def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: + ... @overload -def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: ... +def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: + ... @overload -def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: ... +def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: + ... @overload -def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: ... +def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: + ... @overload -def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: ... +def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: + ... def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: @@ -355,7 +364,8 @@ class GTDimsInterface(Protocol): """ @property - def __gt_dims__(self) -> Tuple[str, ...]: ... + def __gt_dims__(self) -> Tuple[str, ...]: + ... class GTOriginInterface(Protocol): @@ -366,7 +376,8 @@ class GTOriginInterface(Protocol): """ @property - def __gt_origin__(self) -> Tuple[int, ...]: ... + def __gt_origin__(self) -> Tuple[int, ...]: + ... # -- Device representation -- @@ -436,45 +447,64 @@ def __iter__(self) -> Iterator[DeviceTypeT | int]: class NDArrayObject(Protocol): @property - def ndim(self) -> int: ... + def ndim(self) -> int: + ... @property - def shape(self) -> tuple[int, ...]: ... + def shape(self) -> tuple[int, ...]: + ... @property - def dtype(self) -> Any: ... + def dtype(self) -> Any: + ... - def item(self) -> Any: ... + def item(self) -> Any: + ... - def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: + ... - def __getitem__(self, item: Any) -> NDArrayObject: ... + def __getitem__(self, item: Any) -> NDArrayObject: + ... - def __abs__(self) -> NDArrayObject: ... + def __abs__(self) -> NDArrayObject: + ... - def __neg__(self) -> NDArrayObject: ... + def __neg__(self) -> NDArrayObject: + ... - def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __radd__(self, other: Any) -> NDArrayObject: ... + def __radd__(self, other: Any) -> NDArrayObject: + ... - def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __rsub__(self, other: Any) -> NDArrayObject: ... + def __rsub__(self, other: Any) -> NDArrayObject: + ... - def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __rmul__(self, other: Any) -> NDArrayObject: ... + def __rmul__(self, other: Any) -> NDArrayObject: + ... - def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __rfloordiv__(self, other: Any) -> NDArrayObject: ... + def __rfloordiv__(self, other: Any) -> NDArrayObject: + ... - def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __rtruediv__(self, other: Any) -> NDArrayObject: ... + def __rtruediv__(self, other: Any) -> NDArrayObject: + ... - def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... @@ -494,8 +524,11 @@ def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignor def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable ... - def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 669110161e..62e36de721 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -305,7 +305,8 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs class MakeModuleSourceCallable(Protocol): - def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ... + def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: + ... class PurePythonBackendCLIMixin(CLIBackendMixin): diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b02c765ad7..f75511f9ba 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -684,14 +684,14 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li if name in sdfg.arrays: data = sdfg.arrays[name] assert isinstance(data, dace.data.Array) - res[name] = ( - "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type=( - "object" if self.backend.storage_info["device"] == "gpu" else "buffer" - ), - name=name, - ndim=len(data.shape), - ) + res[ + name + ] = "py::{pybind_type} {name}, std::array {name}_origin".format( + pybind_type=( + "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + ), + name=name, + ndim=len(data.shape), ) elif name in sdfg.symbols and not name.startswith("__"): assert name in sdfg.symbols diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index 1ffa5a412d..e12669ae0f 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -179,7 +179,8 @@ def build_pybind_ext( build_path: str, target_path: str, **kwargs: str, -) -> Tuple[str, str]: ... +) -> Tuple[str, str]: + ... @overload @@ -197,7 +198,8 @@ def build_pybind_ext( build_ext_class: Type = None, verbose: bool = False, clean: bool = False, -) -> Tuple[str, str]: ... +) -> Tuple[str, str]: + ... def build_pybind_ext( diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 2df8c106ce..8a813fba73 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -618,7 +618,7 @@ def visit_If(self, node: ast.If): def _make_temp_decls( - descriptors: Dict[str, gtscript._FieldDescriptor] + descriptors: Dict[str, gtscript._FieldDescriptor], ) -> Dict[str, nodes.FieldDecl]: return { name: nodes.FieldDecl( diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 9357c34632..1e0364d721 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -906,7 +906,7 @@ def data_type_to_typestr(dtype: DataType) -> str: def op_to_ufunc( op: Union[ UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction - ] + ], ) -> np.ufunc: if not isinstance( op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction) diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index de1ca93557..567d128c29 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -29,7 +29,8 @@ class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: ... + def __call__(self, name: str) -> str: + ... def _make_axis_offset_expr(bound: common.AxisBound, axis_index: int) -> cuir.Expr: diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 779fca0c8d..e2ce48ec74 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -58,7 +58,9 @@ def _visit_offset( context_info = copy.deepcopy(access_info) context_info.variable_offset_axes = [] ranges = make_dace_subset( - access_info, context_info, data_dims=() # data_index added in visit_IndexAccess + access_info, + context_info, + data_dims=(), # data_index added in visit_IndexAccess ) ranges.offset(sym_offsets, negative=False) res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 8d8a0c90f7..cfde545f40 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -419,7 +419,7 @@ def flatten_list(list_or_node: Union[List[Any], eve.Node]): def collect_toplevel_computation_nodes( - list_or_node: Union[List[Any], eve.Node] + list_or_node: Union[List[Any], eve.Node], ) -> List["dcir.ComputationNode"]: class ComputationNodeCollector(eve.NodeVisitor): def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List): @@ -431,7 +431,7 @@ def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List) def collect_toplevel_iteration_nodes( - list_or_node: Union[List[Any], eve.Node] + list_or_node: Union[List[Any], eve.Node], ) -> List["dcir.IterationNode"]: class IterationNodeCollector(eve.NodeVisitor): def visit_IterationNode(self, node: dcir.IterationNode, *, collection: List): diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index 58cddffd5f..82991af1d4 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -94,7 +94,8 @@ def _make_axis_offset_expr( class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: ... + def __call__(self, name: str) -> str: + ... class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 43ab047c6e..24ea38b36a 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -22,7 +22,7 @@ def _iter_field_names( - node: Union[gtir.Stencil, gtir.ParAssignStmt] + node: Union[gtir.Stencil, gtir.ParAssignStmt], ) -> eve.utils.XIterable[gtir.FieldAccess]: return node.walk_values().if_isinstance(gtir.FieldDecl).getattr("name").unique() diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 69ce980bda..d6d1237229 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -92,7 +92,7 @@ def _extract_array_infos( def _extract_stencil_arrays( - array_infos: Dict[str, Optional[ArgsInfo]] + array_infos: Dict[str, Optional[ArgsInfo]], ) -> Dict[str, Optional[FieldType]]: return {name: info.array if info is not None else None for name, info in array_infos.items()} @@ -283,7 +283,7 @@ def __call__(self, *args, **kwargs) -> None: @staticmethod def _make_origin_dict( - origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None] + origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None], ) -> Dict[str, Tuple[int, ...]]: try: if isinstance(origin, dict): diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 99ad14f87c..77b5189ebf 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -392,10 +392,7 @@ class StencilTestSuite(metaclass=SuiteMeta): .. code-block:: python - { - 'float_symbols' : (np.float32, np.float64), - 'int_symbols' : (int, np.int_, np.int64) - } + {"float_symbols": (np.float32, np.float64), "int_symbols": (int, np.int_, np.int64)} domain_range : `Sequence` of pairs like `((int, int), (int, int) ... )` Required class attribute. diff --git a/src/gt4py/cartesian/type_hints.py b/src/gt4py/cartesian/type_hints.py index 3a776ba847..a1af6b93d1 100644 --- a/src/gt4py/cartesian/type_hints.py +++ b/src/gt4py/cartesian/type_hints.py @@ -21,7 +21,8 @@ class StencilFunc(Protocol): __name__: str __module__: str - def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ... + def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: + ... class AnnotatedStencilFunc(StencilFunc, Protocol): diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index 46bbf3dcfd..1bf82ba9df 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -240,13 +240,16 @@ def attribute(of, optional=False, **kwargs): class AttributeClassLike: - def validate(self): ... + def validate(self): + ... @property - def attributes(self): ... + def attributes(self): + ... @property - def as_dict(self): ... + def as_dict(self): + ... def attribclass(cls_or_none=None, **kwargs): diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 72f0e8858f..cab5931148 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -305,13 +305,13 @@ def indented(self, steps: int = 1) -> Iterator[TextBlock]: common `indent - append - dedent` workflows. Examples: - >>> block = TextBlock(); - >>> block.append('first line') # doctest: +ELLIPSIS + >>> block = TextBlock() + >>> block.append("first line") # doctest: +ELLIPSIS <...> >>> with block.indented(): - ... block.append('second line'); # doctest: +ELLIPSIS + ... block.append("second line") # doctest: +ELLIPSIS <...> - >>> block.append('third line') # doctest: +ELLIPSIS + >>> block.append("third line") # doctest: +ELLIPSIS <...> >>> print(block.text) first line @@ -476,7 +476,9 @@ def render_values(self, **kwargs: Any) -> str: message += f" (created at {self.definition_loc[0]}:{self.definition_loc[1]})" try: loc_info = re.search(r"line (\d+), col (\d+)", str(e)) - message += f" rendering error at template line: {loc_info[1]}, column {loc_info[2]}." # type: ignore + message += ( + f" rendering error at template line: {loc_info[1]}, column {loc_info[2]}." # type: ignore + ) except Exception: message += " rendering error." @@ -541,7 +543,9 @@ def __init__(self, definition: mako_tpl.Template, **kwargs: Any) -> None: if self.definition_loc: message += f" created at {self.definition_loc[0]}:{self.definition_loc[1]}" try: - message += f" (error likely around line {e.lineno}, column: {getattr(e, 'pos', '?')})" # type: ignore # assume Mako exception + message += ( + f" (error likely around line {e.lineno}, column: {getattr(e, 'pos', '?')})" # type: ignore # assume Mako exception + ) except Exception: message = f"{message}:\n---\n{definition}\n---\n" @@ -641,13 +645,15 @@ def __init_subclass__(cls, *, inherit_templates: bool = True, **kwargs: Any) -> @overload @classmethod - def apply(cls, root: LeafNode, **kwargs: Any) -> str: ... + def apply(cls, root: LeafNode, **kwargs: Any) -> str: + ... @overload @classmethod def apply( # noqa: F811 # redefinition of symbol cls, root: CollectionNode, **kwargs: Any - ) -> Collection[str]: ... + ) -> Collection[str]: + ... @classmethod def apply( # noqa: F811 # redefinition of symbol diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 11ad824aab..7dd8e3ec26 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -84,7 +84,8 @@ class _AttrsClassTP(Protocol): class DataModelTP(_AttrsClassTP, xtyping.DevToolsPrettyPrintable, Protocol): - def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: + ... __datamodel_fields__: ClassVar[utils.FrozenNamespace[Attribute]] = cast( utils.FrozenNamespace[Attribute], None @@ -115,7 +116,8 @@ class GenericDataModelTP(DataModelTP, Protocol): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: ... + ) -> Union[DataModelTP, GenericDataModelTP]: + ... _DM = TypeVar("_DM", bound="DataModel") @@ -278,7 +280,8 @@ def datamodel( coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Callable[[Type[_T]], Type[_T]]: ... +) -> Callable[[Type[_T]], Type[_T]]: + ... @overload @@ -297,7 +300,8 @@ def datamodel( # noqa: F811 # redefinion of unused symbol coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Type[_T]: ... +) -> Type[_T]: + ... # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -406,7 +410,8 @@ def __call__( type_validation_factory: Optional[ FieldTypeValidatorFactory ] = DefaultFieldTypeValidatorFactory, - ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: ... + ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: + ... frozenmodel: _DataModelDecoratorTP = functools.partial(datamodel, frozen=True) @@ -419,11 +424,13 @@ def __call__( if xtyping.TYPE_CHECKING: class DataModel(DataModelTP): - def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: + ... def __pretty__( self, fmt: Callable[[Any], Any], **kwargs: Any - ) -> Generator[Any, None, None]: ... + ) -> Generator[Any, None, None]: + ... else: # TODO(egparedes): use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -559,7 +566,7 @@ def field( >>> from typing import List >>> @datamodel ... class C: - ... mylist: List[int] = field(default_factory=lambda : [1, 2, 3]) + ... mylist: List[int] = field(default_factory=lambda: [1, 2, 3]) >>> c = C() >>> c.mylist [1, 2, 3] @@ -694,7 +701,7 @@ def asdict( ... x: int ... y: int >>> c = C(x=1, y=2) - >>> assert asdict(c) == {'x': 1, 'y': 2} + >>> assert asdict(c) == {"x": 1, "y": 2} """ # noqa: RST301 # sphinx.napoleon conventions confuse RST validator if not is_datamodel(instance) or isinstance(instance, type): raise TypeError(f"Invalid datamodel instance: '{instance}'.") @@ -807,7 +814,10 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] + datamodel_cls, + *type_args, + class_name=class_name, + module=module, # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) @@ -1180,7 +1190,8 @@ def _make_datamodel( # noqa: C901 # too complex but still readable and documen cls.__attrs_pre_init__ = cls.__pre_init__ # type: ignore[attr-defined] # adding new attribute if "__attrs_post_init__" in cls.__dict__ and not hasattr( - cls.__attrs_post_init__, _DATAMODEL_TAG # type: ignore[attr-defined] # mypy doesn't know about __attr_post_init__ + cls.__attrs_post_init__, + _DATAMODEL_TAG, # type: ignore[attr-defined] # mypy doesn't know about __attr_post_init__ ): raise TypeError(f"'{cls.__name__}' class contains forbidden custom '__attrs_post_init__'.") cls.__attrs_post_init__ = _make_post_init(has_post_init="__post_init__" in cls.__dict__) # type: ignore[attr-defined] # adding new attribute @@ -1366,7 +1377,8 @@ class GenericDataModel(GenericDataModelTP): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: ... + ) -> Union[DataModelTP, GenericDataModelTP]: + ... else: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 82076d1a9c..1f9aa83bc7 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -177,16 +177,19 @@ class NonDataDescriptor(Protocol[_C, _V]): @overload def __get__( self, _instance: Literal[None], _owner_type: Optional[Type[_C]] = None - ) -> NonDataDescriptor[_C, _V]: ... + ) -> NonDataDescriptor[_C, _V]: + ... @overload def __get__( # noqa: F811 # redefinion of unused member self, _instance: _C, _owner_type: Optional[Type[_C]] = None - ) -> _V: ... + ) -> _V: + ... def __get__( # noqa: F811 # redefinion of unused member self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None - ) -> _V | NonDataDescriptor[_C, _V]: ... + ) -> _V | NonDataDescriptor[_C, _V]: + ... class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): @@ -195,9 +198,11 @@ class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): See https://docs.python.org/3/howto/descriptor.html for further information. """ - def __set__(self, _instance: _C, _value: _V) -> None: ... + def __set__(self, _instance: _C, _value: _V) -> None: + ... - def __delete__(self, _instance: _C) -> None: ... + def __delete__(self, _instance: _C) -> None: + ... # -- Based on typeshed definitions -- @@ -215,20 +220,26 @@ class HashlibAlgorithm(Protocol): block_size: int name: str - def __init__(self, data: ReadableBuffer = ...) -> None: ... + def __init__(self, data: ReadableBuffer = ...) -> None: + ... - def copy(self) -> HashlibAlgorithm: ... + def copy(self) -> HashlibAlgorithm: + ... - def update(self, data: ReadableBuffer) -> None: ... + def update(self, data: ReadableBuffer) -> None: + ... - def digest(self) -> bytes: ... + def digest(self) -> bytes: + ... - def hexdigest(self) -> str: ... + def hexdigest(self) -> str: + ... # -- Third party protocols -- class SupportsArray(Protocol): - def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: ... + def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: + ... def supports_array(value: Any) -> TypeGuard[SupportsArray]: @@ -237,7 +248,8 @@ def supports_array(value: Any) -> TypeGuard[SupportsArray]: class ArrayInterface(Protocol): @property - def __array_interface__(self) -> Dict[str, Any]: ... + def __array_interface__(self) -> Dict[str, Any]: + ... class ArrayInterfaceTypedDict(TypedDict): @@ -253,7 +265,8 @@ class ArrayInterfaceTypedDict(TypedDict): class StrictArrayInterface(Protocol): @property - def __array_interface__(self) -> ArrayInterfaceTypedDict: ... + def __array_interface__(self) -> ArrayInterfaceTypedDict: + ... def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: @@ -262,7 +275,8 @@ def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: class CUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> Dict[str, Any]: ... + def __cuda_array_interface__(self) -> Dict[str, Any]: + ... class CUDAArrayInterfaceTypedDict(TypedDict): @@ -278,7 +292,8 @@ class CUDAArrayInterfaceTypedDict(TypedDict): class StrictCUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: ... + def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: + ... def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: @@ -290,15 +305,19 @@ def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: class MultiStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: Optional[int] = None) -> Any: ... + def __dlpack__(self, *, stream: Optional[int] = None) -> Any: + ... - def __dlpack_device__(self) -> DLPackDevice: ... + def __dlpack_device__(self) -> DLPackDevice: + ... class SingleStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: None = None) -> Any: ... + def __dlpack__(self, *, stream: None = None) -> Any: + ... - def __dlpack_device__(self) -> DLPackDevice: ... + def __dlpack_device__(self) -> DLPackDevice: + ... DLPackBuffer: TypeAlias = Union[MultiStreamDLPackBuffer, SingleStreamDLPackBuffer] @@ -314,9 +333,8 @@ def supports_dlpack(value: Any) -> TypeGuard[DLPackBuffer]: class DevToolsPrettyPrintable(Protocol): """Used by python-devtools (https://python-devtools.helpmanual.io/).""" - def __pretty__( - self, fmt: Callable[[Any], Any], **kwargs: Any - ) -> Generator[Any, None, None]: ... + def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: + ... # -- Added functionality -- @@ -339,7 +357,8 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> Callable[[_ProtoT], _ProtoT]: ... +) -> Callable[[_ProtoT], _ProtoT]: + ... @overload @@ -348,7 +367,8 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> _ProtoT: ... +) -> _ProtoT: + ... def extended_runtime_checkable( # noqa: C901 # too complex but unavoidable @@ -660,7 +680,7 @@ def eval_forward_ref( Examples: >>> from typing import Dict, Tuple - >>> print("Result:", eval_forward_ref('Dict[str, Tuple[int, float]]')) + >>> print("Result:", eval_forward_ref("Dict[str, Tuple[int, float]]")) Result: ...ict[str, ...uple[int, float]] """ @@ -724,21 +744,23 @@ def infer_type( # noqa: C901 # function is complex but well organized in indep >>> infer_type(frozenset([1, 2, 3])) frozenset[int] - >>> infer_type({'a': 0, 'b': 1}) + >>> infer_type({"a": 0, "b": 1}) dict[str, int] - >>> infer_type({'a': 0, 'b': 'B'}) + >>> infer_type({"a": 0, "b": "B"}) dict[str, ...Any] >>> print("Result:", infer_type(lambda a, b: a + b)) Result: ...Callable[[...Any, ...Any], ...Any] # Note that some patch versions of cpython3.9 show weird behaviors - >>> def f(a: int, b) -> int: ... + >>> def f(a: int, b) -> int: + ... ... >>> print("Result:", infer_type(f)) Result: ...Callable[[...int..., ...Any], int] - >>> def f(a: int, b) -> int: ... + >>> def f(a: int, b) -> int: + ... ... >>> print("Result:", infer_type(f)) Result: ...Callable[..., int] @@ -755,7 +777,7 @@ def infer_type( # noqa: C901 # function is complex but well organized in indep ... @extended_infer_type.register(float) ... @extended_infer_type.register(complex) ... def _infer_type_number(value, *, annotate_callable_kwargs: bool = False): - ... return numbers.Number + ... return numbers.Number >>> extended_infer_type(3.4) >>> infer_type(3.4) diff --git a/src/gt4py/eve/pattern_matching.py b/src/gt4py/eve/pattern_matching.py index 16b0d4e0e4..be3e1db160 100644 --- a/src/gt4py/eve/pattern_matching.py +++ b/src/gt4py/eve/pattern_matching.py @@ -31,9 +31,9 @@ class and all attributes of the pattern (recursively) match the Examples: >>> class Foo: - ... def __init__(self, bar, baz): - ... self.bar = bar - ... self.baz = baz + ... def __init__(self, bar, baz): + ... self.bar = bar + ... self.baz = baz >>> assert ObjectPattern(Foo, bar=1).match(Foo(1, 2)) """ diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 7bfd22cdf7..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -62,10 +62,12 @@ class TreeLike(abc.ABC): # noqa: B024 class Tree(Protocol): @abc.abstractmethod - def iter_children_values(self) -> Iterable: ... + def iter_children_values(self) -> Iterable: + ... @abc.abstractmethod - def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: ... + def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: + ... TreeLike.register(Tree) diff --git a/src/gt4py/eve/type_definitions.py b/src/gt4py/eve/type_definitions.py index 1ee981f548..8543a85bb3 100644 --- a/src/gt4py/eve/type_definitions.py +++ b/src/gt4py/eve/type_definitions.py @@ -98,7 +98,8 @@ class ConstrainedStr(str): class keyword argument or as class variable. Examples: - >>> class OnlyLetters(ConstrainedStr, regex=re.compile(r"^[a-zA-Z]*$")): pass + >>> class OnlyLetters(ConstrainedStr, regex=re.compile(r"^[a-zA-Z]*$")): + ... pass >>> OnlyLetters("aabbCC") OnlyLetters('aabbCC') diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 124957fa20..65f492ebfe 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -110,7 +110,8 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: ... + ) -> FixedTypeValidator: + ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -122,7 +123,8 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: ... + ) -> Optional[FixedTypeValidator]: + ... @abc.abstractmethod def __call__( # noqa: F811 # redefinion of unused member @@ -167,7 +169,8 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: ... + ) -> FixedTypeValidator: + ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -179,7 +182,8 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: ... + ) -> Optional[FixedTypeValidator]: + ... def __call__( # noqa: F811,C901 # redefinion of unused member / complex but well organized in cases self, diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8e634c4b11..3440c84b62 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -92,7 +92,7 @@ def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], >>> checker = isinstancechecker((int, str)) >>> checker(3) True - >>> checker('3') + >>> checker("3") True >>> checker(3.3) False @@ -117,17 +117,17 @@ def attrchecker(*names: str) -> Callable[[Any], bool]: Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> point = Point(1.0, 2.0) - >>> checker = attrchecker('x') + >>> checker = attrchecker("x") >>> checker(point) True - >>> checker = attrchecker('x', 'y') + >>> checker = attrchecker("x", "y") >>> checker(point) True - >>> checker = attrchecker('z') + >>> checker = attrchecker("z") >>> checker(point) False @@ -144,19 +144,19 @@ def attrgetter_(*names: str, default: Any = NOTHING) -> Callable[[Any], Any]: Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> point = Point(1.0, 2.0) - >>> getter = attrgetter_('x') + >>> getter = attrgetter_("x") >>> getter(point) 1.0 >>> import math - >>> getter = attrgetter_('z', default=math.nan) + >>> getter = attrgetter_("z", default=math.nan) >>> getter(point) nan >>> import math - >>> getter = attrgetter_('x', 'y', 'z', default=math.nan) + >>> getter = attrgetter_("x", "y", "z", default=math.nan) >>> getter(point) (1.0, 2.0, nan) @@ -187,12 +187,12 @@ def getitem_(obj: Any, key: Any, default: Any = NOTHING) -> Any: Similar to :func:`operator.getitem()` but accepts a default value. Examples: - >>> d = {'a': 1} - >>> getitem_(d, 'a') + >>> d = {"a": 1} + >>> getitem_(d, "a") 1 - >>> d = {'a': 1} - >>> getitem_(d, 'b', 'default') + >>> d = {"a": 1} + >>> getitem_(d, "b", "default") 'default' """ @@ -213,13 +213,13 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]: Similar to :func:`operator.itemgetter()` but accepts a default value. Examples: - >>> d = {'a': 1} - >>> getter = itemgetter_('a') + >>> d = {"a": 1} + >>> getter = itemgetter_("a") >>> getter(d) 1 - >>> d = {'a': 1} - >>> getter = itemgetter_('b', 'default') + >>> d = {"a": 1} + >>> getter = itemgetter_("b", "default") >>> getter(d) 'default' @@ -241,13 +241,15 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: @overload def with_fluid_partial( func: Literal[None] = None, *args: Any, **kwargs: Any -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + ... @overload def with_fluid_partial( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *args: Any, **kwargs: Any -) -> Callable[_P, _T]: ... +) -> Callable[_P, _T]: + ... def with_fluid_partial( # noqa: F811 # redefinition of unused function @@ -269,7 +271,6 @@ def with_fluid_partial( # noqa: F811 # redefinition of unused function >>> @with_fluid_partial ... def add(a, b): ... return a + b - ... >>> add.partial(1)(2) 3 """ @@ -284,13 +285,15 @@ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]: @overload def optional_lru_cache( func: Literal[None] = None, *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + ... @overload def optional_lru_cache( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[_P, _T]: ... +) -> Callable[_P, _T]: + ... def optional_lru_cache( # noqa: F811 # redefinition of unused function @@ -303,7 +306,6 @@ def optional_lru_cache( # noqa: F811 # redefinition of unused function ... def func(a, b): ... print(f"Inside func({a}, {b})") ... return a + b - ... >>> print(func(1, 3)) Inside func(1, 3) 4 @@ -346,14 +348,11 @@ def register_subclasses(*subclasses: Type) -> Callable[[Type], Type]: >>> import abc >>> class MyVirtualSubclassA: ... pass - ... >>> class MyVirtualSubclassB: - ... pass - ... + ... pass >>> @register_subclasses(MyVirtualSubclassA, MyVirtualSubclassB) ... class MyBaseClass(abc.ABC): - ... pass - ... + ... pass >>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass(MyVirtualSubclassB, MyBaseClass) True @@ -822,7 +821,7 @@ def if_isinstance(self, *types: Type) -> XIterable[T]: Equivalent to ``xiter(item for item in self if isinstance(item, types))``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) >>> list(it.if_isinstance(int, float)) [1, 3.3] @@ -835,7 +834,7 @@ def if_not_isinstance(self, *types: Type) -> XIterable[T]: Equivalent to ``xiter(item for item in self if not isinstance(item, types))``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) >>> list(it.if_not_isinstance(int, float)) ['2', [4, 5], {6, 7}] @@ -942,12 +941,12 @@ def if_hasattr(self, *names: str) -> XIterable[T]: Equivalent to ``filter(attrchecker(names), self)``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) - >>> list(it.if_hasattr('__len__')) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) + >>> list(it.if_hasattr("__len__")) ['2', [4, 5], {6, 7}] - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) - >>> list(it.if_hasattr('__len__', 'index')) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) + >>> list(it.if_hasattr("__len__", "index")) ['2', [4, 5]] """ @@ -968,13 +967,13 @@ def getattr( # noqa # A003: shadowing a python builtin Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> it = xiter([Point(1.0, -1.0), Point(2.0, -2.0), Point(3.0, -3.0)]) - >>> list(it.getattr('y')) + >>> list(it.getattr("y")) [-1.0, -2.0, -3.0] >>> it = xiter([Point(1.0, -1.0), Point(2.0, -2.0), Point(3.0, -3.0)]) - >>> list(it.getattr('x', 'z', default=None)) + >>> list(it.getattr("x", "z", default=None)) [(1.0, None), (2.0, None), (3.0, None)] """ @@ -991,16 +990,18 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl For detailed information check :func:`toolz.itertoolz.pluck` reference. - >>> it = xiter([('a', 1), ('b', 2), ('c', 3)]) + >>> it = xiter([("a", 1), ("b", 2), ("c", 3)]) >>> list(it.getitem(0)) ['a', 'b', 'c'] - >>> it = xiter([ - ... dict(name="AA", age=20, country="US"), - ... dict(name="BB", age=30, country="UK"), - ... dict(name="CC", age=40, country="EU"), - ... dict(country="CH") - ... ]) + >>> it = xiter( + ... [ + ... dict(name="AA", age=20, country="US"), + ... dict(name="BB", age=30, country="UK"), + ... dict(name="CC", age=40, country="EU"), + ... dict(country="CH"), + ... ] + ... ) >>> list(it.getitem("name", "age", default=None)) [('AA', 20), ('BB', 30), ('CC', 40), (None, None)] @@ -1023,12 +1024,12 @@ def chain(self, *others: Iterable) -> XIterable[Union[T, S]]: For detailed information check :func:`itertools.chain` reference. Examples: - >>> it_a, it_b = xiter(range(2)), xiter(['a', 'b']) + >>> it_a, it_b = xiter(range(2)), xiter(["a", "b"]) >>> list(it_a.chain(it_b)) [0, 1, 'a', 'b'] >>> it_a = xiter(range(2)) - >>> list(it_a.chain(['a', 'b'], ['A', 'B'])) + >>> list(it_a.chain(["a", "b"], ["A", "B"])) [0, 1, 'a', 'b', 'A', 'B'] """ @@ -1092,7 +1093,7 @@ def product( For detailed information check :func:`itertools.product` reference. Examples: - >>> it_a, it_b = xiter([0, 1]), xiter(['a', 'b']) + >>> it_a, it_b = xiter([0, 1]), xiter(["a", "b"]) >>> list(it_a.product(it_b)) [(0, 'a'), (0, 'b'), (1, 'a'), (1, 'b')] @@ -1189,16 +1190,16 @@ def zip( # noqa # A003: shadowing a python builtin Examples: >>> it_a = xiter(range(3)) - >>> it_b = ['a', 'b', 'c'] + >>> it_b = ["a", "b", "c"] >>> list(it_a.zip(it_b)) [(0, 'a'), (1, 'b'), (2, 'c')] >>> it = xiter(range(3)) - >>> list(it.zip(['a', 'b', 'c'], ['A', 'B', 'C'])) + >>> list(it.zip(["a", "b", "c"], ["A", "B", "C"])) [(0, 'a', 'A'), (1, 'b', 'B'), (2, 'c', 'C')] >>> it = xiter(range(5)) - >>> list(it.zip(['a', 'b', 'c'], ['A', 'B', 'C'], fill=None)) + >>> list(it.zip(["a", "b", "c"], ["A", "B", "C"], fill=None)) [(0, 'a', 'A'), (1, 'b', 'B'), (2, 'c', 'C'), (3, None, None), (4, None, None)] """ @@ -1216,7 +1217,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: For detailed information check :func:`zip` reference. Examples: - >>> it = xiter([('a', 1), ('b', 2), ('c', 3)]) + >>> it = xiter([("a", 1), ("b", 2), ("c", 3)]) >>> list(it.unzip()) [('a', 'b', 'c'), (1, 2, 3)] @@ -1224,10 +1225,12 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: return XIterable(zip(*self.iterator)) @typing.overload - def islice(self, __stop: int) -> XIterable[T]: ... + def islice(self, __stop: int) -> XIterable[T]: + ... @typing.overload - def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: ... + def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: + ... def islice( self, @@ -1296,7 +1299,7 @@ def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: >>> list(it.unique()) [1, 2, 3] - >>> it = xiter(['cat', 'mouse', 'dog', 'hen']) + >>> it = xiter(["cat", "mouse", "dog", "hen"]) >>> list(it.unique(key=len)) ['cat', 'mouse'] @@ -1309,17 +1312,18 @@ def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: @typing.overload def groupby( self, key: str, *other_keys: str, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: ... + ) -> XIterable[Tuple[Any, List[T]]]: + ... @typing.overload - def groupby( - self, key: List[Any], *, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: ... + def groupby(self, key: List[Any], *, as_dict: bool = False) -> XIterable[Tuple[Any, List[T]]]: + ... @typing.overload def groupby( self, key: Callable[[T], Any], *, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: ... + ) -> XIterable[Tuple[Any, List[T]]]: + ... def groupby( self, @@ -1347,29 +1351,29 @@ def groupby( For detailed information check :func:`toolz.itertoolz.groupby` reference. Examples: - >>> it = xiter([(1.0, -1.0), (1.0,-2.0), (2.2, -3.0)]) + >>> it = xiter([(1.0, -1.0), (1.0, -2.0), (2.2, -3.0)]) >>> list(it.groupby([0])) [(1.0, [(1.0, -1.0), (1.0, -2.0)]), (2.2, [(2.2, -3.0)])] >>> data = [ - ... {'x': 1.0, 'y': -1.0, 'z': 1.0}, - ... {'x': 1.0, 'y': -2.0, 'z': 1.0}, - ... {'x': 2.2, 'y': -3.0, 'z': 2.2} + ... {"x": 1.0, "y": -1.0, "z": 1.0}, + ... {"x": 1.0, "y": -2.0, "z": 1.0}, + ... {"x": 2.2, "y": -3.0, "z": 2.2}, ... ] - >>> list(xiter(data).groupby(['x'])) + >>> list(xiter(data).groupby(["x"])) [(1.0, [{'x': 1.0, 'y': -1.0, 'z': 1.0}, {'x': 1.0, 'y': -2.0, 'z': 1.0}]), (2.2, [{'x': 2.2, 'y': -3.0, 'z': 2.2}])] - >>> list(xiter(data).groupby(['x', 'z'])) + >>> list(xiter(data).groupby(["x", "z"])) [((1.0, 1.0), [{'x': 1.0, 'y': -1.0, 'z': 1.0}, {'x': 1.0, 'y': -2.0, 'z': 1.0}]), ((2.2, 2.2), [{'x': 2.2, 'y': -3.0, 'z': 2.2}])] >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y', 'z']) + >>> Point = namedtuple("Point", ["x", "y", "z"]) >>> data = [Point(1.0, -2.0, 1.0), Point(1.0, -2.0, 1.0), Point(2.2, 3.0, 2.0)] - >>> list(xiter(data).groupby('x')) + >>> list(xiter(data).groupby("x")) [(1.0, [Point(x=1.0, y=-2.0, z=1.0), Point(x=1.0, y=-2.0, z=1.0)]), (2.2, [Point(x=2.2, y=3.0, z=2.0)])] - >>> list(xiter(data).groupby('x', 'z')) + >>> list(xiter(data).groupby("x", "z")) [((1.0, 1.0), [Point(x=1.0, y=-2.0, z=1.0), Point(x=1.0, y=-2.0, z=1.0)]), ((2.2, 2.0), [Point(x=2.2, y=3.0, z=2.0)])] - >>> it = xiter(['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']) + >>> it = xiter(["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]) >>> list(it.groupby(len)) [(5, ['Alice', 'Edith', 'Frank']), (3, ['Bob', 'Dan']), (7, ['Charlie'])] @@ -1432,8 +1436,12 @@ def reduce(self, bin_op_func: Callable[[Any, T], Any], *, init: Any = None) -> A >>> it.reduce((lambda accu, i: accu + i), init=0) 10 - >>> it = xiter(['a', 'b', 'c', 'd', 'e']) - >>> sorted(it.reduce((lambda accu, item: (accu or set()) | {item} if item in 'aeiou' else accu))) + >>> it = xiter(["a", "b", "c", "d", "e"]) + >>> sorted( + ... it.reduce( + ... (lambda accu, item: (accu or set()) | {item} if item in "aeiou" else accu) + ... ) + ... ) ['a', 'e'] """ @@ -1447,7 +1455,8 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[str, S]]: ... + ) -> XIterable[Tuple[str, S]]: + ... @typing.overload def reduceby( @@ -1458,7 +1467,8 @@ def reduceby( *attr_keys: str, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[Tuple[str, ...], S]]: ... + ) -> XIterable[Tuple[Tuple[str, ...], S]]: + ... @typing.overload def reduceby( @@ -1468,7 +1478,8 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[str, S]: ... + ) -> Dict[str, S]: + ... @typing.overload def reduceby( @@ -1479,7 +1490,8 @@ def reduceby( *attr_keys: str, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[Tuple[str, ...], S]: ... + ) -> Dict[Tuple[str, ...], S]: + ... @typing.overload def reduceby( @@ -1489,7 +1501,8 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: ... + ) -> XIterable[Tuple[K, S]]: + ... @typing.overload def reduceby( @@ -1499,7 +1512,8 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: ... + ) -> Dict[K, S]: + ... @typing.overload def reduceby( @@ -1509,7 +1523,8 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: ... + ) -> XIterable[Tuple[K, S]]: + ... @typing.overload def reduceby( @@ -1519,7 +1534,8 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: ... + ) -> Dict[K, S]: + ... def reduceby( self, @@ -1558,30 +1574,34 @@ def reduceby( For detailed information check :func:`toolz.itertoolz.reduceby` reference. Examples: - >>> it = xiter([(1.0, -1.0), (1.0,-2.0), (2.2, -3.0)]) + >>> it = xiter([(1.0, -1.0), (1.0, -2.0), (2.2, -3.0)]) >>> list(it.reduceby((lambda accu, _: accu + 1), [0], init=0)) [(1.0, 2), (2.2, 1)] >>> data = [ - ... {'x': 1.0, 'y': -1.0, 'z': 1.0}, - ... {'x': 1.0, 'y': -2.0, 'z': 1.0}, - ... {'x': 2.2, 'y': -3.0, 'z': 2.2} + ... {"x": 1.0, "y": -1.0, "z": 1.0}, + ... {"x": 1.0, "y": -2.0, "z": 1.0}, + ... {"x": 2.2, "y": -3.0, "z": 2.2}, ... ] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ['x'], init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ["x"], init=0)) [(1.0, 2), (2.2, 1)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ['x', 'z'], init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ["x", "z"], init=0)) [((1.0, 1.0), 2), ((2.2, 2.2), 1)] >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y', 'z']) + >>> Point = namedtuple("Point", ["x", "y", "z"]) >>> data = [Point(1.0, -2.0, 1.0), Point(1.0, -2.0, 1.0), Point(2.2, 3.0, 2.0)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), 'x', init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), "x", init=0)) [(1.0, 2), (2.2, 1)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), 'x', 'z', init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), "x", "z", init=0)) [((1.0, 1.0), 2), ((2.2, 2.0), 1)] - >>> it = xiter(['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']) - >>> list(it.reduceby(lambda nvowels, name: nvowels + sum(i in 'aeiou' for i in name), len, init=0)) + >>> it = xiter(["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]) + >>> list( + ... it.reduceby( + ... lambda nvowels, name: nvowels + sum(i in "aeiou" for i in name), len, init=0 + ... ) + ... ) [(5, 4), (3, 2), (7, 3)] """ # noqa: RST203, RST301 # sphinx.napoleon conventions confuse RST validator diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 30775dbab9..2a94766999 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -58,7 +58,8 @@ class FieldBufferAllocatorProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_device_type__(self) -> core_defs.DeviceTypeT: ... + def __gt_device_type__(self) -> core_defs.DeviceTypeT: + ... @abc.abstractmethod def __gt_allocate__( @@ -67,7 +68,8 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: ... + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -85,7 +87,8 @@ class FieldBufferAllocatorFactoryProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: ... + def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: + ... def is_field_allocator_factory(obj: Any) -> TypeGuard[FieldBufferAllocatorFactoryProtocol]: @@ -175,9 +178,9 @@ def __gt_allocate__( if TYPE_CHECKING: - __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[FieldBufferAllocatorProtocol] = ( - BaseFieldBufferAllocator - ) + __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[ + FieldBufferAllocatorProtocol + ] = BaseFieldBufferAllocator def horizontal_first_layout_mapper( diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index fdf515d2f8..a800bbcf69 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -161,7 +161,8 @@ def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @overload - def __getitem__(self, index: int) -> int: ... + def __getitem__(self, index: int) -> int: + ... @overload def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused @@ -414,7 +415,8 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: return all(UnitRange.is_finite(rng) for rng in obj.ranges) @overload - def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: + ... @overload def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @@ -423,7 +425,8 @@ def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @overload def __getitem__( # noqa: F811 # redefine unused self, index: Dimension - ) -> tuple[Dimension, _Rng]: ... + ) -> tuple[Dimension, _Rng]: + ... def __getitem__( # noqa: F811 # redefine unused self, index: int | slice | Dimension @@ -569,7 +572,8 @@ def _broadcast_ranges( _R = TypeVar("_R", _Value, tuple[_Value, ...]) class GTBuiltInFuncDispatcher(Protocol): - def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: ... + def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: + ... # TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. @@ -598,48 +602,60 @@ class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property - def domain(self) -> Domain: ... + def domain(self) -> Domain: + ... @property def __gt_domain__(self) -> Domain: return self.domain @property - def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... + def codomain(self) -> type[core_defs.ScalarT] | Dimension: + ... @property - def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + ... @property - def ndarray(self) -> core_defs.NDArrayObject: ... + def ndarray(self) -> core_defs.NDArrayObject: + ... def __str__(self) -> str: return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod - def asnumpy(self) -> np.ndarray: ... + def asnumpy(self) -> np.ndarray: + ... @abc.abstractmethod - def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: + ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field: ... + def restrict(self, item: AnyIndexSpec) -> Field: + ... @abc.abstractmethod - def as_scalar(self) -> core_defs.ScalarT: ... + def as_scalar(self) -> core_defs.ScalarT: + ... # Operators @abc.abstractmethod - def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: + ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field: ... + def __getitem__(self, item: AnyIndexSpec) -> Field: + ... @abc.abstractmethod - def __abs__(self) -> Field: ... + def __abs__(self) -> Field: + ... @abc.abstractmethod - def __neg__(self) -> Field: ... + def __neg__(self) -> Field: + ... @abc.abstractmethod def __invert__(self) -> Field: @@ -654,37 +670,48 @@ def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants re ... @abc.abstractmethod - def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __add__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __radd__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __radd__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __sub__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __sub__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __mul__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __mul__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod - def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + def __pow__(self, other: Field | core_defs.ScalarT) -> Field: + ... @abc.abstractmethod def __and__(self, other: Field | core_defs.ScalarT) -> Field: @@ -712,7 +739,8 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: + ... def is_mutable_field( @@ -736,7 +764,8 @@ class ConnectivityKind(enum.Flag): class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + ... @property def kind(self) -> ConnectivityKind: @@ -747,7 +776,8 @@ def kind(self) -> ConnectivityKind: ) @abc.abstractmethod - def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + ... # Operators def __abs__(self) -> Never: @@ -983,11 +1013,11 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) >>> promote_dims([I, J], [I, J, K]) == [I, J, K] True - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. @@ -1052,9 +1082,9 @@ class FieldBuiltinFuncRegistry: dispatching (via ChainMap) to its parent's registries. """ - _builtin_func_map: collections.ChainMap[fbuiltins.BuiltInFunction, Callable] = ( - collections.ChainMap() - ) + _builtin_func_map: collections.ChainMap[ + fbuiltins.BuiltInFunction, Callable + ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): cls._builtin_func_map = collections.ChainMap( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index c39408ba3a..e78746a294 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -455,10 +455,12 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func( - fbuiltins.abs, NdArrayField.__abs__ # type: ignore[attr-defined] + fbuiltins.abs, + NdArrayField.__abs__, # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.power, NdArrayField.__pow__ # type: ignore[attr-defined] + fbuiltins.power, + NdArrayField.__pow__, # type: ignore[attr-defined] ) # TODO gamma @@ -472,18 +474,23 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name)) NdArrayField.register_builtin_func( - fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined] + fbuiltins.minimum, + _make_builtin("minimum", "minimum"), ) NdArrayField.register_builtin_func( - fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined] + fbuiltins.maximum, + _make_builtin("maximum", "maximum"), ) NdArrayField.register_builtin_func( - fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined] + fbuiltins.fmod, + _make_builtin("fmod", "fmod"), ) NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[ +def _make_reduction( + builtin_name: str, array_builtin_name: str +) -> Callable[ ..., NdArrayField[common.DimsT, core_defs.ScalarT], ]: diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index fc3ccda335..34b63639b7 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -42,7 +42,9 @@ class ScanOperator(EmbeddedOperator[_R, _P]): init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] axis: common.Dimension - def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + def __call__( + self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar + ) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun scan_range = embedded_context.closure_column_range.get() assert self.axis == scan_range[0] scan_axis = scan_range[0] @@ -143,7 +145,7 @@ def impl(target: common.MutableField, source: common.Field): def _intersect_scan_args( - *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> common.Domain: return embedded_common.intersect_domains( *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] @@ -151,7 +153,7 @@ def _intersect_scan_args( def _get_array_ns( - *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> ModuleType: for arg in utils.flatten_nested_tuple(args): if hasattr(arg, "array_ns"): diff --git a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py index 653456f6c5..afa8b730b1 100644 --- a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py +++ b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py @@ -27,17 +27,15 @@ class RemoveDocstrings(ast.NodeTransformer): >>> def example_docstring(): ... a = 1 ... "This is a docstring" + ... ... def example_docstring_2(): - ... a = 2.0 - ... "This is a new docstring" - ... return a + ... a = 2.0 + ... "This is a new docstring" + ... return a + ... ... a = example_docstring_2() ... return a - >>> print(ast.unparse( - ... RemoveDocstrings.apply( - ... ast.parse(inspect.getsource(example_docstring)) - ... ) - ... )) + >>> print(ast.unparse(RemoveDocstrings.apply(ast.parse(inspect.getsource(example_docstring))))) def example_docstring(): a = 1 diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index 8b079bb8c1..966b234e79 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -61,11 +61,7 @@ class SingleAssignTargetPass(NodeYielder): ... a = b = 1 ... return a, b >>> - >>> print(ast.unparse( - ... SingleAssignTargetPass.apply( - ... ast.parse(inspect.getsource(foo)) - ... ) - ... )) + >>> print(ast.unparse(SingleAssignTargetPass.apply(ast.parse(inspect.getsource(foo))))) def foo(): __sat_tmp0 = 1 a = __sat_tmp0 diff --git a/src/gt4py/next/ffront/ast_passes/single_static_assign.py b/src/gt4py/next/ffront/ast_passes/single_static_assign.py index ee1e29a8e8..02545e360b 100644 --- a/src/gt4py/next/ffront/ast_passes/single_static_assign.py +++ b/src/gt4py/next/ffront/ast_passes/single_static_assign.py @@ -107,11 +107,7 @@ class SingleStaticAssignPass(ast.NodeTransformer): ... a = 3 + a ... return a - >>> print(ast.unparse( - ... SingleStaticAssignPass.apply( - ... ast.parse(inspect.getsource(foo)) - ... ) - ... )) + >>> print(ast.unparse(SingleStaticAssignPass.apply(ast.parse(inspect.getsource(foo))))) def foo(): aᐞ0 = 1 aᐞ1 = 2 + aᐞ0 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a556d0ea34..4e663a65b8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -490,13 +490,13 @@ def itir(self): @typing.overload -def program(definition: types.FunctionType) -> Program: ... +def program(definition: types.FunctionType) -> Program: + ... @typing.overload -def program( - *, backend: Optional[ppi.ProgramExecutor] -) -> Callable[[types.FunctionType], Program]: ... +def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.FunctionType], Program]: + ... def program( @@ -510,17 +510,17 @@ def program( Examples: >>> @program # noqa: F821 # doctest: +SKIP - ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 + ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 ... field_op(in_field, out=out_field) - >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP + >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP >>> # the backend can optionally be passed if already decided >>> # not passing it will result in embedded execution by default >>> # the above is equivalent to >>> @program(backend="roundtrip") # noqa: F821 # doctest: +SKIP - ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 + ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 ... field_op(in_field, out=out_field) - >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP + >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP """ def program_inner(definition: types.FunctionType) -> Program: @@ -750,13 +750,15 @@ def __call__( @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] -) -> FieldOperator[foast.FieldOperator]: ... +) -> FieldOperator[foast.FieldOperator]: + ... @typing.overload def field_operator( *, backend: Optional[ppi.ProgramExecutor] -) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: ... +) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: + ... def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): @@ -765,14 +767,14 @@ def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): Examples: >>> @field_operator # doctest: +SKIP - ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 + ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 ... ... >>> field_op(in_field, out=out_field) # noqa: F821 # doctest: +SKIP >>> # the backend can optionally be passed if already decided >>> # not passing it will result in embedded execution by default >>> @field_operator(backend="roundtrip") # doctest: +SKIP - ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 + ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 ... ... """ @@ -793,7 +795,8 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> FieldOperator[foast.ScanOperator]: ... +) -> FieldOperator[foast.ScanOperator]: + ... @typing.overload @@ -804,7 +807,8 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... +) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: + ... def scan_operator( @@ -838,9 +842,9 @@ def scan_operator( >>> KDim = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) >>> inp = gtx.as_field([KDim], np.ones((10,))) >>> out = gtx.as_field([KDim], np.zeros((10,))) - >>> @gtx.scan_operator(axis=KDim, forward=True, init=0.) + >>> @gtx.scan_operator(axis=KDim, forward=True, init=0.0) ... def scan_operator(carry: float, val: float) -> float: - ... return carry+val + ... return carry + val >>> scan_operator(inp, out=out, offset_provider={}) # doctest: +SKIP >>> out.array() # doctest: +SKIP array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 493493f697..64704ef195 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -146,9 +146,7 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple( - where(mask, t, f) for t, f in zip(true_field, false_field) - ) # type: ignore[return-value] # `tuple` is not `_R` + return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 322a6df2e0..6b772227b2 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -153,7 +153,8 @@ class Call(Expr): kwargs: dict[str, Expr] -class Stmt(LocatedNode): ... +class Stmt(LocatedNode): + ... class Starred(Expr): diff --git a/src/gt4py/next/ffront/foast_introspection.py b/src/gt4py/next/ffront/foast_introspection.py index 404b99d1a0..08efa426ea 100644 --- a/src/gt4py/next/ffront/foast_introspection.py +++ b/src/gt4py/next/ffront/foast_introspection.py @@ -30,23 +30,23 @@ def deduce_stmt_return_kind(node: foast.Stmt) -> StmtReturnKind: Example with ``StmtReturnKind.UNCONDITIONAL_RETURN``:: if cond: - return 1 + return 1 else: - return 2 + return 2 Example with ``StmtReturnKind.CONDITIONAL_RETURN``:: if cond: - return 1 + return 1 else: - result = 2 + result = 2 Example with ``StmtReturnKind.NO_RETURN``:: if cond: - result = 1 + result = 1 else: - result = 2 + result = 2 """ if isinstance(node, foast.IfStmt): return_kinds = ( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 146a8ef400..a62d889db1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -41,7 +41,9 @@ def with_altered_scalar_kind( >>> print(with_altered_scalar_kind(scalar_t, ts.ScalarKind.BOOL)) bool - >>> field_t = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + >>> field_t = ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ) >>> print(with_altered_scalar_kind(field_t, ts.ScalarKind.FLOAT32)) Field[[I], float32] """ @@ -67,9 +69,14 @@ def construct_tuple_type( Examples: --------- >>> from gt4py.next import Dimension - >>> mask_type = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)) + >>> mask_type = ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ... ) >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] - >>> false_branch_types = [ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), ts.ScalarType(kind=ts.ScalarKind)] + >>> false_branch_types = [ + ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), + ... ts.ScalarType(kind=ts.ScalarKind), + ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ @@ -104,18 +111,20 @@ def promote_to_mask_type( >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)) + >>> promote_to_mask_type( + ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) + ... ) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) - >>> promote_to_mask_type(ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I,J], dtype=dtype)) + >>> promote_to_mask_type( + ... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype) + ... ) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) """ if isinstance(input_type, ts.ScalarType) or not all( item in input_type.dims for item in mask_type.dims ): return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type - return type_info.promote( - input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype) - ) # type: ignore + return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore else: return input_type @@ -233,8 +242,9 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla DeferredType(constraint=None) >>> typed_fieldop = FieldOperatorTypeDeduction.apply(untyped_fieldop) - >>> assert typed_fieldop.body.stmts[0].value.type == ts.FieldType(dtype=ts.ScalarType( - ... kind=ts.ScalarKind.FLOAT64), dims=[IDim]) + >>> assert typed_fieldop.body.stmts[0].value.type == ts.FieldType( + ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), dims=[IDim] + ... ) """ @classmethod @@ -438,9 +448,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) - symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( - new_true_branch.annex.symtable[sym].type - ) + symtable[sym].type = ( + new_node.annex.propagated_symbols[sym].type + ) = new_true_branch.annex.symtable[sym].type return new_node diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 2e5c158c23..db32ffac57 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -54,7 +54,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> >>> IDim = Dimension("IDim") >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp + ... return inp >>> >>> parsed = FieldOperatorParser.apply_to_function(fieldop) >>> lowered = FieldOperatorLowering.apply(parsed) @@ -62,7 +62,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS + >>> lowered.params # doctest: +ELLIPSIS [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ @@ -440,4 +440,5 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) -class FieldOperatorLoweringError(Exception): ... +class FieldOperatorLoweringError(Exception): + ... diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 0831fc3bb2..1ff3acc205 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -66,14 +66,16 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): If a syntax error is encountered, it will point to the location in the source code. >>> def wrong_syntax(inp: Field[[IDim], int]): - ... for i in [1, 2, 3]: # for is not part of the field operator syntax + ... for i in [1, 2, 3]: # for is not part of the field operator syntax ... tmp = inp ... return tmp >>> - >>> try: # doctest: +ELLIPSIS + >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) ... except errors.DSLError as err: - ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") + ... print( + ... f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})" + ... ) Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index f2b221083a..bf8b87b8f0 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -30,8 +30,14 @@ def to_tuples_of_iterator(expr: itir.Expr | str, arg_type: ts.TypeSpec): Supports arbitrary nesting. - >>> print(to_tuples_of_iterator("arg", ts.TupleType(types=[ts.FieldType(dims=[], - ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + >>> print( + ... to_tuples_of_iterator( + ... "arg", + ... ts.TupleType( + ... types=[ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))] + ... ), + ... ) + ... ) # doctest: +ELLIPSIS (λ(__toi_...) → {(↑(λ(it) → (·it)[0]))(__toi_...)})(arg) """ param = f"__toi_{_expr_hash(expr)}" @@ -56,8 +62,14 @@ def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec): Supports arbitrary nesting. - >>> print(to_iterator_of_tuples("arg", ts.TupleType(types=[ts.FieldType(dims=[], - ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + >>> print( + ... to_iterator_of_tuples( + ... "arg", + ... ts.TupleType( + ... types=[ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))] + ... ), + ... ) + ... ) # doctest: +ELLIPSIS (λ(__iot_...) → (↑(λ(__iot_el_0) → {·__iot_el_0}))(__iot_...[0]))(arg) """ param = f"__iot_{_expr_hash(expr)}" @@ -66,7 +78,10 @@ def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec): ti_ffront.promote_scalars_to_zero_dim_field(type_) for type_ in type_info.primitive_constituents(arg_type) ] - assert all(isinstance(type_, ts.FieldType) and type_.dims == type_constituents[0].dims for type_ in type_constituents) # type: ignore[attr-defined] # ensure by assert above + assert all( + isinstance(type_, ts.FieldType) and type_.dims == type_constituents[0].dims + for type_ in type_constituents + ) # type: ignore[attr-defined] # ensure by assert above def fun(_, path): param_name = "__iot_el" diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 8be9309630..a636e8822f 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -57,18 +57,19 @@ class ProgramLowering( >>> IDim = Dimension("IDim") >>> >>> def fieldop(inp: Field[[IDim], "float64"]) -> Field[[IDim], "float64"]: - ... ... + ... ... >>> def program(inp: Field[[IDim], "float64"], out: Field[[IDim], "float64"]): - ... fieldop(inp, out=out) + ... fieldop(inp, out=out) >>> >>> parsed = ProgramParser.apply_to_function(program) # doctest: +SKIP >>> fieldop_def = ir.FunctionDefinition( ... id="fieldop", ... params=[ir.Sym(id="inp")], - ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]) + ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]), + ... ) # doctest: +SKIP + >>> lowered = ProgramLowering.apply( + ... parsed, [fieldop_def], grid_type=GridType.CARTESIAN ... ) # doctest: +SKIP - >>> lowered = ProgramLowering.apply(parsed, [fieldop_def], - ... grid_type=GridType.CARTESIAN) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP >>> lowered.id # doctest: +SKIP diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index 4ff8265f70..14151fc243 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -93,7 +93,8 @@ class Slice(Expr): step: Literal[None] -class Stmt(LocatedNode): ... +class Stmt(LocatedNode): + ... class Program(LocatedNode, SymbolTableTrait): diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index baf3037d5e..d6037189f3 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -106,11 +106,11 @@ class SourceDefinition: >>> def foo(a): ... return a >>> src_def = SourceDefinition.from_function(foo) - >>> print(src_def) # doctest:+ELLIPSIS + >>> print(src_def) # doctest:+ELLIPSIS SourceDefinition(source='def foo(a):...', filename='...', line_offset=0, column_offset=0) >>> source, filename, starting_line = src_def - >>> print(source) # doctest:+ELLIPSIS + >>> print(source) # doctest:+ELLIPSIS def foo(a): return a ... diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index c25b7dd829..2072c4164a 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -171,7 +171,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64) + ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), ... ) FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ @@ -180,7 +180,9 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: assert isinstance(dtype, ts.ScalarType) try: el_type = reduce( - lambda type_, idx: type_.types[idx], path, arg # type: ignore[attr-defined] + lambda type_, idx: type_.types[idx], + path, + arg, # type: ignore[attr-defined] ) return ts.FieldType(dims=type_info.extract_dims(el_type), dtype=dtype) except (IndexError, AttributeError): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a45b81a773..4855ae0096 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -80,7 +80,8 @@ ) -class SparseTag(Tag): ... +class SparseTag(Tag): + ... class NeighborTableOffsetProvider: @@ -155,11 +156,14 @@ class ItIterator(Protocol): `ItIterator` to avoid name clashes with `Iterator` from `typing` and `collections.abc`. """ - def shift(self, *offsets: OffsetPart) -> ItIterator: ... + def shift(self, *offsets: OffsetPart) -> ItIterator: + ... - def can_deref(self) -> bool: ... + def can_deref(self) -> bool: + ... - def deref(self) -> Any: ... + def deref(self) -> Any: + ... @runtime_checkable @@ -168,11 +172,13 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def dims(self) -> tuple[common.Dimension, ...]: ... + def dims(self) -> tuple[common.Dimension, ...]: + ... # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_getitem(self, indices: NamedFieldIndices) -> Any: ... + def field_getitem(self, indices: NamedFieldIndices) -> Any: + ... @property def __gt_origin__(self) -> tuple[int, ...]: @@ -185,7 +191,8 @@ class MutableLocatedField(LocatedField, Protocol): # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... + def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: + ... #: Column range used in column mode (`column_axis != None`) in the current closure execution context. @@ -573,7 +580,7 @@ def execute_shift( def _is_list_of_complete_offsets( - complete_offsets: list[tuple[Any, Any]] + complete_offsets: list[tuple[Any, Any]], ) -> TypeGuard[list[CompleteOffset]]: return all( isinstance(tag, Tag) and isinstance(offset, (int, np.integer)) @@ -702,7 +709,8 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Tag, -) -> tuple[tuple | Column, ...]: ... +) -> tuple[tuple | Column, ...]: + ... @overload @@ -718,7 +726,8 @@ def _make_tuple( @overload def _make_tuple( field_or_tuple: LocatedField, named_indices: NamedFieldIndices, *, column_axis: Tag -) -> Column: ... +) -> Column: + ... @overload @@ -727,7 +736,8 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Literal[None] = None, -) -> npt.DTypeLike | Undefined: ... +) -> npt.DTypeLike | Undefined: + ... def _make_tuple( @@ -968,11 +978,13 @@ def get_ordered_indices(axes: Iterable[Axis], pos: NamedFieldIndices) -> tuple[F @overload -def _shift_range(range_or_index: range, offset: int) -> slice: ... +def _shift_range(range_or_index: range, offset: int) -> slice: + ... @overload -def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: ... +def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: + ... def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayIndex: @@ -986,11 +998,13 @@ def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayI @overload -def _range2slice(r: range) -> slice: ... +def _range2slice(r: range) -> slice: + ... @overload -def _range2slice(r: common.IntIndex) -> common.IntIndex: ... +def _range2slice(r: common.IntIndex) -> common.IntIndex: + ... def _range2slice(r: range | common.IntIndex) -> slice | common.IntIndex: @@ -1294,7 +1308,8 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): ... +class _List(tuple, Generic[DT]): + ... @dataclasses.dataclass(frozen=True) @@ -1429,7 +1444,8 @@ def is_tuple_of_field(field) -> bool: ) -class TupleFieldMeta(type): ... +class TupleFieldMeta(type): + ... class TupleField(metaclass=TupleFieldMeta): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 10caecc591..37abbec9e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -45,9 +45,9 @@ class Sym(Node): # helper # TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the # type inference. kind: typing.Literal["Iterator", "Value", None] = None - dtype: Optional[tuple[str, bool]] = ( - None # format: name of primitive type, boolean indicating if it is a list - ) + dtype: Optional[ + tuple[str, bool] + ] = None # format: name of primitive type, boolean indicating if it is a list @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): @@ -63,7 +63,8 @@ def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu @noninstantiable -class Expr(Node): ... +class Expr(Node): + ... class Literal(Expr): diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f6655e9d41..b504a3bfe7 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -243,17 +243,17 @@ class let: -------- >>> str(let("a", "b")("a")) # doctest: +ELLIPSIS '(λ(a) → a)(b)' - >>> str(let(("a", 1), - ... ("b", 2) - ... )(plus("a", "b"))) + >>> str(let(("a", 1), ("b", 2))(plus("a", "b"))) '(λ(a, b) → a + b)(1, 2)' """ @typing.overload - def __init__(self, var: str | itir.Sym, init_form: itir.Expr | str): ... + def __init__(self, var: str | itir.Sym, init_form: itir.Expr | str): + ... @typing.overload - def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr | str]]): ... + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr | str]]): + ... def __init__(self, *args): if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): @@ -301,7 +301,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: """ Make a literal node from a value. - >>> literal_from_value(1.) + >>> literal_from_value(1.0) Literal(value='1.0', type='float64') >>> literal_from_value(1) Literal(value='1', type='int32') diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 5de4839b55..e12ae84dbc 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -43,10 +43,12 @@ def offset(value): return Offset(value) -class CartesianDomain(dict): ... +class CartesianDomain(dict): + ... -class UnstructuredDomain(dict): ... +class UnstructuredDomain(dict): + ... # dependency inversion, register fendef for embedded execution or for tracing/parsing here diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index f9cf272c45..32714232a6 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -244,23 +244,28 @@ def extract_subexpression( >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 >>> new_expr, extracted_subexprs, _ = extract_subexpression( - ... expr, predicate, UIDGenerator(prefix="_subexpr")) + ... expr, predicate, UIDGenerator(prefix="_subexpr") + ... ) >>> print(new_expr) _subexpr_1 + (_subexpr_1 + z) >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y` The order of the extraction can be configured using `deepest_expr_first`. By default, the nodes closer to the root are eliminated first: - >>> expr = im.plus(im.plus(im.plus("x", "y"), im.plus("x", "y")), im.plus(im.plus("x", "y"), im.plus("x", "y"))) - >>> new_expr, extracted_subexprs, ignored_children = extract_subexpression(expr, predicate, - ... UIDGenerator(prefix="_subexpr"), deepest_expr_first=False) + >>> expr = im.plus( + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... ) + >>> new_expr, extracted_subexprs, ignored_children = extract_subexpression( + ... expr, predicate, UIDGenerator(prefix="_subexpr"), deepest_expr_first=False + ... ) >>> print(new_expr) _subexpr_1 + _subexpr_1 >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y + (x + y)` Since `(x+y)` is a child of one of the expressions it is ignored: @@ -270,13 +275,21 @@ def extract_subexpression( Setting `deepest_expr_first` will extract nodes deeper in the tree first: - >>> expr = im.plus(im.plus(im.plus("x", "y"), im.plus("x", "y")), im.plus(im.plus("x", "y"), im.plus("x", "y"))) - >>> new_expr, extracted_subexprs, _ = extract_subexpression(expr, predicate, - ... UIDGenerator(prefix="_subexpr"), once_only=True, deepest_expr_first=True) + >>> expr = im.plus( + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... ) + >>> new_expr, extracted_subexprs, _ = extract_subexpression( + ... expr, + ... predicate, + ... UIDGenerator(prefix="_subexpr"), + ... once_only=True, + ... deepest_expr_first=True, + ... ) >>> print(new_expr) _subexpr_1 + _subexpr_1 + (_subexpr_1 + _subexpr_1) >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y` Note that this requires `once_only` to be set right now. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 4f4fd053b2..1fb6340083 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -318,7 +318,9 @@ def always_extract_heuristics(_): domain=AUTO_DOMAIN, stencil=stencil, output=im.ref(tmp_sym.id), - inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + inputs=[ + closure_param_arg_mapping[param.id] for param in lift_expr.args + ], location=current_closure.location, ) ) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index fb85d074df..c17ff1a759 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -44,7 +44,8 @@ def __call__( self, source: stages.CompilableSource[SrcL, LS, TgtL], cache_lifetime: config.BuildCacheLifetime, - ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... + ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: + ... @dataclasses.dataclass(frozen=True) @@ -95,4 +96,5 @@ class Meta: model = Compiler -class CompilationError(RuntimeError): ... +class CompilationError(RuntimeError): + ... diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py index 2397878271..b0d01d91ab 100644 --- a/src/gt4py/next/otf/languages.py +++ b/src/gt4py/next/otf/languages.py @@ -57,7 +57,8 @@ class Python(LanguageTag): ... -class NanobindSrcL(LanguageTag): ... +class NanobindSrcL(LanguageTag): + ... class Cpp(NanobindSrcL): diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index bd7f59e7aa..106015ccaf 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -107,13 +107,15 @@ class BuildSystemProject(Protocol[SrcL_co, SettingT_co, TgtL_co]): and is not responsible for importing the results into Python. """ - def build(self) -> None: ... + def build(self) -> None: + ... class CompiledProgram(Protocol): """Executable python representation of a program.""" - def __call__(self, *args, **kwargs) -> None: ... + def __call__(self, *args, **kwargs) -> None: + ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: @@ -122,8 +124,14 @@ def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryD Examples: --------- - >>> libs_a = (interface.LibraryDependency("foo", "1.2.3"), interface.LibraryDependency("common", "1.0.0")) - >>> libs_b = (interface.LibraryDependency("common", "1.0.0"), interface.LibraryDependency("bar", "1.2.3")) + >>> libs_a = ( + ... interface.LibraryDependency("foo", "1.2.3"), + ... interface.LibraryDependency("common", "1.0.0"), + ... ) + >>> libs_b = ( + ... interface.LibraryDependency("common", "1.0.0"), + ... interface.LibraryDependency("bar", "1.2.3"), + ... ) >>> _unique_libs(*libs_a, *libs_b) (LibraryDependency(name='foo', version='1.2.3'), LibraryDependency(name='common', version='1.0.0'), LibraryDependency(name='bar', version='1.2.3')) """ diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py index 43def259ab..5eeb5c495b 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/step_types.py @@ -46,7 +46,8 @@ class BindingStep(Protocol[SrcL, LS, TgtL]): def __call__( self, program_source: stages.ProgramSource[SrcL, LS] - ) -> stages.CompilableSource[SrcL, LS, TgtL]: ... + ) -> stages.CompilableSource[SrcL, LS, TgtL]: + ... class CompilationStep( @@ -55,6 +56,5 @@ class CompilationStep( ): """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" - def __call__( - self, source: stages.CompilableSource[SrcL, LS, TgtL] - ) -> stages.CompiledProgram: ... + def __call__(self, source: stages.CompilableSource[SrcL, LS, TgtL]) -> stages.CompiledProgram: + ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 4bdb4bbb41..a47d02eb05 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -39,10 +39,10 @@ def make_step(function: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT --------- >>> @make_step ... def times_two(x: int) -> int: - ... return x * 2 + ... return x * 2 >>> def stringify(x: int) -> str: - ... return str(x) + ... return str(x) >>> # create a workflow int -> int -> str >>> times_two.chain(stringify)(3) @@ -61,7 +61,8 @@ class Workflow(Protocol[StartT_contra, EndT_co]): - take a single input argument """ - def __call__(self, inp: StartT_contra) -> EndT_co: ... + def __call__(self, inp: StartT_contra) -> EndT_co: + ... class ReplaceEnabledWorkflowMixin(Workflow[StartT_contra, EndT_co], Protocol): @@ -102,25 +103,21 @@ class NamedStepSequence( >>> import dataclasses >>> def parse(x: str) -> int: - ... return int(x) + ... return int(x) >>> def plus_half(x: int) -> float: - ... return x + 0.5 + ... return x + 0.5 >>> def stringify(x: float) -> str: - ... return str(x) + ... return str(x) >>> @dataclasses.dataclass(frozen=True) ... class ParseOpPrint(NamedStepSequence[str, str]): - ... parse: Workflow[str, int] - ... op: Workflow[int, float] - ... print: Workflow[float, str] + ... parse: Workflow[str, int] + ... op: Workflow[int, float] + ... print: Workflow[float, str] - >>> pop = ParseOpPrint( - ... parse=parse, - ... op=plus_half, - ... print=stringify - ... ) + >>> pop = ParseOpPrint(parse=parse, op=plus_half, print=stringify) >>> pop.step_order ['parse', 'op', 'print'] @@ -129,7 +126,7 @@ class NamedStepSequence( '73.5' >>> def plus_tenth(x: int) -> float: - ... return x + 0.1 + ... return x + 0.1 >>> pop.replace(op=plus_tenth)(73) @@ -169,13 +166,13 @@ class StepSequence(ChainableWorkflowMixin[StartT, EndT]): Examples: --------- >>> def plus_one(x: int) -> int: - ... return x + 1 + ... return x + 1 >>> def plus_half(x: int) -> float: - ... return x + 0.5 + ... return x + 0.5 >>> def stringify(x: float) -> str: - ... return str(x) + ... return str(x) >>> StepSequence.start(plus_one).chain(plus_half).chain(stringify)(73) '74.5' @@ -222,8 +219,8 @@ class CachedStep( Examples: --------- >>> def heavy_computation(x: int) -> int: - ... print("This might take a while...") - ... return x + ... print("This might take a while...") + ... return x >>> cached_step = CachedStep(step=heavy_computation) @@ -241,9 +238,7 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field( - default=hash - ) # type: ignore[assignment] + hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py index a62f50fc44..f0843919fe 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py @@ -21,7 +21,8 @@ from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Sym, SymRef -class Stmt(Node): ... +class Stmt(Node): + ... class AssignStmt(Stmt): @@ -34,7 +35,8 @@ class InitStmt(AssignStmt): init_type: str = "auto" -class EmptyListInitializer(Expr): ... +class EmptyListInitializer(Expr): + ... class Conditional(Stmt): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py index cb9aeffb90..79d4c18828 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py @@ -25,7 +25,8 @@ class Sym(Node): # helper id: Coerced[SymbolName] # noqa: A003 -class Expr(Node): ... +class Expr(Node): + ... class SymRef(Expr): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 842080f8ae..98eff62d60 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -489,7 +489,7 @@ def visit_StencilClosure( @staticmethod def _merge_scans( - executions: list[Union[StencilExecution, ScanExecution]] + executions: list[Union[StencilExecution, ScanExecution]], ) -> list[Union[StencilExecution, ScanExecution]]: def merge(a: ScanExecution, b: ScanExecution) -> ScanExecution: assert a.backend == b.backend diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 0c280202b8..9bf4c623c5 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -40,12 +40,14 @@ class ProgramProcessorCallable(Protocol[OutputT]): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: ... + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: + ... class ProgramProcessor(ProgramProcessorCallable[OutputT], Protocol[OutputT, ProcessorKindT]): @property - def kind(self) -> type[ProcessorKindT]: ... + def kind(self) -> type[ProcessorKindT]: + ... class ProgramFormatter(ProgramProcessor[str, "ProgramFormatter"], Protocol): @@ -232,12 +234,14 @@ class ProgramBackend( ProgramProcessor[None, "ProgramExecutor"], next_allocators.FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT], Protocol[core_defs.DeviceTypeT], -): ... +): + ... def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: return is_processor_kind( - obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + obj, + ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory(obj) @@ -245,5 +249,6 @@ def is_program_backend_for( obj: Callable, device: core_defs.DeviceTypeT ) -> TypeGuard[ProgramBackend[core_defs.DeviceTypeT]]: return is_processor_kind( - obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + obj, + ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory_for(obj, device) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 1263cff502..386a36d328 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -298,10 +298,13 @@ def build_sdfg_from_itir( for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: - _, frameinfo = warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg." - ), getframeinfo( - currentframe() # type: ignore + _, frameinfo = ( + warnings.warn( + f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), + getframeinfo( + currentframe() # type: ignore + ), ) nested_sdfg.debuginfo = dace.dtypes.DebugInfo( start_line=frameinfo.lineno, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 04af4a5283..8f77d4af75 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -130,9 +130,9 @@ class Params: translation = factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type") ) - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( - nanobind.bind_source - ) + bindings: workflow.Workflow[ + stages.ProgramSource, stages.CompilableSource + ] = nanobind.bind_source compilation = factory.SubFactory( compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), diff --git a/src/gt4py/next/type_inference.py b/src/gt4py/next/type_inference.py index 10ae524451..9b5d9070e3 100644 --- a/src/gt4py/next/type_inference.py +++ b/src/gt4py/next/type_inference.py @@ -94,11 +94,13 @@ def visit_TypeVar(self, node: V, *, index_map: dict[int, int]) -> V: @typing.overload -def freshen(dtypes: list[T]) -> list[T]: ... +def freshen(dtypes: list[T]) -> list[T]: + ... @typing.overload -def freshen(dtypes: T) -> T: ... +def freshen(dtypes: T) -> T: + ... def freshen(dtypes: list[T] | T) -> list[T] | T: @@ -323,13 +325,15 @@ def _handle_constraint(self, constraint: tuple[_Box, _Box]) -> bool: @typing.overload def unify( dtypes: list[Type], constraints: set[tuple[Type, Type]] -) -> tuple[list[Type], list[tuple[Type, Type]]]: ... +) -> tuple[list[Type], list[tuple[Type, Type]]]: + ... @typing.overload def unify( dtypes: Type, constraints: set[tuple[Type, Type]] -) -> tuple[Type, list[tuple[Type, Type]]]: ... +) -> tuple[Type, list[tuple[Type, Type]]]: + ... def unify( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5cfb901ff1..32adbf6e8b 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -29,7 +29,7 @@ def _number_to_ordinal_number(number: int) -> str: Convert number into ordinal number. >>> for i in range(0, 5): - ... print(_number_to_ordinal_number(i)) + ... print(_number_to_ordinal_number(i)) 0th 1st 2nd @@ -91,14 +91,16 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: def primitive_constituents( symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[False] = False, -) -> XIterable[ts.TypeSpec]: ... +) -> XIterable[ts.TypeSpec]: + ... @typing.overload def primitive_constituents( symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[True], -) -> XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: ... +) -> XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: + ... def primitive_constituents( @@ -150,7 +152,11 @@ def apply_to_primitive_constituents( >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> tuple_type = ts.TupleType(types=[int_type, int_type]) - >>> print(apply_to_primitive_constituents(tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type))) + >>> print( + ... apply_to_primitive_constituents( + ... tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + ... ) + ... ) tuple[Field[[], int64], Field[[], int64]] """ if isinstance(symbol_type, ts.TupleType): @@ -298,7 +304,9 @@ def is_type_or_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> >>> field_type = ts.FieldType(dims=[], dtype=scalar_type) >>> is_type_or_tuple_of_type(field_type, ts.FieldType) True - >>> is_type_or_tuple_of_type(ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType)) + >>> is_type_or_tuple_of_type( + ... ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType) + ... ) True >>> is_type_or_tuple_of_type(scalar_type, ts.FieldType) False @@ -318,7 +326,9 @@ def is_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> TypeGua >>> field_type = ts.FieldType(dims=[], dtype=scalar_type) >>> is_tuple_of_type(field_type, ts.FieldType) False - >>> is_tuple_of_type(ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType)) + >>> is_tuple_of_type( + ... ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType) + ... ) True >>> is_tuple_of_type(ts.TupleType(types=[scalar_type]), ts.FieldType) False @@ -381,38 +391,37 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: Examples: --------- >>> is_concretizable( - ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... to_type=ts.ScalarType(kind=ts.ScalarKind.INT64) + ... ts.ScalarType(kind=ts.ScalarKind.INT64), to_type=ts.ScalarType(kind=ts.ScalarKind.INT64) ... ) True >>> is_concretizable( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... to_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... to_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ) False >>> is_concretizable( ... ts.DeferredType(constraint=None), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) True >>> is_concretizable( ... ts.DeferredType(constraint=ts.DataType), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) True >>> is_concretizable( ... ts.DeferredType(constraint=ts.OffsetType), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) False >>> is_concretizable( ... ts.DeferredType(constraint=ts.TypeSpec), - ... to_type=ts.DeferredType(constraint=ts.ScalarType) + ... to_type=ts.DeferredType(constraint=ts.ScalarType), ... ) True @@ -437,17 +446,14 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp >>> dtype = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> I, J, K = (common.Dimension(value=dim) for dim in ["I", "J", "K"]) >>> promoted: ts.FieldType = promote( - ... ts.FieldType(dims=[I, J], dtype=dtype), - ... ts.FieldType(dims=[I, J, K], dtype=dtype), - ... dtype + ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[I, J, K], dtype=dtype), dtype ... ) >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True >>> promote( - ... ts.FieldType(dims=[I, J], dtype=dtype), - ... ts.FieldType(dims=[K], dtype=dtype) - ... ) # doctest: +ELLIPSIS + ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) + ... ) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. @@ -700,7 +706,12 @@ def function_signature_incompatibilities_field( # TODO: This code does not handle ellipses for dimensions. Fix it. assert field_type.dims is not ... if field_type.dims and source_dim not in field_type.dims: - yield f"Incompatible offset can not shift field defined on " f"{', '.join([dim.value for dim in field_type.dims])} from " f"{source_dim.value} to target dim(s): " f"{', '.join([dim.value for dim in target_dims])}" + yield ( + f"Incompatible offset can not shift field defined on " + f"{', '.join([dim.value for dim in field_type.dims])} from " + f"{source_dim.value} to target dim(s): " + f"{', '.join([dim.value for dim in target_dims])}" + ) def accepts_args( @@ -724,7 +735,7 @@ def accepts_args( ... pos_only_args=[bool_type], ... pos_or_kw_args={"foo": bool_type}, ... kw_only_args={}, - ... returns=ts.VoidType() + ... returns=ts.VoidType(), ... ) >>> accepts_args(func_type, with_args=[bool_type], with_kwargs={"foo": bool_type}) True diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index ec459906e0..21932afd70 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -21,10 +21,10 @@ class RecursionGuard: Context manager to guard against inifinite recursion. >>> def foo(i): - ... with RecursionGuard(i): - ... if i % 2 == 0: - ... foo(i) - ... return i + ... with RecursionGuard(i): + ... if i % 2 == 0: + ... foo(i) + ... return i >>> foo(3) 3 >>> foo(2) # doctest:+ELLIPSIS diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 0482ec1e65..061f79f146 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -156,7 +156,8 @@ class BufferAllocator(Protocol[core_defs.DeviceTypeT]): """Protocol for buffer allocators.""" @property - def device_type(self) -> core_defs.DeviceTypeT: ... + def device_type(self) -> core_defs.DeviceTypeT: + ... def allocate( self, @@ -320,17 +321,20 @@ class _NumPyLibStridesModule(Protocol): @staticmethod def as_strided( ndarray: core_defs.NDArrayObject, **kwargs: Any - ) -> core_defs.NDArrayObject: ... + ) -> core_defs.NDArrayObject: + ... stride_tricks: _NumPyLibStridesModule lib: _NumPyLibModule @staticmethod - def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: ... + def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: + ... @staticmethod - def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: ... + def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: + ... def is_valid_nplike_allocation_ns(obj: Any) -> TypeGuard[ValidNumPyLikeAllocationNS]: diff --git a/src/gt4py/storage/cartesian/layout.py b/src/gt4py/storage/cartesian/layout.py index 26e34e35d6..65b1967448 100644 --- a/src/gt4py/storage/cartesian/layout.py +++ b/src/gt4py/storage/cartesian/layout.py @@ -73,7 +73,7 @@ def check_layout(layout_map, strides): def layout_maker_factory( - base_layout: Tuple[int, ...] + base_layout: Tuple[int, ...], ) -> Callable[[Tuple[str, ...]], Tuple[int, ...]]: def layout_maker(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: mask = [dim in dimensions for dim in "IJK"] diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 6b8c02e41c..5814daa495 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -49,7 +49,10 @@ def advection_def( @staticmethod def diffusion_def( - in_phi: gtscript.Field[float], out_phi: gtscript.Field[float], *, alpha: float # type: ignore + in_phi: gtscript.Field[float], + out_phi: gtscript.Field[float], + *, + alpha: float, # type: ignore ): with computation(PARALLEL), interval(...): # type: ignore # noqa lap1 = ( diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 4ac239fdd2..79056c2914 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -142,7 +142,11 @@ def native_functions(field_a: Field3D, field_b: Field3D): field_b = ( trunc_res if isfinite(trunc_res) - else field_a if isinf(trunc_res) else field_b if isnan(trunc_res) else 0.0 + else field_a + if isinf(trunc_res) + else field_b + if isnan(trunc_res) + else 0.0 ) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 6110e29cdb..10019343ab 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -165,9 +165,11 @@ def definition(field_a, field_b, field_c, field_out, *, weight, alpha_factor): factor = alpha_factor else: factor = 1.0 - field_out = factor * field_a[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - (1 - factor) * (field_b[0, 0, 0] - weight * field_c[0, 0, 0]) + field_out = ( + factor + * field_a[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - (1 - factor) * (field_b[0, 0, 0] - weight * field_c[0, 0, 0]) + ) def validation( field_a, field_b, field_c, field_out, *, weight, alpha_factor, domain, origin, **kwargs @@ -225,9 +227,10 @@ def definition(u, diffusion, *, weight): laplacian = 4.0 * u[0, 0, 0] - (u[1, 0, 0] + u[-1, 0, 0] + u[0, 1, 0] + u[0, -1, 0]) flux_i = laplacian[1, 0, 0] - laplacian[0, 0, 0] flux_j = laplacian[0, 1, 0] - laplacian[0, 0, 0] - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -290,9 +293,10 @@ def definition(u, diffusion, *, weight): with computation(PARALLEL), interval(...): laplacian = lap_op(u=u) flux_i, flux_j = fwd_diff(field=laplacian) - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -330,9 +334,10 @@ def definition(u, diffusion, *, weight): flux_j = fwd_diff_op_y(field=laplacian) else: flux_i, flux_j = fwd_diff_op_xy(field=laplacian) - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -792,9 +797,7 @@ class TestVariableKRead(gt_testing.StencilTestSuite): def definition(field_in, field_out, index): with computation(PARALLEL), interval(1, None): - field_out = field_in[ # noqa: F841 # Local name is assigned to but never used - 0, 0, index - ] + field_out = field_in[0, 0, index] # noqa: F841 # Local name is assigned to but never used def validation(field_in, field_out, index, *, domain, origin): field_out[:, :, 1:] = field_in[:, :, (np.arange(field_in.shape[-1]) + index)[1:]] diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py index 83195a898a..934597f4e5 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py @@ -45,7 +45,8 @@ def defir_to_gtir(): def test_stencil_definition( - defir_to_gtir, ijk_domain # noqa: F811 [redefinition, reason: fixture] + defir_to_gtir, + ijk_domain, # noqa: F811 [redefinition, reason: fixture] ): stencil_definition = ( TDefinition(name="definition", domain=ijk_domain, fields=["a", "b"]) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 8cfff12df4..1a51cad736 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,13 +312,15 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( # noqa: B018 - nodes=[ - SymbolChildNode(name="foo"), - SymbolRefChildNode(name="foo"), - SymbolRefChildNode(name="foo"), - ], - ), + ( + SymbolTableRootNode( # noqa: B018 + nodes=[ + SymbolChildNode(name="foo"), + SymbolRefChildNode(name="foo"), + SymbolRefChildNode(name="foo"), + ], + ), + ) SymbolTableRootNode( nodes=[ SymbolChildNode(name="outer_scope"), diff --git a/tests/eve_tests/unit_tests/test_extended_typing.py b/tests/eve_tests/unit_tests/test_extended_typing.py index d90a577bf9..7213bc2c66 100644 --- a/tests/eve_tests/unit_tests/test_extended_typing.py +++ b/tests/eve_tests/unit_tests/test_extended_typing.py @@ -413,10 +413,12 @@ class B: def test_is_protocol(): class AProtocol(typing.Protocol): - def do_something(self, value: int) -> int: ... + def do_something(self, value: int) -> int: + ... class NotProtocol(AProtocol): - def do_something_else(self, value: float) -> float: ... + def do_something_else(self, value: float) -> float: + ... class AXProtocol(xtyping.Protocol): A = 1 @@ -425,7 +427,8 @@ class NotXProtocol(AXProtocol): A = 1 class AgainProtocol(AProtocol, xtyping.Protocol): - def do_something_else(self, value: float) -> float: ... + def do_something_else(self, value: float) -> float: + ... assert xtyping.is_protocol(AProtocol) assert xtyping.is_protocol(AXProtocol) @@ -437,13 +440,16 @@ def do_something_else(self, value: float) -> float: ... def test_get_partial_type_hints(): - def f1(a: int) -> float: ... + def f1(a: int) -> float: + ... assert xtyping.get_partial_type_hints(f1) == {"a": int, "return": float} - class MissingRef: ... + class MissingRef: + ... - def f_partial(a: int) -> MissingRef: ... + def f_partial(a: int) -> MissingRef: + ... # This is expected behavior because this test file uses # 'from __future__ import annotations' and therefore local @@ -461,7 +467,8 @@ def f_partial(a: int) -> MissingRef: ... "return": int, } - def f_nested_partial(a: int) -> Dict[str, MissingRef]: ... + def f_nested_partial(a: int) -> Dict[str, MissingRef]: + ... assert xtyping.get_partial_type_hints(f_nested_partial) == { "a": int, @@ -493,7 +500,8 @@ def test_eval_forward_ref(): == Dict[str, Tuple[int, float]] ) - class MissingRef: ... + class MissingRef: + ... assert ( xtyping.eval_forward_ref("Callable[[int], MissingRef]", localns={"MissingRef": MissingRef}) @@ -515,9 +523,7 @@ class MissingRef: ... globalns={"Annotated": Annotated, "Callable": Callable}, localns={"MissingRef": MissingRef}, ) - ) == Callable[ - [int], MissingRef - ] or ( # some patch versions of cpython3.9 show weird behaviors + ) == Callable[[int], MissingRef] or ( # some patch versions of cpython3.9 show weird behaviors sys.version_info >= (3, 9) and sys.version_info < (3, 10) and (ref == Callable[[Annotated[int, "Foo"]], MissingRef]) @@ -551,16 +557,19 @@ def test_infer_type(): assert xtyping.infer_type(str) == Type[str] - class A: ... + class A: + ... assert xtyping.infer_type(A()) == A assert xtyping.infer_type(A) == Type[A] - def f1(): ... + def f1(): + ... assert xtyping.infer_type(f1) == Callable[[], Any] - def f2(a: int, b: float) -> None: ... + def f2(a: int, b: float) -> None: + ... assert xtyping.infer_type(f2) == Callable[[int, float], type(None)] @@ -568,7 +577,8 @@ def f3( a: Dict[Tuple[str, ...], List[int]], b: List[Callable[[List[int]], Set[Set[int]]]], c: Type[List[int]], - ) -> Any: ... + ) -> Any: + ... assert ( xtyping.infer_type(f3) @@ -582,7 +592,8 @@ def f3( ] ) - def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: ... + def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: + ... assert xtyping.infer_type(f4) == Callable[[int, float], type(None)] assert ( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 56b220e0e9..43d2e340a3 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -61,10 +61,12 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): class ExecutionAndAllocatorDescriptor(Protocol): # Used for test infrastructure, consider implementing this in gt4py when refactoring otf @property - def executor(self) -> Optional[ppi.ProgramExecutor]: ... + def executor(self) -> Optional[ppi.ProgramExecutor]: + ... @property - def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: ... + def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: + ... @dataclasses.dataclass(frozen=True) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 8513c98d89..6bedc74117 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -93,7 +93,8 @@ class DataInitializer(Protocol): @property - def scalar_value(self) -> ScalarValue: ... + def scalar_value(self) -> ScalarValue: + ... def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: # some unlikely numpy dtypes are picky about arguments @@ -104,7 +105,8 @@ def field( allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, - ) -> FieldValue: ... + ) -> FieldValue: + ... def from_case( self: Self, @@ -245,19 +247,22 @@ def __getattr__(self, name: str) -> Any: @typing.overload -def make_builder(*args: Callable) -> Callable[..., Builder]: ... +def make_builder(*args: Callable) -> Callable[..., Builder]: + ... @typing.overload def make_builder( *args: Literal[None], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]]: ... +) -> Callable[[Callable], Callable[..., Builder]]: + ... @typing.overload def make_builder( *args: Optional[Callable], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: ... +) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: + ... # TODO(ricoh): Think about improving the type hints using `typing.ParamSpec`. @@ -298,7 +303,8 @@ def setter(self: Builder) -> Builder: argspec = inspect.getfullargspec(func) @dataclasses.dataclass(frozen=True) - class NewBuilder(Builder): ... + class NewBuilder(Builder): + ... for argname in argspec.args + argspec.kwonlyargs: setattr(NewBuilder, argname, make_setter(argname)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index d8c4696073..913f4d1fb6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -48,11 +48,14 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU) - OPTIONAL_PROCESSORS.append( - pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu - ) - ), + ( + OPTIONAL_PROCESSORS.append( + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=pytest.mark.requires_gpu, + ) + ), + ) @pytest.fixture( @@ -134,22 +137,28 @@ def debug_itir(tree): class MeshDescriptor(Protocol): @property - def name(self) -> str: ... + def name(self) -> str: + ... @property - def num_vertices(self) -> int: ... + def num_vertices(self) -> int: + ... @property - def num_cells(self) -> int: ... + def num_cells(self) -> int: + ... @property - def num_edges(self) -> int: ... + def num_edges(self) -> int: + ... @property - def num_levels(self) -> int: ... + def num_levels(self) -> int: + ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> dict[str, common.Connectivity]: + ... def simple_mesh() -> MeshDescriptor: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index b71031c54d..ece57345a4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -779,7 +779,7 @@ def test_scan_nested_tuple_output(forward, cartesian_case): @gtx.scan_operator(axis=KDim, forward=forward, init=init) def simple_scan_operator( - carry: tuple[int32, tuple[int32, int32]] + carry: tuple[int32, tuple[int32, int32]], ) -> tuple[int32, tuple[int32, int32]]: return (carry[0] + 1, (carry[1][0] + 1, carry[1][1] + 1)) @@ -1085,7 +1085,7 @@ def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cas def test_undefined_symbols(cartesian_case): - with pytest.raises(errors.DSLError, match="Undeclared symbol"): + with pytest.raises(errors.DFError, match="Undeclared symbol"): @gtx.field_operator(backend=cartesian_case.executor) def return_undefined(): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 05824fa779..1cbd01f6bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -168,7 +168,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): fencil, ref=lambda edge_f: 3 * np.sum( - -edge_f[v2e_table] ** 2 * 2, + -(edge_f[v2e_table] ** 2) * 2, axis=1, initial=0, where=v2e_table != common.SKIP_VALUE, diff --git a/tests/next_tests/past_common_fixtures.py b/tests/next_tests/past_common_fixtures.py index 3ac931f319..756d81b6d9 100644 --- a/tests/next_tests/past_common_fixtures.py +++ b/tests/next_tests/past_common_fixtures.py @@ -44,7 +44,7 @@ def identity(in_field: gtx.Field[[IDim], "float64"]) -> gtx.Field[[IDim], "float def make_tuple_op(): @gtx.field_operator() def make_tuple_op_impl( - inp: gtx.Field[[IDim], float64] + inp: gtx.Field[[IDim], float64], ) -> Tuple[gtx.Field[[IDim], float64], gtx.Field[[IDim], float64]]: return inp, inp diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index a0035348ad..9ebd991e36 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -526,13 +526,15 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl def test_builtin_int_constructors(): - def int_constrs() -> tuple[ - int32, - int32, - int64, - int32, - int64, - ]: + def int_constrs() -> ( + tuple[ + int32, + int32, + int64, + int32, + int64, + ] + ): return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) @@ -550,15 +552,17 @@ def int_constrs() -> tuple[ def test_builtin_float_constructors(): - def float_constrs() -> tuple[ - float, - float, - float32, - float64, - float, - float32, - float64, - ]: + def float_constrs() -> ( + tuple[ + float, + float, + float32, + float64, + float, + float32, + float64, + ] + ): return ( 0.1, float(0.1), diff --git a/tests/next_tests/unit_tests/test_type_inference.py b/tests/next_tests/unit_tests/test_type_inference.py index 3db67320f1..74178e7548 100644 --- a/tests/next_tests/unit_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/test_type_inference.py @@ -20,7 +20,8 @@ class Foo(ti.Type): bar: ti.Type baz: ti.Type - class Bar(ti.Type): ... + class Bar(ti.Type): + ... r = ti._Renamer() actual = [