From 380bf89b213796b731258a2e78d610d147902ee2 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 27 Oct 2023 10:10:35 +0200 Subject: [PATCH] Backports for v0.14.0rc2 (#3032) * Fix WaveNet inputs (#3022) * Fix version range for lightning (#3023) * Allow VCS install using GLUONTS_FALLBACK_VERSION (#3028) * Make `mypy` checks opt-out instead of opt-in, fix several type issues (#3027) Co-authored-by: Oleksandr Shchur * No API docs for nursery. (#3030) * Fix: #3030. (#3031) * Add support for Pydantic v1 and v2. (#3026) --------- Co-authored-by: Oleksandr Shchur Co-authored-by: ddelange <14880945+ddelange@users.noreply.github.com> Co-authored-by: Jasper --- .github/workflows/style_type_checks.yml | 1 + docs/Makefile | 21 ++++-- requirements/requirements-pytorch.txt | 9 +-- requirements/requirements.txt | 2 +- setup.py | 24 +++--- src/gluonts/core/component.py | 11 ++- src/gluonts/core/serde/_base.py | 3 +- src/gluonts/core/serde/_dataclass.py | 6 +- src/gluonts/core/settings.py | 3 +- src/gluonts/dataset/.typesafe | 0 src/gluonts/dataset/common.py | 4 +- src/gluonts/dataset/loader.py | 2 +- src/gluonts/ev/aggregations.py | 11 +++ src/gluonts/ev/metrics.py | 5 +- src/gluonts/ev/stats.py | 1 + src/gluonts/evaluation/.typesafe | 0 src/gluonts/exceptions.py | 2 +- src/gluonts/ext/naive_2/_predictor.py | 2 +- src/gluonts/ext/prophet/.typesafe | 0 .../ext/r_forecast/_hierarchical_predictor.py | 6 +- src/gluonts/ext/r_forecast/_predictor.py | 4 +- .../ext/r_forecast/_univariate_predictor.py | 8 +- src/gluonts/ext/rotbaum/_model.py | 15 ++-- src/gluonts/ext/rotbaum/_predictor.py | 37 +++++++--- src/gluonts/ext/rotbaum/_preprocess.py | 24 +++--- src/gluonts/ext/rotbaum/_types.py | 3 +- src/gluonts/itertools.py | 10 +-- src/gluonts/meta/_version.py | 4 +- src/gluonts/model/evaluation.py | 4 +- src/gluonts/model/forecast.py | 21 ++++-- src/gluonts/model/forecast_generator.py | 10 ++- src/gluonts/model/npts/.typesafe | 0 src/gluonts/model/predictor.py | 11 +-- src/gluonts/model/seasonal_naive/.typesafe | 0 src/gluonts/model/trivial/.typesafe | 0 src/gluonts/model/trivial/mean.py | 2 +- src/gluonts/mx/batchify.py | 2 +- src/gluonts/mx/block/.typesafe | 0 src/gluonts/mx/distribution/.typesafe | 0 src/gluonts/mx/kernels/.typesafe | 0 src/gluonts/mx/model/canonical/_network.py | 4 +- src/gluonts/mx/model/deep_factor/.typesafe | 0 src/gluonts/mx/model/deepar/.typesafe | 0 src/gluonts/mx/model/deepstate/.typesafe | 0 src/gluonts/mx/model/deepvar/.typesafe | 0 .../mx/model/deepvar_hierarchical/.typesafe | 0 src/gluonts/mx/model/estimator.py | 18 +++-- src/gluonts/mx/model/forecast.py | 3 +- src/gluonts/mx/model/gp_forecaster/.typesafe | 0 .../model/gp_forecaster/gaussian_process.py | 4 +- src/gluonts/mx/model/gpvar/.typesafe | 0 src/gluonts/mx/model/lstnet/_network.py | 2 +- src/gluonts/mx/model/n_beats/.typesafe | 0 src/gluonts/mx/model/n_beats/_ensemble.py | 2 +- src/gluonts/mx/model/predictor.py | 11 +-- src/gluonts/mx/model/renewal/_predictor.py | 6 +- src/gluonts/mx/model/seq2seq/.typesafe | 0 .../mx/model/simple_feedforward/.typesafe | 0 src/gluonts/mx/model/tft/.typesafe | 0 src/gluonts/mx/model/tpp/.typesafe | 0 src/gluonts/mx/model/transformer/.typesafe | 0 src/gluonts/mx/model/wavenet/.typesafe | 0 src/gluonts/mx/prelude.py | 4 +- src/gluonts/mx/representation/.typesafe | 0 src/gluonts/mx/trainer/.typesafe | 0 src/gluonts/mx/trainer/callback.py | 2 +- .../mx/trainer/learning_rate_scheduler.py | 2 +- .../{core/.typesafe => nursery/.typeunsafe} | 0 .../_precision_recall_utils.py | 12 +-- .../supervised_metrics/bounded_pr_auc.py | 6 +- src/gluonts/pydantic.py | 74 +++++++++++++++++++ src/gluonts/shell/.typesafe | 0 src/gluonts/shell/sagemaker/train.py | 2 +- src/gluonts/shell/serve/__init__.py | 2 +- src/gluonts/shell/serve/app.py | 3 +- src/gluonts/testutil/.typesafe | 0 src/gluonts/time_feature/.typesafe | 0 src/gluonts/time_feature/_base.py | 3 +- src/gluonts/torch/batchify.py | 8 +- src/gluonts/torch/distributions/__init__.py | 3 - .../torch/distributions/binned_uniforms.py | 10 +-- .../distributions/distribution_output.py | 10 +-- .../torch/distributions/generalized_pareto.py | 4 +- .../implicit_quantile_network.py | 4 +- src/gluonts/torch/distributions/isqf.py | 8 +- .../torch/distributions/negative_binomial.py | 2 +- .../torch/distributions/piecewise_linear.py | 6 +- .../distributions/spliced_binned_pareto.py | 20 ++--- src/gluonts/torch/distributions/studentT.py | 2 +- .../torch/distributions/truncated_normal.py | 5 +- src/gluonts/torch/model/deep_npts/.typesafe | 0 src/gluonts/torch/model/deepar/.typesafe | 0 src/gluonts/torch/model/estimator.py | 20 ++--- src/gluonts/torch/model/forecast.py | 9 ++- src/gluonts/torch/model/lightning_util.py | 4 - .../mqf2.py => model/mqf2/distribution.py} | 30 ++++---- src/gluonts/torch/model/mqf2/estimator.py | 11 ++- src/gluonts/torch/model/mqf2/icnn_utils.py | 6 +- .../torch/model/mqf2/lightning_module.py | 3 +- src/gluonts/torch/model/mqf2/module.py | 10 +-- src/gluonts/torch/model/patch_tst/module.py | 6 +- src/gluonts/torch/model/predictor.py | 6 +- .../torch/model/simple_feedforward/.typesafe | 0 src/gluonts/torch/model/tft/layers.py | 10 +-- src/gluonts/torch/model/tft/module.py | 36 ++++++--- src/gluonts/torch/model/wavenet/estimator.py | 3 +- .../torch/model/wavenet/lightning_module.py | 4 + src/gluonts/torch/model/wavenet/module.py | 36 +++++---- src/gluonts/torch/modules/feature.py | 15 ++-- src/gluonts/torch/modules/lookup_table.py | 1 + src/gluonts/torch/modules/loss.py | 10 ++- src/gluonts/torch/modules/quantile_output.py | 2 +- src/gluonts/torch/prelude.py | 4 +- src/gluonts/torch/util.py | 2 +- src/gluonts/transform/.typesafe | 0 src/gluonts/transform/sampler.py | 2 +- src/gluonts/zebras/.typesafe | 0 src/gluonts/zebras/schema.py | 3 +- test/core/test_serde_dataclass.py | 3 +- test/core/test_serde_flat.py | 3 +- test/core/test_settings.py | 3 +- test/ext/naive_2/test_predictors.py | 2 +- .../test_mx_distribution_inference.py | 2 +- test/mx/test_mx_serde.py | 2 +- test/torch/model/test_mqf2_modules.py | 2 +- .../test_torch_distribution_inference.py | 2 +- 126 files changed, 458 insertions(+), 304 deletions(-) delete mode 100644 src/gluonts/dataset/.typesafe delete mode 100644 src/gluonts/evaluation/.typesafe delete mode 100644 src/gluonts/ext/prophet/.typesafe delete mode 100644 src/gluonts/model/npts/.typesafe delete mode 100644 src/gluonts/model/seasonal_naive/.typesafe delete mode 100644 src/gluonts/model/trivial/.typesafe delete mode 100644 src/gluonts/mx/block/.typesafe delete mode 100644 src/gluonts/mx/distribution/.typesafe delete mode 100644 src/gluonts/mx/kernels/.typesafe delete mode 100644 src/gluonts/mx/model/deep_factor/.typesafe delete mode 100644 src/gluonts/mx/model/deepar/.typesafe delete mode 100644 src/gluonts/mx/model/deepstate/.typesafe delete mode 100644 src/gluonts/mx/model/deepvar/.typesafe delete mode 100644 src/gluonts/mx/model/deepvar_hierarchical/.typesafe delete mode 100644 src/gluonts/mx/model/gp_forecaster/.typesafe delete mode 100644 src/gluonts/mx/model/gpvar/.typesafe delete mode 100644 src/gluonts/mx/model/n_beats/.typesafe delete mode 100644 src/gluonts/mx/model/seq2seq/.typesafe delete mode 100644 src/gluonts/mx/model/simple_feedforward/.typesafe delete mode 100644 src/gluonts/mx/model/tft/.typesafe delete mode 100644 src/gluonts/mx/model/tpp/.typesafe delete mode 100644 src/gluonts/mx/model/transformer/.typesafe delete mode 100644 src/gluonts/mx/model/wavenet/.typesafe delete mode 100644 src/gluonts/mx/representation/.typesafe delete mode 100644 src/gluonts/mx/trainer/.typesafe rename src/gluonts/{core/.typesafe => nursery/.typeunsafe} (100%) create mode 100644 src/gluonts/pydantic.py delete mode 100644 src/gluonts/shell/.typesafe delete mode 100644 src/gluonts/testutil/.typesafe delete mode 100644 src/gluonts/time_feature/.typesafe delete mode 100644 src/gluonts/torch/model/deep_npts/.typesafe delete mode 100644 src/gluonts/torch/model/deepar/.typesafe rename src/gluonts/torch/{distributions/mqf2.py => model/mqf2/distribution.py} (94%) delete mode 100644 src/gluonts/torch/model/simple_feedforward/.typesafe delete mode 100644 src/gluonts/transform/.typesafe delete mode 100644 src/gluonts/zebras/.typesafe diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index a5d1721deb..f8ffcc4ca8 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -22,6 +22,7 @@ jobs: pip install click black mypy pip install types-python-dateutil pip install types-waitress + pip install types-PyYAML - name: Style and type checks run: | just black diff --git a/docs/Makefile b/docs/Makefile index ccd2836f46..4e5cce0da9 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -8,10 +8,6 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build -APIDOC = sphinx-apidoc -APIDOC_OPTS = --implicit-namespaces --separate --module-first -APIDOC_ROOT = gluonts - # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) @@ -19,10 +15,19 @@ help: .PHONY: help Makefile apidoc: - @rm -Rf api/$(APIDOC_ROOT) - @$(APIDOC) $(APIDOC_OPTS) -o api/$(APIDOC_ROOT) ../src/$(APIDOC_ROOT) setup* */bin/* test docs *pycache* - @rm -Rf api/$(APIDOC_ROOT)/modules.rst - @sed -i"" -e "s/$(APIDOC_ROOT) package/API Docs/" api/$(APIDOC_ROOT)/$(APIDOC_ROOT).rst + @rm -Rf api/gluonts + + @sphinx-apidoc \ + --implicit-namespaces \ + --separate \ + --module-first \ + -o api/gluonts \ + ../src/gluonts \ + ../src/gluonts/nursery/* \ + ../src/gluonts/pydantic.py + + @rm -Rf api/gluonts/modules.rst + @sed -i"" -e "s/gluonts package/API Docs/" api/gluonts/gluonts.rst # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt index 03f4e997ab..683bb101d7 100644 --- a/requirements/requirements-pytorch.txt +++ b/requirements/requirements-pytorch.txt @@ -1,9 +1,6 @@ torch>=1.9,<3 -lightning>=1.8,<2.2 +lightning>=2.0,<2.2 # Capping `lightning` does not cap `pytorch_lightning`, so we cap manually -pytorch_lightning>=1.8,<2.2 -# Need to pin protobuf (for now) -# See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 -protobuf~=3.19.0 +pytorch_lightning>=2.0,<2.2 scipy~=1.10; python_version > "3.7.0" -scipy~=1.7.3; python_version <= "3.7.0" \ No newline at end of file +scipy~=1.7.3; python_version <= "3.7.0" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c833b78584..49df084e85 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ numpy~=1.16 pandas>=1.0,<3 -pydantic~=1.7 +pydantic>=1.7,<3 tqdm~=4.23 toolz~=0.10 diff --git a/setup.py b/setup.py index ec0b6e76d2..8673c79c8b 100644 --- a/setup.py +++ b/setup.py @@ -57,18 +57,24 @@ def run(self): # otherwise a module-not-found error is thrown import mypy.api - folders = [ - str(p.parent.resolve()) for p in ROOT.glob("src/**/.typesafe") + excluded_folders = [ + str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe") ] - print( - "The following folders contain a `.typesafe` marker file " - "and will be type-checked with `mypy`:" - ) - for folder in folders: + if len(excluded_folders) > 0: + print( + "The following folders contain a `.typeunsafe` marker file " + "and will *not* be type-checked with `mypy`:" + ) + for folder in excluded_folders: print(f" {folder}") - std_out, std_err, exit_code = mypy.api.run(folders) + args = [str(ROOT / "src")] + for folder in excluded_folders: + args.append("--exclude") + args.append(folder) + + std_out, std_err, exit_code = mypy.api.run(args) print(std_out, file=sys.stdout) print(std_err, file=sys.stderr) @@ -78,7 +84,7 @@ def run(self): f""" Mypy command - mypy {" ".join(folders)} + mypy {" ".join(args)} returned a non-zero exit code. Fix the type errors listed above and then run diff --git a/src/gluonts/core/component.py b/src/gluonts/core/component.py index 2bfb5e10a7..8fcdc07f8f 100644 --- a/src/gluonts/core/component.py +++ b/src/gluonts/core/component.py @@ -19,10 +19,15 @@ from typing import Any, Type, TypeVar import numpy as np -from pydantic import BaseConfig, BaseModel, ValidationError, create_model from gluonts.core import fqname_for from gluonts.exceptions import GluonTSHyperparametersError +from gluonts.pydantic import ( + BaseConfig, + BaseModel, + ValidationError, + create_model, +) logger = logging.getLogger(__name__) @@ -252,7 +257,7 @@ def validated(base_model=None): >>> c = ComplexNumber(y=None) Traceback (most recent call last): ... - pydantic.error_wrappers.ValidationError: 1 validation error for + pydantic.v1.error_wrappers.ValidationError: 1 validation error for ComplexNumberModel y none is not an allowed value (type=type_error.none.not_allowed) @@ -262,7 +267,7 @@ def validated(base_model=None): accessed through the ``Model`` attribute of the decorated initializer. >>> ComplexNumber.__init__.Model - + The Pydantic model is synthesized automatically from on the parameter names and types of the decorated initializer. In the ``ComplexNumber`` diff --git a/src/gluonts/core/serde/_base.py b/src/gluonts/core/serde/_base.py index 5432bd8105..9efee3aab8 100644 --- a/src/gluonts/core/serde/_base.py +++ b/src/gluonts/core/serde/_base.py @@ -20,9 +20,8 @@ from toolz.dicttoolz import valmap -from pydantic import BaseModel - from gluonts.core import fqname_for +from gluonts.pydantic import BaseModel bad_type_msg = textwrap.dedent( """ diff --git a/src/gluonts/core/serde/_dataclass.py b/src/gluonts/core/serde/_dataclass.py index 8cf583b32c..1b31dd7adb 100644 --- a/src/gluonts/core/serde/_dataclass.py +++ b/src/gluonts/core/serde/_dataclass.py @@ -24,10 +24,8 @@ TypeVar, ) -import pydantic -import pydantic.dataclasses - from gluonts.itertools import select +from gluonts.pydantic import pydantic, dataclass as pydantic_dataclass T = TypeVar("T") @@ -152,7 +150,7 @@ class Config: arbitrary_types_allowed = True # make `cls` a dataclass - pydantic.dataclasses.dataclass( + pydantic_dataclass( init=init, repr=repr, eq=eq, diff --git a/src/gluonts/core/settings.py b/src/gluonts/core/settings.py index fa3ff55dba..2b08ca2085 100644 --- a/src/gluonts/core/settings.py +++ b/src/gluonts/core/settings.py @@ -76,8 +76,7 @@ def fn(debug): from operator import attrgetter from typing import Any -import pydantic -from pydantic.utils import deep_update +from gluonts.pydantic import pydantic, deep_update class ListElement: diff --git a/src/gluonts/dataset/.typesafe b/src/gluonts/dataset/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/dataset/common.py b/src/gluonts/dataset/common.py index eb39ea1a5f..8bde27dd2d 100644 --- a/src/gluonts/dataset/common.py +++ b/src/gluonts/dataset/common.py @@ -22,19 +22,17 @@ import pandas as pd from pandas.tseries.frequencies import to_offset -import pydantic - from gluonts import json from gluonts.itertools import Cached, Map from gluonts.dataset.field_names import FieldName from gluonts.dataset.schema import Translator from gluonts.exceptions import GluonTSDataError +from gluonts.pydantic import pydantic from . import Dataset, DatasetCollection, DataEntry, DataBatch # noqa from . import jsonl, DatasetWriter - arrow: Optional[ModuleType] try: diff --git a/src/gluonts/dataset/loader.py b/src/gluonts/dataset/loader.py index e87fc53afa..0339502ee7 100644 --- a/src/gluonts/dataset/loader.py +++ b/src/gluonts/dataset/loader.py @@ -15,7 +15,6 @@ from typing import Callable, Iterable, Optional import numpy as np -from pydantic import BaseModel from gluonts.dataset import DataBatch, Dataset from gluonts.itertools import ( @@ -25,6 +24,7 @@ batcher, rows_to_columns, ) +from gluonts.pydantic import BaseModel from gluonts.transform import ( AdhocTransform, Identity, diff --git a/src/gluonts/ev/aggregations.py b/src/gluonts/ev/aggregations.py index c56753c854..2a8a8231f8 100644 --- a/src/gluonts/ev/aggregations.py +++ b/src/gluonts/ev/aggregations.py @@ -48,6 +48,8 @@ class Sum(Aggregation): partial_result: Optional[Union[List[np.ndarray], np.ndarray]] = None def step(self, values: np.ndarray) -> None: + assert self.axis is None or isinstance(self.axis, tuple) + summed_values = np.ma.sum(values, axis=self.axis) if self.axis is None or 0 in self.axis: @@ -57,9 +59,12 @@ def step(self, values: np.ndarray) -> None: else: if self.partial_result is None: self.partial_result = [] + assert isinstance(self.partial_result, list) self.partial_result.append(summed_values) def get(self) -> np.ndarray: + assert self.axis is None or isinstance(self.axis, tuple) + if self.axis is None or 0 in self.axis: return np.ma.copy(self.partial_result) @@ -85,6 +90,8 @@ class Mean(Aggregation): n: Optional[Union[int, np.ndarray]] = None def step(self, values: np.ndarray) -> None: + assert self.axis is None or isinstance(self.axis, tuple) + if self.axis is None or 0 in self.axis: summed_values = np.ma.sum(values, axis=self.axis) if self.partial_result is None: @@ -101,10 +108,14 @@ def step(self, values: np.ndarray) -> None: self.partial_result = [] mean_values = np.ma.mean(values, axis=self.axis) + assert isinstance(self.partial_result, list) self.partial_result.append(mean_values) def get(self) -> np.ndarray: + assert self.axis is None or isinstance(self.axis, tuple) + if self.axis is None or 0 in self.axis: + assert isinstance(self.partial_result, np.ndarray) return self.partial_result / self.n return np.ma.concatenate(self.partial_result) diff --git a/src/gluonts/ev/metrics.py b/src/gluonts/ev/metrics.py index 9408ccf045..7d3c0b382b 100644 --- a/src/gluonts/ev/metrics.py +++ b/src/gluonts/ev/metrics.py @@ -139,6 +139,9 @@ def __call__(self, axis: Optional[int] = None) -> Metric: class BaseMetricDefinition: + def __call__(self, axis): + raise NotImplementedError() + def __add__(self, other) -> MetricDefinitionCollection: if isinstance(other, MetricDefinitionCollection): return other + self @@ -154,7 +157,7 @@ def add(self, *others): @dataclass class MetricDefinitionCollection(BaseMetricDefinition): - metrics: List[MetricDefinition] + metrics: List[BaseMetricDefinition] def __call__(self, axis: Optional[int] = None) -> MetricCollection: return MetricCollection([metric(axis=axis) for metric in self.metrics]) diff --git a/src/gluonts/ev/stats.py b/src/gluonts/ev/stats.py index 10be26aeda..250fc1bbe9 100644 --- a/src/gluonts/ev/stats.py +++ b/src/gluonts/ev/stats.py @@ -18,6 +18,7 @@ def num_masked_target_values(data: Dict[str, np.ndarray]) -> np.ndarray: if np.ma.isMaskedArray(data["label"]): + assert isinstance(data["label"], np.ma.MaskedArray) return data["label"].mask.astype(float) else: return np.zeros(data["label"].shape) diff --git a/src/gluonts/evaluation/.typesafe b/src/gluonts/evaluation/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/exceptions.py b/src/gluonts/exceptions.py index 406755a16c..2ff046bfd6 100644 --- a/src/gluonts/exceptions.py +++ b/src/gluonts/exceptions.py @@ -13,7 +13,7 @@ from typing import Any -from pydantic.error_wrappers import ValidationError, display_errors +from gluonts.pydantic import ValidationError, display_errors class GluonTSException(Exception): diff --git a/src/gluonts/ext/naive_2/_predictor.py b/src/gluonts/ext/naive_2/_predictor.py index 7320730a6a..ca4937b091 100644 --- a/src/gluonts/ext/naive_2/_predictor.py +++ b/src/gluonts/ext/naive_2/_predictor.py @@ -21,7 +21,7 @@ from gluonts.model.predictor import RepresentablePredictor -def seasonality_test(past_ts_data: np.array, season_length: int) -> bool: +def seasonality_test(past_ts_data: np.ndarray, season_length: int) -> bool: """ Test the time series for seasonal patterns by performing a 90% auto- correlation test. diff --git a/src/gluonts/ext/prophet/.typesafe b/src/gluonts/ext/prophet/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/ext/r_forecast/_hierarchical_predictor.py b/src/gluonts/ext/r_forecast/_hierarchical_predictor.py index f1415896ed..63a0374443 100644 --- a/src/gluonts/ext/r_forecast/_hierarchical_predictor.py +++ b/src/gluonts/ext/r_forecast/_hierarchical_predictor.py @@ -32,7 +32,7 @@ "erm", ] -HIERARCHICAL_SAMPLE_FORECAST_METHODS = [] # TODO: Add `depbu_mint`. +HIERARCHICAL_SAMPLE_FORECAST_METHODS: List[str] = [] # TODO: Add `depbu_mint`. SUPPORTED_HIERARCHICAL_METHODS = ( HIERARCHICAL_POINT_FORECAST_METHODS + HIERARCHICAL_SAMPLE_FORECAST_METHODS @@ -200,6 +200,8 @@ def _get_r_forecast(self, data: Dict) -> Dict: else: hier_ts = self._hts_pkg.gts(y_bottom_ts, groups=nodes) + assert isinstance(self.params["num_samples"], int) + forecast = self._r_method(hier_ts, r_params) all_forecasts = list(forecast) @@ -244,7 +246,7 @@ def _forecast_dict_to_obj( forecast_dict: Dict, forecast_start_date: pd.Timestamp, item_id: Optional[str], - info: Dict, + info: Optional[Dict], ) -> SampleForecast: samples = np.array(forecast_dict["samples"]) diff --git a/src/gluonts/ext/r_forecast/_predictor.py b/src/gluonts/ext/r_forecast/_predictor.py index fe09bfeb05..09b93417d1 100644 --- a/src/gluonts/ext/r_forecast/_predictor.py +++ b/src/gluonts/ext/r_forecast/_predictor.py @@ -91,7 +91,7 @@ def __init__( self, freq: str, prediction_length: int, - period: int = None, + period: Optional[int] = None, trunc_length: Optional[int] = None, save_info: bool = False, r_file_prefix: str = "", @@ -215,7 +215,7 @@ def _forecast_dict_to_obj( forecast_dict: Dict, forecast_start_date: pd.Timestamp, item_id: Optional[str], - info: Dict, + info: Optional[Dict], ) -> Forecast: """ Returns object of type `gluonts.model.Forecast`. diff --git a/src/gluonts/ext/r_forecast/_univariate_predictor.py b/src/gluonts/ext/r_forecast/_univariate_predictor.py index 5416a79f86..c8d13d7478 100644 --- a/src/gluonts/ext/r_forecast/_univariate_predictor.py +++ b/src/gluonts/ext/r_forecast/_univariate_predictor.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Dict, Optional +from typing import Dict, Optional, Any import numpy as np import pandas as pd @@ -79,7 +79,7 @@ def __init__( freq: str, prediction_length: int, method_name: str = "ets", - period: int = None, + period: Optional[int] = None, trunc_length: Optional[int] = None, save_info: bool = False, params: Dict = dict(), @@ -100,7 +100,7 @@ def __init__( self.method_name = method_name self._r_method = self._robjects.r[method_name] - self.params = { + self.params: Dict[str, Any] = { "prediction_length": self.prediction_length, "frequency": self.period, } @@ -185,7 +185,7 @@ def _forecast_dict_to_obj( forecast_dict: Dict, forecast_start_date: pd.Timestamp, item_id: Optional[str], - info: Dict, + info: Optional[Dict], ) -> QuantileForecast: stats_dict = {"mean": forecast_dict["mean"]} diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index eb6b345ffa..c2924a8171 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import copy import numpy as np @@ -105,11 +105,7 @@ def __init__( else: self.model = self._create_xgboost_model(xgboost_params) self.min_bin_size = min_bin_size - self.sorted_train_preds = None - self.x_train_is_dataframe = None - self.id_to_bins = None - self.preds_to_id = None - self.quantile_dicts = defaultdict(dict) + self.quantile_dicts: Dict[Any, dict] = defaultdict(dict) @staticmethod def _create_xgboost_model(model_params: Optional[dict] = None): @@ -157,6 +153,7 @@ def fit( # doens't like lists if max_sample_size and x_train_is_dataframe: assert max_sample_size > 0 + assert isinstance(x_train, pd.DataFrame) sample_size = min(max_sample_size, len(x_train)) x_train = x_train.sample( n=min(sample_size, len(x_train)), @@ -295,7 +292,7 @@ def preprocess_df(self, df: pd.DataFrame, min_bin_size: int = 100) -> Dict: return dic @classmethod - def get_closest_pt(cls, sorted_list: List, num: int) -> int: + def get_closest_pt(cls, sorted_list: List[float], num: float) -> float: """ Given a sorted list of floats, returns the number closest to num. @@ -365,7 +362,7 @@ def predict( preds = self.model.predict(x_test) predicted_values = [ self._get_and_cache_quantile_computation( - self.get_closest_pt(self.sorted_train_preds, pred), + self.get_closest_pt(self.sorted_train_preds, pred), # type: ignore quantile, ) for pred in preds @@ -381,7 +378,7 @@ def predict( ) predicted_values.append( self._get_and_cache_quantile_computation( - closest_pred, quantile + closest_pred, quantile # type: ignore ) ) return predicted_values diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 860a7e564a..a34f49e350 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -14,7 +14,7 @@ import concurrent.futures import logging from itertools import chain -from typing import Iterator, List, Optional +from typing import Iterator, List, Optional, Any, Dict from toolz import first import numpy as np @@ -58,7 +58,7 @@ def __init__( self.item_id = None self.lead_time = None - def quantile(self, q: float) -> np.ndarray: + def quantile(self, q: float) -> np.ndarray: # type: ignore """ Returns np.array, where the i^th entry is the estimate of the q quantile of the conditional distribution of the value of the i^th step @@ -114,7 +114,7 @@ def __init__( use_past_feat_dynamic_real: bool = False, use_feat_dynamic_real: bool = False, use_feat_dynamic_cat: bool = False, - cardinality: Cardinality = "auto", + cardinality: Cardinality = "auto", # type: ignore one_hot_encode: bool = False, model_params: Optional[dict] = None, max_workers: Optional[int] = None, @@ -177,7 +177,7 @@ def __init__( self.min_bin_size = min_bin_size self.quantiles = quantiles self.model = model - self.model_list = None + self.model_list: Optional[list] = None logger.info( "If using the Evaluator class with a TreePredictor, set" @@ -208,10 +208,10 @@ def train( self.preprocess_object.preprocess_from_list( ts_list=list(training_data), change_internal_variables=True ) - feature_data, target_data = ( - self.preprocess_object.feature_data, - self.preprocess_object.target_data, - ) + + feature_data = self.preprocess_object.feature_data + target_data = np.array(self.preprocess_object.target_data) + n_models = self.prediction_length logging.info(f"Length of forecast horizon: {n_models}") if self.method == "QuantileRegression": @@ -232,6 +232,10 @@ def train( ) for _ in range(n_models) ] + + assert self.model_list is not None + assert feature_data is not None + if train_QRX_only_using_timestep != -1: assert ( 0 @@ -244,7 +248,7 @@ def train( ) self.model_list[train_QRX_only_using_timestep].fit( feature_data, - np.array(target_data)[:, train_QRX_only_using_timestep], + target_data[:, train_QRX_only_using_timestep], ) self.model_list = [ QRX( @@ -256,7 +260,6 @@ def train( else self.model_list[i] for i in range(n_models) ] - target_data = np.array(target_data) with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers ) as executor: @@ -302,7 +305,7 @@ def train( ) return self - def predict( + def predict( # type: ignore self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: """ @@ -372,6 +375,9 @@ def explain( "supported for " "QuantileRegression" ) + + assert self.model_list is not None + importances = np.array( [ [ @@ -399,7 +405,14 @@ def explain( ) num_feat_dynamic_real = self.preprocess_object.num_feat_dynamic_real num_feat_dynamic_cat = self.preprocess_object.num_feat_dynamic_cat - coordinate_map = {} + + assert num_feat_static_real is not None + assert num_feat_static_cat is not None + assert num_past_feat_dynamic_real is not None + assert num_feat_dynamic_real is not None + assert num_feat_dynamic_cat is not None + + coordinate_map: Dict[str, Any] = {} coordinate_map["target"] = (0, dynamic_length) coordinate_map["feat_static_real"] = [ (dynamic_length + i, dynamic_length + i + 1) diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index b8e59e500b..0f39095976 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -15,7 +15,7 @@ import logging from enum import Enum from itertools import chain, starmap -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union, Optional, Sequence import numpy as np @@ -85,8 +85,8 @@ def __init__( self.max_n_datapts = max_n_datapts self.kwargs = kwargs self.num_samples = num_samples - self.feature_data = None - self.target_data = None + self.feature_data: Optional[list] = None + self.target_data: Optional[list] = None if seed is not None: np.random.seed(seed) @@ -150,8 +150,10 @@ def preprocess_from_single_ts(self, time_series: Dict) -> Tuple: if max_num_context_windows < 1: return [[]], [[]] + assert self.num_samples is not None + if self.num_samples > 0: - locations = [ + locations: Sequence[int] = [ np.random.randint(max_num_context_windows) for _ in range(self.num_samples) ] @@ -195,7 +197,7 @@ def infer_feature_characteristics(self, ts): def preprocess_from_list( self, ts_list, change_internal_variables: bool = True - ) -> Tuple: + ) -> Optional[Tuple]: """ Applies self.preprocess_from_single_ts for each time series in ts_list, and collates the results into self.feature_data and self.target_data. @@ -217,10 +219,10 @@ def preprocess_from_list( feature_data, target_data = [], [] self.num_samples = self.get_num_samples(ts_list) - if isinstance(self.cardinality, str): + if isinstance(self.cardinality, str): # type: ignore self.cardinality = ( self.infer_cardinalities(ts_list) - if self.cardinality == "auto" + if self.cardinality == "auto" # type: ignore else [] ) @@ -243,7 +245,9 @@ def preprocess_from_list( ) ) if change_internal_variables: - self.feature_data, self.target_data = feature_data, target_data + self.feature_data = feature_data + self.target_data = target_data + return None else: return feature_data, target_data @@ -271,7 +275,7 @@ def get_num_samples(self, ts_list) -> int: n_windows_per_time_series = -1 return n_windows_per_time_series - def infer_cardinalities(self): + def infer_cardinalities(self, ts_list): raise NotImplementedError @@ -336,7 +340,7 @@ def __init__( @classmethod def _pre_transform( cls, time_series_window, subtract_mean, count_nans - ) -> Tuple: + ) -> list: """ Makes features given time series window. Returns list of features, one for every step of the lag (equaling mean-adjusted lag features); and a diff --git a/src/gluonts/ext/rotbaum/_types.py b/src/gluonts/ext/rotbaum/_types.py index 1c736a4d1d..d8dbf1d019 100644 --- a/src/gluonts/ext/rotbaum/_types.py +++ b/src/gluonts/ext/rotbaum/_types.py @@ -14,7 +14,8 @@ from typing import List, Union, Optional import numpy as np -from pydantic import BaseModel, root_validator + +from gluonts.pydantic import BaseModel, root_validator class FeatureImportanceResult(BaseModel): diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index 0d71ed2dfa..8ed8d245cb 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -142,7 +142,7 @@ class Chain: This is a thin wrapper around ``itertools.chain``. """ - iterables: Collection[Iterable] + iterables: Collection[SizedIterable] def __iter__(self): yield from itertools.chain.from_iterable(self.iterables) @@ -179,7 +179,7 @@ class Fuse: ``Collection``s. """ - collections: List[Collection] + collections: List[Sequence] _lengths: List[int] = field(default_factory=list) def __post_init__(self): @@ -263,7 +263,7 @@ def __repr__(self): return f"Fuse" -def split(xs: Collection, indices: List[int]) -> List[Collection]: +def split(xs: Sequence, indices: List[int]) -> List[Sequence]: """Split ``xs`` into subsets given ``indices``. >>> split("abcdef", [1, 3]) @@ -282,7 +282,7 @@ def split(xs: Collection, indices: List[int]) -> List[Collection]: ] -def split_into(xs: Collection, n: int) -> Collection: +def split_into(xs: Sequence, n: int) -> Sequence: """Split ``xs`` into ``n`` parts of similar size. >>> split_into("abcd", 2) @@ -726,7 +726,7 @@ def replace(values: Sequence[T], idx: int, value: T) -> Sequence[T]: xs = list(values) xs[idx] = value - return type(values)(xs) + return type(values)(xs) # type: ignore def chop( diff --git a/src/gluonts/meta/_version.py b/src/gluonts/meta/_version.py index 44e409800a..b660e5f627 100644 --- a/src/gluonts/meta/_version.py +++ b/src/gluonts/meta/_version.py @@ -60,9 +60,11 @@ def get_version_and_cmdclass(version_file): ) """ +import os import subprocess from pathlib import Path +FALLBACK_VERSION = os.environ.get("GLUONTS_FALLBACK_VERSION", "0.0.0") GIT_DESCRIBE = [ "git", @@ -250,4 +252,4 @@ def make_release_tree(self, base_dir, files): return {"sdist": sdist, "build_py": build_py} -__version__ = get_version(fallback="0.0.0") +__version__ = get_version(fallback=FALLBACK_VERSION) diff --git a/src/gluonts/model/evaluation.py b/src/gluonts/model/evaluation.py index dbbdd2e2fd..ce537c2d88 100644 --- a/src/gluonts/model/evaluation.py +++ b/src/gluonts/model/evaluation.py @@ -62,7 +62,7 @@ def _get_data_batch( seasonality: Optional[int] = None, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, -) -> dict: +) -> ChainMap: forecast_dict = BatchForecast(forecast, allow_nan=allow_nan_forecast) freq = forecast.start_date.freqstr @@ -85,7 +85,7 @@ def _get_data_batch( ), } - return ChainMap(other_data, forecast_dict) + return ChainMap(other_data, forecast_dict) # type: ignore def evaluate_forecasts_raw( diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py index 91ef5f100c..c69ae385fc 100644 --- a/src/gluonts/model/forecast.py +++ b/src/gluonts/model/forecast.py @@ -18,16 +18,17 @@ import numpy as np import pandas as pd -from pydantic.dataclasses import dataclass from gluonts.core.component import validated +from gluonts.pydantic import dataclass from gluonts import maybe + logger = logging.getLogger(__name__) def _linear_interpolation( - xs: np.ndarray, ys: np.ndarray, x: float + xs: List[float], ys: List[np.ndarray], x: float ) -> np.ndarray: assert sorted(xs) == xs assert len(xs) == len(ys) @@ -247,9 +248,12 @@ class Forecast: item_id: Optional[str] info: Optional[Dict] prediction_length: int - mean: np.ndarray _index = None + @property + def mean(self) -> np.ndarray: + raise NotImplementedError() + def quantile(self, q: Union[float, str]) -> np.ndarray: """ Compute a quantile from the predicted distribution. @@ -433,7 +437,7 @@ def __init__( self.samples = samples self._sorted_samples_value = None self._mean = None - self._dim = None + self._dim: Optional[int] = None self.item_id = item_id self.info = info @@ -469,6 +473,7 @@ def mean(self) -> np.ndarray: """ if self._mean is None: self._mean = np.mean(self.samples, axis=0) + assert self._mean is not None return self._mean @property @@ -540,7 +545,7 @@ def to_quantile_forecast(self, quantiles: List[str]) -> "QuantileForecast": return QuantileForecast( forecast_arrays=np.array( [ - self.quantile(q) if q != "mean" else self.mean() + self.quantile(q) if q != "mean" else self.mean for q in quantiles ] ), @@ -593,7 +598,7 @@ def __init__( ] self.item_id = item_id self.info = info - self._dim = None + self._dim: Optional[int] = None shape = self.forecast_array.shape assert shape[0] == len(self.forecast_keys), ( @@ -634,7 +639,9 @@ def quantile(self, inference_quantile: Union[float, str]) -> np.ndarray: return exp_tail_approximation.right(inference_quantile) else: return _linear_interpolation( - quantiles, quantile_predictions, inference_quantile + quantiles, + quantile_predictions, + inference_quantile, ) def copy_dim(self, dim: int) -> "QuantileForecast": diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 32e7d23e56..c4a6f3e38f 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -165,10 +165,12 @@ def __call__( outputs = output_transform(batch, outputs) collected_samples.append(outputs) num_collected_samples += outputs[0].shape[0] - outputs = [ - np.concatenate(s)[:num_samples] - for s in zip(*collected_samples) - ] + outputs = np.stack( + [ + np.concatenate(s)[:num_samples] + for s in zip(*collected_samples) + ] + ) assert len(outputs[0]) == num_samples i = -1 for i, output in enumerate(outputs): diff --git a/src/gluonts/model/npts/.typesafe b/src/gluonts/model/npts/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/model/predictor.py b/src/gluonts/model/predictor.py index 0d1ac5e3d1..1b64124574 100644 --- a/src/gluonts/model/predictor.py +++ b/src/gluonts/model/predictor.py @@ -20,7 +20,7 @@ from pathlib import Path from pydoc import locate from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Callable, Iterator, Optional +from typing import TYPE_CHECKING, Callable, Iterator, Optional, Dict, Any import numpy as np @@ -116,6 +116,7 @@ def deserialize(cls, path: Path, **kwargs) -> "Predictor": tpe = locate(tpe_str) assert tpe is not None, f"Cannot locate {tpe_str}." + assert isinstance(tpe, type) # ensure that predictor_cls is a subtype of Predictor if not issubclass(tpe, Predictor): @@ -179,7 +180,7 @@ def serialize(self, path: Path) -> None: print(dump_json(self), file=fp) @classmethod - def deserialize(cls, path: Path) -> "RepresentablePredictor": + def deserialize(cls, path: Path) -> "RepresentablePredictor": # type: ignore with (path / "predictor.json").open("r") as fp: return load_json(fp.read()) @@ -257,8 +258,8 @@ def __init__( ) self._chunk_size = chunk_size self._num_running_workers = 0 - self._input_queues = [] - self._output_queue = None + self._input_queues: list = [] + self._output_queue: Optional[mp.Queue] = None def _grouper(self, iterable, n): iterator = iter(iterable) @@ -305,7 +306,7 @@ def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]: self._send_idx = 0 self._next_idx = 0 - self._data_buffer = {} + self._data_buffer: Dict[int, Any] = {} worker_ids = list(range(self._num_workers)) diff --git a/src/gluonts/model/seasonal_naive/.typesafe b/src/gluonts/model/seasonal_naive/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/model/trivial/.typesafe b/src/gluonts/model/trivial/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/model/trivial/mean.py b/src/gluonts/model/trivial/mean.py index ae53b51c72..ff9ce17b2c 100644 --- a/src/gluonts/model/trivial/mean.py +++ b/src/gluonts/model/trivial/mean.py @@ -14,7 +14,6 @@ from typing import Optional import numpy as np -from pydantic import PositiveInt from gluonts.core.component import validated from gluonts.dataset.common import DataEntry, Dataset @@ -24,6 +23,7 @@ from gluonts.model.forecast import SampleForecast from gluonts.model.predictor import RepresentablePredictor from gluonts.model.trivial.constant import ConstantPredictor +from gluonts.pydantic import PositiveInt class MeanPredictor(RepresentablePredictor): diff --git a/src/gluonts/mx/batchify.py b/src/gluonts/mx/batchify.py index 0d3807828e..ea1de2adb0 100644 --- a/src/gluonts/mx/batchify.py +++ b/src/gluonts/mx/batchify.py @@ -21,7 +21,7 @@ def pad_to_size( - x: np.array, size: int, axis: int = 0, is_right_pad: bool = True + x: np.ndarray, size: int, axis: int = 0, is_right_pad: bool = True ): """ Pads `xs` with 0 on the right (default) on the specified axis, which is the diff --git a/src/gluonts/mx/block/.typesafe b/src/gluonts/mx/block/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/distribution/.typesafe b/src/gluonts/mx/distribution/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/kernels/.typesafe b/src/gluonts/mx/kernels/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/canonical/_network.py b/src/gluonts/mx/model/canonical/_network.py index eadb1be764..ae0ce01956 100644 --- a/src/gluonts/mx/model/canonical/_network.py +++ b/src/gluonts/mx/model/canonical/_network.py @@ -70,7 +70,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): class CanonicalTrainingNetwork(CanonicalNetworkBase): - def hybrid_forward( + def hybrid_forward( # type: ignore self, F, feat_static_cat: Tensor, # (batch_size, num_features) @@ -121,7 +121,7 @@ def __init__( self.prediction_len = prediction_len self.num_parallel_samples = num_parallel_samples - def hybrid_forward( + def hybrid_forward( # type: ignore self, F, feat_static_cat: Tensor, diff --git a/src/gluonts/mx/model/deep_factor/.typesafe b/src/gluonts/mx/model/deep_factor/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/deepar/.typesafe b/src/gluonts/mx/model/deepar/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/deepstate/.typesafe b/src/gluonts/mx/model/deepstate/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/deepvar/.typesafe b/src/gluonts/mx/model/deepvar/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/deepvar_hierarchical/.typesafe b/src/gluonts/mx/model/deepvar_hierarchical/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/estimator.py b/src/gluonts/mx/model/estimator.py index 65eff45666..1cf992a796 100644 --- a/src/gluonts/mx/model/estimator.py +++ b/src/gluonts/mx/model/estimator.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import NamedTuple, Optional, Type +from typing import NamedTuple, Optional, Type, Union import numpy as np @@ -29,9 +29,9 @@ from gluonts.mx.model.predictor import GluonPredictor from gluonts.mx.trainer import Trainer from gluonts.mx.util import copy_parameters -from gluonts.transform import Transformation +from gluonts.pydantic import ValidationError +from gluonts.transform import Transformation, TransformedDataset from mxnet.gluon import HybridBlock -from pydantic import ValidationError class TrainOutput(NamedTuple): @@ -177,7 +177,9 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data = transformation.apply(training_data) + transformed_training_data: Union[ + TransformedDataset, Cached + ] = transformation.apply(training_data) if cache_data: transformed_training_data = Cached(transformed_training_data) @@ -190,9 +192,9 @@ def train_model( if validation_data is not None: with env._let(max_idle_transforms=max(len(validation_data), 100)): - transformed_validation_data = transformation.apply( - validation_data - ) + transformed_validation_data: Union[ + TransformedDataset, Cached + ] = transformation.apply(validation_data) if cache_data: transformed_validation_data = Cached( transformed_validation_data @@ -243,7 +245,7 @@ def train( def train_from( self, - predictor: Predictor, + predictor: GluonPredictor, training_data: Dataset, validation_data: Optional[Dataset] = None, shuffle_buffer_length: Optional[int] = None, diff --git a/src/gluonts/mx/model/forecast.py b/src/gluonts/mx/model/forecast.py index d6bcafa3ca..1422831af5 100644 --- a/src/gluonts/mx/model/forecast.py +++ b/src/gluonts/mx/model/forecast.py @@ -82,6 +82,7 @@ def mean(self) -> np.ndarray: return self._mean else: self._mean = self.distribution.mean.asnumpy() + assert isinstance(self._mean, np.ndarray) return self._mean @property @@ -107,7 +108,7 @@ def to_sample_forecast(self, num_samples: int = 200) -> SampleForecast: def to_quantile_forecast(self, quantiles: List[Union[float, str]]): return QuantileForecast( forecast_arrays=np.array([self.quantile(q) for q in quantiles]), - forecast_keys=quantiles, + forecast_keys=[str(Quantile.parse(level)) for level in quantiles], start_date=self.start_date, item_id=self.item_id, info=self.info, diff --git a/src/gluonts/mx/model/gp_forecaster/.typesafe b/src/gluonts/mx/model/gp_forecaster/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/gp_forecaster/gaussian_process.py b/src/gluonts/mx/model/gp_forecaster/gaussian_process.py index 13737e4797..50d5d1821d 100644 --- a/src/gluonts/mx/model/gp_forecaster/gaussian_process.py +++ b/src/gluonts/mx/model/gp_forecaster/gaussian_process.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import List, Optional, Tuple, Type +from typing import Optional, Tuple, Type import mxnet as mx import numpy as np @@ -335,7 +335,7 @@ def plot( mean: Optional[Tensor] = None, std: Optional[Tensor] = None, samples: Optional[Tensor] = None, - axis: Optional[List] = None, + axis: Optional[Tuple[float, float, float, float]] = None, ) -> None: """ This method plots the sampled GP distribution at the test points in diff --git a/src/gluonts/mx/model/gpvar/.typesafe b/src/gluonts/mx/model/gpvar/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/lstnet/_network.py b/src/gluonts/mx/model/lstnet/_network.py index 2ffd18c73b..1f8e1b8225 100644 --- a/src/gluonts/mx/model/lstnet/_network.py +++ b/src/gluonts/mx/model/lstnet/_network.py @@ -283,7 +283,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.loss_fn = loss.L1Loss() - def hybrid_forward( + def hybrid_forward( # type: ignore self, F, past_target: Tensor, diff --git a/src/gluonts/mx/model/n_beats/.typesafe b/src/gluonts/mx/model/n_beats/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/n_beats/_ensemble.py b/src/gluonts/mx/model/n_beats/_ensemble.py index 6ccfe19f5f..f1f8f8eb81 100644 --- a/src/gluonts/mx/model/n_beats/_ensemble.py +++ b/src/gluonts/mx/model/n_beats/_ensemble.py @@ -20,7 +20,6 @@ import mxnet as mx import numpy as np -from pydantic import ValidationError from gluonts.core import fqname_for from gluonts.core.component import from_hyperparameters, validated @@ -33,6 +32,7 @@ from gluonts.model.forecast import Forecast, SampleForecast from gluonts.mx.model.predictor import RepresentableBlockPredictor from gluonts.mx.trainer import Trainer +from gluonts.pydantic import ValidationError from ._estimator import NBEATSEstimator from ._network import VALID_LOSS_FUNCTIONS diff --git a/src/gluonts/mx/model/predictor.py b/src/gluonts/mx/model/predictor.py index 1db9d5db5b..4c2b95da1e 100644 --- a/src/gluonts/mx/model/predictor.py +++ b/src/gluonts/mx/model/predictor.py @@ -77,7 +77,7 @@ class GluonPredictor(Predictor): def __init__( self, input_names: List[str], - prediction_net: BlockType, + prediction_net, batch_size: int, prediction_length: int, ctx: mx.Context, @@ -102,7 +102,7 @@ def __init__( self.dtype = dtype @property - def network(self) -> BlockType: + def network(self): return self.prediction_net def hybridize(self, batch: DataBatch) -> None: @@ -238,7 +238,7 @@ def serialize_prediction_net(self, path: Path) -> None: export_symb_block(self.prediction_net, path, "prediction_net") @classmethod - def deserialize( + def deserialize( # type: ignore cls, path: Path, ctx: Optional[mx.Context] = None ) -> "SymbolBlockPredictor": ctx = ctx if ctx is not None else get_mxnet_context() @@ -288,7 +288,7 @@ class RepresentableBlockPredictor(GluonPredictor): def __init__( self, - prediction_net: BlockType, + prediction_net, batch_size: int, prediction_length: int, ctx: mx.Context, @@ -319,6 +319,7 @@ def as_symbol_block_predictor( dataset: Optional[Dataset] = None, ) -> SymbolBlockPredictor: if batch is None: + assert dataset is not None data_loader = InferenceDataLoader( dataset, transform=self.input_transform, @@ -358,7 +359,7 @@ def serialize_prediction_net(self, path: Path) -> None: export_repr_block(self.prediction_net, path, "prediction_net") @classmethod - def deserialize( + def deserialize( # type: ignore cls, path: Path, ctx: Optional[mx.Context] = None ) -> "RepresentableBlockPredictor": ctx = ctx if ctx is not None else get_mxnet_context() diff --git a/src/gluonts/mx/model/renewal/_predictor.py b/src/gluonts/mx/model/renewal/_predictor.py index 5b6930e28f..e5b90b4ffe 100644 --- a/src/gluonts/mx/model/renewal/_predictor.py +++ b/src/gluonts/mx/model/renewal/_predictor.py @@ -65,7 +65,7 @@ class DeepRenewalProcessPredictor(RepresentableBlockPredictor): def __init__( self, - prediction_net: BlockType, + prediction_net, batch_size: int, prediction_length: int, ctx: mx.Context, @@ -92,7 +92,7 @@ def __init__( if input_names is not None: self.input_names = input_names - def predict( + def predict( # type: ignore self, dataset: Dataset, num_samples: Optional[int] = None, @@ -121,7 +121,7 @@ def predict( ) @classmethod - def deserialize( + def deserialize( # type: ignore cls, path: Path, ctx: Optional[mx.Context] = None ) -> "DeepRenewalProcessPredictor": repr_predictor = super().deserialize(path, ctx) diff --git a/src/gluonts/mx/model/seq2seq/.typesafe b/src/gluonts/mx/model/seq2seq/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/simple_feedforward/.typesafe b/src/gluonts/mx/model/simple_feedforward/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/tft/.typesafe b/src/gluonts/mx/model/tft/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/tpp/.typesafe b/src/gluonts/mx/model/tpp/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/transformer/.typesafe b/src/gluonts/mx/model/transformer/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/model/wavenet/.typesafe b/src/gluonts/mx/model/wavenet/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/prelude.py b/src/gluonts/mx/prelude.py index e1820dba56..1bad5cf011 100644 --- a/src/gluonts/mx/prelude.py +++ b/src/gluonts/mx/prelude.py @@ -13,8 +13,10 @@ # flake8: noqa: F401, F403 +from typing import List + from .component import * from .serde import * from .model.forecast_generator import * -__all__ = [] +__all__: List[str] = [] diff --git a/src/gluonts/mx/representation/.typesafe b/src/gluonts/mx/representation/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/trainer/.typesafe b/src/gluonts/mx/trainer/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/mx/trainer/callback.py b/src/gluonts/mx/trainer/callback.py index 2524d9a9e1..002207879d 100644 --- a/src/gluonts/mx/trainer/callback.py +++ b/src/gluonts/mx/trainer/callback.py @@ -21,11 +21,11 @@ import mxnet.gluon.nn as nn import mxnet as mx from mxnet import gluon -from pydantic import BaseModel, PrivateAttr # First-party imports from gluonts.core.component import validated from gluonts.mx.util import copy_parameters +from gluonts.pydantic import BaseModel, PrivateAttr logger = logging.getLogger(__name__) diff --git a/src/gluonts/mx/trainer/learning_rate_scheduler.py b/src/gluonts/mx/trainer/learning_rate_scheduler.py index f7d6a695fc..91ecd43aa8 100644 --- a/src/gluonts/mx/trainer/learning_rate_scheduler.py +++ b/src/gluonts/mx/trainer/learning_rate_scheduler.py @@ -14,7 +14,6 @@ from dataclasses import field from typing import Dict, Any, Optional -from pydantic.dataclasses import dataclass from typing_extensions import Literal import numpy as np @@ -23,6 +22,7 @@ import mxnet.gluon.nn as nn from gluonts.core.component import validated +from gluonts.pydantic import dataclass from .callback import Callback diff --git a/src/gluonts/core/.typesafe b/src/gluonts/nursery/.typeunsafe similarity index 100% rename from src/gluonts/core/.typesafe rename to src/gluonts/nursery/.typeunsafe diff --git a/src/gluonts/nursery/anomaly_detection/supervised_metrics/_precision_recall_utils.py b/src/gluonts/nursery/anomaly_detection/supervised_metrics/_precision_recall_utils.py index f133f77dd5..c6ebbb3d31 100644 --- a/src/gluonts/nursery/anomaly_detection/supervised_metrics/_precision_recall_utils.py +++ b/src/gluonts/nursery/anomaly_detection/supervised_metrics/_precision_recall_utils.py @@ -21,10 +21,10 @@ class PrecisionRecallAndWeights(NamedTuple): - precisions: np.array - recalls: np.array - precision_weights: np.array - recall_weights: np.array + precisions: np.ndarray + recalls: np.ndarray + precision_weights: np.ndarray + recall_weights: np.ndarray def singleton_precision_recall( @@ -63,7 +63,7 @@ def singleton_precision_recall( def precision_recall_curve_per_ts( labels: List[bool], scores: List[float], - thresholds: np.array, + thresholds: np.ndarray, partial_filter: Optional[Callable] = None, singleton_curve: bool = False, precision_recall_fn: Callable = buffered_precision_recall, @@ -126,7 +126,7 @@ def aggregate_precision_recall_curve( label_score_iterable: Iterable An iterable that gives 2-tuples of np.arrays (of identical length), corresponding to `true_labels` and `pred_scores` respectively. - thresholds: np.array + thresholds: np.ndarray An np.array of score thresholds for which to compute precision recall values. If the filter_type argument is provided, these are the threshold values of the filter. If not, they will be applied as a single step hard threshold to diff --git a/src/gluonts/nursery/anomaly_detection/supervised_metrics/bounded_pr_auc.py b/src/gluonts/nursery/anomaly_detection/supervised_metrics/bounded_pr_auc.py index f3acb03633..928396ed59 100644 --- a/src/gluonts/nursery/anomaly_detection/supervised_metrics/bounded_pr_auc.py +++ b/src/gluonts/nursery/anomaly_detection/supervised_metrics/bounded_pr_auc.py @@ -16,16 +16,16 @@ def bounded_pr_auc( - precisions: np.array, recalls: np.array, lower_bound: float = 0 + precisions: np.ndarray, recalls: np.ndarray, lower_bound: float = 0 ) -> float: """ Bounded PR AUC --> AUC when recall > lower_bound. Parameters ---------- - precisions : np.array + precisions : np.ndarray precisions of different thresholds - recalls : np.array + recalls : np.ndarray recalls of different thresholds lower_bound : float lower bound of recalls diff --git a/src/gluonts/pydantic.py b/src/gluonts/pydantic.py new file mode 100644 index 0000000000..99269afea7 --- /dev/null +++ b/src/gluonts/pydantic.py @@ -0,0 +1,74 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""This modules contains pydantic imports, which are used throughout the codebase.""" + +from pydantic import __version__ + +PYDANTIC_V2 = __version__.startswith("2") + +if PYDANTIC_V2: + import pydantic.v1 as pydantic + from pydantic.v1 import ( + BaseConfig, + BaseModel, + create_model, + root_validator, + PositiveInt, + PrivateAttr, + Field, + parse_obj_as, + PositiveFloat, + BaseSettings, + ) + from pydantic.v1.error_wrappers import ValidationError, display_errors + from pydantic.v1.utils import deep_update + from pydantic.v1.dataclasses import dataclass +else: + import pydantic # type: ignore[no-redef] + from pydantic import ( # type: ignore[no-redef, assignment] + BaseConfig, + BaseModel, + create_model, + root_validator, + PositiveInt, + PrivateAttr, + Field, + parse_obj_as, + PositiveFloat, + BaseSettings, + ) + from pydantic.error_wrappers import ValidationError, display_errors # type: ignore[no-redef] + from pydantic.utils import deep_update # type: ignore[no-redef] + from pydantic.dataclasses import dataclass # type: ignore[no-redef] + + +__all__ = [ + "BaseConfig", + "BaseModel", + "BaseSettings", + "Field", + "PositiveFloat", + "PositiveInt", + "PrivateAttr", + "ValidationError", + "__version__", + "create_model", + "dataclass", + "deep_update", + "display_errors", + "parse_obj_as", + "pydantic", + "root_validator", +] diff --git a/src/gluonts/shell/.typesafe b/src/gluonts/shell/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/shell/sagemaker/train.py b/src/gluonts/shell/sagemaker/train.py index af7ddad673..8f9f0eec87 100644 --- a/src/gluonts/shell/sagemaker/train.py +++ b/src/gluonts/shell/sagemaker/train.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import Dict, Optional, Tuple -from pydantic import BaseModel +from gluonts.pydantic import BaseModel from .dyn import install_and_restart from .params import decode_sagemaker_parameters diff --git a/src/gluonts/shell/serve/__init__.py b/src/gluonts/shell/serve/__init__.py index fc8875281f..ac1951ab24 100644 --- a/src/gluonts/shell/serve/__init__.py +++ b/src/gluonts/shell/serve/__init__.py @@ -17,13 +17,13 @@ from typing import List, Optional, Type, Union from flask import Flask -from pydantic import BaseSettings import gluonts from gluonts.core import fqname_for from gluonts.model.estimator import Estimator from gluonts.model.predictor import Predictor from gluonts.shell.env import ServeEnv +from gluonts.pydantic import BaseSettings from .app import make_app diff --git a/src/gluonts/shell/serve/app.py b/src/gluonts/shell/serve/app.py index 66fa38d0d3..f0df9dbe1d 100644 --- a/src/gluonts/shell/serve/app.py +++ b/src/gluonts/shell/serve/app.py @@ -22,14 +22,13 @@ from typing_extensions import Literal from flask import Flask, Response, jsonify, request -from pydantic import BaseModel, Field from gluonts.dataset.common import ListDataset from gluonts.dataset.jsonl import encode_json from gluonts.model.forecast import Forecast, Quantile +from gluonts.pydantic import BaseModel, Field from gluonts.shell.util import forecaster_type_by_name - logger = logging.getLogger("gluonts.serve") diff --git a/src/gluonts/testutil/.typesafe b/src/gluonts/testutil/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/time_feature/.typesafe b/src/gluonts/time_feature/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/time_feature/_base.py b/src/gluonts/time_feature/_base.py index 1bfc97f176..ab6ab20935 100644 --- a/src/gluonts/time_feature/_base.py +++ b/src/gluonts/time_feature/_base.py @@ -17,7 +17,8 @@ import pandas as pd from pandas.tseries import offsets from pandas.tseries.frequencies import to_offset -from pydantic import BaseModel + +from gluonts.pydantic import BaseModel TimeFeature = Callable[[pd.PeriodIndex], np.ndarray] diff --git a/src/gluonts/torch/batchify.py b/src/gluonts/torch/batchify.py index 3ec1ba15fc..04f2010b7f 100644 --- a/src/gluonts/torch/batchify.py +++ b/src/gluonts/torch/batchify.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import List, Optional +from typing import List import numpy as np import torch @@ -19,7 +19,7 @@ from gluonts.dataset.common import DataBatch -def stack(data, device: Optional[torch.device] = None): +def stack(data, device: torch.types.Device = None): if isinstance(data[0], np.ndarray): data = torch.tensor(np.array(data), device=device) elif isinstance(data[0], (list, tuple)): @@ -27,9 +27,7 @@ def stack(data, device: Optional[torch.device] = None): return data -def batchify( - data: List[dict], device: Optional[torch.device] = None -) -> DataBatch: +def batchify(data: List[dict], device: torch.types.Device = None) -> DataBatch: return { key: stack(data=[item[key] for item in data], device=device) for key in data[0].keys() diff --git a/src/gluonts/torch/distributions/__init__.py b/src/gluonts/torch/distributions/__init__.py index e279e023df..375d0abe36 100644 --- a/src/gluonts/torch/distributions/__init__.py +++ b/src/gluonts/torch/distributions/__init__.py @@ -28,7 +28,6 @@ ImplicitQuantileNetworkOutput, ) from .isqf import ISQF, ISQFOutput -from .mqf2 import MQF2Distribution, MQF2DistributionOutput from .negative_binomial import NegativeBinomialOutput from .piecewise_linear import PiecewiseLinear, PiecewiseLinearOutput from .spliced_binned_pareto import ( @@ -53,8 +52,6 @@ "ISQF", "ISQFOutput", "LaplaceOutput", - "MQF2Distribution", - "MQF2DistributionOutput", "NegativeBinomialOutput", "NormalOutput", "PiecewiseLinear", diff --git a/src/gluonts/torch/distributions/binned_uniforms.py b/src/gluonts/torch/distributions/binned_uniforms.py index beba0cddbe..d8672e8028 100644 --- a/src/gluonts/torch/distributions/binned_uniforms.py +++ b/src/gluonts/torch/distributions/binned_uniforms.py @@ -43,9 +43,9 @@ def __init__( self, bins_lower_bound: float, bins_upper_bound: float, - logits: torch.tensor, + logits: torch.Tensor, numb_bins: int = 100, - validate_args: bool = None, + validate_args: Optional[bool] = None, ): assert bins_lower_bound < bins_upper_bound, ( f"bins_lower_bound {bins_lower_bound} needs to less than " @@ -62,7 +62,7 @@ def __init__( super(BinnedUniforms, self).__init__( batch_shape=logits.shape[:-1], - event_shape=logits.shape[-1], + event_shape=logits.shape[-1:], validate_args=validate_args, ) @@ -467,7 +467,7 @@ def __init__( ) @classmethod - def domain_map(cls, logits: torch.Tensor) -> torch.Tensor: + def domain_map(cls, logits: torch.Tensor) -> torch.Tensor: # type: ignore logits = torch.abs(logits) return logits @@ -475,7 +475,7 @@ def domain_map(cls, logits: torch.Tensor) -> torch.Tensor: def distribution( self, distr_args, - loc: Optional[torch.Tensor] = 0, + loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> BinnedUniforms: return self.distr_cls( diff --git a/src/gluonts/torch/distributions/distribution_output.py b/src/gluonts/torch/distributions/distribution_output.py index 3290ab05c8..1380f6bbe6 100644 --- a/src/gluonts/torch/distributions/distribution_output.py +++ b/src/gluonts/torch/distributions/distribution_output.py @@ -180,7 +180,7 @@ class NormalOutput(DistributionOutput): distr_cls: type = Normal @classmethod - def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): + def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): # type: ignore scale = F.softplus(scale) return loc.squeeze(-1), scale.squeeze(-1) @@ -194,7 +194,7 @@ class LaplaceOutput(DistributionOutput): distr_cls: type = Laplace @classmethod - def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): + def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): # type: ignore scale = F.softplus(scale) return loc.squeeze(-1), scale.squeeze(-1) @@ -208,7 +208,7 @@ class BetaOutput(DistributionOutput): distr_cls: type = Beta @classmethod - def domain_map( + def domain_map( # type: ignore cls, concentration1: torch.Tensor, concentration0: torch.Tensor ): epsilon = np.finfo(cls._dtype).eps # machine epsilon @@ -230,7 +230,7 @@ class GammaOutput(DistributionOutput): distr_cls: type = Gamma @classmethod - def domain_map(cls, concentration: torch.Tensor, rate: torch.Tensor): + def domain_map(cls, concentration: torch.Tensor, rate: torch.Tensor): # type: ignore epsilon = np.finfo(cls._dtype).eps # machine epsilon concentration = F.softplus(concentration) + epsilon rate = F.softplus(rate) + epsilon @@ -250,7 +250,7 @@ class PoissonOutput(DistributionOutput): distr_cls: type = Poisson @classmethod - def domain_map(cls, rate: torch.Tensor): + def domain_map(cls, rate: torch.Tensor): # type: ignore rate_pos = F.softplus(rate).clone() return (rate_pos.squeeze(-1),) diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py index 09caf5b93a..a35a2ba1dc 100644 --- a/src/gluonts/torch/distributions/generalized_pareto.py +++ b/src/gluonts/torch/distributions/generalized_pareto.py @@ -157,7 +157,7 @@ def __init__( ) @classmethod - def domain_map( + def domain_map( # type: ignore cls, xi: torch.Tensor, beta: torch.Tensor, @@ -170,7 +170,7 @@ def domain_map( def distribution( self, distr_args, - loc: Optional[torch.Tensor] = 0, + loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> GeneralizedPareto: return self.distr_cls( diff --git a/src/gluonts/torch/distributions/implicit_quantile_network.py b/src/gluonts/torch/distributions/implicit_quantile_network.py index 605b228e4e..1c18b4d446 100644 --- a/src/gluonts/torch/distributions/implicit_quantile_network.py +++ b/src/gluonts/torch/distributions/implicit_quantile_network.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributions import Distribution, Beta +from torch.distributions import Distribution, Beta, constraints from gluonts.core.component import validated from gluonts.torch.distributions import DistributionOutput @@ -114,7 +114,7 @@ class ImplicitQuantileNetwork(Distribution): corresponding outputs. """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} def __init__( self, outputs: torch.Tensor, taus: torch.Tensor, validate_args=None diff --git a/src/gluonts/torch/distributions/isqf.py b/src/gluonts/torch/distributions/isqf.py index 9a5bb924e1..ff375b719b 100644 --- a/src/gluonts/torch/distributions/isqf.py +++ b/src/gluonts/torch/distributions/isqf.py @@ -429,7 +429,7 @@ def cdf_spline(self, z: torch.Tensor) -> torch.Tensor: mask_sum_s0_minus = torch.cat( [ mask_sum_s0[..., 1:], - torch.zeros_like(qk_y_expand, dtype=bool), + torch.zeros_like(qk_y_expand, dtype=torch.bool), ], dim=-1, ) @@ -696,7 +696,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: return sample @property - def batch_shape(self) -> torch.Size(): + def batch_shape(self) -> torch.Size: return self.beta_l.shape @@ -741,7 +741,7 @@ def __init__( } @classmethod - def domain_map( + def domain_map( # type: ignore cls, spline_knots: torch.Tensor, spline_heights: torch.Tensor, @@ -791,7 +791,7 @@ def domain_map( def distribution( self, distr_args, - loc: Optional[torch.Tensor] = 0, + loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> ISQF: """ diff --git a/src/gluonts/torch/distributions/negative_binomial.py b/src/gluonts/torch/distributions/negative_binomial.py index 4af37c3dec..909b20fb1b 100644 --- a/src/gluonts/torch/distributions/negative_binomial.py +++ b/src/gluonts/torch/distributions/negative_binomial.py @@ -69,7 +69,7 @@ class NegativeBinomialOutput(DistributionOutput): distr_cls: type = NegativeBinomial @classmethod - def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor): + def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor): # type: ignore total_count = F.softplus(total_count) return total_count.squeeze(-1), logits.squeeze(-1) diff --git a/src/gluonts/torch/distributions/piecewise_linear.py b/src/gluonts/torch/distributions/piecewise_linear.py index fe1ffa7dc0..cdc3d52758 100644 --- a/src/gluonts/torch/distributions/piecewise_linear.py +++ b/src/gluonts/torch/distributions/piecewise_linear.py @@ -194,7 +194,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: return sample @property - def batch_shape(self) -> torch.Size(): + def batch_shape(self) -> torch.Size: return self.gamma.shape @@ -216,7 +216,7 @@ def __init__(self, num_pieces: int) -> None: ) @classmethod - def domain_map( + def domain_map( # type: ignore cls, gamma: torch.Tensor, slopes: torch.Tensor, @@ -231,7 +231,7 @@ def domain_map( def distribution( self, distr_args, - loc: Optional[torch.Tensor] = 0, + loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> PiecewiseLinear: if scale is None: diff --git a/src/gluonts/torch/distributions/spliced_binned_pareto.py b/src/gluonts/torch/distributions/spliced_binned_pareto.py index c636cb336c..5df3b6822c 100644 --- a/src/gluonts/torch/distributions/spliced_binned_pareto.py +++ b/src/gluonts/torch/distributions/spliced_binned_pareto.py @@ -50,11 +50,11 @@ def __init__( self, bins_lower_bound: float, bins_upper_bound: float, - logits: torch.tensor, - upper_gp_xi: torch.tensor, - upper_gp_beta: torch.tensor, - lower_gp_xi: torch.tensor, - lower_gp_beta: torch.tensor, + logits: torch.Tensor, + upper_gp_xi: torch.Tensor, + upper_gp_beta: torch.Tensor, + lower_gp_xi: torch.Tensor, + lower_gp_beta: torch.Tensor, numb_bins: int = 100, tail_percentile_gen_pareto: float = 0.05, validate_args=None, @@ -99,7 +99,7 @@ def __init__( # TODO: # - need another implementation of the mean dependent on the tails - def log_prob(self, x: torch.tensor, for_training=True): + def log_prob(self, x: torch.Tensor, for_training=True): """ Arguments ---------- @@ -183,7 +183,7 @@ def pdf(self, x): # one tends to train with the log-prob return torch.exp(self.log_prob(x, for_training=False)) - def _inverse_cdf(self, quantiles: torch.tensor): + def _inverse_cdf(self, quantiles: torch.Tensor): """ Inverse cdf of a tensor of quantile `quantiles` 'quantiles' is of shape (*batch_shape) with values between (0.0, 1.0) @@ -225,7 +225,7 @@ def _inverse_cdf(self, quantiles: torch.tensor): return icdf_value - def cdf(self, x: torch.tensor): + def cdf(self, x: torch.Tensor): """ Cumulative density tensor for a tensor of data points `x`. 'x' is expected to be of shape (*batch_shape) @@ -312,7 +312,7 @@ def __init__( ) @classmethod - def domain_map( + def domain_map( # type: ignore cls, logits: torch.Tensor, upper_gp_xi: torch.Tensor, @@ -334,7 +334,7 @@ def domain_map( def distribution( self, distr_args, - loc: Optional[torch.Tensor] = 0, + loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> BinnedUniforms: return self.distr_cls( diff --git a/src/gluonts/torch/distributions/studentT.py b/src/gluonts/torch/distributions/studentT.py index 44dae43682..135f64bf92 100644 --- a/src/gluonts/torch/distributions/studentT.py +++ b/src/gluonts/torch/distributions/studentT.py @@ -64,7 +64,7 @@ class StudentTOutput(DistributionOutput): distr_cls: type = StudentT @classmethod - def domain_map( + def domain_map( # type: ignore cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor ): epsilon = torch.finfo(scale.dtype).eps diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index a3ff655b1f..6328d48759 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -256,7 +256,7 @@ def __init__( } @classmethod - def domain_map( + def domain_map( # type: ignore cls, loc: torch.Tensor, scale: torch.Tensor, @@ -278,6 +278,9 @@ def distribution( ) -> Distribution: (loc, scale) = distr_args + assert isinstance(loc, torch.Tensor) + assert isinstance(scale, torch.Tensor) + return TruncatedNormal( loc=loc, scale=scale, diff --git a/src/gluonts/torch/model/deep_npts/.typesafe b/src/gluonts/torch/model/deep_npts/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/torch/model/deepar/.typesafe b/src/gluonts/torch/model/deepar/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index ba91f0d725..7cca653a15 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import NamedTuple, Optional, Iterable, Dict, Any +from typing import NamedTuple, Optional, Iterable, Dict, Any, Union import logging import numpy as np @@ -24,7 +24,7 @@ from gluonts.itertools import Cached from gluonts.model import Estimator, Predictor from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.transform import Transformation +from gluonts.transform import Transformation, TransformedDataset logger = logging.getLogger(__name__) @@ -156,16 +156,18 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data = transformation.apply( - training_data, is_train=True - ) + transformed_training_data: Union[ + Cached, TransformedDataset + ] = transformation.apply(training_data, is_train=True) if cache_data: transformed_training_data = Cached(transformed_training_data) training_network = self.create_lightning_module() training_data_loader = self.create_training_data_loader( - transformed_training_data, + Cached(transformed_training_data) + if cache_data + else transformed_training_data, training_network, shuffle_buffer_length=shuffle_buffer_length, ) @@ -174,9 +176,9 @@ def train_model( if validation_data is not None: with env._let(max_idle_transforms=max(len(validation_data), 100)): - transformed_validation_data = transformation.apply( - validation_data, is_train=True - ) + transformed_validation_data: Union[ + Cached, TransformedDataset + ] = transformation.apply(validation_data, is_train=True) if cache_data: transformed_validation_data = Cached( transformed_validation_data diff --git a/src/gluonts/torch/model/forecast.py b/src/gluonts/torch/model/forecast.py index fd5cded1c0..0b7aa41d61 100644 --- a/src/gluonts/torch/model/forecast.py +++ b/src/gluonts/torch/model/forecast.py @@ -74,8 +74,9 @@ def mean(self) -> np.ndarray: if self._mean is not None: return self._mean else: - self._mean = self.distribution.mean.cpu().numpy() - return self._mean + _mean = self.distribution.mean.cpu().numpy() + self._mean = _mean + return _mean @property def mean_ts(self) -> pd.Series: @@ -96,7 +97,9 @@ def quantile(self, level: Union[float, str]) -> np.ndarray: def to_sample_forecast(self, num_samples: int = 200) -> SampleForecast: return SampleForecast( - samples=self.distribution.sample((num_samples,)).cpu().numpy(), + samples=self.distribution.sample(torch.Size((num_samples,))) + .cpu() + .numpy(), start_date=self.start_date, item_id=self.item_id, info=self.info, diff --git a/src/gluonts/torch/model/lightning_util.py b/src/gluonts/torch/model/lightning_util.py index 73e2396140..4a631986f8 100644 --- a/src/gluonts/torch/model/lightning_util.py +++ b/src/gluonts/torch/model/lightning_util.py @@ -11,12 +11,8 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from packaging import version - import lightning.pytorch as pl def has_validation_loop(trainer: pl.Trainer): - if version.parse(pl.__version__) < version.parse("2.0.0"): - return trainer._data_connector._val_dataloader_source.is_defined() return trainer.fit_loop.epoch_loop.val_loop._data_source.is_defined() diff --git a/src/gluonts/torch/distributions/mqf2.py b/src/gluonts/torch/model/mqf2/distribution.py similarity index 94% rename from src/gluonts/torch/distributions/mqf2.py rename to src/gluonts/torch/model/mqf2/distribution.py index 8a1df95a90..8e4c4cd24c 100644 --- a/src/gluonts/torch/distributions/mqf2.py +++ b/src/gluonts/torch/model/mqf2/distribution.py @@ -18,8 +18,9 @@ from torch.distributions.normal import Normal from gluonts.core.component import validated - -from .distribution_output import DistributionOutput +from gluonts.itertools import prod +from gluonts.torch.distributions import DistributionOutput +from gluonts.torch.model.mqf2.module import PICNN class MQF2Distribution(torch.distributions.Distribution): @@ -60,7 +61,7 @@ class MQF2Distribution(torch.distributions.Distribution): def __init__( self, - picnn: torch.nn.Module, + picnn: PICNN, hidden_state: torch.Tensor, prediction_length: int, is_energy_score: bool = True, @@ -86,7 +87,7 @@ def __init__( if len(self.hidden_state.shape) > 2 else 1 ) - self.numel_batch = MQF2Distribution.get_numel(self.batch_shape) + self.numel_batch = prod(self.batch_shape) # mean zero and std one mu = torch.tensor( @@ -211,7 +212,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: numel_batch = self.numel_batch prediction_length = self.prediction_length - num_samples_per_batch = MQF2Distribution.get_numel(sample_shape) + num_samples_per_batch = prod(sample_shape) num_samples = num_samples_per_batch * numel_batch hidden_state_repeat = self.hidden_state.repeat_interleave( @@ -265,20 +266,14 @@ def quantile( return result - @staticmethod - def get_numel(tensor_shape: torch.Size) -> int: - # Auxiliary function - # compute number of elements specified in a torch.Size() - return torch.prod(torch.tensor(tensor_shape)).item() - @property def batch_shape(self) -> torch.Size: # last dimension is the hidden state size return self.hidden_state.shape[:-1] @property - def event_shape(self) -> Tuple: - return () + def event_shape(self) -> torch.Size: + return torch.Size() @property def event_dim(self) -> int: @@ -311,18 +306,17 @@ def __init__( self.beta = beta @classmethod - def domain_map( + def domain_map( # type: ignore cls, hidden_state: torch.Tensor, ) -> Tuple: # A null function to be called by ArgProj return () - def distribution( + def distribution( # type: ignore self, picnn: torch.nn.Module, hidden_state: torch.Tensor, - loc: Optional[torch.Tensor] = 0, scale: Optional[torch.Tensor] = None, ) -> MQF2Distribution: distr = self.distr_cls( @@ -339,7 +333,7 @@ def distribution( return distr else: return TransformedMQF2Distribution( - distr, [AffineTransform(loc=loc, scale=scale)] + distr, [AffineTransform(loc=0.0, scale=scale)] ) @property @@ -370,6 +364,8 @@ def scale_input( z = t._inverse(y) scale *= t.scale + assert isinstance(scale, torch.Tensor) + return z, scale def repeat_scale(self, scale: torch.Tensor) -> torch.Tensor: diff --git a/src/gluonts/torch/model/mqf2/estimator.py b/src/gluonts/torch/model/mqf2/estimator.py index 9048efb281..3fc5a05759 100644 --- a/src/gluonts/torch/model/mqf2/estimator.py +++ b/src/gluonts/torch/model/mqf2/estimator.py @@ -13,14 +13,13 @@ from typing import List, Optional, Dict, Any +from gluonts.core.component import validated +from gluonts.time_feature import TimeFeature from gluonts.torch.model.deepar.estimator import DeepAREstimator from gluonts.torch.modules.loss import NegativeLogLikelihood, EnergyScore -from gluonts.torch.distributions import MQF2DistributionOutput - -from . import MQF2MultiHorizonLightningModule -from gluonts.core.component import validated -from gluonts.time_feature import TimeFeature +from .lightning_module import MQF2MultiHorizonLightningModule +from .distribution import MQF2DistributionOutput class MQF2MultiHorizonEstimator(DeepAREstimator): @@ -183,7 +182,7 @@ def __init__( self.threshold_input = threshold_input self.estimate_logdet = estimate_logdet - def create_lightning_module(self) -> MQF2MultiHorizonLightningModule: + def create_lightning_module(self) -> MQF2MultiHorizonLightningModule: # type: ignore return MQF2MultiHorizonLightningModule( loss=self.loss, lr=self.lr, diff --git a/src/gluonts/torch/model/mqf2/icnn_utils.py b/src/gluonts/torch/model/mqf2/icnn_utils.py index 04fd13cd0f..7ca2925b71 100644 --- a/src/gluonts/torch/model/mqf2/icnn_utils.py +++ b/src/gluonts/torch/model/mqf2/icnn_utils.py @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union from cpflows.flows import SequentialFlow, DeepConvexFlow @@ -110,10 +110,10 @@ def get_potential( def forward_transform( self, x: torch.Tensor, - logdet: Optional[torch.Tensor] = 0, + logdet: Optional[Union[float, torch.Tensor]] = 0.0, context: Optional[torch.Tensor] = None, extra: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.estimate_logdet: return self.forward_transform_stochastic( x, logdet, context=context, extra=extra diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 16916c3c41..4e6d1f48e8 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -19,7 +19,8 @@ from gluonts.core.component import validated from gluonts.torch.modules.loss import DistributionLoss, EnergyScore -from . import MQF2MultiHorizonModel + +from .module import MQF2MultiHorizonModel class MQF2MultiHorizonLightningModule(pl.LightningModule): diff --git a/src/gluonts/torch/model/mqf2/module.py b/src/gluonts/torch/model/mqf2/module.py index 582827beeb..53e7632f93 100644 --- a/src/gluonts/torch/model/mqf2/module.py +++ b/src/gluonts/torch/model/mqf2/module.py @@ -17,13 +17,11 @@ from gluonts.core.component import validated from gluonts.torch.model.deepar.module import DeepARModel -from gluonts.torch.distributions import ( - DistributionOutput, - MQF2DistributionOutput, -) from cpflows.flows import ActNorm from cpflows.icnn import PICNN + +from .distribution import MQF2DistributionOutput from .icnn_utils import DeepConvexNet, SequentialNet @@ -38,7 +36,7 @@ def __init__( num_feat_static_real: int, num_feat_static_cat: int, cardinality: List[int], - distr_output: Optional[DistributionOutput] = None, + distr_output: Optional[MQF2DistributionOutput] = None, embedding_dimension: Optional[List[int]] = None, num_layers: int = 2, hidden_size: int = 40, @@ -113,7 +111,7 @@ def __init__( ActNorm(prediction_length), ] - self.picnn = SequentialNet(networks) + self.picnn = SequentialNet(networks) # type: ignore @torch.jit.ignore def output_distribution( diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index b405d2c121..ebc873d1de 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -30,10 +30,10 @@ class SinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int) -> None: super().__init__(num_positions, embedding_dim) - self.weight = self._init_weight(self.weight) + self.weight = self._init_weight(self.weight) # type: ignore @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(out: torch.Tensor) -> torch.Tensor: """ Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ @@ -54,7 +54,7 @@ def _init_weight(out: nn.Parameter) -> nn.Parameter: return out @torch.no_grad() - def forward( + def forward( # type: ignore self, input_ids_shape: torch.Size, past_key_values_length: int = 0 ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen x ...].""" diff --git a/src/gluonts/torch/model/predictor.py b/src/gluonts/torch/model/predictor.py index 58432148a3..59a4fec3be 100644 --- a/src/gluonts/torch/model/predictor.py +++ b/src/gluonts/torch/model/predictor.py @@ -71,7 +71,7 @@ def to(self, device: Union[str, torch.device]) -> "PyTorchPredictor": def network(self) -> nn.Module: return self.prediction_net - def predict( + def predict( # type: ignore self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: inference_data_loader = InferenceDataLoader( @@ -103,11 +103,13 @@ def serialize(self, path: Path) -> None: ) @classmethod - def deserialize( + def deserialize( # type: ignore cls, path: Path, device: Optional[Union[str, torch.device]] = None ) -> "PyTorchPredictor": predictor = super().deserialize(path) + assert isinstance(predictor, cls) + if device is not None: device = resolve_device(device) predictor.to(device) diff --git a/src/gluonts/torch/model/simple_feedforward/.typesafe b/src/gluonts/torch/model/simple_feedforward/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/torch/model/tft/layers.py b/src/gluonts/torch/model/tft/layers.py index d13cdcc633..da3469f062 100644 --- a/src/gluonts/torch/model/tft/layers.py +++ b/src/gluonts/torch/model/tft/layers.py @@ -22,15 +22,13 @@ class FeatureEmbedder(BaseFeatureEmbedder): - def forward(self, features: torch.Tensor) -> List[torch.Tensor]: + def forward(self, features: torch.Tensor) -> List[torch.Tensor]: # type: ignore concat_features = super().forward(features=features) if self._num_features > 1: - features = torch.chunk(concat_features, self._num_features, dim=-1) + return torch.chunk(concat_features, self._num_features, dim=-1) else: - features = [concat_features] - - return features + return [concat_features] class FeatureProjector(nn.Module): @@ -316,7 +314,7 @@ def forward( self, x: torch.Tensor, static: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor, ) -> torch.Tensor: expanded_static = static.repeat( (1, self.context_length + self.prediction_length, 1) diff --git a/src/gluonts/torch/model/tft/module.py b/src/gluonts/torch/model/tft/module.py index 534442738f..72de20e2f9 100644 --- a/src/gluonts/torch/model/tft/module.py +++ b/src/gluonts/torch/model/tft/module.py @@ -92,7 +92,9 @@ def __init__( self.target_proj = nn.Linear(in_features=1, out_features=self.d_var) # Past-only dynamic features if self.d_past_feat_dynamic_real: - self.past_feat_dynamic_proj = FeatureProjector( + self.past_feat_dynamic_proj: Optional[ + FeatureProjector + ] = FeatureProjector( feature_dims=self.d_past_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_past_feat_dynamic_real), @@ -101,7 +103,9 @@ def __init__( self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: - self.past_feat_dynamic_embed = FeatureEmbedder( + self.past_feat_dynamic_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( cardinalities=self.c_past_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_past_feat_dynamic_cat), @@ -111,7 +115,9 @@ def __init__( # Known dynamic features if self.d_feat_dynamic_real: - self.feat_dynamic_proj = FeatureProjector( + self.feat_dynamic_proj: Optional[ + FeatureProjector + ] = FeatureProjector( feature_dims=self.d_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), ) @@ -119,7 +125,9 @@ def __init__( self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: - self.feat_dynamic_embed = FeatureEmbedder( + self.feat_dynamic_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( cardinalities=self.c_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), ) @@ -128,7 +136,9 @@ def __init__( # Static features if self.d_feat_static_real: - self.feat_static_proj = FeatureProjector( + self.feat_static_proj: Optional[ + FeatureProjector + ] = FeatureProjector( feature_dims=self.d_feat_static_real, embedding_dims=[self.d_var] * len(self.d_feat_static_real), ) @@ -136,7 +146,9 @@ def __init__( self.feat_static_proj = None if self.c_feat_static_cat: - self.feat_static_embed = FeatureEmbedder( + self.feat_static_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( cardinalities=self.c_feat_static_cat, embedding_dims=[self.d_var] * len(self.c_feat_static_cat), ) @@ -262,12 +274,12 @@ def _preprocess( self, past_target: torch.Tensor, # [N, T] past_observed_values: torch.Tensor, # [N, T] - feat_static_real: torch.Tensor, # [N, D_sr] - feat_static_cat: torch.Tensor, # [N, D_sc] - feat_dynamic_real: torch.Tensor, # [N, T + H, D_dr] - feat_dynamic_cat: torch.Tensor, # [N, T + H, D_dc] - past_feat_dynamic_real: torch.Tensor, # [N, T, D_pr] - past_feat_dynamic_cat: torch.Tensor, # [N, T, D_pc] + feat_static_real: Optional[torch.Tensor], # [N, D_sr] + feat_static_cat: Optional[torch.Tensor], # [N, D_sc] + feat_dynamic_real: Optional[torch.Tensor], # [N, T + H, D_dr] + feat_dynamic_cat: Optional[torch.Tensor], # [N, T + H, D_dc] + past_feat_dynamic_real: Optional[torch.Tensor], # [N, T, D_pr] + past_feat_dynamic_cat: Optional[torch.Tensor], # [N, T, D_pc] ) -> Tuple[ List[torch.Tensor], List[torch.Tensor], diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index e7fea4b0d7..2da9e87d91 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -54,6 +54,7 @@ PREDICTION_INPUT_NAMES = [ "feat_static_cat", + "feat_static_real", "past_target", "past_observed_values", "past_time_feat", @@ -95,7 +96,7 @@ def __init__( num_batches_per_epoch: int = 50, num_parallel_samples: int = 100, negative_data: bool = False, - trainer_kwargs: Dict[str, Any] = None, + trainer_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """WaveNet estimator that uses the architecture proposed in [Oord et al., 2016] with quantized targets. The model is trained diff --git a/src/gluonts/torch/model/wavenet/lightning_module.py b/src/gluonts/torch/model/wavenet/lightning_module.py index 85dd0a6671..f4e7043ebb 100644 --- a/src/gluonts/torch/model/wavenet/lightning_module.py +++ b/src/gluonts/torch/model/wavenet/lightning_module.py @@ -53,6 +53,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore Execute training step. """ feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] past_target = batch["past_target"] past_observed_values = batch["past_observed_values"] past_time_feat = batch["past_time_feat"] @@ -63,6 +64,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore train_loss = self.model.loss( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_target=past_target, past_observed_values=past_observed_values, past_time_feat=past_time_feat, @@ -87,6 +89,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore Execute validation step. """ feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] past_target = batch["past_target"] past_observed_values = batch["past_observed_values"] past_time_feat = batch["past_time_feat"] @@ -97,6 +100,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore val_loss = self.model.loss( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_target=past_target, past_observed_values=past_observed_values, past_time_feat=past_time_feat, diff --git a/src/gluonts/torch/model/wavenet/module.py b/src/gluonts/torch/model/wavenet/module.py index c34c5203ab..386e6a7e25 100644 --- a/src/gluonts/torch/model/wavenet/module.py +++ b/src/gluonts/torch/model/wavenet/module.py @@ -70,7 +70,7 @@ def __init__( kernel_size=1, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u = self.conv_sigmoid(x) * self.conv_tanh(x) s = self.conv_skip(u) if not self.return_dense_output: @@ -143,6 +143,7 @@ def __init__( + num_feat_dynamic_real + num_feat_static_real + int(use_log_scale_feature) # the log(scale) + + 1 # for observed value indicator ) self.use_log_scale_feature = use_log_scale_feature @@ -185,6 +186,7 @@ def __init__( bias=True, ) with torch.no_grad(): + assert self.conv_project.bias is not None self.conv_project.bias.zero_() self.conv1 = nn.Conv1d( @@ -217,6 +219,7 @@ def get_receptive_field(dilation_depth: int, num_stacks: int) -> int: def get_full_features( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, future_time_feat: torch.Tensor, @@ -230,6 +233,8 @@ def get_full_features( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_observed_values Observed value indicator for the past target: (batch_size, receptive_field) @@ -256,6 +261,7 @@ def get_full_features( static_feat = torch.cat( [static_feat, torch.log(scale + 1.0)], dim=1 ) + static_feat = torch.cat([static_feat, feat_static_real], dim=1) repeated_static_feat = torch.repeat_interleave( static_feat[..., None], self.prediction_length + self.receptive_field, @@ -303,8 +309,7 @@ def target_feature_embedding( def base_net( self, inputs: torch.Tensor, - prediction_mode: bool = False, - queues: List[torch.Tensor] = None, + queues: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward pass through the WaveNet. @@ -313,9 +318,6 @@ def base_net( inputs A tensor of inputs Shape: (batch_size, num_residual_channels, sequence_length) - prediction_mode, optional - Flag indicating whether the network is being used - for prediction, by default False queues, optional Convolutional queues containing past computations. This speeds up predictions and must be provided @@ -324,6 +326,7 @@ def base_net( [Paine et al., 2016] "Fast wavenet generation algorithm." arXiv preprint arXiv:1611.09482 (2016). + Returns ------- A tensor containing the unnormalized outputs of the network of @@ -331,17 +334,12 @@ def base_net( convolutional queues for each layer. The queue corresponding to layer `l` has shape: (batch_size, num_residual_channels, 2^l). """ - if prediction_mode: - assert ( - queues is not None - ), "Queues cannot be empty in prediction mode!" - skip_outs = [] queues_next = [] out = inputs for i, layer in enumerate(self.residuals): skip, out = layer(out) - if prediction_mode: + if queues is not None: trimmed_skip = skip if i + 1 < len(self.residuals): out = torch.cat([queues[i], out], dim=-1) @@ -361,6 +359,7 @@ def base_net( def loss( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, @@ -375,6 +374,8 @@ def loss( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_target Past target: (batch_size, receptive_field) past_observed_values @@ -401,6 +402,7 @@ def loss( full_target = torch.cat([past_target, future_target], dim=-1).long() full_features = self.get_full_features( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_observed_values=past_observed_values, past_time_feat=past_time_feat, future_time_feat=future_time_feat, @@ -410,7 +412,7 @@ def loss( input_embedding = self.target_feature_embedding( target=full_target[..., :-1], features=full_features[..., 1:] ) - logits, _ = self.base_net(input_embedding, prediction_mode=False) + logits, _ = self.base_net(input_embedding) labels = full_target[..., self.receptive_field :] loss_weight = torch.cat( [past_observed_values, future_observed_values], dim=-1 @@ -457,6 +459,7 @@ def _initialize_conv_queues( def forward( self, feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, past_time_feat: torch.Tensor, @@ -472,6 +475,8 @@ def forward( ---------- feat_static_cat Static categorical features: (batch_size, num_cat_features) + feat_static_real + Static real-valued features: (batch_size, num_feat_static_real) past_target Past target: (batch_size, receptive_field) past_observed_values @@ -508,6 +513,7 @@ def forward( past_target = past_target.long() full_features = self.get_full_features( feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, past_observed_values=past_observed_values, past_time_feat=past_time_feat, future_time_feat=future_time_feat, @@ -550,9 +556,7 @@ def forward( current_features, num_parallel_samples, dim=0 ), ) - logits, queues = self.base_net( - input_embedding, prediction_mode=True, queues=queues - ) + logits, queues = self.base_net(input_embedding, queues=queues) if temperature > 0.0: probs = torch.softmax(logits / temperature, dim=-1) diff --git a/src/gluonts/torch/modules/feature.py b/src/gluonts/torch/modules/feature.py index c7e3ae9df9..852d3372c9 100644 --- a/src/gluonts/torch/modules/feature.py +++ b/src/gluonts/torch/modules/feature.py @@ -61,9 +61,11 @@ def __init__( super().__init__() self.T = T - self.embeddings = nn.ModuleDict( - {"embed_static": embed_static, "embed_dynamic": embed_dynamic} - ) + self.embeddings = nn.ModuleDict() + if embed_static is not None: + self.embeddings["embed_static"] = embed_static + if embed_dynamic is not None: + self.embeddings["embed_dynamic"] = embed_dynamic def forward( self, @@ -82,15 +84,14 @@ def forward( return torch.cat(processed_features, dim=-1) def process_static_cat(self, feature: torch.Tensor) -> torch.Tensor: - if self.embeddings["embed_static"] is not None: + if "embed_static" in self.embeddings: feature = self.embeddings["embed_static"](feature) return feature.unsqueeze(1).expand(-1, self.T, -1).float() def process_dynamic_cat(self, feature: torch.Tensor) -> torch.Tensor: - if self.embeddings["embed_dynamic"] is None: - return feature.float() - else: + if "embed_dynamic" in self.embeddings: return self.embeddings["embed_dynamic"](feature) + return feature.float() def process_static_real(self, feature: torch.Tensor) -> torch.Tensor: return feature.unsqueeze(1).expand(-1, self.T, -1) diff --git a/src/gluonts/torch/modules/lookup_table.py b/src/gluonts/torch/modules/lookup_table.py index b412f0ba96..edf362925a 100644 --- a/src/gluonts/torch/modules/lookup_table.py +++ b/src/gluonts/torch/modules/lookup_table.py @@ -32,6 +32,7 @@ def __init__(self, bin_values: torch.Tensor): self.register_buffer("bin_values", bin_values) def forward(self, indices: torch.Tensor) -> torch.Tensor: + assert isinstance(self.bin_values, torch.Tensor) indices = torch.clamp(indices, 0, self.bin_values.shape[0] - 1) return torch.index_select( self.bin_values, 0, indices.reshape(-1) diff --git a/src/gluonts/torch/modules/loss.py b/src/gluonts/torch/modules/loss.py index aa8f829d69..1f91af824e 100644 --- a/src/gluonts/torch/modules/loss.py +++ b/src/gluonts/torch/modules/loss.py @@ -12,7 +12,8 @@ # permissions and limitations under the License. import torch -from pydantic import BaseModel + +from gluonts.pydantic import BaseModel class DistributionLoss(BaseModel): @@ -77,6 +78,7 @@ class CRPS(DistributionLoss): def __call__( self, input: torch.distributions.Distribution, target: torch.Tensor ) -> torch.Tensor: + assert hasattr(input, "crps") return input.crps(target) @@ -84,9 +86,13 @@ class QuantileLoss(DistributionLoss): def __call__( self, input: torch.distributions.Distribution, target: torch.Tensor ) -> torch.Tensor: + assert hasattr(input, "quantile_loss") return input.quantile_loss(target) class EnergyScore(DistributionLoss): - def __call__(self, input, target: torch.Tensor) -> torch.Tensor: + def __call__( + self, input: torch.distributions.Distribution, target: torch.Tensor + ) -> torch.Tensor: + assert hasattr(input, "energy_score") return input.energy_score(target) diff --git a/src/gluonts/torch/modules/quantile_output.py b/src/gluonts/torch/modules/quantile_output.py index b55e5dd0bf..11b6735316 100644 --- a/src/gluonts/torch/modules/quantile_output.py +++ b/src/gluonts/torch/modules/quantile_output.py @@ -43,7 +43,7 @@ def __init__(self, quantiles: List[float]) -> None: def quantiles(self) -> List[float]: return self._quantiles - def domain_map(self, quantiles_pred: torch.Tensor): + def domain_map(self, quantiles_pred: torch.Tensor): # type: ignore return quantiles_pred def quantile_loss( diff --git a/src/gluonts/torch/prelude.py b/src/gluonts/torch/prelude.py index a66ac9fd05..cfb29821a4 100644 --- a/src/gluonts/torch/prelude.py +++ b/src/gluonts/torch/prelude.py @@ -13,7 +13,9 @@ # flake8: noqa: F401, F403 +from typing import List + from .component import * from .model.forecast_generator import * -__all__ = [] +__all__: List[str] = [] diff --git a/src/gluonts/torch/util.py b/src/gluonts/torch/util.py index a23334ef42..12a2a1e4f0 100644 --- a/src/gluonts/torch/util.py +++ b/src/gluonts/torch/util.py @@ -37,7 +37,7 @@ def resolve_device( def copy_parameters( net_source: torch.nn.Module, net_dest: torch.nn.Module, - strict: Optional[bool] = True, + strict: bool = True, ) -> None: """ Copies parameters from one network to another. diff --git a/src/gluonts/transform/.typesafe b/src/gluonts/transform/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/transform/sampler.py b/src/gluonts/transform/sampler.py index a64045fd56..eb687bda92 100644 --- a/src/gluonts/transform/sampler.py +++ b/src/gluonts/transform/sampler.py @@ -14,9 +14,9 @@ from typing import Tuple import numpy as np -from pydantic import BaseModel from gluonts.dataset.stat import ScaleHistogram +from gluonts.pydantic import BaseModel class InstanceSampler(BaseModel): diff --git a/src/gluonts/zebras/.typesafe b/src/gluonts/zebras/.typesafe deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/gluonts/zebras/schema.py b/src/gluonts/zebras/schema.py index 3b08a6b472..5bfa777fb2 100644 --- a/src/gluonts/zebras/schema.py +++ b/src/gluonts/zebras/schema.py @@ -15,15 +15,16 @@ from typing import Any, Callable, Optional, Union, Type, Dict import numpy as np -from pydantic import parse_obj_as, BaseModel from gluonts.itertools import partition +from gluonts.pydantic import parse_obj_as, BaseModel from ._freq import Freq from ._period import Period from ._time_frame import time_frame, TimeFrame from ._split_frame import split_frame, SplitFrame + """ This module provides tooling to extract ``zebras.TimeFrame`` and ``zebras.SplitFrame`` instances from Python dictionaries:: diff --git a/test/core/test_serde_dataclass.py b/test/core/test_serde_dataclass.py index b093266636..324e33c746 100644 --- a/test/core/test_serde_dataclass.py +++ b/test/core/test_serde_dataclass.py @@ -13,9 +13,8 @@ from typing import List -from pydantic import BaseModel - from gluonts.core import serde +from gluonts.pydantic import BaseModel @serde.dataclass diff --git a/test/core/test_serde_flat.py b/test/core/test_serde_flat.py index 12d3702d86..3838dafcb0 100644 --- a/test/core/test_serde_flat.py +++ b/test/core/test_serde_flat.py @@ -11,10 +11,9 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from pydantic import BaseModel - from gluonts.core import serde from gluonts.core.component import equals +from gluonts.pydantic import BaseModel class A(BaseModel): diff --git a/test/core/test_settings.py b/test/core/test_settings.py index 283e13250a..bdc969e809 100644 --- a/test/core/test_settings.py +++ b/test/core/test_settings.py @@ -11,9 +11,8 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from pydantic import BaseModel - from gluonts.core.settings import Settings, let +from gluonts.pydantic import BaseModel class MySettings(Settings): diff --git a/test/ext/naive_2/test_predictors.py b/test/ext/naive_2/test_predictors.py index 25ac8bb4b2..225a895f29 100644 --- a/test/ext/naive_2/test_predictors.py +++ b/test/ext/naive_2/test_predictors.py @@ -18,7 +18,6 @@ import pandas as pd import pytest from flaky import flaky -from pydantic import PositiveInt from gluonts.dataset.artificial import constant_dataset from gluonts.dataset.common import Dataset @@ -27,6 +26,7 @@ from gluonts.ext.naive_2 import Naive2Predictor from gluonts.model.predictor import Predictor from gluonts.model.seasonal_naive import SeasonalNaivePredictor +from gluonts.pydantic import PositiveInt from gluonts.time_feature import get_seasonality diff --git a/test/mx/distribution/test_mx_distribution_inference.py b/test/mx/distribution/test_mx_distribution_inference.py index d8859224dc..d347d0b35c 100644 --- a/test/mx/distribution/test_mx_distribution_inference.py +++ b/test/mx/distribution/test_mx_distribution_inference.py @@ -22,7 +22,6 @@ import mxnet as mx import numpy as np import pytest -from pydantic import PositiveFloat, PositiveInt from gluonts.mx.model.tpp.distribution import ( Loglogistic, @@ -84,6 +83,7 @@ from gluonts.mx.distribution.transformed_distribution_output import ( TransformedDistributionOutput, ) +from gluonts.pydantic import PositiveFloat, PositiveInt pytestmark = pytest.mark.timeout(60) NUM_SAMPLES = 2000 diff --git a/test/mx/test_mx_serde.py b/test/mx/test_mx_serde.py index b2d14b8c38..7f06ca7219 100644 --- a/test/mx/test_mx_serde.py +++ b/test/mx/test_mx_serde.py @@ -17,10 +17,10 @@ import mxnet as mx import numpy as np -from pydantic import BaseModel from gluonts.core import serde from gluonts.core.component import equals +from gluonts.pydantic import BaseModel class CategoricalFeatureInfo(BaseModel): diff --git a/test/torch/model/test_mqf2_modules.py b/test/torch/model/test_mqf2_modules.py index 85fa21337f..a8c2685259 100644 --- a/test/torch/model/test_mqf2_modules.py +++ b/test/torch/model/test_mqf2_modules.py @@ -16,8 +16,8 @@ import pytest import torch -from gluonts.torch.distributions import MQF2DistributionOutput from gluonts.torch.model.mqf2 import MQF2MultiHorizonLightningModule +from gluonts.torch.model.mqf2.distribution import MQF2DistributionOutput @pytest.mark.parametrize( diff --git a/test/torch/modules/test_torch_distribution_inference.py b/test/torch/modules/test_torch_distribution_inference.py index fe3da93496..06616fcd76 100644 --- a/test/torch/modules/test_torch_distribution_inference.py +++ b/test/torch/modules/test_torch_distribution_inference.py @@ -21,7 +21,6 @@ import pytest import torch import torch.nn as nn -from pydantic import PositiveFloat, PositiveInt from scipy.special import softmax from torch.distributions import ( Beta, @@ -35,6 +34,7 @@ from torch.optim import SGD from torch.utils.data import DataLoader, TensorDataset +from gluonts.pydantic import PositiveFloat, PositiveInt from gluonts.torch.distributions import ( BetaOutput, DistributionOutput,