diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 317aa88..ce4e4c7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,7 +35,7 @@ jobs: run: | hatch build - tests: + lints: runs-on: ubuntu-latest needs: - precommit @@ -51,22 +51,11 @@ jobs: - name: Install dependencies run: | pip install hatch - - name: Build package - run: | - pip install jaxlib jax - - name: Run pytest - run: | - hatch run test:test - - name: Copy file + - name: Run lints run: | - cp coverage.xml cobertura.xml - - name: Run codacy-coverage-reporter - uses: codacy/codacy-coverage-reporter-action@v1 - with: - project-token: ${{ secrets.CODACY_PROJECT_TOKEN }} - coverage-reports: cobertura.xml - language: Python - lints: + hatch run test:lint + + tests: runs-on: ubuntu-latest needs: - precommit @@ -82,9 +71,16 @@ jobs: - name: Install dependencies run: | pip install hatch - - name: Run lints + - name: Build package run: | - hatch run test:lint + pip install jaxlib jax + - name: Run tests + run: | + hatch run test:test + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} examples: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3c0a3e..e9f3c50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,6 +40,8 @@ repos: language: python language_version: python3 types: [python] + args: ["-c", "pyproject.toml"] + additional_dependencies: ["toml"] files: "(reconcile|examples)" - repo: https://github.com/PyCQA/flake8 @@ -59,7 +61,13 @@ repos: files: "(reconcile|examples)" - repo: https://github.com/jorisroovers/gitlint - rev: v0.18.0 + rev: v0.19.1 hooks: - - id: gitlint - - id: gitlint-ci + - id: gitlint + - id: gitlint-ci + +- repo: https://github.com/pycqa/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle + additional_dependencies: ["toml"] diff --git a/README.md b/README.md index ff318bb..7cbe714 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,6 @@ [![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept) [![ci](https://github.com/dirmeier/reconcile/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/reconcile/actions/workflows/ci.yaml) -[![codacy badge](https://app.codacy.com/project/badge/Grade/f0a254348e894c7c85b4e979bc81f1d9)](https://www.codacy.com/gh/dirmeier/reconcile/dashboard?utm_source=github.com&utm_medium=referral&utm_content=dirmeier/reconcile&utm_campaign=Badge_Grade) -[![codacy badge](https://app.codacy.com/project/badge/Coverage/f0a254348e894c7c85b4e979bc81f1d9)](https://www.codacy.com/gh/dirmeier/reconcile/dashboard?utm_source=github.com&utm_medium=referral&utm_content=dirmeier/reconcile&utm_campaign=Badge_Coverage) [![version](https://img.shields.io/pypi/v/probabilistic-reconciliation.svg?colorB=black&style=flat)](https://pypi.org/project/probabilistic-reconciliation/) > Probabilistic reconciliation of time series forecasts @@ -15,8 +13,8 @@ Reconcile implements probabilistic time series forecast reconciliation methods i 1) Zambon, Lorenzo, Dario Azzimonti, and Giorgio Corani. ["Probabilistic reconciliation of forecasts via importance sampling."](https://doi.org/10.48550/arXiv.2210.02286) arXiv preprint arXiv:2210.02286 (2022). 2) Panagiotelis, Anastasios, et al. ["Probabilistic forecast reconciliation: Properties, evaluation and score optimisation."](https://doi.org/10.1016/j.ejor.2022.07.040) European Journal of Operational Research (2022). -The package implements methods to compute summing/aggregation matrices for grouped and hierarchical time series and reconciliation methods for probabilistic forecasts based on sampling and optimization, -and in the near future also some recent forecasting methods, such as proposed in [Benavoli, *et al.* (2021)](https://doi.org/10.1007/978-3-030-91445-5_2) or [Corani *et al.*, (2020)](https://arxiv.org/abs/2009.08102) via [GPJax](https://github.com/JaxGaussianProcesses/GPJax). +The package implements methods to compute summing/aggregation matrices for grouped and hierarchical time series and reconciliation methods for probabilistic forecasts based on sampling and optimization, +and in the near future also some recent forecasting methods, such as proposed in [Benavoli, *et al.* (2021)](https://doi.org/10.1007/978-3-030-91445-5_2) or [Corani *et al.*, (2020)](https://arxiv.org/abs/2009.08102) via [GPJax](https://github.com/JaxGaussianProcesses/GPJax). ## Examples diff --git a/examples/reconciliation.py b/examples/reconciliation.py index 77a00f8..79fc5d6 100644 --- a/examples/reconciliation.py +++ b/examples/reconciliation.py @@ -9,7 +9,7 @@ import pandas as pd from chex import Array, PRNGKey from jax import numpy as jnp -from jax import random +from jax import random as jr from jax.config import config from statsmodels.tsa.arima_process import arma_generate_sample @@ -24,6 +24,7 @@ class GPForecaster(Forecaster): """Example implementation of a forecaster""" def __init__(self): + super().__init__() self._models: List = [] self._xs: Array = None self._ys: Array = None @@ -46,44 +47,42 @@ def fit(self, rng_key: PRNGKey, ys: Array, xs: Array, niter=2000): for i in np.arange(p): x, y = xs[:, [i], :], ys[:, [i], :] # fit a model for each time series - learned_params, _, D = self._fit_one(rng_key, x, y, niter) + opt_posterior, _, D = self._fit_one(rng_key, x, y, niter) # save the learned parameters and the original data - self._models[i] = learned_params, D + self._models[i] = opt_posterior, D def _fit_one(self, rng_key, x, y, niter): # here we use GPs to model the time series D = gpx.Dataset(X=x.reshape(-1, 1), y=y.reshape(-1, 1)) - sgpr, q, likelihood = self._model(rng_key, D.n) + elbo, q, likelihood = self._model(rng_key, D.n) - parameter_state = gpx.initialise(sgpr, rng_key) - negative_elbo = jax.jit(sgpr.elbo(D, negative=True)) + negative_elbo = jax.jit(elbo) optimiser = optax.adam(learning_rate=5e-3) - inference_state = gpx.fit( + opt_posterior, history = gpx.fit( + model=q, objective=negative_elbo, - parameter_state=parameter_state, - optax_optim=optimiser, + train_data=D, + optim=optimiser, num_iters=niter, + key=rng_key, ) - learned_params, training_history = inference_state.unpack() - return learned_params, training_history, D + return opt_posterior, history, D @staticmethod def _model(rng_key, n): - z = random.uniform(rng_key, (20, 1)) + z = jr.uniform(rng_key, (20, 1)) prior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) likelihood = gpx.Gaussian(num_datapoints=n) posterior = prior * likelihood q = gpx.CollapsedVariationalGaussian( - prior=prior, - likelihood=likelihood, + posterior=posterior, inducing_inputs=z, ) - sgpr = gpx.CollapsedVI(posterior=posterior, variational_family=q) - return sgpr, q, likelihood + elbo = gpx.CollapsedELBO(negative=True) + return elbo, q, likelihood def posterior_predictive(self, rng_key, xs_test: Array): - """Compute the joint - posterior predictive distribution of all timeseries at xs_test""" + """Compute the joint posterior predictive distribution at xs_test.""" chex.assert_rank(xs_test, 3) q = xs_test.shape[1] @@ -91,12 +90,12 @@ def posterior_predictive(self, rng_key, xs_test: Array): covs = [None] * q for i in np.arange(q): x_test = xs_test[:, [i], :].reshape(-1, 1) - learned_params, D = self._models[i] + opt_posterior, D = self._models[i] _, q, likelihood = self._model(rng_key, D.n) - latent_dist = q(learned_params, D)(x_test) - predictive_dist = likelihood(learned_params, latent_dist) + latent_dist = opt_posterior(x_test, train_data=D) + predictive_dist = opt_posterior.posterior.likelihood(latent_dist) means[i] = predictive_dist.mean() - cov = jnp.linalg.cholesky(predictive_dist.covariance_matrix) + cov = predictive_dist.scale_tril covs[i] = cov.reshape((1, *cov.shape)) # here we stack the means and covariance functions of all @@ -163,7 +162,7 @@ def run(): forecaster = GPForecaster() forecaster.fit( - random.PRNGKey(1), + jr.PRNGKey(1), all_timeseries[:, :, :90], all_features[:, :, :90], ) @@ -171,11 +170,11 @@ def run(): recon = ProbabilisticReconciliation(grouping, forecaster) # do reconciliation via sampling _ = recon.sample_reconciled_posterior_predictive( - random.PRNGKey(1), all_features, n_iter=100, n_warmup=50 + jr.PRNGKey(1), all_features, n_iter=100, n_warmup=50 ) # do reconciliation via optimization of the energy score _ = recon.fit_reconciled_posterior_predictive( - random.PRNGKey(1), all_features, n_samples=100 + jr.PRNGKey(1), all_features, n_samples=100 ) diff --git a/pyproject.toml b/pyproject.toml index 769ab32..51dbab3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,16 +15,20 @@ classifiers=[ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] requires-python = ">=3.9" dependencies = [ "blackjax-nightly>=0.9.6.post127", "distrax>=0.1.2", "chex>=0.1.5", - "flax>=0.6.1", - "gpjax>=0.5.9", + "jaxlib>=0.4.18", + "jax>=0.4.18", + "flax>=0.7.3", + "gpjax>=0.6.9", "optax>=0.1.3", "pandas>=1.5.1" ] @@ -106,3 +110,7 @@ invalid-name,missing-module-docstring,R0801,E0633 [tool.bandit] skips = ["B101"] + +[tool.pydocstyle] +convention= 'google' +match = '^reconcile/*((?!test).)*\.py' diff --git a/reconcile/__init__.py b/reconcile/__init__.py index 7dcf7ee..e66570b 100644 --- a/reconcile/__init__.py +++ b/reconcile/__init__.py @@ -1,6 +1,4 @@ -""" -reconcile: Probabilistic reconciliation of time series forecasts -""" +"""reconcile: Probabilistic reconciliation of time series forecasts.""" __version__ = "0.0.4" diff --git a/reconcile/forecast.py b/reconcile/forecast.py index 905cbd0..dd17f3a 100644 --- a/reconcile/forecast.py +++ b/reconcile/forecast.py @@ -1,29 +1,28 @@ +"""Forecasting module.""" + import abc from typing import Tuple import distrax -from chex import Array, PRNGKey +from chex import PRNGKey +from jax import Array class Forecaster(metaclass=abc.ABCMeta): - """ - Forecast base class + """Forecast base class. Needs to be inherited for using a custom forecaster """ def __init__(self): - pass + """Construct a forecaster.""" @property @abc.abstractmethod def data(self) -> Tuple[Array, Array]: - """ - Returns the data set used for training + """Returns the data set used for training. - Returns - ------- - Tuple + Returns: returns a tuple consisting of two chex.Arrays where the first element are the time series (Y), and the second element are the features (X) @@ -31,48 +30,38 @@ def data(self) -> Tuple[Array, Array]: @abc.abstractmethod def fit(self, rng_key: PRNGKey, ys: Array, xs: Array) -> None: - """ - Fit the forecaster to data + """Fit the forecaster to data. Fit a forecaster for each base and upper time series. Can be implemented - as global model or by fitting one model per time series - - Parameters - ---------- - rng_key: chex.PRNGKey - a key for random number generation - ys: chex.Array - a (1 x P x N)-dimensional array of time series measurements where - the second axis (P) corresponds to the different time series - and the last axis (N) are measurements at different time points - xs: chex.Array - a (1 x P x N)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (N) are the time points for which measurements - are taken + as global model or by fitting one model per time series. + + Args: + rng_key: a key for random number generation + ys: a (1 x P x N)-dimensional array of time series measurements + where the second axis (P) corresponds to the different time + series and the last axis (N) are measurements at different time + points + xs: a (1 x P x N)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (N) are the time points for which measurements + are taken """ @abc.abstractmethod def posterior_predictive( self, rng_key: PRNGKey, xs_test: Array ) -> distrax.Distribution: - """ - Computes the posterior predictive distribution at some input points - - Parameters - ---------- - rng_key: chex.PRNGKey - a key for random number generation - xs_test: chex.Array - a (1 x P x M)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (M) are the time points for which measurements - are to be predicted. The second axis, P, needs to have as many - elements as the original training data - - Returns - ------- - distrax.Distribution + """Computes the posterior predictive distribution at some input points. + + Args: + rng_key: a key for random number generation + xs_test: a (1 x P x M)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (M) are the time points for which measurements + are to be predicted. The second axis, P, needs to have as many + elements as the original training data + + Return: returns a distrax Distribution with batch shape (,P) and event shape (,M), such that a single sample has shape (P, M) and multiple samples have shape (S, P, M) @@ -82,29 +71,22 @@ def posterior_predictive( def predictive_posterior_probability( self, rng_key: PRNGKey, ys_test: Array, xs_test: Array ) -> Array: - """ - Evaluates the probability of an observation - - Parameters - ---------- - rng_key: chex.PRNGKey - a key for random number generation - ys_test: chex.Array - a (1 x P x M)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (M) are the time points for which measurements - are to be predicted. The second axis, P, needs to have as many - elements as the original training data - xs_test: chex.Array - a (1 x P x M)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (M) are the time points for which measurements - are to be predicted. The second axis, P, needs to have as many - elements as the original training data - - Returns - ------- - chex.Array + """Evaluates the probability of an observation. + + Args: + rng_key: a key for random number generation + ys_test: a (1 x P x M)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (M) are the time points for which measurements + are to be predicted. The second axis, P, needs to have as many + elements as the original training data + xs_test: a (1 x P x M)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (M) are the time points for which measurements + are to be predicted. The second axis, P, needs to have as many + elements as the original training data + + Returns: returns a chex Array of size P with the log predictive probability of the data given a fit """ diff --git a/reconcile/grouping.py b/reconcile/grouping.py index 5ef3c6f..7aa18a4 100644 --- a/reconcile/grouping.py +++ b/reconcile/grouping.py @@ -1,3 +1,6 @@ +"""Grouping module.""" + + import warnings from itertools import chain @@ -8,26 +11,20 @@ # pylint: disable=missing-function-docstring -def as_list(maybe_list): +def _as_list(maybe_list): return maybe_list if isinstance(maybe_list, list) else [maybe_list] # pylint: disable=missing-function-docstring,too-many-locals,unnecessary-comprehension # noqa: E501 class Grouping: - """ - Class that represents a grouping/hierarchy of a grouped or hierarchical - time series - """ + """Class that represents a grouping/hierarchy of a time series.""" def __init__(self, groups: pd.DataFrame): - """ - Initialize a grouping + """Initialize a grouping. - Parameters - ---------- - groups: pd.DataFrame + Args: + groups: pd.DataFrame """ - self._p = groups.shape[0] self._groups = groups self._group_names = list(groups.columns) @@ -40,43 +37,53 @@ def __init__(self, groups: pd.DataFrame): else: out_edges_per_level, labels, _ = self._hts_create_nodes() gmat = self._hts_create_g_mat(out_edges_per_level) - labels = [as_list(labels[key]) for key in sorted(labels.keys())] + labels = [_as_list(labels[key]) for key in sorted(labels.keys())] self._labels = list(chain(*labels)) self._s_matrix = self._smatrix(gmat) self._n_all_timeseries = self._s_matrix.shape[0] def all_timeseries_column_names(self): + """Getter for column names of all time series.""" return self._labels def bottom_timeseries_column_names(self): + """Getter for column names of bottom time series.""" return self._labels[self.n_upper_timeseries :] @property def n_groups(self): + """Getter for number of groups.""" return self._groups.shape[1] @property def n_all_timeseries(self): + """Getter for number of all time series.""" return self._n_all_timeseries @property def n_bottom_timeseries(self): + """Getter for number of bottom time series.""" return self._p @property def n_upper_timeseries(self): + """Getter for number of upper time series.""" return self.n_all_timeseries - self.n_bottom_timeseries def all_timeseries(self, b: jnp.ndarray): + """Getter for all time series.""" return jnp.einsum("...ijk,jl->...ilk", b, self._s_matrix.T.toarray()) def summing_matrix(self): + """Getter for the summing matrix.""" return self._s_matrix def extract_bottom_timeseries(self, y): + """Getter for the bottom time series.""" return y[:, self.n_upper_timeseries :, :] def upper_time_series(self, b): + """Getter for upper time series.""" y = self.all_timeseries(b) return y[:, : self.n_upper_timeseries, :] @@ -85,11 +92,11 @@ def _paste0(a, b): return np.array([":".join([e, k]) for e, k in zip(a, b)]) def _gts_create_g_mat(self): - """ - Compute the 'G Matrix'. This is a direct transpilation of the method + """Compute the G Matrix. + + This is a direct transpilation of the method 'CreateGmat' of the R package 'hts' (version 6.0.2) """ - total_len = len(self._group_names) sub_len = [0] for group_name in self._group_names: diff --git a/reconcile/probabilistic_reconciliation.py b/reconcile/probabilistic_reconciliation.py index c97149b..a122b17 100644 --- a/reconcile/probabilistic_reconciliation.py +++ b/reconcile/probabilistic_reconciliation.py @@ -1,3 +1,5 @@ +"""Probabilistic reconciliation module.""" + import logging from typing import Callable @@ -9,7 +11,7 @@ from flax.training.early_stopping import EarlyStopping from flax.training.train_state import TrainState from jax import numpy as jnp -from jax import random +from jax import random as jr from reconcile.forecast import Forecaster from reconcile.grouping import Grouping @@ -19,11 +21,10 @@ # pylint: disable=too-many-arguments,too-many-locals,arguments-differ class ProbabilisticReconciliation: - """ - Probabilistic reconcilation of hierarchical time series class - """ + """Probabilistic reconcilation of hierarchical time series class.""" def __init__(self, grouping: Grouping, forecaster: Forecaster): + """Construct a ProbabilisticReconciliation object.""" self._forecaster = forecaster self._grouping = grouping @@ -35,41 +36,31 @@ def sample_reconciled_posterior_predictive( n_iter=2000, n_warmup=1000, ): - """ - Probabilistic reconciliation using Markov Chain Monte Carlo + """Probabilistic reconciliation using Markov Chain Monte Carlo. Compute the reconciled bottom time series forecast by sampling from the joint density of bottom and upper predictive densities. The implementation and method loosely follow [1]_ but is not the same method (!). - Parameters - ---------- - rng_key: chex.PRNGKey - a key for random number generation - xs_test: chex.Array - a (1 x P x N)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (N) are the time points for which predictions - are made. The second axis, P, needs to have as many elements as the - original training data - n_chains: int - number of chains to sample from - n_iter: int - number of samples to take per chain - n_warmup: int - number of samples to discard as burn-in from the chain - - Returns - ------- - chex.Array + Args: + rng_key: a key for random number generation + xs_test: a (1 x P x N)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (N) are the time points for which predictions + are made. The second axis, P, needs to have as many elements as + the original training data + n_chains: number of chains to sample from + n_iter: number of samples to take per chain + n_warmup: number of samples to discard as burn-in from the chain + + Returns: returns a posterior sample of shape (n_iter x n_chains x P x N) representing the reconciled bottom time series forecast - References - ---------- - .. [1] Zambon, Lorenzo, et al. "Probabilistic reconciliation of - forecasts via importance sampling." arXiv:2210.02286 (2022). + References: + .. [1] Zambon, Lorenzo, et al. "Probabilistic reconciliation of + forecasts via importance sampling." arXiv:2210.02286 (2022). """ def _logprob_fn(b): @@ -86,7 +77,7 @@ def _logprob_fn(b): def lp(x): return _logprob_fn(**x) - curr_key, rng_key = random.split(rng_key, 2) + curr_key, rng_key = jr.split(rng_key, 2) initial_positions = self._forecaster.posterior_predictive( curr_key, xs_test, @@ -95,7 +86,7 @@ def lp(x): "b": self._grouping.extract_bottom_timeseries(initial_positions) } - init_keys = random.split(rng_key, n_chains) + init_keys = jr.split(rng_key, n_chains) warmup = blackjax.window_adaptation(blackjax.nuts, lp) initial_states, kernel_params = jax.vmap( lambda seed, param: warmup.run(seed, param)[0] @@ -107,11 +98,11 @@ def lp(x): def _inference_loop(rng_key, kernel, initial_state, num_samples): @jax.jit def _step(states, rng_key): - keys = jax.random.split(rng_key, n_chains) + keys = jr.split(rng_key, n_chains) states, infos = jax.vmap(kernel)(keys, states) return states, (states, infos) - curr_keys = jax.random.split(rng_key, num_samples) + curr_keys = jr.split(rng_key, num_samples) _, (states, _) = jax.lax.scan(_step, initial_state, curr_keys) return states @@ -129,42 +120,33 @@ def fit_reconciled_posterior_predictive( net: Callable = None, n_iter: int = None, ): - """ - Probabilistic reconciliation using energy score optimization + """Probabilistic reconciliation using energy score optimization. Compute the reconciled bottom time series forecast by optimization of an energy score. The implementation and method loosely follow [1]_ but is not the exactly same method. - Parameters - ---------- - rng_key: chex.PRNGKey - a key for random number generation - xs_test: chex.Array - a (1 x P x N)-dimensional array of time points where - the second axis (P) corresponds to the different time series - and the last axis (N) are the time points for which predictions - are made. The second axis, P, needs to have as many elements as the - original training data - n_samples: int - number of samples to return - net: Callable - a flax neural network that is used for the projection or None to use - the linear projection from [1] - n_iter: int - number of iterations to train the network or None for early stopping - - Returns - ------- - chex.Array + Args: + rng_key: a key for random number generation + xs_test: a (1 x P x N)-dimensional array of time points where + the second axis (P) corresponds to the different time series + and the last axis (N) are the time points for which predictions + are made. The second axis, P, needs to have as many elements as + the original training data + n_samples: number of samples to return + net: a flax neural network that is used for the projection or None + to use the linear projection from [1] + n_iter: number of iterations to train the network or None for + early stopping + + Returns: returns a posterior sample of shape (n_samples x P x N) representing the reconciled bottom time series forecast - References - ---------- - .. [1] Panagiotelis, Anastasios, et al. "Probabilistic forecast - reconciliation: Properties, evaluation and score optimisation." - European Journal of Operational Research (2022). + References: + .. [1] Panagiotelis, Anastasios, et al. "Probabilistic forecast + reconciliation: Properties, evaluation and score + optimisation." European Journal of Operational Research (2022). """ def _projection(output_dim): @@ -228,7 +210,7 @@ def loss_fn(params): early_stop = EarlyStopping(min_delta=0.1, patience=5) itr = 0 while True: - sample_key, rng_key = random.split(rng_key) + sample_key, rng_key = jr.split(rng_key) y_predictive_batch = predictive.sample( seed=sample_key, sample_shape=(batch_size, 2),