Skip to content

Commit

Permalink
Merge pull request #640 from DHI/dataset-type
Browse files Browse the repository at this point in the history
Dataset type hints
  • Loading branch information
ecomodeller authored Feb 12, 2024
2 parents 2f9a120 + b69038d commit 905a098
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 169 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@
"restructuredtext.confPath": "${workspaceFolder}\\docs",
"files.eol": "\n",
"python.formatting.provider": "black",
"editor.formatOnSave": true
"editor.formatOnSave": true,
"mypy-type-checker.args": [
"--disallow-incomplete-defs",
],
}
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test:
pytest --disable-warnings

typecheck:
mypy $(LIB)/ --config-file pyproject.toml
mypy $(LIB)/

coverage:
pytest --cov-report html --cov=$(LIB) tests/
Expand Down
28 changes: 20 additions & 8 deletions mikeio/dataset/_data_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
from .._spectral import plot_2dspectrum

if TYPE_CHECKING:
from ..dataset import DataArray
from ..dataset import DataArray, Dataset


class _DataArrayPlotter:
"""Context aware plotter (sensible plotting according to geometry)"""

def __init__(self, da: "DataArray") -> None:
self.da = da

def __call__(self, ax:Axes | None=None, figsize=None, **kwargs):
def __call__(
self,
ax: Axes | None = None,
figsize: Tuple[float, float] | None = None,
**kwargs: Any,
) -> Axes:
"""Plot DataArray according to geometry
Parameters
Expand Down Expand Up @@ -65,7 +71,13 @@ def _get_fig_ax(ax=None, figsize=None):
fig = plt.gcf()
return fig, ax

def hist(self, ax: Axes| None=None, figsize:Tuple[float,float] | None=None, title: str | None=None, **kwargs: Any) -> Axes:
def hist(
self,
ax: Axes | None = None,
figsize: Tuple[float, float] | None = None,
title: str | None = None,
**kwargs: Any,
) -> Axes:
"""Plot DataArray as histogram (using ax.hist)
Parameters
Expand Down Expand Up @@ -97,7 +109,7 @@ def hist(self, ax: Axes| None=None, figsize:Tuple[float,float] | None=None, titl
ax.set_title(title)
return self._hist(ax, **kwargs)

def _hist(self, ax: Axes, **kwargs):
def _hist(self, ax: Axes, **kwargs: Any) -> Any:
result = ax.hist(self.da.values.ravel(), **kwargs)
ax.set_xlabel(self._label_txt())
return result
Expand Down Expand Up @@ -575,7 +587,7 @@ def _plot_2dspectrum(self, **kwargs):
**kwargs,
)

def _get_title(self):
def _get_title(self) -> str:
txt = f"{self.da.time[0]}"
x, y = self.da.geometry.x, self.da.geometry.y
if x is not None and y is not None:
Expand All @@ -587,7 +599,7 @@ def _get_title(self):


class _DataArrayPlotterLineSpectrum(_DataArrayPlotterGrid1D):
def __init__(self, da) -> None:
def __init__(self, da: DataArray) -> None:
if da.n_timesteps > 1:
Hm0 = da[0].to_Hm0()
else:
Expand All @@ -596,7 +608,7 @@ def __init__(self, da) -> None:


class _DataArrayPlotterAreaSpectrum(_DataArrayPlotterFM):
def __init__(self, da) -> None:
def __init__(self, da: DataArray) -> None:
if da.n_timesteps > 1:
Hm0 = da[0].to_Hm0()
else:
Expand All @@ -605,7 +617,7 @@ def __init__(self, da) -> None:


class _DatasetPlotter:
def __init__(self, ds) -> None:
def __init__(self, ds: Dataset) -> None:
self.ds = ds

def __call__(self, figsize=None, **kwargs):
Expand Down
84 changes: 51 additions & 33 deletions mikeio/dataset/_dataarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import warnings
from copy import deepcopy
from pathlib import Path
from datetime import datetime
from functools import cached_property
from typing import (
Expand Down Expand Up @@ -1561,7 +1562,7 @@ def aggregate(
if isinstance(axis, int):
axes = (axis,)
else:
axes = axis
axes = axis # type: ignore

dims = tuple([d for i, d in enumerate(self.dims) if i not in axes])

Expand Down Expand Up @@ -1695,35 +1696,37 @@ def _quantile(self, q, *, axis: int | str = 0, func=np.quantile, **kwargs: Any):

# ============= MATH operations ===========

def __radd__(self, other) -> "DataArray":
def __radd__(self, other: "DataArray" | float) -> "DataArray":
return self.__add__(other)

def __add__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.add, "+")
def __add__(self, other: "DataArray" | float) -> "DataArray":
return self._apply_math_operation(other, np.add, txt="+")

def __rsub__(self, other) -> "DataArray":
def __rsub__(self, other: "DataArray" | float) -> "DataArray":
return other + self.__neg__()

def __sub__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.subtract, "-")
def __sub__(self, other: "DataArray" | float) -> "DataArray":
return self._apply_math_operation(other, np.subtract, txt="-")

def __rmul__(self, other) -> "DataArray":
def __rmul__(self, other: "DataArray" | float) -> "DataArray":
return self.__mul__(other)

def __mul__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.multiply, "x") # x in place of *
def __mul__(self, other: "DataArray" | float) -> "DataArray":
return self._apply_math_operation(
other, np.multiply, txt="x"
) # x in place of *

def __pow__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.power, "**")
def __pow__(self, other: float) -> "DataArray":
return self._apply_math_operation(other, np.power, txt="**")

def __truediv__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.divide, "/")
def __truediv__(self, other: "DataArray" | float) -> "DataArray":
return self._apply_math_operation(other, np.divide, txt="/")

def __floordiv__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.floor_divide, "//")
def __floordiv__(self, other: "DataArray" | float) -> "DataArray":
return self._apply_math_operation(other, np.floor_divide, txt="//")

def __mod__(self, other) -> "DataArray":
return self._apply_math_operation(other, np.mod, "%")
def __mod__(self, other: float) -> "DataArray":
return self._apply_math_operation(other, np.mod, txt="%")

def __neg__(self) -> "DataArray":
return self._apply_unary_math_operation(np.negative)
Expand All @@ -1734,7 +1737,7 @@ def __pos__(self) -> "DataArray":
def __abs__(self) -> "DataArray":
return self._apply_unary_math_operation(np.abs)

def _apply_unary_math_operation(self, func) -> "DataArray":
def _apply_unary_math_operation(self, func: Callable) -> "DataArray":
try:
data = func(self.values)

Expand All @@ -1745,7 +1748,9 @@ def _apply_unary_math_operation(self, func) -> "DataArray":
new_da.values = data
return new_da

def _apply_math_operation(self, other, func, txt="with") -> "DataArray":
def _apply_math_operation(
self, other: "DataArray" | float, func: Callable, *, txt: str
) -> "DataArray":
"""Apply a binary math operation with a scalar, an array or another DataArray"""
try:
other_values = other.values if hasattr(other, "values") else other
Expand All @@ -1766,7 +1771,9 @@ def _apply_math_operation(self, other, func, txt="with") -> "DataArray":

return new_da

def _keep_EUM_after_math_operation(self, other, func) -> bool:
def _keep_EUM_after_math_operation(
self, other: "DataArray" | float, func: Callable
) -> bool:
"""Does the math operation falsify the EUM?"""
if hasattr(other, "shape") and hasattr(other, "ndim"):
# other is array-like, so maybe we cannot keep EUM
Expand All @@ -1786,19 +1793,19 @@ def _keep_EUM_after_math_operation(self, other, func) -> bool:

# ============= Logical indexing ===========

def __lt__(self, other) -> "DataArray":
def __lt__(self, other) -> "DataArray": # type: ignore
bmask = self.values < self._other_to_values(other)
return self._boolmask_to_new_DataArray(bmask)

def __gt__(self, other) -> "DataArray":
def __gt__(self, other) -> "DataArray": # type: ignore
bmask = self.values > self._other_to_values(other)
return self._boolmask_to_new_DataArray(bmask)

def __le__(self, other) -> "DataArray":
def __le__(self, other) -> "DataArray": # type: ignore
bmask = self.values <= self._other_to_values(other)
return self._boolmask_to_new_DataArray(bmask)

def __ge__(self, other) -> "DataArray":
def __ge__(self, other) -> "DataArray": # type: ignore
bmask = self.values >= self._other_to_values(other)
return self._boolmask_to_new_DataArray(bmask)

Expand All @@ -1811,10 +1818,12 @@ def __ne__(self, other) -> "DataArray": # type: ignore
return self._boolmask_to_new_DataArray(bmask)

@staticmethod
def _other_to_values(other):
def _other_to_values(
other: "DataArray" | np.ndarray,
) -> np.ndarray:
return other.values if isinstance(other, DataArray) else other

def _boolmask_to_new_DataArray(self, bmask) -> "DataArray":
def _boolmask_to_new_DataArray(self, bmask) -> "DataArray": # type: ignore
return DataArray(
data=bmask,
time=self.time,
Expand All @@ -1833,7 +1842,7 @@ def _to_dataset(self):
{self.name: self}
) # Single-item Dataset (All info is contained in the DataArray, no need for additional info)

def to_dfs(self, filename, **kwargs: Any) -> None:
def to_dfs(self, filename: str | Path, **kwargs: Any) -> None:
"""Write data to a new dfs file
Parameters
Expand Down Expand Up @@ -1994,7 +2003,10 @@ def _time_by_agg_axis(
return time

@staticmethod
def _get_time_idx_list(time: pd.DatetimeIndex, steps):
def _get_time_idx_list(
time: pd.DatetimeIndex,
steps: int | Iterable[int] | str | datetime | pd.DatetimeIndex | slice,
) -> list[int] | slice:
"""Find list of idx in DatetimeIndex"""

return _get_time_idx_list(time, steps)
Expand All @@ -2004,7 +2016,7 @@ def _n_selected_timesteps(time: Sized, k: slice | Sized) -> int:
return _n_selected_timesteps(time, k)

@staticmethod
def _is_boolean_mask(x) -> bool:
def _is_boolean_mask(x: Any) -> bool:
if hasattr(x, "dtype"): # isinstance(x, (np.ndarray, DataArray)):
return x.dtype == np.dtype("bool")
return False
Expand All @@ -2016,14 +2028,16 @@ def _get_by_boolean_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
return data[mask]

@staticmethod
def _set_by_boolean_mask(data: np.ndarray, mask: np.ndarray, value) -> None:
def _set_by_boolean_mask(
data: np.ndarray, mask: np.ndarray, value: np.ndarray
) -> None:
if data.shape != mask.shape:
data[np.broadcast_to(mask, data.shape)] = value
else:
data[mask] = value

@staticmethod
def _parse_time(time) -> pd.DatetimeIndex:
def _parse_time(time: Any) -> pd.DatetimeIndex:
"""Allow anything that we can create a DatetimeIndex from"""
if time is None:
time = [pd.Timestamp(2018, 1, 1)] # TODO is this the correct epoch?
Expand All @@ -2043,7 +2057,11 @@ def _parse_time(time) -> pd.DatetimeIndex:
return index

@staticmethod
def _parse_axis(data_shape, dims, axis) -> int | Tuple[int]:
def _parse_axis(
data_shape: Tuple[int, ...],
dims: Tuple[str, ...],
axis: int | Tuple[int, ...] | str | None,
) -> int | Tuple[int, ...]:
# TODO change to return tuple always
# axis = 0 if axis == "time" else axis
if (axis == "spatial") or (axis == "space"):
Expand Down
Loading

0 comments on commit 905a098

Please sign in to comment.