Skip to content

Commit

Permalink
Backports for v0.14.0rc2 (#3032)
Browse files Browse the repository at this point in the history
* Fix WaveNet inputs (#3022)

* Fix version range for lightning (#3023)

* Allow VCS install using GLUONTS_FALLBACK_VERSION (#3028)

* Make `mypy` checks opt-out instead of opt-in, fix several type issues (#3027)

Co-authored-by: Oleksandr Shchur <[email protected]>

* No API docs for nursery. (#3030)

* Fix: #3030. (#3031)

* Add support for Pydantic v1 and v2. (#3026)

---------

Co-authored-by: Oleksandr Shchur <[email protected]>
Co-authored-by: ddelange <[email protected]>
Co-authored-by: Jasper <[email protected]>
  • Loading branch information
4 people authored Oct 27, 2023
1 parent be52a5f commit 380bf89
Show file tree
Hide file tree
Showing 126 changed files with 458 additions and 304 deletions.
1 change: 1 addition & 0 deletions .github/workflows/style_type_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
pip install click black mypy
pip install types-python-dateutil
pip install types-waitress
pip install types-PyYAML
- name: Style and type checks
run: |
just black
Expand Down
21 changes: 13 additions & 8 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,26 @@ SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

APIDOC = sphinx-apidoc
APIDOC_OPTS = --implicit-namespaces --separate --module-first
APIDOC_ROOT = gluonts

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

apidoc:
@rm -Rf api/$(APIDOC_ROOT)
@$(APIDOC) $(APIDOC_OPTS) -o api/$(APIDOC_ROOT) ../src/$(APIDOC_ROOT) setup* */bin/* test docs *pycache*
@rm -Rf api/$(APIDOC_ROOT)/modules.rst
@sed -i"" -e "s/$(APIDOC_ROOT) package/API Docs/" api/$(APIDOC_ROOT)/$(APIDOC_ROOT).rst
@rm -Rf api/gluonts

@sphinx-apidoc \
--implicit-namespaces \
--separate \
--module-first \
-o api/gluonts \
../src/gluonts \
../src/gluonts/nursery/* \
../src/gluonts/pydantic.py

@rm -Rf api/gluonts/modules.rst
@sed -i"" -e "s/gluonts package/API Docs/" api/gluonts/gluonts.rst

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand Down
9 changes: 3 additions & 6 deletions requirements/requirements-pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
torch>=1.9,<3
lightning>=1.8,<2.2
lightning>=2.0,<2.2
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
pytorch_lightning>=1.8,<2.2
# Need to pin protobuf (for now)
# See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159
protobuf~=3.19.0
pytorch_lightning>=2.0,<2.2
scipy~=1.10; python_version > "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy~=1.16
pandas>=1.0,<3
pydantic~=1.7
pydantic>=1.7,<3
tqdm~=4.23
toolz~=0.10

Expand Down
24 changes: 15 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,24 @@ def run(self):
# otherwise a module-not-found error is thrown
import mypy.api

folders = [
str(p.parent.resolve()) for p in ROOT.glob("src/**/.typesafe")
excluded_folders = [
str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe")
]

print(
"The following folders contain a `.typesafe` marker file "
"and will be type-checked with `mypy`:"
)
for folder in folders:
if len(excluded_folders) > 0:
print(
"The following folders contain a `.typeunsafe` marker file "
"and will *not* be type-checked with `mypy`:"
)
for folder in excluded_folders:
print(f" {folder}")

std_out, std_err, exit_code = mypy.api.run(folders)
args = [str(ROOT / "src")]
for folder in excluded_folders:
args.append("--exclude")
args.append(folder)

std_out, std_err, exit_code = mypy.api.run(args)

print(std_out, file=sys.stdout)
print(std_err, file=sys.stderr)
Expand All @@ -78,7 +84,7 @@ def run(self):
f"""
Mypy command
mypy {" ".join(folders)}
mypy {" ".join(args)}
returned a non-zero exit code. Fix the type errors listed above
and then run
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
from typing import Any, Type, TypeVar

import numpy as np
from pydantic import BaseConfig, BaseModel, ValidationError, create_model

from gluonts.core import fqname_for
from gluonts.exceptions import GluonTSHyperparametersError
from gluonts.pydantic import (
BaseConfig,
BaseModel,
ValidationError,
create_model,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,7 +257,7 @@ def validated(base_model=None):
>>> c = ComplexNumber(y=None)
Traceback (most recent call last):
...
pydantic.error_wrappers.ValidationError: 1 validation error for
pydantic.v1.error_wrappers.ValidationError: 1 validation error for
ComplexNumberModel
y
none is not an allowed value (type=type_error.none.not_allowed)
Expand All @@ -262,7 +267,7 @@ def validated(base_model=None):
accessed through the ``Model`` attribute of the decorated initializer.
>>> ComplexNumber.__init__.Model
<class 'pydantic.main.ComplexNumberModel'>
<class 'pydantic.v1.main.ComplexNumberModel'>
The Pydantic model is synthesized automatically from on the parameter
names and types of the decorated initializer. In the ``ComplexNumber``
Expand Down
3 changes: 1 addition & 2 deletions src/gluonts/core/serde/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@

from toolz.dicttoolz import valmap

from pydantic import BaseModel

from gluonts.core import fqname_for
from gluonts.pydantic import BaseModel

bad_type_msg = textwrap.dedent(
"""
Expand Down
6 changes: 2 additions & 4 deletions src/gluonts/core/serde/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
TypeVar,
)

import pydantic
import pydantic.dataclasses

from gluonts.itertools import select
from gluonts.pydantic import pydantic, dataclass as pydantic_dataclass

T = TypeVar("T")

Expand Down Expand Up @@ -152,7 +150,7 @@ class Config:
arbitrary_types_allowed = True

# make `cls` a dataclass
pydantic.dataclasses.dataclass(
pydantic_dataclass(
init=init,
repr=repr,
eq=eq,
Expand Down
3 changes: 1 addition & 2 deletions src/gluonts/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def fn(debug):
from operator import attrgetter
from typing import Any

import pydantic
from pydantic.utils import deep_update
from gluonts.pydantic import pydantic, deep_update


class ListElement:
Expand Down
Empty file removed src/gluonts/dataset/.typesafe
Empty file.
4 changes: 1 addition & 3 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@
import pandas as pd
from pandas.tseries.frequencies import to_offset

import pydantic

from gluonts import json
from gluonts.itertools import Cached, Map
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.schema import Translator
from gluonts.exceptions import GluonTSDataError
from gluonts.pydantic import pydantic


from . import Dataset, DatasetCollection, DataEntry, DataBatch # noqa
from . import jsonl, DatasetWriter


arrow: Optional[ModuleType]

try:
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable, Iterable, Optional

import numpy as np
from pydantic import BaseModel

from gluonts.dataset import DataBatch, Dataset
from gluonts.itertools import (
Expand All @@ -25,6 +24,7 @@
batcher,
rows_to_columns,
)
from gluonts.pydantic import BaseModel
from gluonts.transform import (
AdhocTransform,
Identity,
Expand Down
11 changes: 11 additions & 0 deletions src/gluonts/ev/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Sum(Aggregation):
partial_result: Optional[Union[List[np.ndarray], np.ndarray]] = None

def step(self, values: np.ndarray) -> None:
assert self.axis is None or isinstance(self.axis, tuple)

summed_values = np.ma.sum(values, axis=self.axis)

if self.axis is None or 0 in self.axis:
Expand All @@ -57,9 +59,12 @@ def step(self, values: np.ndarray) -> None:
else:
if self.partial_result is None:
self.partial_result = []
assert isinstance(self.partial_result, list)
self.partial_result.append(summed_values)

def get(self) -> np.ndarray:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
return np.ma.copy(self.partial_result)

Expand All @@ -85,6 +90,8 @@ class Mean(Aggregation):
n: Optional[Union[int, np.ndarray]] = None

def step(self, values: np.ndarray) -> None:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
summed_values = np.ma.sum(values, axis=self.axis)
if self.partial_result is None:
Expand All @@ -101,10 +108,14 @@ def step(self, values: np.ndarray) -> None:
self.partial_result = []

mean_values = np.ma.mean(values, axis=self.axis)
assert isinstance(self.partial_result, list)
self.partial_result.append(mean_values)

def get(self) -> np.ndarray:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
assert isinstance(self.partial_result, np.ndarray)
return self.partial_result / self.n

return np.ma.concatenate(self.partial_result)
5 changes: 4 additions & 1 deletion src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def __call__(self, axis: Optional[int] = None) -> Metric:


class BaseMetricDefinition:
def __call__(self, axis):
raise NotImplementedError()

def __add__(self, other) -> MetricDefinitionCollection:
if isinstance(other, MetricDefinitionCollection):
return other + self
Expand All @@ -154,7 +157,7 @@ def add(self, *others):

@dataclass
class MetricDefinitionCollection(BaseMetricDefinition):
metrics: List[MetricDefinition]
metrics: List[BaseMetricDefinition]

def __call__(self, axis: Optional[int] = None) -> MetricCollection:
return MetricCollection([metric(axis=axis) for metric in self.metrics])
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/ev/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def num_masked_target_values(data: Dict[str, np.ndarray]) -> np.ndarray:
if np.ma.isMaskedArray(data["label"]):
assert isinstance(data["label"], np.ma.MaskedArray)
return data["label"].mask.astype(float)
else:
return np.zeros(data["label"].shape)
Expand Down
Empty file removed src/gluonts/evaluation/.typesafe
Empty file.
2 changes: 1 addition & 1 deletion src/gluonts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing import Any

from pydantic.error_wrappers import ValidationError, display_errors
from gluonts.pydantic import ValidationError, display_errors


class GluonTSException(Exception):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/ext/naive_2/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gluonts.model.predictor import RepresentablePredictor


def seasonality_test(past_ts_data: np.array, season_length: int) -> bool:
def seasonality_test(past_ts_data: np.ndarray, season_length: int) -> bool:
"""
Test the time series for seasonal patterns by performing a 90% auto-
correlation test.
Expand Down
Empty file removed src/gluonts/ext/prophet/.typesafe
Empty file.
6 changes: 4 additions & 2 deletions src/gluonts/ext/r_forecast/_hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"erm",
]

HIERARCHICAL_SAMPLE_FORECAST_METHODS = [] # TODO: Add `depbu_mint`.
HIERARCHICAL_SAMPLE_FORECAST_METHODS: List[str] = [] # TODO: Add `depbu_mint`.

SUPPORTED_HIERARCHICAL_METHODS = (
HIERARCHICAL_POINT_FORECAST_METHODS + HIERARCHICAL_SAMPLE_FORECAST_METHODS
Expand Down Expand Up @@ -200,6 +200,8 @@ def _get_r_forecast(self, data: Dict) -> Dict:
else:
hier_ts = self._hts_pkg.gts(y_bottom_ts, groups=nodes)

assert isinstance(self.params["num_samples"], int)

forecast = self._r_method(hier_ts, r_params)

all_forecasts = list(forecast)
Expand Down Expand Up @@ -244,7 +246,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> SampleForecast:
samples = np.array(forecast_dict["samples"])

Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/ext/r_forecast/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self,
freq: str,
prediction_length: int,
period: int = None,
period: Optional[int] = None,
trunc_length: Optional[int] = None,
save_info: bool = False,
r_file_prefix: str = "",
Expand Down Expand Up @@ -215,7 +215,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> Forecast:
"""
Returns object of type `gluonts.model.Forecast`.
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/ext/r_forecast/_univariate_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, Optional
from typing import Dict, Optional, Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
freq: str,
prediction_length: int,
method_name: str = "ets",
period: int = None,
period: Optional[int] = None,
trunc_length: Optional[int] = None,
save_info: bool = False,
params: Dict = dict(),
Expand All @@ -100,7 +100,7 @@ def __init__(
self.method_name = method_name
self._r_method = self._robjects.r[method_name]

self.params = {
self.params: Dict[str, Any] = {
"prediction_length": self.prediction_length,
"frequency": self.period,
}
Expand Down Expand Up @@ -185,7 +185,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> QuantileForecast:
stats_dict = {"mean": forecast_dict["mean"]}

Expand Down
Loading

0 comments on commit 380bf89

Please sign in to comment.