diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ce4e4c7..fe48ad8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -41,7 +41,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -61,7 +61,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -88,7 +88,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9f3c50..6d120ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,46 +13,6 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v2.29.1 - hooks: - - id: pyupgrade - args: [--py38-plus] - -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - args: ["--config=pyproject.toml"] - files: "(reconcile|examples)" - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--settings-path=pyproject.toml"] - files: "(reconcile|examples)" - -- repo: https://github.com/pycqa/bandit - rev: 1.7.1 - hooks: - - id: bandit - language: python - language_version: python3 - types: [python] - args: ["-c", "pyproject.toml"] - additional_dependencies: ["toml"] - files: "(reconcile|examples)" - -- repo: https://github.com/PyCQA/flake8 - rev: 5.0.1 - hooks: - - id: flake8 - additional_dependencies: [ - flake8-typing-imports==1.14.0, - flake8-pyproject==1.1.0.post0 - ] - - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910-1 hooks: @@ -60,14 +20,9 @@ repos: args: ["--ignore-missing-imports"] files: "(reconcile|examples)" -- repo: https://github.com/jorisroovers/gitlint - rev: v0.19.1 - hooks: - - id: gitlint - - id: gitlint-ci - -- repo: https://github.com/pycqa/pydocstyle - rev: 6.1.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 hooks: - - id: pydocstyle - additional_dependencies: ["toml"] + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/Makefile b/Makefile deleted file mode 100644 index f8802c2..0000000 --- a/Makefile +++ /dev/null @@ -1,5 +0,0 @@ -PKG_VERSION=`hatch version` - -tag: - git tag -a v${PKG_VERSION} -m v${PKG_VERSION} - git push --tag diff --git a/examples/reconciliation.py b/examples/reconciliation.py index 79fc5d6..e1c2ab1 100644 --- a/examples/reconciliation.py +++ b/examples/reconciliation.py @@ -1,5 +1,3 @@ -from typing import List - import chex import distrax import gpjax as gpx @@ -7,17 +5,15 @@ import numpy as np import optax import pandas as pd -from chex import Array, PRNGKey from jax import numpy as jnp from jax import random as jr -from jax.config import config from statsmodels.tsa.arima_process import arma_generate_sample from reconcile.forecast import Forecaster from reconcile.grouping import Grouping from reconcile.probabilistic_reconciliation import ProbabilisticReconciliation -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) class GPForecaster(Forecaster): @@ -25,16 +21,18 @@ class GPForecaster(Forecaster): def __init__(self): super().__init__() - self._models: List = [] - self._xs: Array = None - self._ys: Array = None + self._models: list = [] + self._xs: jax.Array = None + self._ys: jax.Array = None @property def data(self): """Returns the data""" return self._ys, self._xs - def fit(self, rng_key: PRNGKey, ys: Array, xs: Array, niter=2000): + def fit( + self, rng_key: jr.PRNGKey, ys: jax.Array, xs: jax.Array, niter=2000 + ): """Fit a model to each of the time series""" self._xs = xs @@ -71,17 +69,20 @@ def _fit_one(self, rng_key, x, y, niter): @staticmethod def _model(rng_key, n): z = jr.uniform(rng_key, (20, 1)) - prior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) - likelihood = gpx.Gaussian(num_datapoints=n) + prior = gpx.gps.Prior( + mean_function=gpx.mean_functions.Constant(), + kernel=gpx.kernels.RBF(), + ) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) posterior = prior * likelihood - q = gpx.CollapsedVariationalGaussian( + q = gpx.variational_families.CollapsedVariationalGaussian( posterior=posterior, inducing_inputs=z, ) - elbo = gpx.CollapsedELBO(negative=True) + elbo = gpx.objectives.CollapsedELBO(negative=True) return elbo, q, likelihood - def posterior_predictive(self, rng_key, xs_test: Array): + def posterior_predictive(self, rng_key, xs_test: jax.Array): """Compute the joint posterior predictive distribution at xs_test.""" chex.assert_rank(xs_test, 3) @@ -109,7 +110,7 @@ def posterior_predictive(self, rng_key, xs_test: Array): return posterior_predictive def predictive_posterior_probability( - self, rng_key: PRNGKey, ys_test: Array, xs_test: Array + self, rng_key: jr.PRNGKey, ys_test: jax.Array, xs_test: jax.Array ): """Compute the log predictive posterior probability of an observation""" chex.assert_rank([ys_test, xs_test], [3, 3]) diff --git a/pyproject.toml b/pyproject.toml index 51dbab3..92a277c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" homepage = "https://github.com/dirmeier/reconcile" keywords = ["probabilistic reconciliation", "forecasting", "timeseries", "hierarchical time series"] classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", @@ -33,11 +33,13 @@ dependencies = [ "pandas>=1.5.1" ] dynamic = ["version"] -packages = [{include = "reconcile"}] [project.urls] homepage = "https://github.com/dirmeier/reconcile" +[tool.hatch.build.targets.wheel] +packages = ["reconcile"] + [tool.hatch.version] path = "reconcile/__init__.py" @@ -50,7 +52,7 @@ exclude = [ [tool.hatch.envs.test] dependencies = [ - "pylint>=2.15.10", + "ruff>=0.3.0", "pytest>=7.2.0", "pytest-cov>=4.0.0", "gpjax>=0.5.0", @@ -58,7 +60,7 @@ dependencies = [ ] [tool.hatch.envs.test.scripts] -lint = 'pylint reconcile' +lint = 'ruff check reconcile examples' test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile' [tool.hatch.envs.examples] @@ -70,47 +72,17 @@ dependencies = [ [tool.hatch.envs.examples.scripts] reconciliation = 'python examples/reconciliation.py' -[tool.black] -line-length = 80 -extend-ignore = "E203" -target-version = ['py39'] -exclude = ''' -/( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - - -[tool.isort] -profile = "black" -line_length = 80 -include_trailing_comma = true - - -[tool.flake8] -max-line-length = 80 -extend-ignore = ["E203", "W503", "E731"] -per-file-ignores = [ - '__init__.py:F401', -] - -[tool.pylint.messages_control] -disable = """ -invalid-name,missing-module-docstring,R0801,E0633 -""" - [tool.bandit] skips = ["B101"] -[tool.pydocstyle] +[tool.ruff] +fix = true +line-length = 80 + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F"] +extend-select = ["UP", "I", "PL", "S"] +ignore =["S101", "PLR2004", "PLR0913", "E2"] + +[tool.ruff.lint.pydocstyle] convention= 'google' -match = '^reconcile/*((?!test).)*\.py' diff --git a/reconcile/forecast.py b/reconcile/forecast.py index dd17f3a..879f719 100644 --- a/reconcile/forecast.py +++ b/reconcile/forecast.py @@ -1,11 +1,10 @@ """Forecasting module.""" import abc -from typing import Tuple import distrax -from chex import PRNGKey from jax import Array +from jax import random as jr class Forecaster(metaclass=abc.ABCMeta): @@ -19,7 +18,7 @@ def __init__(self): @property @abc.abstractmethod - def data(self) -> Tuple[Array, Array]: + def data(self) -> tuple[Array, Array]: """Returns the data set used for training. Returns: @@ -29,7 +28,7 @@ def data(self) -> Tuple[Array, Array]: """ @abc.abstractmethod - def fit(self, rng_key: PRNGKey, ys: Array, xs: Array) -> None: + def fit(self, rng_key: jr.PRNGKey, ys: Array, xs: Array) -> None: """Fit the forecaster to data. Fit a forecaster for each base and upper time series. Can be implemented @@ -49,7 +48,7 @@ def fit(self, rng_key: PRNGKey, ys: Array, xs: Array) -> None: @abc.abstractmethod def posterior_predictive( - self, rng_key: PRNGKey, xs_test: Array + self, rng_key: jr.PRNGKey, xs_test: Array ) -> distrax.Distribution: """Computes the posterior predictive distribution at some input points. @@ -69,7 +68,7 @@ def posterior_predictive( @abc.abstractmethod def predictive_posterior_probability( - self, rng_key: PRNGKey, ys_test: Array, xs_test: Array + self, rng_key: jr.PRNGKey, ys_test: Array, xs_test: Array ) -> Array: """Evaluates the probability of an observation. diff --git a/reconcile/grouping.py b/reconcile/grouping.py index 7aa18a4..f88cd0b 100644 --- a/reconcile/grouping.py +++ b/reconcile/grouping.py @@ -1,6 +1,5 @@ """Grouping module.""" - import warnings from itertools import chain @@ -115,18 +114,12 @@ def _gts_create_g_mat(self): token[i] = [] for i in range(total_len): - token[i].append( - temp_tokens[ - cs[i], - ] - ) + token[i].append(temp_tokens[cs[i],]) if sub_len[i + 1] >= 2: for j in range(1, sub_len[i + 1]): col = self._paste0( token[i][j - 1], - temp_tokens[ - cs[i] + j, - ], + temp_tokens[cs[i] + j,], ) token[i].append(col) token[i] = np.vstack(token[i]) diff --git a/reconcile/probabilistic_reconciliation.py b/reconcile/probabilistic_reconciliation.py index a122b17..6b7823c 100644 --- a/reconcile/probabilistic_reconciliation.py +++ b/reconcile/probabilistic_reconciliation.py @@ -6,7 +6,6 @@ import blackjax import jax import optax -from chex import Array, PRNGKey from flax import linen as nn from flax.training.early_stopping import EarlyStopping from flax.training.train_state import TrainState @@ -30,8 +29,8 @@ def __init__(self, grouping: Grouping, forecaster: Forecaster): def sample_reconciled_posterior_predictive( self, - rng_key: PRNGKey, - xs_test: Array, + rng_key: jr.PRNGKey, + xs_test: jax.Array, n_chains=4, n_iter=2000, n_warmup=1000, @@ -114,8 +113,8 @@ def _step(states, rng_key): def fit_reconciled_posterior_predictive( self, - rng_key: PRNGKey, - xs_test: Array, + rng_key: jr.PRNGKey, + xs_test: jax.Array, n_samples=2000, net: Callable = None, n_iter: int = None, @@ -152,7 +151,7 @@ def fit_reconciled_posterior_predictive( def _projection(output_dim): class _network(nn.Module): @nn.compact - def __call__(self, x: Array): + def __call__(self, x: jax.Array): x = x.swapaxes(-2, -1) x = nn.Sequential( [ @@ -168,7 +167,9 @@ def __call__(self, x: Array): return _network() if net is None else net() - def _loss(y: Array, y_reconciled_0: Array, y_reconciled_1: Array): + def _loss( + y: jax.Array, y_reconciled_0: jax.Array, y_reconciled_1: jax.Array + ): y = y.reshape((1, *y.shape)) y = jnp.tile(y, [y_reconciled_0.shape[0], 1, 1, 1]) lhs = jnp.linalg.norm(y_reconciled_0 - y, axis=2, keepdims=True)