Skip to content

Commit

Permalink
Merge branch 'main' into chex_test_for_rmh
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Sep 11, 2023
2 parents 168e0b2 + 655c36b commit 4f36f9b
Show file tree
Hide file tree
Showing 69 changed files with 2,744 additions and 2,034 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.9
- name: Give PyPI some time to update the index
run: sleep 240
- name: Attempt install from PyPI
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- style
strategy:
matrix:
python-version: [ '3.8', '3.10']
python-version: [ '3.9', '3.11']
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: 2

python:
version: "3.8"
version: "3.9"
install:
- method: pip
path: .
Expand Down
59 changes: 59 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Guidelines for Contributing

Thank you for interested in contributing to Blackjax! We value the following contributions:

- Bug fixes
- Documentation
- High-level sampling algorithms from any family of algorithms: random walk,
hamiltonian monte carlo, sequential monte carlo, variational inference,
inference compilation, etc.
- New building blocks, e.g. new metrics for HMC, integrators, etc.

## How to contribute?

1. Run `pip install -r requirements.txt` to install all the dev
dependencies.
2. Run `pre-commit run --all-files` and `make test` before pushing on the repo; CI should pass if
these pass locally.

## Editing documentations

The Blackjax repository (and [sampling-book](https://github.com/blackjax-devs/sampling-book)) provides examples in the form of Markdown documents. [Jupytext](https://github.com/mwouts/jupytext) can be used by the users to convert their Jupyter notebooks to this format, or convert these documents to Jupyter notebooks. Examples are rendered in the [documentation](https://blackjax-devs.github.io/blackjax/).

### Load examples in a Jupyter notebook

To convert any example file to a Jupyter notebook you can use:

```shell
jupytext docs/examples/your_example_file.md --to notebook
```

you can then interact with the resulting notebook just like with any notebook.

### Convert my Jupyter notebook to markdown

If you implemented your example in a Jupyter notebook you can convert your `.ipynb` file to Markdown using the command below:

```shell
jupytext docs/examples/your_example_notebook.ipynb --to myst
```

Once the example file is converted to a Markdown file, you have two options for editing:

1. Edit the Markdown version as it is a regular Markdown file.
2. Edit the Notebook version, then convert it to a Markdown file once you finish editing with the command above. Jupytext can handle the change if the example has the same file name.

**Please make sure to only commit the Markdown file.**

### Composing Documentation on Sphinx-Doc

We use `Sphinx` to generate documents for this repo. We highly encourage you to check how your changes to the examples are rendered in the documentation:

1. Add your documentation to `docs/examples.rst`
2. Run the command below:

```shell
make build-docs
```

3. Check the generated HTML documentation in `docs/_build`
21 changes: 3 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.PRNGKey(0)
for _ in range(100):
_, rng_key = jax.random.split(rng_key)
state, _ = nuts.step(rng_key, state)
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
```

See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
Expand Down Expand Up @@ -128,22 +128,7 @@ passing parameters.

## Contributions

### What contributions?

We value the following contributions:
- Bug fixes
- Documentation
- High-level sampling algorithms from any family of algorithms: random walk,
hamiltonian monte carlo, sequential monte carlo, variational inference,
inference compilation, etc.
- New building blocks, e.g. new metrics for HMC, integrators, etc.

### How to contribute?

1. Run `pip install -r requirements.txt` to install all the dev
dependencies.
2. Run `pre-commit run --all-files` and `make test` before pushing on the repo; CI should pass if
these pass locally.
Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/main/CONTRIBUTING.md).

## Citing Blackjax

Expand Down
44 changes: 22 additions & 22 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from blackjax._version import __version__

from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .kernels import (
adaptive_tempered_smc,
additive_step_random_walk,
csgld,
elliptical_slice,
ghmc,
hmc,
irmh,
mala,
meads_adaptation,
meanfield_vi,
mgrad_gaussian,
nuts,
orbital_hmc,
pathfinder,
pathfinder_adaptation,
rmh,
sghmc,
sgld,
tempered_smc,
window_adaptation,
)
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.svgd import svgd

__all__ = [
"__version__",
Expand All @@ -42,6 +40,7 @@
"ghmc",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
Expand All @@ -50,6 +49,7 @@
"tempered_smc",
"meanfield_vi", # variational inference
"pathfinder",
"svgd",
"ess", # diagnostics
"rhat",
]
27 changes: 27 additions & 0 deletions blackjax/adaptation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple

from blackjax.types import ArrayTree


class AdaptationResults(NamedTuple):
state: ArrayTree
parameters: dict


class AdaptationInfo(NamedTuple):
state: NamedTuple
info: NamedTuple
adaptation_state: NamedTuple
20 changes: 11 additions & 9 deletions blackjax/adaptation/mass_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
parameters used in Hamiltonian Monte Carlo.
"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

from blackjax.types import Array
from blackjax.types import Array, ArrayLike

__all__ = [
"WelfordAlgorithmState",
Expand Down Expand Up @@ -68,7 +68,7 @@ class MassMatrixAdaptationState(NamedTuple):

def mass_matrix_adaptation(
is_diagonal_matrix: bool = True,
) -> Tuple[Callable, Callable, Callable]:
) -> tuple[Callable, Callable, Callable]:
"""Adapts the values in the mass matrix by computing the covariance
between parameters.
Expand Down Expand Up @@ -111,7 +111,7 @@ def init(n_dims: int) -> MassMatrixAdaptationState:
return MassMatrixAdaptationState(inverse_mass_matrix, wc_state)

def update(
mm_state: MassMatrixAdaptationState, position: Array
mm_state: MassMatrixAdaptationState, position: ArrayLike
) -> MassMatrixAdaptationState:
"""Update the algorithm's state.
Expand Down Expand Up @@ -156,7 +156,7 @@ def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState:
return init, update, final


def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]:
def welford_algorithm(is_diagonal_matrix: bool) -> tuple[Callable, Callable, Callable]:
r"""Welford's online estimator of covariance.
It is possible to compute the variance of a population of values in an
Expand Down Expand Up @@ -203,14 +203,16 @@ def init(n_dims: int) -> WelfordAlgorithmState:
m2 = jnp.zeros((n_dims, n_dims))
return WelfordAlgorithmState(mean, m2, sample_size)

def update(wa_state: WelfordAlgorithmState, value: Array) -> WelfordAlgorithmState:
def update(
wa_state: WelfordAlgorithmState, value: ArrayLike
) -> WelfordAlgorithmState:
"""Update the M2 matrix using the new value.
Parameters
----------
state:
wa_state:
The current state of the Welford Algorithm
position: Array, shape (1,)
value: Array, shape (1,)
The new sample (typically position of the chain) used to update m2
"""
Expand All @@ -229,7 +231,7 @@ def update(wa_state: WelfordAlgorithmState, value: Array) -> WelfordAlgorithmSta

def final(
wa_state: WelfordAlgorithmState,
) -> Tuple[Array, int, Array]:
) -> tuple[Array, int, Array]:
mean, m2, sample_size = wa_state
covariance = m2 / (sample_size - 1)
return covariance, sample_size, mean
Expand Down
Loading

0 comments on commit 4f36f9b

Please sign in to comment.