Skip to content

Commit

Permalink
ran pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Feb 22, 2024
1 parent 9abded7 commit bc50c07
Show file tree
Hide file tree
Showing 80 changed files with 839 additions and 548 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ repos:
# Add all type stubs from typeshed
- types-all
args: [--no-install-types]
exclude: |
exclude: |-
(?x)^(
setup.py |
build/.* |
Expand All @@ -120,4 +120,4 @@ repos:
tests/next_tests/past_common_fixtures.py |
tests/next_tests/toy_connectivity.py |
tests/.*
)$
)$
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -473,4 +473,4 @@ version = {attr = 'gt4py.__about__.__version__'}
'gt4py' = ['py.typed', '*.md', '*.rst']

[tool.setuptools.packages]
find = {namespaces = false, where = ['src']}
find = {namespaces = false, where = ['src']}
99 changes: 66 additions & 33 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,24 @@
BoolScalar: TypeAlias = Union[bool_, bool]
BoolT = TypeVar("BoolT", bound=BoolScalar)
BOOL_TYPES: Final[Tuple[type, ...]] = cast(
Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined]
Tuple[type, ...],
BoolScalar.__args__, # type: ignore[attr-defined]
)


IntScalar: TypeAlias = Union[int8, int16, int32, int64, int]
IntT = TypeVar("IntT", bound=IntScalar)
INT_TYPES: Final[Tuple[type, ...]] = cast(
Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined]
Tuple[type, ...],
IntScalar.__args__, # type: ignore[attr-defined]
)


UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64]
UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar)
UINT_TYPES: Final[Tuple[type, ...]] = cast(
Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined]
Tuple[type, ...],
UnsignedIntScalar.__args__, # type: ignore[attr-defined]
)


Expand All @@ -100,7 +103,8 @@
FloatingScalar: TypeAlias = Union[float32, float64, float]
FloatingT = TypeVar("FloatingT", bound=FloatingScalar)
FLOAT_TYPES: Final[Tuple[type, ...]] = cast(
Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined]
Tuple[type, ...],
FloatingScalar.__args__, # type: ignore[attr-defined]
)


Expand Down Expand Up @@ -165,23 +169,28 @@ class DTypeKind(eve.StrEnum):


@overload
def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: ...
def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]:
...


@overload
def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: ...
def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]:
...


@overload
def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: ...
def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]:
...


@overload
def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: ...
def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]:
...


@overload
def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: ...
def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind:
...


def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind:
Expand Down Expand Up @@ -355,7 +364,8 @@ class GTDimsInterface(Protocol):
"""

@property
def __gt_dims__(self) -> Tuple[str, ...]: ...
def __gt_dims__(self) -> Tuple[str, ...]:
...


class GTOriginInterface(Protocol):
Expand All @@ -366,7 +376,8 @@ class GTOriginInterface(Protocol):
"""

@property
def __gt_origin__(self) -> Tuple[int, ...]: ...
def __gt_origin__(self) -> Tuple[int, ...]:
...


# -- Device representation --
Expand Down Expand Up @@ -436,45 +447,64 @@ def __iter__(self) -> Iterator[DeviceTypeT | int]:

class NDArrayObject(Protocol):
@property
def ndim(self) -> int: ...
def ndim(self) -> int:
...

@property
def shape(self) -> tuple[int, ...]: ...
def shape(self) -> tuple[int, ...]:
...

@property
def dtype(self) -> Any: ...
def dtype(self) -> Any:
...

def item(self) -> Any: ...
def item(self) -> Any:
...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ...
def astype(self, dtype: npt.DTypeLike) -> NDArrayObject:
...

def __getitem__(self, item: Any) -> NDArrayObject: ...
def __getitem__(self, item: Any) -> NDArrayObject:
...

def __abs__(self) -> NDArrayObject: ...
def __abs__(self) -> NDArrayObject:
...

def __neg__(self) -> NDArrayObject: ...
def __neg__(self) -> NDArrayObject:
...

def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __radd__(self, other: Any) -> NDArrayObject: ...
def __radd__(self, other: Any) -> NDArrayObject:
...

def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __rsub__(self, other: Any) -> NDArrayObject: ...
def __rsub__(self, other: Any) -> NDArrayObject:
...

def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __rmul__(self, other: Any) -> NDArrayObject: ...
def __rmul__(self, other: Any) -> NDArrayObject:
...

def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __rfloordiv__(self, other: Any) -> NDArrayObject: ...
def __rfloordiv__(self, other: Any) -> NDArrayObject:
...

def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __rtruediv__(self, other: Any) -> NDArrayObject: ...
def __rtruediv__(self, other: Any) -> NDArrayObject:
...

def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool`
...
Expand All @@ -494,8 +524,11 @@ def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignor
def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable
...

def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...

def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject:
...
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs


class MakeModuleSourceCallable(Protocol):
def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ...
def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str:
...


class PurePythonBackendCLIMixin(CLIBackendMixin):
Expand Down
16 changes: 8 additions & 8 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,14 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li
if name in sdfg.arrays:
data = sdfg.arrays[name]
assert isinstance(data, dace.data.Array)
res[name] = (
"py::{pybind_type} {name}, std::array<gt::int_t,{ndim}> {name}_origin".format(
pybind_type=(
"object" if self.backend.storage_info["device"] == "gpu" else "buffer"
),
name=name,
ndim=len(data.shape),
)
res[
name
] = "py::{pybind_type} {name}, std::array<gt::int_t,{ndim}> {name}_origin".format(
pybind_type=(
"object" if self.backend.storage_info["device"] == "gpu" else "buffer"
),
name=name,
ndim=len(data.shape),
)
elif name in sdfg.symbols and not name.startswith("__"):
assert name in sdfg.symbols
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/cartesian/backend/pyext_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def build_pybind_ext(
build_path: str,
target_path: str,
**kwargs: str,
) -> Tuple[str, str]: ...
) -> Tuple[str, str]:
...


@overload
Expand All @@ -197,7 +198,8 @@ def build_pybind_ext(
build_ext_class: Type = None,
verbose: bool = False,
clean: bool = False,
) -> Tuple[str, str]: ...
) -> Tuple[str, str]:
...


def build_pybind_ext(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def visit_If(self, node: ast.If):


def _make_temp_decls(
descriptors: Dict[str, gtscript._FieldDescriptor]
descriptors: Dict[str, gtscript._FieldDescriptor],
) -> Dict[str, nodes.FieldDecl]:
return {
name: nodes.FieldDecl(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def data_type_to_typestr(dtype: DataType) -> str:
def op_to_ufunc(
op: Union[
UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction
]
],
) -> np.ufunc:
if not isinstance(
op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction)
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@


class SymbolNameCreator(Protocol):
def __call__(self, name: str) -> str: ...
def __call__(self, name: str) -> str:
...


def _make_axis_offset_expr(bound: common.AxisBound, axis_index: int) -> cuir.Expr:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _visit_offset(
context_info = copy.deepcopy(access_info)
context_info.variable_offset_axes = []
ranges = make_dace_subset(
access_info, context_info, data_dims=() # data_index added in visit_IndexAccess
access_info,
context_info,
data_dims=(), # data_index added in visit_IndexAccess
)
ranges.offset(sym_offsets, negative=False)
res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1])
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def flatten_list(list_or_node: Union[List[Any], eve.Node]):


def collect_toplevel_computation_nodes(
list_or_node: Union[List[Any], eve.Node]
list_or_node: Union[List[Any], eve.Node],
) -> List["dcir.ComputationNode"]:
class ComputationNodeCollector(eve.NodeVisitor):
def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List):
Expand All @@ -431,7 +431,7 @@ def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List)


def collect_toplevel_iteration_nodes(
list_or_node: Union[List[Any], eve.Node]
list_or_node: Union[List[Any], eve.Node],
) -> List["dcir.IterationNode"]:
class IterationNodeCollector(eve.NodeVisitor):
def visit_IterationNode(self, node: dcir.IterationNode, *, collection: List):
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def _make_axis_offset_expr(


class SymbolNameCreator(Protocol):
def __call__(self, name: str) -> str: ...
def __call__(self, name: str) -> str:
...


class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def _iter_field_names(
node: Union[gtir.Stencil, gtir.ParAssignStmt]
node: Union[gtir.Stencil, gtir.ParAssignStmt],
) -> eve.utils.XIterable[gtir.FieldAccess]:
return node.walk_values().if_isinstance(gtir.FieldDecl).getattr("name").unique()

Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _extract_array_infos(


def _extract_stencil_arrays(
array_infos: Dict[str, Optional[ArgsInfo]]
array_infos: Dict[str, Optional[ArgsInfo]],
) -> Dict[str, Optional[FieldType]]:
return {name: info.array if info is not None else None for name, info in array_infos.items()}

Expand Down Expand Up @@ -283,7 +283,7 @@ def __call__(self, *args, **kwargs) -> None:

@staticmethod
def _make_origin_dict(
origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None]
origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None],
) -> Dict[str, Tuple[int, ...]]:
try:
if isinstance(origin, dict):
Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/cartesian/testing/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,7 @@ class StencilTestSuite(metaclass=SuiteMeta):
.. code-block:: python
{
'float_symbols' : (np.float32, np.float64),
'int_symbols' : (int, np.int_, np.int64)
}
{"float_symbols": (np.float32, np.float64), "int_symbols": (int, np.int_, np.int64)}
domain_range : `Sequence` of pairs like `((int, int), (int, int) ... )`
Required class attribute.
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class StencilFunc(Protocol):
__name__: str
__module__: str

def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...
def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None:
...


class AnnotatedStencilFunc(StencilFunc, Protocol):
Expand Down
Loading

0 comments on commit bc50c07

Please sign in to comment.