Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates #12

Merged
merged 3 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: |
hatch build

tests:
lints:
runs-on: ubuntu-latest
needs:
- precommit
Expand All @@ -51,22 +51,11 @@ jobs:
- name: Install dependencies
run: |
pip install hatch
- name: Build package
run: |
pip install jaxlib jax
- name: Run pytest
run: |
hatch run test:test
- name: Copy file
- name: Run lints
run: |
cp coverage.xml cobertura.xml
- name: Run codacy-coverage-reporter
uses: codacy/codacy-coverage-reporter-action@v1
with:
project-token: ${{ secrets.CODACY_PROJECT_TOKEN }}
coverage-reports: cobertura.xml
language: Python
lints:
hatch run test:lint

tests:
runs-on: ubuntu-latest
needs:
- precommit
Expand All @@ -82,9 +71,16 @@ jobs:
- name: Install dependencies
run: |
pip install hatch
- name: Run lints
- name: Build package
run: |
hatch run test:lint
pip install jaxlib jax
- name: Run tests
run: |
hatch run test:test
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

examples:
runs-on: ubuntu-latest
Expand Down
14 changes: 11 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ repos:
language: python
language_version: python3
types: [python]
args: ["-c", "pyproject.toml"]
additional_dependencies: ["toml"]
files: "(reconcile|examples)"

- repo: https://github.com/PyCQA/flake8
Expand All @@ -59,7 +61,13 @@ repos:
files: "(reconcile|examples)"

- repo: https://github.com/jorisroovers/gitlint
rev: v0.18.0
rev: v0.19.1
hooks:
- id: gitlint
- id: gitlint-ci
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
[![ci](https://github.com/dirmeier/reconcile/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/reconcile/actions/workflows/ci.yaml)
[![codacy badge](https://app.codacy.com/project/badge/Grade/f0a254348e894c7c85b4e979bc81f1d9)](https://www.codacy.com/gh/dirmeier/reconcile/dashboard?utm_source=github.com&utm_medium=referral&utm_content=dirmeier/reconcile&utm_campaign=Badge_Grade)
[![codacy badge](https://app.codacy.com/project/badge/Coverage/f0a254348e894c7c85b4e979bc81f1d9)](https://www.codacy.com/gh/dirmeier/reconcile/dashboard?utm_source=github.com&utm_medium=referral&utm_content=dirmeier/reconcile&utm_campaign=Badge_Coverage)
[![version](https://img.shields.io/pypi/v/probabilistic-reconciliation.svg?colorB=black&style=flat)](https://pypi.org/project/probabilistic-reconciliation/)

> Probabilistic reconciliation of time series forecasts
Expand All @@ -15,8 +13,8 @@ Reconcile implements probabilistic time series forecast reconciliation methods i
1) Zambon, Lorenzo, Dario Azzimonti, and Giorgio Corani. ["Probabilistic reconciliation of forecasts via importance sampling."](https://doi.org/10.48550/arXiv.2210.02286) arXiv preprint arXiv:2210.02286 (2022).
2) Panagiotelis, Anastasios, et al. ["Probabilistic forecast reconciliation: Properties, evaluation and score optimisation."](https://doi.org/10.1016/j.ejor.2022.07.040) European Journal of Operational Research (2022).

The package implements methods to compute summing/aggregation matrices for grouped and hierarchical time series and reconciliation methods for probabilistic forecasts based on sampling and optimization,
and in the near future also some recent forecasting methods, such as proposed in [Benavoli, *et al.* (2021)](https://doi.org/10.1007/978-3-030-91445-5_2) or [Corani *et al.*, (2020)](https://arxiv.org/abs/2009.08102) via [GPJax](https://github.com/JaxGaussianProcesses/GPJax).
The package implements methods to compute summing/aggregation matrices for grouped and hierarchical time series and reconciliation methods for probabilistic forecasts based on sampling and optimization,
and in the near future also some recent forecasting methods, such as proposed in [Benavoli, *et al.* (2021)](https://doi.org/10.1007/978-3-030-91445-5_2) or [Corani *et al.*, (2020)](https://arxiv.org/abs/2009.08102) via [GPJax](https://github.com/JaxGaussianProcesses/GPJax).

## Examples

Expand Down
49 changes: 24 additions & 25 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from chex import Array, PRNGKey
from jax import numpy as jnp
from jax import random
from jax import random as jr
from jax.config import config
from statsmodels.tsa.arima_process import arma_generate_sample

Expand All @@ -24,6 +24,7 @@ class GPForecaster(Forecaster):
"""Example implementation of a forecaster"""

def __init__(self):
super().__init__()
self._models: List = []
self._xs: Array = None
self._ys: Array = None
Expand All @@ -46,57 +47,55 @@ def fit(self, rng_key: PRNGKey, ys: Array, xs: Array, niter=2000):
for i in np.arange(p):
x, y = xs[:, [i], :], ys[:, [i], :]
# fit a model for each time series
learned_params, _, D = self._fit_one(rng_key, x, y, niter)
opt_posterior, _, D = self._fit_one(rng_key, x, y, niter)
# save the learned parameters and the original data
self._models[i] = learned_params, D
self._models[i] = opt_posterior, D

def _fit_one(self, rng_key, x, y, niter):
# here we use GPs to model the time series
D = gpx.Dataset(X=x.reshape(-1, 1), y=y.reshape(-1, 1))
sgpr, q, likelihood = self._model(rng_key, D.n)
elbo, q, likelihood = self._model(rng_key, D.n)

parameter_state = gpx.initialise(sgpr, rng_key)
negative_elbo = jax.jit(sgpr.elbo(D, negative=True))
negative_elbo = jax.jit(elbo)
optimiser = optax.adam(learning_rate=5e-3)
inference_state = gpx.fit(
opt_posterior, history = gpx.fit(
model=q,
objective=negative_elbo,
parameter_state=parameter_state,
optax_optim=optimiser,
train_data=D,
optim=optimiser,
num_iters=niter,
key=rng_key,
)
learned_params, training_history = inference_state.unpack()
return learned_params, training_history, D
return opt_posterior, history, D

@staticmethod
def _model(rng_key, n):
z = random.uniform(rng_key, (20, 1))
z = jr.uniform(rng_key, (20, 1))
prior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = prior * likelihood
q = gpx.CollapsedVariationalGaussian(
prior=prior,
likelihood=likelihood,
posterior=posterior,
inducing_inputs=z,
)
sgpr = gpx.CollapsedVI(posterior=posterior, variational_family=q)
return sgpr, q, likelihood
elbo = gpx.CollapsedELBO(negative=True)
return elbo, q, likelihood

def posterior_predictive(self, rng_key, xs_test: Array):
"""Compute the joint
posterior predictive distribution of all timeseries at xs_test"""
"""Compute the joint posterior predictive distribution at xs_test."""
chex.assert_rank(xs_test, 3)

q = xs_test.shape[1]
means = [None] * q
covs = [None] * q
for i in np.arange(q):
x_test = xs_test[:, [i], :].reshape(-1, 1)
learned_params, D = self._models[i]
opt_posterior, D = self._models[i]
_, q, likelihood = self._model(rng_key, D.n)
latent_dist = q(learned_params, D)(x_test)
predictive_dist = likelihood(learned_params, latent_dist)
latent_dist = opt_posterior(x_test, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
means[i] = predictive_dist.mean()
cov = jnp.linalg.cholesky(predictive_dist.covariance_matrix)
cov = predictive_dist.scale_tril
covs[i] = cov.reshape((1, *cov.shape))

# here we stack the means and covariance functions of all
Expand Down Expand Up @@ -163,19 +162,19 @@ def run():

forecaster = GPForecaster()
forecaster.fit(
random.PRNGKey(1),
jr.PRNGKey(1),
all_timeseries[:, :, :90],
all_features[:, :, :90],
)

recon = ProbabilisticReconciliation(grouping, forecaster)
# do reconciliation via sampling
_ = recon.sample_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_iter=100, n_warmup=50
jr.PRNGKey(1), all_features, n_iter=100, n_warmup=50
)
# do reconciliation via optimization of the energy score
_ = recon.fit_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_samples=100
jr.PRNGKey(1), all_features, n_samples=100
)


Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ classifiers=[
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
requires-python = ">=3.9"
dependencies = [
"blackjax-nightly>=0.9.6.post127",
"distrax>=0.1.2",
"chex>=0.1.5",
"flax>=0.6.1",
"gpjax>=0.5.9",
"jaxlib>=0.4.18",
"jax>=0.4.18",
"flax>=0.7.3",
"gpjax>=0.6.9",
"optax>=0.1.3",
"pandas>=1.5.1"
]
Expand Down Expand Up @@ -106,3 +110,7 @@ invalid-name,missing-module-docstring,R0801,E0633

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
convention= 'google'
match = '^reconcile/*((?!test).)*\.py'
4 changes: 1 addition & 3 deletions reconcile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
reconcile: Probabilistic reconciliation of time series forecasts
"""
"""reconcile: Probabilistic reconciliation of time series forecasts."""

__version__ = "0.0.4"

Expand Down
Loading