Skip to content

Commit

Permalink
Thin out some unnecessary files (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Jul 20, 2024
1 parent 9d12628 commit 5f84ffb
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 140 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -41,7 +41,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -61,7 +61,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -88,7 +88,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
55 changes: 5 additions & 50 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,16 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(reconcile|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(reconcile|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
- id: bandit
language: python
language_version: python3
types: [python]
args: ["-c", "pyproject.toml"]
additional_dependencies: ["toml"]
files: "(reconcile|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(reconcile|examples)"

- repo: https://github.com/jorisroovers/gitlint
rev: v0.19.1
hooks:
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
5 changes: 0 additions & 5 deletions Makefile

This file was deleted.

31 changes: 16 additions & 15 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
from typing import List

import chex
import distrax
import gpjax as gpx
import jax
import numpy as np
import optax
import pandas as pd
from chex import Array, PRNGKey
from jax import numpy as jnp
from jax import random as jr
from jax.config import config
from statsmodels.tsa.arima_process import arma_generate_sample

from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
from reconcile.probabilistic_reconciliation import ProbabilisticReconciliation

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class GPForecaster(Forecaster):
"""Example implementation of a forecaster"""

def __init__(self):
super().__init__()
self._models: List = []
self._xs: Array = None
self._ys: Array = None
self._models: list = []
self._xs: jax.Array = None
self._ys: jax.Array = None

@property
def data(self):
"""Returns the data"""
return self._ys, self._xs

def fit(self, rng_key: PRNGKey, ys: Array, xs: Array, niter=2000):
def fit(
self, rng_key: jr.PRNGKey, ys: jax.Array, xs: jax.Array, niter=2000
):
"""Fit a model to each of the time series"""

self._xs = xs
Expand Down Expand Up @@ -71,17 +69,20 @@ def _fit_one(self, rng_key, x, y, niter):
@staticmethod
def _model(rng_key, n):
z = jr.uniform(rng_key, (20, 1))
prior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints=n)
prior = gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(),
kernel=gpx.kernels.RBF(),
)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = prior * likelihood
q = gpx.CollapsedVariationalGaussian(
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior,
inducing_inputs=z,
)
elbo = gpx.CollapsedELBO(negative=True)
elbo = gpx.objectives.CollapsedELBO(negative=True)
return elbo, q, likelihood

def posterior_predictive(self, rng_key, xs_test: Array):
def posterior_predictive(self, rng_key, xs_test: jax.Array):
"""Compute the joint posterior predictive distribution at xs_test."""
chex.assert_rank(xs_test, 3)

Expand Down Expand Up @@ -109,7 +110,7 @@ def posterior_predictive(self, rng_key, xs_test: Array):
return posterior_predictive

def predictive_posterior_probability(
self, rng_key: PRNGKey, ys_test: Array, xs_test: Array
self, rng_key: jr.PRNGKey, ys_test: jax.Array, xs_test: jax.Array
):
"""Compute the log predictive posterior probability of an observation"""
chex.assert_rank([ys_test, xs_test], [3, 3])
Expand Down
60 changes: 16 additions & 44 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
homepage = "https://github.com/dirmeier/reconcile"
keywords = ["probabilistic reconciliation", "forecasting", "timeseries", "hierarchical time series"]
classifiers=[
"Development Status :: 1 - Planning",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
Expand All @@ -33,11 +33,13 @@ dependencies = [
"pandas>=1.5.1"
]
dynamic = ["version"]
packages = [{include = "reconcile"}]

[project.urls]
homepage = "https://github.com/dirmeier/reconcile"

[tool.hatch.build.targets.wheel]
packages = ["reconcile"]

[tool.hatch.version]
path = "reconcile/__init__.py"

Expand All @@ -50,15 +52,15 @@ exclude = [

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0",
"gpjax>=0.5.0",
"statsmodels>=0.13.2"
]

[tool.hatch.envs.test.scripts]
lint = 'pylint reconcile'
lint = 'ruff check reconcile examples'
test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile'

[tool.hatch.envs.examples]
Expand All @@ -70,47 +72,17 @@ dependencies = [
[tool.hatch.envs.examples.scripts]
reconciliation = 'python examples/reconciliation.py'

[tool.black]
line-length = 80
extend-ignore = "E203"
target-version = ['py39']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''


[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true


[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731"]
per-file-ignores = [
'__init__.py:F401',
]

[tool.pylint.messages_control]
disable = """
invalid-name,missing-module-docstring,R0801,E0633
"""

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
[tool.ruff]
fix = true
line-length = 80

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = ["UP", "I", "PL", "S"]
ignore =["S101", "PLR2004", "PLR0913", "E2"]

[tool.ruff.lint.pydocstyle]
convention= 'google'
match = '^reconcile/*((?!test).)*\.py'
11 changes: 5 additions & 6 deletions reconcile/forecast.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Forecasting module."""

import abc
from typing import Tuple

import distrax
from chex import PRNGKey
from jax import Array
from jax import random as jr


class Forecaster(metaclass=abc.ABCMeta):
Expand All @@ -19,7 +18,7 @@ def __init__(self):

@property
@abc.abstractmethod
def data(self) -> Tuple[Array, Array]:
def data(self) -> tuple[Array, Array]:
"""Returns the data set used for training.
Returns:
Expand All @@ -29,7 +28,7 @@ def data(self) -> Tuple[Array, Array]:
"""

@abc.abstractmethod
def fit(self, rng_key: PRNGKey, ys: Array, xs: Array) -> None:
def fit(self, rng_key: jr.PRNGKey, ys: Array, xs: Array) -> None:
"""Fit the forecaster to data.
Fit a forecaster for each base and upper time series. Can be implemented
Expand All @@ -49,7 +48,7 @@ def fit(self, rng_key: PRNGKey, ys: Array, xs: Array) -> None:

@abc.abstractmethod
def posterior_predictive(
self, rng_key: PRNGKey, xs_test: Array
self, rng_key: jr.PRNGKey, xs_test: Array
) -> distrax.Distribution:
"""Computes the posterior predictive distribution at some input points.
Expand All @@ -69,7 +68,7 @@ def posterior_predictive(

@abc.abstractmethod
def predictive_posterior_probability(
self, rng_key: PRNGKey, ys_test: Array, xs_test: Array
self, rng_key: jr.PRNGKey, ys_test: Array, xs_test: Array
) -> Array:
"""Evaluates the probability of an observation.
Expand Down
11 changes: 2 additions & 9 deletions reconcile/grouping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Grouping module."""


import warnings
from itertools import chain

Expand Down Expand Up @@ -115,18 +114,12 @@ def _gts_create_g_mat(self):
token[i] = []

for i in range(total_len):
token[i].append(
temp_tokens[
cs[i],
]
)
token[i].append(temp_tokens[cs[i],])
if sub_len[i + 1] >= 2:
for j in range(1, sub_len[i + 1]):
col = self._paste0(
token[i][j - 1],
temp_tokens[
cs[i] + j,
],
temp_tokens[cs[i] + j,],
)
token[i].append(col)
token[i] = np.vstack(token[i])
Expand Down
Loading

0 comments on commit 5f84ffb

Please sign in to comment.