Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Parameter state #21

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
b505f56
(1) Remove PyTree.
daniel-dodd Feb 7, 2023
bdcfacb
remove `dict`
daniel-dodd Feb 7, 2023
21b2619
Add stuff.
daniel-dodd Feb 7, 2023
ca6b5a3
Update structure.
daniel-dodd Feb 7, 2023
cf76b0a
Update README.md
daniel-dodd Feb 7, 2023
e3640db
Remove base. Move to module.
daniel-dodd Feb 8, 2023
a7f4f65
Update README.md
daniel-dodd Feb 8, 2023
2aa295c
Remove config. Drop distrax.
daniel-dodd Feb 8, 2023
1c22a87
Update README.md
daniel-dodd Feb 8, 2023
1cc7b54
Update README.md
daniel-dodd Feb 8, 2023
7438478
Refactor.
daniel-dodd Feb 8, 2023
8518e61
Merge branch 'equinox' of https://github.com/JaxGaussianProcesses/Jax…
daniel-dodd Feb 8, 2023
fc8010e
Create test_bijectors.py
daniel-dodd Feb 8, 2023
768716c
Create test_objective.py
daniel-dodd Feb 8, 2023
6b1029a
Update.
daniel-dodd Feb 8, 2023
a6d2d45
Test module.
daniel-dodd Feb 9, 2023
a2c6560
Update setup.py
daniel-dodd Feb 9, 2023
857cf8b
Update setup.py
daniel-dodd Feb 9, 2023
9ba5451
Update setup.py
daniel-dodd Feb 9, 2023
d5bf42d
Update setup.py
daniel-dodd Feb 9, 2023
7646a0f
Update fit.py
daniel-dodd Feb 9, 2023
e9eb8d2
Delete params.py
daniel-dodd Feb 9, 2023
6f3f63d
Update.
daniel-dodd Feb 12, 2023
4d037a6
Add test.
daniel-dodd Feb 12, 2023
4aee517
Update.
daniel-dodd Feb 12, 2023
02c5fa1
Update module.py
daniel-dodd Feb 12, 2023
509cf19
Update module.py
daniel-dodd Feb 12, 2023
63cad7f
Update fit.py
daniel-dodd Feb 12, 2023
026d045
Add some checks to fit.
daniel-dodd Feb 12, 2023
2765c30
Update checks.
daniel-dodd Feb 12, 2023
d36ad74
Add default trainable status to param field.
daniel-dodd Feb 12, 2023
fad45c0
Update module.py
daniel-dodd Feb 12, 2023
0fb7b8d
Update module.py
daniel-dodd Feb 12, 2023
d1a2d3d
Update bijectors.py
daniel-dodd Feb 12, 2023
3089fcb
Update bijectors.py
daniel-dodd Feb 12, 2023
fad89c6
Remove legacy verify dataset.
daniel-dodd Feb 12, 2023
0737314
Update fit.py
daniel-dodd Feb 12, 2023
daf5d22
Add compilation message to progress bar.
daniel-dodd Feb 12, 2023
15fce76
Leaf node checking.
daniel-dodd Feb 12, 2023
4253485
Update module.py
daniel-dodd Feb 12, 2023
51b33ff
Switch from functional to leaf approach
daniel-dodd Feb 12, 2023
7f4a7f6
Switch to metadata.
daniel-dodd Feb 12, 2023
580a36c
Add minimal verbose scan. Need to generalise and cleanup!
daniel-dodd Feb 12, 2023
29d44f6
Update module.py
daniel-dodd Feb 12, 2023
79dace9
Fix `jaxutils.fit` `__init__.py` import, add documentation to `module…
daniel-dodd Feb 14, 2023
c2cf3c2
Bugfix.
daniel-dodd Feb 17, 2023
fcdd52e
Update.
daniel-dodd Feb 19, 2023
6971942
Update module.py
daniel-dodd Feb 19, 2023
3767684
Fix list tuple breakages (SEE NOTE).
daniel-dodd Feb 19, 2023
f753734
Minimal functionality for accessing and setting trainables, transform…
daniel-dodd Feb 20, 2023
9997805
Push minimal meta functionality.
daniel-dodd Feb 22, 2023
6a6c244
Merge branch 'equinox' of https://github.com/JaxGaussianProcesses/Jax…
daniel-dodd Feb 22, 2023
ce52640
Python 3.8 does not like "|"
daniel-dodd Feb 22, 2023
c289133
Update test_module.py
daniel-dodd Feb 22, 2023
a47c8dd
Update test_module.py
daniel-dodd Feb 22, 2023
08920bf
Redesign PyTree and introduce Module.
daniel-dodd Mar 7, 2023
75574d9
Remove some redundant code.
daniel-dodd Mar 8, 2023
731f46e
ParameterState.
daniel-dodd Mar 9, 2023
3710e9c
Add stuff to __init__
daniel-dodd Mar 9, 2023
aadf1d0
Minimal test for `fit`.
daniel-dodd Mar 9, 2023
7b63890
Update fit tests.
daniel-dodd Mar 9, 2023
c95114e
Update setup.py
daniel-dodd Mar 9, 2023
2a67a50
Add basic dict methods
thomaspinder Mar 9, 2023
49559a4
Add priors
thomaspinder Mar 10, 2023
fd97b43
Add log prior density fn
thomaspinder Mar 10, 2023
37bb0ee
Add unit test
thomaspinder Mar 10, 2023
00109f7
Merge pull request #18 from JaxGaussianProcesses/priors
thomaspinder Mar 10, 2023
9f7a32a
Parameters unit tests
thomaspinder Mar 13, 2023
33a7c6a
Test key-structure for bijectors and distributions
thomaspinder Mar 13, 2023
de692bf
Test key-structure for bijectors and distributions
thomaspinder Mar 13, 2023
1c33661
Remove npm-files
thomaspinder Mar 13, 2023
92c9396
Action TODOs
thomaspinder Mar 13, 2023
dba579c
Lint and format files
thomaspinder Mar 13, 2023
c506a20
Add combine fn.
thomaspinder Mar 14, 2023
57f46d7
Add parameter method
thomaspinder Mar 14, 2023
9b16787
Augment add_parameter
thomaspinder Mar 15, 2023
5698dd7
simplify fit
thomaspinder Mar 19, 2023
d91c6fb
simplify fit
thomaspinder Mar 19, 2023
eaae509
Update tests
thomaspinder Mar 19, 2023
ee7cb47
Fix failing test
thomaspinder Mar 23, 2023
a4a2fa1
Fix failing test
thomaspinder Mar 23, 2023
f8d2288
Fix failing test
thomaspinder Mar 23, 2023
ae153b2
Switch PyTree version
thomaspinder Mar 23, 2023
0d34df5
Update workflow
thomaspinder Mar 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ orbs:
jobs:
build-and-test:
docker:
- image: cimg/python:3.8.0
- image: cimg/python:3.9.0
steps:
- checkout
- run:
Expand All @@ -16,6 +16,9 @@ jobs:
- python/install-packages:
pkg-manager: pip-dist
path-args: .[dev]
- run:
name: Simple PyTree
command: pip install https://github.com/cgarciae/simple-pytree/archive/refs/heads/improve-new-handling.zip
- run:
name: Run tests
command: pytest --cov=./ --cov-report=xml
Expand Down
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.0.0
hooks:
- id: autoflake
args: ["--in-place", "--remove-unused-variables", "--remove-all-unused-imports", "--recursive"]
name: AutoFlake
description: "Format with AutoFlake"
stages: [commit]
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.254'
hooks:
- id: ruff
args: ['--fix']
78 changes: 61 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,90 @@

[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)

`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>
[`JaxUtils`](https://github.com/JaxGaussianProcesses/JaxUtils) is a lightweight library built on [`Equinox`](https://github.com/patrick-kidger/equinox) purposed to provide clean (and fast) model training functionality. This library also serves as a backend for the [`JaxGaussianProcesses`]() ecosystem.</h2>


# Contents

- [PyTree](#pytree)
- [Overview](#overview)
- [Module] (#module)
- [Objective] (#objective)
- [Vscan] (#vscan)
- [Fit] (#fit)
- [Bijectors](#bijectors)
- [Dataset](#dataset)

# PyTree
# Overview

## Overview
## Linear Model example.

`JaxUtils` is designed....


## Linear Model example.

`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.
We fit a simple one-dimensional linear regression model with a `weight` and a `bias` parameter.

### (1) Dataset

```python
class MyClass(jaxutils.PyTree):
...
# Import dependancies.
import jaxutils as ju
import jax.numpy as jnp
import jax.random as jr
import optax as ox
import matplotlib.pyplot as plt

# Simulate labels.
key = jr.PRNGKey(42)
X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1)
y = 2.0 * X + 1.0 + jr.normal(key, X.shape)

# Create dataset object.
D = ju.Dataset(X, y)
```

## Example
### (2) Model

A model is defined through inheriting from the `JaxUtils`'s `Module` object.
```python
import jaxutils
class LinearModel(ju.Module):
weight: float = ju.param(ju.Identity)
bias: float = ju.param(ju.Identity)

from jaxtyping import Float, Array
def __call__(self, x):
return self.weight * x + self.bias

class Line(jaxutils.PyTree):
def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
self.gradient = gradient
self.intercept = intercept
model = LinearModel(weight=1.0, bias=1.0)
```
The parameters are marked via the `param` field, whose argument is the default `Bijector` transformation for mapping the parameters to the unconstrained space for optimisation. In this case both of our `weight` and `bias` parameters are defined on the reals, so we use the `Identity` transform. Just like in typicall `Equinox` code, we can (optionally) define a foward pass of the model through the `__call__` method.

### (3) Objective

We can define any objective function, such as the mean squared error, via inheriting from the `Objective` object as follows.
```python
class MeanSquaredError(ju.Objective):

def evaluate(self, model: LinearModel, train_data: ju.Dataset) -> float:
return jnp.mean((train_data.y - model(train_data.X)) ** 2)

def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
return x * self.gradient + self.intercept
loss = MeanSquaredError()
```

### (4) Train!

We are now ready to train our model. This can simply be done using the `fit` callable.
```python
# Optimisation loop.
model, hist = ju.fit(model=model, objective=loss, train_data=D, optim=optim, num_iters=1000)
```


# Dataset

## Overview

`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.
`jaxutils.Dataset` is a dataset abstraction.

## Example

Expand Down
36 changes: 12 additions & 24 deletions jaxutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,12 @@
# limitations under the License.
# ==============================================================================


from .pytree import PyTree
from .data import Dataset, verify_dataset
from .dict import (
concat_dictionaries,
merge_dictionaries,
sort_dictionary,
dict_array_coercion,
)
from .parameters import (
ParameterState,
initialise,
recursive_items,
recursive_complete,
)
from .parameters import Parameters
from .bijectors import Identity, Softplus, FillScaleTriL
from .dataset import Dataset
from .fit import fit, get_batch
from .scan import vscan

__authors__ = "Thomas Pinder, Daniel Dodd"
__license__ = "MIT"
Expand All @@ -39,19 +30,16 @@
"https://github.com//JaxGaussianProcesses/JaxUtils/graphs/contributors"
)


__all__ = [
"PyTree",
"Parameters",
"Identity",
"Softplus",
"FillScaleTriL",
"Dataset",
"verify_dataset",
"concat_dictionaries",
"merge_dictionaries",
"sort_dictionary",
"dict_array_coercion",
"ParameterState",
"initialise",
"recursive_items",
"recursive_complete",
"fit",
"get_batch",
"vscan",
]

from . import _version
Expand Down
Loading