diff --git a/.circleci/config.yml b/.circleci/config.yml index d33688e..a0d6a88 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ orbs: jobs: build-and-test: docker: - - image: cimg/python:3.8.0 + - image: cimg/python:3.9.0 steps: - checkout - run: @@ -16,6 +16,9 @@ jobs: - python/install-packages: pkg-manager: pip-dist path-args: .[dev] + - run: + name: Simple PyTree + command: pip install https://github.com/cgarciae/simple-pytree/archive/refs/heads/improve-new-handling.zip - run: name: Run tests command: pytest --cov=./ --cov-report=xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bfb814b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/PyCQA/autoflake + rev: v2.0.0 + hooks: + - id: autoflake + args: ["--in-place", "--remove-unused-variables", "--remove-all-unused-imports", "--recursive"] + name: AutoFlake + description: "Format with AutoFlake" + stages: [commit] + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.0.254' + hooks: + - id: ruff + args: ['--fix'] \ No newline at end of file diff --git a/README.md b/README.md index 3a3b0bf..d22e170 100644 --- a/README.md +++ b/README.md @@ -2,46 +2,90 @@ [![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master) -`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem. +[`JaxUtils`](https://github.com/JaxGaussianProcesses/JaxUtils) is a lightweight library built on [`Equinox`](https://github.com/patrick-kidger/equinox) purposed to provide clean (and fast) model training functionality. This library also serves as a backend for the [`JaxGaussianProcesses`]() ecosystem. + # Contents -- [PyTree](#pytree) +- [Overview](#overview) +- [Module] (#module) +- [Objective] (#objective) +- [Vscan] (#vscan) +- [Fit] (#fit) +- [Bijectors](#bijectors) - [Dataset](#dataset) -# PyTree +# Overview -## Overview +## Linear Model example. + +`JaxUtils` is designed.... + + +## Linear Model example. -`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows. +We fit a simple one-dimensional linear regression model with a `weight` and a `bias` parameter. + +### (1) Dataset ```python -class MyClass(jaxutils.PyTree): - ... +# Import dependancies. +import jaxutils as ju +import jax.numpy as jnp +import jax.random as jr +import optax as ox +import matplotlib.pyplot as plt + +# Simulate labels. +key = jr.PRNGKey(42) +X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1) +y = 2.0 * X + 1.0 + jr.normal(key, X.shape) +# Create dataset object. +D = ju.Dataset(X, y) ``` -## Example +### (2) Model +A model is defined through inheriting from the `JaxUtils`'s `Module` object. ```python -import jaxutils +class LinearModel(ju.Module): + weight: float = ju.param(ju.Identity) + bias: float = ju.param(ju.Identity) -from jaxtyping import Float, Array + def __call__(self, x): + return self.weight * x + self.bias -class Line(jaxutils.PyTree): - def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None - self.gradient = gradient - self.intercept = intercept +model = LinearModel(weight=1.0, bias=1.0) +``` +The parameters are marked via the `param` field, whose argument is the default `Bijector` transformation for mapping the parameters to the unconstrained space for optimisation. In this case both of our `weight` and `bias` parameters are defined on the reals, so we use the `Identity` transform. Just like in typicall `Equinox` code, we can (optionally) define a foward pass of the model through the `__call__` method. + +### (3) Objective + +We can define any objective function, such as the mean squared error, via inheriting from the `Objective` object as follows. +```python +class MeanSquaredError(ju.Objective): + + def evaluate(self, model: LinearModel, train_data: ju.Dataset) -> float: + return jnp.mean((train_data.y - model(train_data.X)) ** 2) - def y(self, x: Float[Array, "N"]) -> Float[Array, "N"] - return x * self.gradient + self.intercept +loss = MeanSquaredError() ``` +### (4) Train! + +We are now ready to train our model. This can simply be done using the `fit` callable. +```python +# Optimisation loop. +model, hist = ju.fit(model=model, objective=loss, train_data=D, optim=optim, num_iters=1000) +``` + + # Dataset ## Overview -`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction. +`jaxutils.Dataset` is a dataset abstraction. ## Example diff --git a/jaxutils/__init__.py b/jaxutils/__init__.py index 8fb732f..ef1b894 100644 --- a/jaxutils/__init__.py +++ b/jaxutils/__init__.py @@ -13,21 +13,12 @@ # limitations under the License. # ============================================================================== - from .pytree import PyTree -from .data import Dataset, verify_dataset -from .dict import ( - concat_dictionaries, - merge_dictionaries, - sort_dictionary, - dict_array_coercion, -) -from .parameters import ( - ParameterState, - initialise, - recursive_items, - recursive_complete, -) +from .parameters import Parameters +from .bijectors import Identity, Softplus, FillScaleTriL +from .dataset import Dataset +from .fit import fit, get_batch +from .scan import vscan __authors__ = "Thomas Pinder, Daniel Dodd" __license__ = "MIT" @@ -39,19 +30,16 @@ "https://github.com//JaxGaussianProcesses/JaxUtils/graphs/contributors" ) - __all__ = [ "PyTree", + "Parameters", + "Identity", + "Softplus", + "FillScaleTriL", "Dataset", - "verify_dataset", - "concat_dictionaries", - "merge_dictionaries", - "sort_dictionary", - "dict_array_coercion", - "ParameterState", - "initialise", - "recursive_items", - "recursive_complete", + "fit", + "get_batch", + "vscan", ] from . import _version diff --git a/jaxutils/_version.py b/jaxutils/_version.py index 4c681d4..c61a5e0 100644 --- a/jaxutils/_version.py +++ b/jaxutils/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -61,17 +60,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -87,10 +87,14 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError: e = sys.exc_info()[1] @@ -125,15 +129,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -192,7 +202,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -201,7 +211,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -209,24 +219,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -248,8 +265,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -257,10 +273,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -275,8 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -316,17 +340,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -335,10 +358,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -387,8 +412,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -417,8 +441,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -579,11 +602,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -607,9 +632,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -623,8 +652,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -633,13 +661,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -653,6 +684,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/jaxutils/bijectors.py b/jaxutils/bijectors.py new file mode 100644 index 0000000..4676542 --- /dev/null +++ b/jaxutils/bijectors.py @@ -0,0 +1,37 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import distrax as dx +import jax.numpy as jnp +import tensorflow_probability.substrates.jax.bijectors as tfb + +Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) + +Softplus = dx.Lambda( + forward=lambda x: jnp.log(1 + jnp.exp(x)), + inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), +) + +FillScaleTriL = dx.Chain( + [ + tfb.FillScaleTriL(diag_shift=jnp.array(1e-6)), + ] +) + +__all__ = [ + "Identity", + "Softplus", + "FillScaleTriL", +] diff --git a/jaxutils/config.py b/jaxutils/config.py deleted file mode 100644 index cabbbfd..0000000 --- a/jaxutils/config.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax -import distrax as dx -import jax.numpy as jnp -import jax.random as jr -import tensorflow_probability.substrates.jax.bijectors as tfb -from ml_collections import ConfigDict - -__config = None - -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) -Softplus = dx.Lambda( - forward=lambda x: jnp.log(1 + jnp.exp(x)), - inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), -) - - -def reset_global_config() -> None: - global __config - __config = get_default_config() - - -def get_global_config() -> ConfigDict: - """Get the global config file used within GPJax. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - global __config - - if __config is None: - __config = get_default_config() - return __config - - # If the global config is available, check if the x64 state has changed - x64_state = jax.config.x64_enabled - - # If the x64 state has not changed, return the existing global config - if x64_state is __config.x64_state: - return __config - - # If the x64 state has changed, return the updated global config - update_x64_sensitive_settings() - return __config - - -def update_x64_sensitive_settings() -> None: - """Update the global config if x64 state changes.""" - global __config - - # Update the x64 state - x64_state = jax.config.x64_enabled - __config.x64_state = x64_state - - # Update the x64 sensitive bijectors - FillScaleTriL = dx.Chain( - [ - tfb.FillScaleTriL(diag_shift=jnp.array(__config.jitter)), - ] - ) - - transformations = __config.transformations - transformations.triangular_transform = FillScaleTriL - - -def get_default_config() -> ConfigDict: - """Construct and return the default config file. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - - config = ConfigDict(type_safe=False) - config.key = jr.PRNGKey(123) - - # Set the x64 state - config.x64_state = jax.config.x64_enabled - - # Covariance matrix stabilising jitter - config.jitter = 1e-6 - - FillScaleTriL = dx.Chain( - [ - tfb.FillScaleTriL(diag_shift=jnp.array(config.jitter)), - ] - ) - - # Default bijections - config.transformations = transformations = ConfigDict() - transformations.positive_transform = Softplus - transformations.identity_transform = Identity - transformations.triangular_transform = FillScaleTriL - - # Default parameter transforms - transformations.alpha = "positive_transform" - transformations.lengthscale = "positive_transform" - transformations.variance = "positive_transform" - transformations.smoothness = "positive_transform" - transformations.shift = "positive_transform" - transformations.obs_noise = "positive_transform" - transformations.latent = "identity_transform" - transformations.basis_fns = "identity_transform" - transformations.offset = "identity_transform" - transformations.inducing_inputs = "identity_transform" - transformations.variational_mean = "identity_transform" - transformations.variational_root_covariance = "triangular_transform" - transformations.natural_vector = "identity_transform" - transformations.natural_matrix = "identity_transform" - transformations.expectation_vector = "identity_transform" - transformations.expectation_matrix = "identity_transform" - - return config - - -# This function is created for testing purposes only -def get_global_config_if_exists() -> ConfigDict: - """Get the global config file used within GPJax if it is available. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - global __config - return __config - - -def add_parameter(param_name: str, bijection: dx.Bijector) -> None: - """Add a parameter and its corresponding transform to GPJax's config file. - - Args: - param_name (str): The name of the parameter that is to be added. - bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value. - """ - lookup_name = f"{param_name}_transform" - get_global_config() - __config.transformations[lookup_name] = bijection - __config.transformations[param_name] = lookup_name diff --git a/jaxutils/data.py b/jaxutils/data.py deleted file mode 100644 index 0a9747e..0000000 --- a/jaxutils/data.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import annotations - -import jax.numpy as jnp -from jaxtyping import Array, Float -from typing import Optional - -from .pytree import PyTree - -class Dataset(PyTree): - """Dataset class.""" - - #TODO: Consider HeterotopicDataset and IsotopicDataset abstractions. - - def __init__( - self, - X: Optional[Float[Array, "N D"]] = None, - y: Optional[Float[Array, "N Q"]] = None, - ) -> None: - """ - Args: - X(Float[Array, "N D"]]): Input data. - y(Float[Array, "N Q"]]): Output data. - - Returns: - Dataset: A dataset object. - """ - - _check_shape(X, y) - self.X = X - self.y = y - - def __repr__(self) -> str: - return ( - f"- Number of datapoints: {self.X.shape[0]}\n- Dimension: {self.X.shape[1]}" - ) - - def is_supervised(self) -> bool: - """Returns True if the dataset is supervised.""" - return self.X is not None and self.y is not None - - def is_unsupervised(self) -> bool: - """Returns True if the dataset is unsupervised.""" - return self.X is None and self.y is not None - - - def __add__(self, other: Dataset) -> Dataset: - """Combines two datasets into one. The right-hand dataset is stacked beneath left.""" - x = jnp.concatenate((self.X, other.X)) - y = jnp.concatenate((self.y, other.y)) - - return Dataset(X=x, y=y) - - @property - def n(self) -> int: - """The number of observations in the dataset.""" - return self.X.shape[0] - - @property - def in_dim(self) -> int: - """The dimension of the input data.""" - return self.X.shape[1] - - @property - def out_dim(self) -> int: - """The dimension of the output data.""" - return self.y.shape[1] - - -def verify_dataset(ds: Dataset) -> None: - """Apply a series of checks to the dataset to ensure that downstream operations are safe.""" - assert ds.X.ndim == 2, ( - "2-dimensional training inputs are required. Current dimension:" - f" {ds.X.ndim}." - ) - if ds.y is not None: - assert ds.y.ndim == 2, ( - "2-dimensional training outputs are required. Current dimension:" - f" {ds.y.ndim}." - ) - assert ds.X.shape[0] == ds.y.shape[0], ( - "Number of inputs must equal the number of outputs. \nCurrent" - f" counts:\n- X: {ds.X.shape[0]}\n- y: {ds.y.shape[0]}" - ) - - -def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None: - """Checks that the shapes of X and y are compatible.""" - if X is not None and y is not None: - if X.shape[0] != y.shape[0]: - raise ValueError( - f"X and y must have the same number of rows. Got X.shape={X.shape} and y.shape={y.shape}." - ) - - if X is not None and X.ndim != 2: - raise ValueError( - f"X must be a 2-dimensional array. Got X.ndim={X.ndim}." - ) - - if y is not None and y.ndim != 2: - raise ValueError( - f"y must be a 2-dimensional array. Got y.ndim={y.ndim}." - ) - -__all__ = [ - "Dataset", -] diff --git a/jaxutils/dataset.py b/jaxutils/dataset.py new file mode 100644 index 0000000..b5b9e90 --- /dev/null +++ b/jaxutils/dataset.py @@ -0,0 +1,101 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +import jax.numpy as jnp +from jaxtyping import Array, Float +from typing import Optional +from simple_pytree import Pytree +from dataclasses import dataclass + + +@dataclass +class Dataset(Pytree): + """Base class for datasets. + + Attributes: + X (Optional[Float[Array, "N D"]]): Input data. + y (Optional[Float[Array, "N Q"]]): Output data. + """ + + X: Optional[Float[Array, "N D"]] = None + y: Optional[Float[Array, "N Q"]] = None + + def __post_init__(self) -> None: + """Checks that the shapes of X and y are compatible.""" + _check_shape(self.X, self.y) + + def __repr__(self) -> str: + """Returns a string representation of the dataset.""" + repr = ( + f"- Number of observations: {self.n}\n- Input dimension:" + f" {self.in_dim}\n- Output dimension: {self.out_dim}" + ) + return repr + + def is_supervised(self) -> bool: + """Returns `True` if the dataset is supervised.""" + return self.X is not None and self.y is not None + + def is_unsupervised(self) -> bool: + """Returns `True` if the dataset is unsupervised.""" + return self.X is None and self.y is not None + + def __add__(self, other: Dataset) -> Dataset: + """Combine two datasets. Right hand dataset is stacked beneath the left.""" + X = jnp.concatenate((self.X, other.X)) + y = jnp.concatenate((self.y, other.y)) + + return Dataset(X=X, y=y) + + @property + def n(self) -> int: + """Number of observations.""" + return self.X.shape[0] + + @property + def in_dim(self) -> int: + """Dimension of the inputs, X.""" + return self.X.shape[1] + + @property + def out_dim(self) -> int: + """Dimension of the outputs, y.""" + return self.y.shape[1] + + +def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None: + """Checks that the shapes of X and y are compatible.""" + if X is not None and y is not None: + if X.shape[0] != y.shape[0]: + raise ValueError( + "Inputs, X, and outputs, y, must have the same number of rows." + f" Got X.shape={X.shape} and y.shape={y.shape}." + ) + + if X is not None and X.ndim != 2: + raise ValueError( + f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X.ndim}." + ) + + if y is not None and y.ndim != 2: + raise ValueError( + f"Outputs, y, must be a 2-dimensional array. Got y.ndim={y.ndim}." + ) + + +__all__ = [ + "Dataset", +] diff --git a/jaxutils/dict.py b/jaxutils/dict.py deleted file mode 100644 index 9b71e80..0000000 --- a/jaxutils/dict.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Callable, Dict, Tuple - -import jax - - -def concat_dictionaries(a: Dict, b: Dict) -> Dict: - """ - Append one dictionary below another. If duplicate keys exist, then the - key-value pair of the second supplied dictionary will be used. - - Args: - a (Dict): The first dictionary. - b (Dict): The second dictionary. - - Returns: - Dict: The merged dictionary. - """ - return {**a, **b} - - -def merge_dictionaries(base_dict: Dict, in_dict: Dict) -> Dict: - """ - This will return a complete dictionary based on the keys of the first - matrix. If the same key should exist in the second matrix, then the - key-value pair from the first dictionary will be overwritten. The purpose of - this is that the base_dict will be a complete dictionary of values such that - an incomplete second dictionary can be used to update specific key-value - pairs. - - Args: - base_dict (Dict): Complete dictionary of key-value pairs. - in_dict (Dict): Subset of key-values pairs such that values from this - dictionary will take precedent. - - Returns: - Dict: A dictionary with the same keys as the base_dict, but with - values from the in_dict. - """ - for k, _ in base_dict.items(): - if k in in_dict.keys(): - base_dict[k] = in_dict[k] - return base_dict - - -def sort_dictionary(base_dict: Dict) -> Dict: - """ - Sort a dictionary based on the dictionary's key values. - - Args: - base_dict (Dict): The dictionary to be sorted. - - Returns: - Dict: The dictionary sorted alphabetically on the dictionary's keys. - """ - return dict(sorted(base_dict.items())) - - -def dict_array_coercion(params: Dict) -> Tuple[Callable, Callable]: - """ - Construct the logic required to map a dictionary of parameters to an array - of parameters. The values of the dictionary can themselves be dictionaries; - the function should work recursively. - - Args: - params (Dict): The dictionary of parameters that we would like to map - into an array. - - Returns: - Tuple[Callable, Callable]: A pair of functions, the first of which maps - a dictionary to an array, and the second of which maps an array to a - dictionary. The remapped dictionary is equal in structure to the original - dictionary. - """ - flattened_pytree = jax.tree_util.tree_flatten(params) - - def dict_to_array(parameter_dict) -> jax.Array: - return jax.tree_util.tree_flatten(parameter_dict)[0] - - def array_to_dict(parameter_array) -> Dict: - return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array) - - return dict_to_array, array_to_dict - - -__all__ = [ - "concat_dictionaries", - "merge_dictionaries", - "sort_dictionary", - "dict_array_coercion", -] diff --git a/jaxutils/fit.py b/jaxutils/fit.py new file mode 100644 index 0000000..0e89908 --- /dev/null +++ b/jaxutils/fit.py @@ -0,0 +1,199 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, Optional + +import jax +import jax.random as jr +import optax as ox + +from jax.random import KeyArray +from jax._src.random import _check_prng_key +from jaxtyping import Array, Float +from typing import Any + +from .parameters import Parameters +from .dataset import Dataset +from .scan import vscan + + +def fit( + *, + objective, + train_data: Dataset, + optim: ox.GradientTransformation, + params: Parameters = None, + fn: Callable[[Parameters, Dataset], Float[Array, "1"]] = None, + num_iters: Optional[int] = 100, + batch_size: Optional[int] = -1, + key: Optional[KeyArray] = jr.PRNGKey(42), + log_rate: Optional[int] = 10, + verbose: Optional[bool] = True, + unroll: int = 1, +) -> Parameters: + """Train a Module model with respect to a supplied Objective function. Optimisers + used here should originate from Optax. + + Args: + params (Parameters): The parameters to be optimised. + objective (Callable[[Parameters, Dataset], Float[Array, "1"]]): The objective + function that we are optimising with respect to. + train_data (Dataset): The training data to be used for the optimisation. + optim (GradientTransformation): The Optax optimiser that is to be used for + learning a parameter set. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to + 100. + batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1 + (i.e. full batch). + key (Optional[KeyArray]): The random key to use for the optimisation batch + selection. Defaults to jr.PRNGKey(42). + log_rate (Optional[int]): How frequently the objective function's value should + be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults + to True. + unroll (int): The number of unrolled steps to use for the optimisation. + Defaults to 1. + + Returns: + Parameters: A Tuple comprising the optimised model and training + history respectively. + """ + if params is None: + params = objective.init_params(key) + + if fn is None: + fn = jax.jit(objective.step) + + # Check inputs. + _check_train_data(train_data) + _check_optim(optim) + _check_num_iters(num_iters) + _check_batch_size(batch_size) + _check_prng_key(key) + _check_log_rate(log_rate) + _check_verbose(verbose) + + # Unconstrained space loss fn. with stop-gradient rule for non-trainable params. + def loss(params: Parameters, batch: Dataset) -> Float[Array, "1"]: + params = params.stop_gradients() + return fn(params.constrain(), batch) + + # Unconstrained space params. + params = params.unconstrain() + + # Initialise optimiser state. + state = optim.init(params) + + # Mini-batch random keys to scan over. + iter_keys = jr.split(key, num_iters) + + # Optimisation step. + def step(carry, key): + params, opt_state = carry + + if batch_size != -1: + batch = get_batch(train_data, batch_size, key) + else: + batch = train_data + + loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch) + updates, opt_state = optim.update(loss_gradient, opt_state, params) + params = ox.apply_updates(params, updates) + + carry = params, opt_state + return carry, loss_val + + # Optimisation scan. + scan = vscan if verbose else jax.lax.scan + + # Optimisation loop. + (params, _), history = scan(step, (params, state), (iter_keys), unroll=unroll) + + # Constrained space. + params = params.constrain() + params = params.update_training_history(history) + + return params + + +def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: + """Batch the data into mini-batches. Sampling is done with replacement. + + Args: + train_data (Dataset): The training dataset. + batch_size (int): The batch size. + key (KeyArray): The random key to use for the batch selection. + + Returns: + Dataset: The batched dataset. + """ + x, y, n = train_data.X, train_data.y, train_data.n + + # Subsample mini-batch indicies with replacement. + indicies = jr.choice(key, n, (batch_size,), replace=True) + + return Dataset(X=x[indicies], y=y[indicies]) + + +def _check_train_data(train_data: Any) -> None: + """Check that the train_data is of type Dataset.""" + if not isinstance(train_data, Dataset): + raise TypeError("train_data must be of type jaxutils.Dataset") + + +def _check_optim(optim: Any) -> None: + """Check that the optimiser is of type GradientTransformation.""" + if not isinstance(optim, ox.GradientTransformation): + raise TypeError("optax_optim must be of type optax.GradientTransformation") + + +def _check_num_iters(num_iters: Any) -> None: + """Check that the number of iterations is of type int and positive.""" + if not isinstance(num_iters, int): + raise TypeError("num_iters must be of type int") + + if not num_iters > 0: + raise ValueError("num_iters must be positive") + + +def _check_log_rate(log_rate: Any) -> None: + """Check that the log rate is of type int and positive.""" + if not isinstance(log_rate, int): + raise TypeError("log_rate must be of type int") + + if not log_rate > 0: + raise ValueError("log_rate must be positive") + + +def _check_verbose(verbose: Any) -> None: + """Check that the verbose is of type bool.""" + if not isinstance(verbose, bool): + raise TypeError("verbose must be of type bool") + + +def _check_batch_size(batch_size: Any) -> None: + """Check that the batch size is of type int and positive if not minus 1.""" + if not isinstance(batch_size, int): + raise TypeError("batch_size must be of type int") + + if not batch_size == -1: + if not batch_size > 0: + raise ValueError("batch_size must be positive") + + +__all__ = [ + "fit", + "get_batch", +] diff --git a/jaxutils/parameters.py b/jaxutils/parameters.py index b573c0a..f9a6951 100644 --- a/jaxutils/parameters.py +++ b/jaxutils/parameters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,408 +13,322 @@ # limitations under the License. # ============================================================================== -import warnings -from copy import deepcopy -from typing import Dict, Tuple -from warnings import warn +from __future__ import annotations -import distrax as dx +import jax.tree_util as jtu import jax import jax.numpy as jnp -import jax.random as jr -from jax.random import KeyArray -from jaxtyping import Array, Float -from jaxutils import PyTree +from typing import Any, Callable, Dict +from .bijectors import Identity +from simple_pytree import Pytree, static_field +from jax.tree_util import tree_flatten, tree_structure +from jaxtyping import Float, Array +from distrax import Bijector, Distribution +from collections.abc import KeysView, ValuesView, ItemsView -from .config import Identity, get_global_config -from .dict import merge_dictionaries - -################################ -# Base operations -################################ -class ParameterState(PyTree): +class Parameters(Pytree, dict): """ The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained. """ - def __init__(self, params: Dict, trainables: Dict, bijectors: Dict) -> None: - self.params = params - self.trainables = trainables - self.bijectors = bijectors - - def unpack(self): - """Unpack the state into a tuple of parameters, trainables and bijectors. - - Returns: - Tuple[Dict, Dict, Dict]: The parameters, trainables and bijectors. - """ - return self.params, self.trainables, self.bijectors - - -def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: - """ - Initialise the stateful parameters of any GPJax object. This function also - returns the trainability status of each parameter and set of bijectors that - allow parameters to be constrained and unconstrained. - - Args: - model: The GPJax object that is to be initialised. - key (KeyArray, optional): The random key that is to be used for - initialisation. Defaults to None. - - Returns: - ParameterState: The state of the model. This includes the parameter - set, which parameters are to be trained and bijectors that allow - parameters to be constrained and unconstrained. - """ - - if key is None: - warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2) - key = jr.PRNGKey(123) - - # Initialise the parameters. - if hasattr(model, "init_params"): - params = model.init_params(key) - - elif hasattr(model, "_initialise_params"): - warn( - "`_initialise_params` is deprecated. Please use `init_params` instead.", - DeprecationWarning, - stacklevel=2, + _param_dict: Dict + _bijector_dict: Dict = static_field() + _trainable_dict: Dict = static_field() + _training_history: list = static_field() + + def __init__( + self, + # TODO: Should block inplace updates e.g., `params.params['a']=jnp.array([2.])` + params: Dict, + bijectors: Dict = None, + trainables: Dict = None, + priors: Dict = None, + training_history=None, + ): + + if bijectors is None: + bijectors = jtu.tree_map(lambda _: Identity, params) + + if trainables is None: + trainables = jtu.tree_map(lambda _: True, params) + + if priors is None: + priors = jtu.tree_map(lambda _: None, params) + + self._param_dict = params + self._trainable_dict = trainables + self._bijector_dict = bijectors + self._prior_dict = priors + self._training_history = training_history + + def __repr__(self) -> str: + return f"Parameters({self.params.__repr__()})" + + def __getitem__(self, __name: str) -> Any: + return self._param_dict.__getitem__(__name) + + def __setitem__(self, __name: str, __value: Any) -> None: + return self._param_dict.__setitem__(__name, __value) + + def __eq__(self, other: Parameters) -> bool: + return self.params == other.params + + @property + def params(self) -> Dict: + return self._param_dict + + def update_params(self, value: Dict) -> Parameters: + self._validate_update(value, self.params, "params") + return Parameters( + value, + self.bijectors, + self.trainables, + self.priors, + self.training_history, ) - params = model._initialise_params(key) - - else: - raise AttributeError("No `init_params` or `_initialise_params` method found.") - - if kwargs: - _validate_kwargs(kwargs, params) - for k, v in kwargs.items(): - params[k] = merge_dictionaries(params[k], v) - - bijectors = build_bijectors(params) - trainables = build_trainables(params) - - return ParameterState( - params=params, - trainables=trainables, - bijectors=bijectors, - ) - -def _validate_kwargs(kwargs, params): - for k, v in kwargs.items(): - if k not in params.keys(): - raise ValueError(f"Parameter {k} is not a valid parameter.") - - -def recursive_items(d1: Dict, d2: Dict): - """ - Recursive loop over pair of dictionaries whereby the value of a given key in - either dictionary can be itself a dictionary. - - Args: - d1 (_type_): _description_ - d2 (_type_): _description_ - - Yields: - _type_: _description_ - """ - for key, value in d1.items(): - if type(value) is dict: - yield from recursive_items(value, d2[key]) - else: - yield (key, value, d2[key]) + @staticmethod + def _validate_update( + value: dict, + comparison: dict, + name: str, + lambda_expression: Callable[[Any], bool] = None, + ): + if tree_structure(comparison, lambda_expression) != tree_structure( + value, lambda_expression + ): + raise ValueError( + f"The structure of the {name} has changed. Please ensure" + f" updates to {name} do not alter the strcuture." + ) + + @property + def bijectors(self) -> Dict: + return self._bijector_dict + + def update_bijectors(self, value: Dict) -> Parameters: + self._validate_update( + value, + self.bijectors, + "bijectors", + lambda x: isinstance(x, Bijector), + ) + return Parameters( + self.params, + value, + self.trainables, + self.priors, + self.training_history, + ) + @property + def trainables(self) -> Dict: + return self._trainable_dict + + def update_trainables(self, value: Dict) -> Parameters: + self._validate_update(value, self.trainables, "trainables") + return Parameters( + self.params, + self.bijectors, + value, + self.priors, + self.training_history, + ) -def recursive_complete(d1: Dict, d2: Dict) -> Dict: - """ - Recursive loop over pair of dictionaries whereby the value of a given key in - either dictionary can be itself a dictionary. If the value of the key in the - second dictionary is None, the value of the key in the first dictionary is - used. + @property + def priors(self) -> Dict: + return self._prior_dict - Args: - d1 (Dict): The reference dictionary. - d2 (Dict): The potentially incomplete dictionary. + def update_priors(self, value: Dict) -> Parameters: + self._validate_update( + value, + self.priors, + "priors", + lambda x: isinstance(x, Distribution), + ) + return Parameters( + self.params, + self.bijectors, + self.trainables, + value, + self.training_history, + ) - Returns: - Dict: A completed form of the second dictionary. - """ - for key, value in d1.items(): - if type(value) is dict: - if key in d2.keys(): - recursive_complete(value, d2[key]) - else: - if key in d2.keys(): - d1[key] = d2[key] - return d1 + @property + def training_history(self) -> list: + return self._training_history + + def update_training_history(self, value: list) -> Parameters: + return Parameters( + self.params, + self.bijectors, + self.trainables, + self.priors, + value, + ) + def unpack( + self, + ) -> Dict[str, Dict[str, Any]]: + """Unpack the state into a tuple of parameters, trainables and bijectors. -################################ -# Parameter transformation -################################ -def build_bijectors(params: Dict) -> Dict: - """ - For each parameter, build the bijection pair that allows the parameter to be - constrained and unconstrained. + Returns: + Dict[str, Dict[str, Any]]: The parameters, trainables and bijectors. + """ + contents = { + "params": self.params, + "trainables": self.trainables, + "bijectors": self.bijectors, + "priors": self.priors, + } + return contents - Args: - params (Dict): _description_ + def constrain(self) -> Parameters: + """Use the bijectors to transform the parameters to a constrained space. - Returns: - Dict: A dictionary that maps each parameter to a bijection. - """ - bijectors = copy_dict_structure(params) - config = get_global_config() - transform_set = config["transformations"] + Returns: + Parameters: A new Parameters object with the constrained parameter values. + """ + return self.update_params( + jtu.tree_map( + lambda param, trans: trans.forward(param), + self.params, + self.bijectors, + ) + ) - def recursive_bijectors_list(ps, bs): - return [recursive_bijectors(ps[i], bs[i]) for i in range(len(bs))] + def unconstrain(self) -> Parameters: + """Use the bijectors to transform the parameters to an unconstrained space. - def recursive_bijectors(ps, bs) -> Tuple[Dict, Dict]: - if type(ps) is list: - bs = recursive_bijectors_list(ps, bs) + Returns: + Parameters: A new Parameters object with the unconstrained parameter + values. + """ + return self.update_params( + jtu.tree_map( + lambda param, trans: trans.inverse(param), + self.params, + self.bijectors, + ) + ) + def add_parameter( + self, + key: str, + *, + parameter: Parameters = None, + value: jax.Array = None, + prior: Distribution = None, + bijector: Bijector = Identity, + trainability: bool = True, + ) -> None: + """Add a parameter to the Parameters object. + + Args: + key (str): The name of the parameter. + value (jax.Array): The value of the parameter. + prior (Distribution): The prior distribution of the parameter. + bijector (Bijector): The bijector to transform the parameter. + trainability (bool): The trainability of the parameter. + """ + if key in self.keys(): + raise ValueError(f"Parameter with key: {key} already exists.") else: - for key, value in ps.items(): - if type(value) is dict: - recursive_bijectors(value, bs[key]) - elif type(value) is list: - bs[key] = recursive_bijectors_list(value, bs[key]) - else: - if key in transform_set.keys(): - transform_type = transform_set[key] - bijector = transform_set[transform_type] - else: - bijector = Identity - warnings.warn( - f"Parameter {key} has no transform. Defaulting to identity transfom." - ) - bs[key] = bijector - return bs - - return recursive_bijectors(params, bijectors) - - -def constrain(params: Dict, bijectors: Dict) -> Dict: - """ - Transform the parameters to the constrained space for corresponding - bijectors. - - Args: - params (Dict): The parameters that are to be transformed. - bijectors (Dict): The bijectors that are to be used for - transformation. - - Returns: - Dict: A transformed parameter set. The dictionary is equal in - structure to the input params dictionary. - """ - map = lambda param, trans: trans.forward(param) - - return jax.tree_util.tree_map(map, params, bijectors) - - -def unconstrain(params: Dict, bijectors: Dict) -> Dict: - """Transform the parameters to the unconstrained space for corresponding - bijectors. - - Args: - params (Dict): The parameters that are to be transformed. - bijectors (Dict): The corresponding dictionary of transforms that - should be applied to the parameter set. - - Returns: - Dict: A transformed parameter set. The dictionary is equal in - structure to the input params dictionary. - """ - - map = lambda param, trans: trans.inverse(param) - - return jax.tree_util.tree_map(map, params, bijectors) - - -################################ -# Priors -################################ -def log_density( - param: Float[Array, "D"], density: dx.Distribution -) -> Float[Array, "1"]: - """Compute the log density of a parameter given a distribution. - - Args: - param (Float[Array, "D"]): The parameter that is to be evaluated. - density (dx.Distribution): The distribution that is to be evaluated. - - Returns: - Float[Array, "1"]: The log density of the parameter. - """ - if type(density) == type(None): - log_prob = jnp.array(0.0) - else: - log_prob = jnp.sum(density.log_prob(param)) - return log_prob - - -def copy_dict_structure(params: Dict) -> Dict: - """Copy the structure of a dictionary. - - Args: - params (Dict): The dictionary that is to be copied. - - Returns: - Dict: A copy of the input dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: None, prior_container) - return prior_container - - -def structure_priors(params: Dict, priors: Dict) -> Dict: - """First create a dictionary with equal structure to the parameters. - Then, for each supplied prior, overwrite the None value if it exists. - - Args: - params (Dict): [description] - priors (Dict): [description] - - Returns: - Dict: [description] - """ - prior_container = copy_dict_structure(params) - # Where a prior has been supplied, override the None value by the prior distribution. - complete_prior = recursive_complete(prior_container, priors) - return complete_prior - - -def evaluate_priors(params: Dict, priors: Dict) -> Dict: - """ - Recursive loop over pair of dictionaries that correspond to a parameter's - current value and the parameter's respective prior distribution. For - parameters where a prior distribution is specified, the log-prior density is - evaluated at the parameter's current value. - - Args: params (Dict): Dictionary containing the current set of parameter - estimates. priors (Dict): Dictionary specifying the parameters' prior - distributions. - - Returns: - Dict: The log-prior density, summed over all parameters. - """ - lpd = jnp.array(0.0) - if priors is not None: - for name, param, prior in recursive_items(params, priors): - lpd += log_density(param, prior) - return lpd - + if parameter is None: + self.params[key] = value + self.priors[key] = prior + self.bijectors[key] = bijector + self.trainables[key] = trainability + else: + contents = parameter.unpack() + self.params[key] = contents["params"] + self.priors[key] = contents["priors"] + self.bijectors[key] = contents["bijectors"] + self.trainables[key] = contents["trainables"] + + def stop_gradients(self) -> Parameters: + def _stop_grad(param: Dict, trainable: Dict) -> Dict: + return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) + + return self.update_params( + jtu.tree_map( + lambda param, trainable: _stop_grad(param, trainable), + self.params, + self.trainables, + ) + ) -def prior_checks(priors: Dict) -> Dict: - """ - Run checks on the parameters' prior distributions. This checks that for - Gaussian processes that are constructed with non-conjugate likelihoods, the - prior distribution on the function's latent values is a unit Gaussian. + def items(self) -> ItemsView: + """Return the items of the parameters.""" + return self.params.items() - Args: - priors (Dict): Dictionary specifying the parameters' prior distributions. + def keys(self) -> KeysView: + """Return the keys of the parameters.""" + return self.params.keys() - Returns: - Dict: Dictionary specifying the parameters' prior distributions. - """ - if "latent" in priors.keys(): - latent_prior = priors["latent"] - if latent_prior is not None: - if not isinstance(latent_prior, dx.Normal): - warnings.warn( - f"A {type(latent_prior)} distribution prior has been placed on" - " the latent function. It is strongly advised that a" - " unit Gaussian prior is used." - ) - else: - warnings.warn("Placing unit Gaussian prior on latent function.") - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) - else: - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) + def values(self) -> ValuesView: + """Return the values of the parameters.""" + return self.params.values() - return priors + def log_prior_density(self) -> Array[Float, "1"]: + """ + Recursive loop over pair of dictionaries that correspond to a parameter's + current value and the parameter's respective prior distribution. For + parameters where a prior distribution is specified, the log-prior density is + evaluated at the parameter's current value. + Args: params (Dict): Dictionary containing the current set of parameter + estimates. priors (Dict): Dictionary specifying the parameters' prior + distributions. -def build_trainables(params: Dict, status: bool = True) -> Dict: - """ - Construct a dictionary of trainable statuses for each parameter. By default, - every parameter within the model is trainable. + Returns: + Dict: The log-prior density, summed over all parameters. + """ - Args: - params (Dict): The parameter set for which trainable statuses should be - derived from. - status (bool): The status of each parameter. Default is True. + def log_density(param, prior): + # TODO: Use a jax.lax.cond be used here? + if prior is not None: + return jnp.sum(prior.log_prob(param)) + else: + return jnp.array(0.0) - Returns: - Dict: A dictionary of boolean trainability statuses. The dictionary is - equal in structure to the input params dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: status, prior_container) - return prior_container + log_prior_density_dict = jtu.tree_map(log_density, self.params, self.priors) + leaves, _ = tree_flatten(log_prior_density_dict) + return sum(leaves) + def combine(self, other: Parameters, left_key: str, right_key: str) -> Parameters: + """Combine two sets of parameters into a single set of parameters. -def _stop_grad(param: Dict, trainable: Dict) -> Dict: - """ - When taking a gradient, we want to stop the gradient from flowing through a - parameter if it is not trainable. This is achieved using the model's - dictionary of parameters and the corresponding trainability status. + Args: + other (Parameters): The other set of parameters. + left_key (str): The key to use for the left (i.e., `self`) set of + parameters. + right_key (str): The key to use for the right (i.e., `other`) set of + parameters - Args: - param (Dict): The parameter set for which trainable statuses should be - derived from. - trainable (Dict): A boolean value denoting the training status the `param`. + Returns: + Parameters: A nested set of parameters. + """ - Returns: - Dict: The gradient is stopped for non-trainable parameters. - """ - return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) + self_contents = self.unpack() + other_contents = other.unpack() + + combined_contents = {} + for k in self_contents.keys(): + combined_contents[k] = { + left_key: self_contents[k], + right_key: other_contents[k], + } + + return Parameters( + params=combined_contents["params"], + bijectors=combined_contents["bijectors"], + trainables=combined_contents["trainables"], + priors=combined_contents["priors"], + ) -def trainable_params(params: Dict, trainables: Dict) -> Dict: - """ - Stop the gradients flowing through parameters whose trainable status is - False. - - Args: - params (Dict): The parameter set for which trainable statuses should - be derived from. - trainables (Dict): A dictionary of boolean trainability statuses. The - dictionary is equal in structure to the input params dictionary. - - Returns: - Dict: A dictionary parameters. The dictionary is equal in structure to - the input params dictionary. - """ - return jax.tree_util.tree_map( - lambda param, trainable: _stop_grad(param, trainable), params, trainables - ) - - -__all__ = [ - "ParameterState", - "initialise", - "recursive_items", - "recursive_complete", - "build_bijectors", - "constrain", - "unconstrain", - "log_density", - "copy_dict_structure", - "structure_priors", - "evaluate_priors", - "prior_checks", - "build_trainables", - "trainable_params", -] +__all__ = ["Parameters"] diff --git a/jaxutils/pytree.py b/jaxutils/pytree.py index d6677b2..e6578c1 100644 --- a/jaxutils/pytree.py +++ b/jaxutils/pytree.py @@ -15,12 +15,15 @@ import abc import jax - from typing import Any +# TODO: To drop this in place of simple_pytree's Pytree. + class PyTree(metaclass=abc.ABCMeta): - """An abstract base class for a JAX compatible pytree. Adapted from `distrax._src.utils.jittable.Jittable`.""" + """An abstract base class for a JAX compatible pytree. Adapted from + `distrax._src.utils.jittable.Jittable`. + """ def __new__(cls, *args, **kwargs): # Discard the parameters to this function because the constructor is not diff --git a/jaxutils/scan.py b/jaxutils/scan.py new file mode 100644 index 0000000..2119648 --- /dev/null +++ b/jaxutils/scan.py @@ -0,0 +1,164 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, List, Optional, Tuple, TypeVar, Any +from jax import lax +from jax.experimental import host_callback as hcb +from tqdm.auto import trange + +import jax.tree_util as jtu +import jax +import jax.numpy as jnp + +Carry = TypeVar("Carry") +X = TypeVar("X") +Y = TypeVar("Y") + + +def _callback(cond: bool, func: Callable, *args: Any) -> None: + """Callback a function for a given argument if a condition is true. + + Args: + cond (bool): The condition. + func (Callable): The function to call. + *args (Any): The arguments to pass to the function. + """ + + # lax.cond requires a result, so we use a dummy result. + _dummy_result = 0 + + def _do_callback(_) -> int: + """Perform the callback.""" + return hcb.id_tap(func, *args, result=_dummy_result) + + def _not_callback(_) -> int: + """Do nothing.""" + return _dummy_result + + _ = lax.cond(cond, _do_callback, _not_callback, operand=None) + + +def vscan( + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: Optional[bool] = False, + unroll: Optional[int] = 1, + log_rate: Optional[int] = 10, + log_value: Optional[bool] = True, +) -> Tuple[Carry, List[Y]]: + """Scan with verbose output. + + This is based on code from the excellent blog post: + https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/. + + Example: + >>> def f(carry, x): + ... return carry + x, carry + x + >>> init = 0 + >>> xs = jnp.arange(10) + >>> vscan(f, init, xs) + (45, DeviceArray([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32)) + + Args: + f (Callable[[Carry, X], Tuple[Carry, Y]]): A function that takes in a carry and + an input and returns a tuple of a new carry and an output. + init (Carry): The initial carry. + xs (X): The inputs. + length (Optional[int]): The length of the inputs. If None, then the length of + the inputs is inferred. + reverse (bool): Whether to scan in reverse. + unroll (int): The number of iterations to unroll. + log_rate (int): The rate at which to log the progress bar. + log_value (bool): Whether to log the value of the objective function. + + Returns: + Tuple[Carry, List[Y]]: A tuple of the final carry and the outputs. + """ + + # TODO: Scope out lower level API for jax.lax.scan, to avoid the need for finding + # the length of the inputs / check inputs. + # TODO: Scope out lower level API for tqdm, for more control over the progress bar. + # Need to check this. + _xs_flat = jtu.tree_leaves(xs) + _length = length if length is not None else len(_xs_flat[0]) + _iter_nums = jnp.arange(_length) + _remainder = _length % log_rate + + _progress_bar = trange(_length) + _progress_bar.set_description("Compiling...", refresh=True) + + def _set_running(args: Any, transform: Any) -> None: + """Set the tqdm progress bar to running.""" + _progress_bar.set_description("Running", refresh=False) + + def _update_tqdm(args: Any, transform: Any) -> None: + """Update the tqdm progress bar with the latest objective value.""" + _value, _iter_num = args + _progress_bar.update(_iter_num) + + if log_value and _value is not None: + _progress_bar.set_postfix({"Value": f"{_value: .2f}"}) + + def _close_tqdm(args: Any, transform: Any) -> None: + """Close the tqdm progress bar.""" + _progress_bar.close() + + def _body_fun(carry: Carry, iter_num_and_x: Tuple[int, X]) -> Tuple[Carry, Y]: + + # Unpack iter_num and x. + iter_num, x = iter_num_and_x + + # Compute body function. + carry, y = f(carry, x) + + # Conditions for iteration number. + _is_first: bool = iter_num == 0 + _is_multiple: bool = (iter_num % log_rate == 0) & ( + iter_num != _length - _remainder + ) + _is_remainder: bool = iter_num == _length - _remainder + _is_last: bool = iter_num == _length - 1 + + # Update progress bar, if first of log_rate. + _callback(_is_first, _set_running, (y, log_rate)) + + # Update progress bar, if multiple of log_rate. + _callback(_is_multiple, _update_tqdm, (y, log_rate)) + + # Update progress bar, if remainder. + _callback(_is_remainder, _update_tqdm, (y, _remainder)) + + # Close progress bar, if last iteration. + _callback(_is_last, _close_tqdm, (y, None)) + + return carry, y + + carry, ys = jax.lax.scan( + _body_fun, + init, + (_iter_nums, xs), + length=length, + reverse=reverse, + unroll=unroll, + ) + + return carry, ys + + +__all__ = [ + "vscan", +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a6c31d2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.ruff] +line-length = 88 +update-check = false +ignore = ["F722"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index c92efe6..f59799d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,3 +9,6 @@ tag_prefix = v exclude = versioneer.py jaxutils/_version.py + +ignore = + F722 diff --git a/setup.py b/setup.py index dd69fb9..e22782d 100644 --- a/setup.py +++ b/setup.py @@ -38,12 +38,14 @@ def get_versions(): versioneer.get_versions = get_versions -REQUIRES = ["jax>=0.4.0", - "jaxlib>=0.4.0", - "jaxtyping", - "ml-collections==0.1.0", - "distrax>=0.1.2", - ] +REQUIRES = [ + "jax>=0.4.0", + "jaxlib>=0.4.0", + "jaxtyping", + "optax", + "tqdm", + "distrax", +] EXTRAS = { "dev": [ diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 08ceaf1..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax -import distrax as dx -from jax.config import config -from ml_collections import ConfigDict - -from jaxutils.config import ( - Identity, - add_parameter, - get_global_config, - get_global_config_if_exists, # ignore: unused-import -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - -# TODO: Fix this test. -# This test needs to be run first to ensure that the global config is not set on library import. -# def test_config_on_library_import(): -# config = get_global_config_if_exists() -# assert config is None - - -def test_add_parameter(): - add_parameter("test_parameter", Identity) - config = get_global_config() - assert "test_parameter" in config.transformations - assert "test_parameter_transform" in config.transformations - assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], dx.Bijector) - - -def test_add_parameter(): - config = get_global_config() - add_parameter("test_parameter", Identity) - config = get_global_config() - assert "test_parameter" in config.transformations - assert "test_parameter_transform" in config.transformations - assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], dx.Bijector) - - -def test_get_global_config(): - config = get_global_config() - assert isinstance(config, ConfigDict) - assert isinstance(config.transformations, ConfigDict) - - -def test_x64_based_config_update(): - cached_jax_precision = jax.config.x64_enabled - - jax.config.update("jax_enable_x64", True) - config = get_global_config() - assert config.x64_state is True - - jax.config.update("jax_enable_x64", False) - config = get_global_config() - assert config.x64_state is False - - # Reset the JAX precision to the original value. - jax.config.update("jax_enable_x64", cached_jax_precision) - get_global_config() diff --git a/tests/test_data.py b/tests/test_dataset.py similarity index 88% rename from tests/test_data.py rename to tests/test_dataset.py index 3ed45ba..58a7688 100644 --- a/tests/test_data.py +++ b/tests/test_dataset.py @@ -15,7 +15,7 @@ import jax.numpy as jnp import pytest -from jaxutils.data import Dataset, verify_dataset +from jaxutils.dataset import Dataset @pytest.mark.parametrize("n", [1, 10]) @@ -26,14 +26,15 @@ def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: x = jnp.ones((n, ind)) y = jnp.ones((n, outd)) d = Dataset(X=x, y=y) - - verify_dataset(d) + assert d.n == n assert d.in_dim == ind assert d.out_dim == outd - - assert d.__repr__() == f"- Number of datapoints: {n}\n- Dimension: {ind}" - + assert ( + d.__repr__() + == f"- Number of observations: {n}\n- Input dimension: {ind}\n- Output" + f" dimension: {outd}" + ) # Test combine datasets. x2 = 2 * jnp.ones((n2, ind)) @@ -54,37 +55,36 @@ def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: dunsup = Dataset(y=y) assert dunsup.is_unsupervised() is True + @pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) @pytest.mark.parametrize("outd", [1, 2, 10]) @pytest.mark.parametrize("ind", [1, 2, 10]) def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: x = jnp.ones((nx, ind)) y = jnp.ones((ny, outd)) - + with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) + Dataset(X=x, y=y) + @pytest.mark.parametrize("n", [1, 2, 10]) @pytest.mark.parametrize("outd", [1, 2, 10]) @pytest.mark.parametrize("ind", [1, 2, 10]) def test_2d_inputs(n: int, outd: int, ind: int) -> None: x = jnp.ones((n, ind)) - y = jnp.ones((n, )) + y = jnp.ones((n,)) with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) + Dataset(X=x, y=y) x = jnp.ones((n,)) y = jnp.ones((n, outd)) with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) - + Dataset(X=x, y=y) - def test_y_none() -> None: x = jnp.ones((10, 1)) d = Dataset(X=x) - verify_dataset(d) assert d.y is None diff --git a/tests/test_dict.py b/tests/test_dict.py deleted file mode 100644 index cfefc97..0000000 --- a/tests/test_dict.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import pytest -from jax.config import config - -from jaxutils.dict import ( - concat_dictionaries, - dict_array_coercion, - merge_dictionaries, - sort_dictionary, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_concat_dict(): - d1 = {"a": 1, "b": 2} - d2 = {"c": 3, "d": 4} - d = concat_dictionaries(d1, d2) - assert list(d.keys()) == ["a", "b", "c", "d"] - assert list(d.values()) == [1, 2, 3, 4] - - -def test_merge_dicts(): - d1 = {"a": 1, "b": 2} - d2 = {"b": 3} - d = merge_dictionaries(d1, d2) - assert list(d.keys()) == ["a", "b"] - assert list(d.values()) == [1, 3] - - -def test_sort_dict(): - unsorted = {"b": 1, "a": 2} - sorted_dict = sort_dictionary(unsorted) - assert list(sorted_dict.keys()) == ["a", "b"] - assert list(sorted_dict.values()) == [2, 1] - - -@pytest.mark.parametrize("d", [1, 2, 10]) -def test_array_coercion(d): - params = { - "kernel": { - "lengthscale": jnp.array([1.0] * d), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - "mean_function": {}, - } - dict_to_array, array_to_dict = dict_array_coercion(params) - assert array_to_dict(dict_to_array(params)) == params - assert isinstance(dict_to_array(params), list) - assert isinstance(array_to_dict(dict_to_array(params)), dict) diff --git a/tests/test_fit.py b/tests/test_fit.py new file mode 100644 index 0000000..729cfee --- /dev/null +++ b/tests/test_fit.py @@ -0,0 +1,78 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxutils.dataset import Dataset +from jaxutils.fit import fit +from jaxutils.parameters import Parameters +import jax.numpy as jnp +import jax.random as jr +import optax as ox +import abc +from dataclasses import dataclass +from simple_pytree import Pytree, static_field +from jaxtyping import Array, Float +from typing import Any + + +### Base class for objective functions: +@dataclass +class Objective(Pytree): + model: Any = static_field() + + @abc.abstractmethod + def step(self, params: Parameters, train_data: Dataset) -> Float[Array, "1"]: + raise NotImplementedError + + def __call__(self, params: Parameters, train_data: Dataset) -> Float[Array, "1"]: + return self.step(params, train_data) + + +def test_simple_linear_model(): + # (1) Create a dataset: + X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1) + y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1) + D = Dataset(X, y) + + # (2) Define your model: + class LinearModel(Pytree): + def __call__(self, params: Parameters, x): + return params["weight"] * x + params["bias"] + + def init_params(self, key): + return Parameters({"weight": 1.0, "bias": 1.0}) + + # (3) Define your objective: + class MeanSquaredError(Objective): + def step(self, params: Parameters, train_data: Dataset) -> Float[Array, "1"]: + return jnp.mean((train_data.y - self.model(params, train_data.X)) ** 2) + + def init_params(self, key): + return self.model.init_params(key) + + model = LinearModel() + objective = MeanSquaredError(model) + params = model.init_params(jr.PRNGKey(0)) + + # (4) Train! + trained_params = fit( + objective=objective, + train_data=D, + optim=ox.sgd(0.001), + num_iters=100, + ) + + assert len(trained_params.training_history) == 100 + assert isinstance(trained_params, Parameters) + assert objective(trained_params, D) < objective(params, D) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index b532b23..02a1efc 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,273 +1,302 @@ -# # Copyright 2022 The GPJax Contributors. All Rights Reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# # ============================================================================== - -# import typing as tp - -# import distrax as dx -# import jax.numpy as jnp -# import jax.random as jr -# import pytest -# from jax.config import config - -# from gpjax.gps import Prior -# from gpjax.kernels import RBF -# from gpjax.likelihoods import Bernoulli, Gaussian -# from gpjax.parameters import ( -# build_bijectors, -# build_trainables, -# constrain, -# copy_dict_structure, -# evaluate_priors, -# initialise, -# log_density, -# prior_checks, -# recursive_complete, -# recursive_items, -# structure_priors, -# unconstrain, -# ) - -# # Enable Float64 for more stable matrix inversions. -# config.update("jax_enable_x64", True) - -# ######################### -# # Test base functionality -# ######################### -# @pytest.mark.parametrize("lik", [Gaussian]) -# def test_initialise(lik): -# key = jr.PRNGKey(123) -# posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) -# params, _, _ = initialise(posterior, key).unpack() -# assert list(sorted(params.keys())) == [ -# "kernel", -# "likelihood", -# "mean_function", -# ] - - -# def test_non_conjugate_initialise(): -# posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=10) -# params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() -# assert list(sorted(params.keys())) == [ -# "kernel", -# "latent", -# "likelihood", -# "mean_function", -# ] - - -# ######################### -# # Test priors -# ######################### -# @pytest.mark.parametrize("x", [-1.0, 0.0, 1.0]) -# def test_lpd(x): -# val = jnp.array(x) -# dist = dx.Normal(loc=0.0, scale=1.0) -# lpd = log_density(val, dist) -# assert lpd is not None -# assert log_density(val, None) == 0.0 - - -# @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -# def test_prior_template(lik): -# posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) -# params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() -# prior_container = copy_dict_structure(params) -# for ( -# k, -# v1, -# v2, -# ) in recursive_items(params, prior_container): -# assert v2 == None - - -# @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -# def test_recursive_complete(lik): -# posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) -# params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() -# priors = {"kernel": {}} -# priors["kernel"]["lengthscale"] = dx.Laplace(loc=0.0, scale=1.0) -# container = copy_dict_structure(params) -# complete_priors = recursive_complete(container, priors) -# for ( -# k, -# v1, -# v2, -# ) in recursive_items(params, complete_priors): -# if k == "lengthscale": -# assert isinstance(v2, dx.Laplace) -# else: -# assert v2 == None - - -# def test_prior_evaluation(): -# """ -# Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained -# value. -# """ -# params = { -# "kernel": { -# "lengthscale": jnp.array([1.0]), -# "variance": jnp.array([1.0]), -# }, -# "likelihood": {"obs_noise": jnp.array([1.0])}, -# } -# priors = { -# "kernel": { -# "lengthscale": dx.Gamma(1.0, 1.0), -# "variance": dx.Gamma(2.0, 2.0), -# }, -# "likelihood": {"obs_noise": dx.Gamma(3.0, 3.0)}, -# } -# lpd = evaluate_priors(params, priors) -# assert pytest.approx(lpd) == -2.0110168 - - -# def test_none_prior(): -# """ -# Test that multiple dispatch is working in the case of no priors. -# """ -# params = { -# "kernel": { -# "lengthscale": jnp.array([1.0]), -# "variance": jnp.array([1.0]), -# }, -# "likelihood": {"obs_noise": jnp.array([1.0])}, -# } -# priors = copy_dict_structure(params) -# lpd = evaluate_priors(params, priors) -# assert lpd == 0.0 - - -# def test_incomplete_priors(): -# """ -# Test the case where a user specifies priors for some, but not all, parameters. -# """ -# params = { -# "kernel": { -# "lengthscale": jnp.array([1.0]), -# "variance": jnp.array([1.0]), -# }, -# "likelihood": {"obs_noise": jnp.array([1.0])}, -# } -# priors = { -# "kernel": { -# "lengthscale": dx.Gamma(1.0, 1.0), -# "variance": dx.Gamma(2.0, 2.0), -# }, -# } -# container = copy_dict_structure(params) -# complete_priors = recursive_complete(container, priors) -# lpd = evaluate_priors(params, complete_priors) -# assert pytest.approx(lpd) == -1.6137061 - - -# @pytest.mark.parametrize("num_datapoints", [1, 10]) -# def test_checks(num_datapoints): -# incomplete_priors = {"lengthscale": jnp.array([1.0])} -# posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=num_datapoints) -# priors = prior_checks(incomplete_priors) -# assert "latent" in priors.keys() -# assert "variance" not in priors.keys() -# assert isinstance(priors["latent"], dx.Normal) - - -# def test_structure_priors(): -# posterior = Prior(kernel=RBF()) * Gaussian(num_datapoints=10) -# params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() -# priors = { -# "kernel": { -# "lengthscale": dx.Gamma(1.0, 1.0), -# "variance": dx.Gamma(2.0, 2.0), -# }, -# } -# structured_priors = structure_priors(params, priors) - -# def recursive_fn(d1, d2, fn: tp.Callable[[tp.Any], tp.Any]): -# for key, value in d1.items(): -# if type(value) is dict: -# yield from recursive_fn(value, d2[key], fn) -# else: -# yield fn(key, key) - -# for v in recursive_fn(params, structured_priors, lambda k1, k2: k1 == k2): -# assert v - - -# @pytest.mark.parametrize("latent_prior", [dx.Laplace(0.0, 1.0), dx.Laplace(0.0, 1.0)]) -# def test_prior_checks(latent_prior): -# priors = { -# "kernel": {"lengthscale": None, "variance": None}, -# "mean_function": {}, -# "liklelihood": {"variance": None}, -# "latent": None, -# } -# new_priors = prior_checks(priors) -# assert "latent" in new_priors.keys() -# assert isinstance(new_priors["latent"], dx.Normal) - -# priors = { -# "kernel": {"lengthscale": None, "variance": None}, -# "mean_function": {}, -# "liklelihood": {"variance": None}, -# } -# new_priors = prior_checks(priors) -# assert "latent" in new_priors.keys() -# assert isinstance(new_priors["latent"], dx.Normal) - -# priors = { -# "kernel": {"lengthscale": None, "variance": None}, -# "mean_function": {}, -# "liklelihood": {"variance": None}, -# "latent": latent_prior, -# } -# with pytest.warns(UserWarning): -# new_priors = prior_checks(priors) -# assert "latent" in new_priors.keys() -# assert isinstance(new_priors["latent"], dx.Laplace) - - -# ######################### -# # Test transforms -# ######################### -# @pytest.mark.parametrize("num_datapoints", [1, 10]) -# @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) -# def test_output(num_datapoints, likelihood): -# posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) -# params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() - -# assert isinstance(bijectors, dict) -# for k, v1, v2 in recursive_items(bijectors, bijectors): -# assert isinstance(v1.forward, tp.Callable) -# assert isinstance(v2.inverse, tp.Callable) - -# unconstrained_params = unconstrain(params, bijectors) -# assert ( -# unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] -# ) -# backconstrained_params = constrain(unconstrained_params, bijectors) -# for k, v1, v2 in recursive_items(params, unconstrained_params): -# assert v1.dtype == v2.dtype - -# for k, v1, v2 in recursive_items(params, backconstrained_params): -# assert all(v1 == v2) - -# augmented_params = params -# augmented_params["test_param"] = jnp.array([1.0]) -# a_bijectors = build_bijectors(augmented_params) - -# assert "test_param" in list(a_bijectors.keys()) -# assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 -# assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 +from jaxutils.parameters import Parameters +from jaxutils.bijectors import Softplus, Identity +import jax +import pytest +import jax.numpy as jnp +import distrax as dx +from jax.config import config +import typing as tp + +config.update("jax_enable_x64", True) + + +def build_params( + param_vals: tp.Dict, + set_priors: bool, + set_trainables: bool, + set_bijectors: bool, +) -> tp.Tuple[Parameters, tp.Dict]: + priors = {k: dx.Normal(0.0, 1.0) for k in param_vals.keys()} if set_priors else None + trainables = {k: True for k in param_vals.keys()} if set_trainables else None + bijections = {k: Identity for k in param_vals.keys()} if set_bijectors else None + params = Parameters( + params=param_vals, + priors=priors, + bijectors=bijections, + trainables=trainables, + ) + truth = { + "params": param_vals, + "priors": priors, + "trainables": trainables, + "bijectors": bijections, + } + return params, truth + + +@pytest.mark.parametrize("jit_compile", [False, True]) +def test_priors(jit_compile): + # Vanilla test for case where every parameter has a defined prior + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + priors = {"a": dx.Normal(0.0, 1.0), "b": dx.Normal(0.0, 1.0)} + + params = Parameters(params=param_vals, priors=priors) + if jit_compile: + lpd = jax.jit(params.log_prior_density)() + else: + lpd = params.log_prior_density() + assert pytest.approx(lpd, 0.00001) == -4.3378773 + assert isinstance(lpd, jax.Array) + + # Check fn. works for no priors + priors = {"a": None, "b": None} + params = Parameters(params=param_vals, priors=priors) + if jit_compile: + lpd = jax.jit(params.log_prior_density)() + else: + lpd = params.log_prior_density() + assert pytest.approx(lpd, 0.00001) == 0.0 + assert isinstance(lpd, jax.Array) + + # Check the fn. works for nested structures with incomplete priors + param_vals = { + "a": jnp.array([1.0]), + "b": {"a": jnp.array([10.0]), "b": jnp.array([3.0])}, + } + priors = {"a": None, "b": {"a": dx.Normal(0, 1.0), "b": dx.Gamma(2.0, 2.0)}} + params = Parameters(params=param_vals, priors=priors) + if jit_compile: + lpd = jax.jit(params.log_prior_density)() + else: + lpd = params.log_prior_density() + assert pytest.approx(lpd, 0.00001) == -54.434032 + assert isinstance(lpd, jax.Array) + + # Check the prior initialising works - by default, there are no priors + params = Parameters(param_vals) + if jit_compile: + lpd = jax.jit(params.log_prior_density)() + else: + lpd = params.log_prior_density() + assert pytest.approx(lpd, 0.00001) == 0.0 + assert isinstance(lpd, jax.Array) + + +def test_constrain_unconstrain(): + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + bijections = {"a": Softplus, "b": Softplus} + params = Parameters(params=param_vals, bijectors=bijections) + + unconstrain_fn = params.unconstrain + + unconstrained_params = unconstrain_fn() + + assert isinstance(unconstrained_params, Parameters) + assert isinstance(unconstrained_params.params, dict) + + constrain_fn = unconstrained_params.constrain + constrained_params = constrain_fn() + assert isinstance(unconstrained_params, Parameters) + assert isinstance(unconstrained_params.params, dict) + + assert constrained_params == params + + +def test_update_param(): + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + bijections = {"a": Softplus, "b": Softplus} + params = Parameters(params=param_vals, bijectors=bijections) + + updated_param_vals = {"a": jnp.array([2.0]), "b": jnp.array([3.0])} + updated_params = params.update_params(updated_param_vals) + + # Check the updated params are correct + assert updated_params.params == updated_param_vals + # Check that nothing else has changed + assert updated_params.bijectors == params.bijectors + assert updated_params.priors == params.priors + assert updated_params.trainables == params.trainables + + # Check that a key structure raises an error + with pytest.raises(ValueError): + updated_params = params.update_params({"a": jnp.array([2.0])}) + + +def test_bijector_update(): + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + bijections = {"a": Softplus, "b": Softplus} + params = Parameters(params=param_vals, bijectors=bijections) + + updated_bijections = {"a": Softplus, "b": Identity} + updated_params = params.update_bijectors(updated_bijections) + + # Check that bijections have been updated + assert updated_params.bijectors == updated_bijections + # Check all else is equal + assert updated_params == params + assert updated_params.trainables == params.trainables + assert updated_params.priors == params.priors + + # Check that a key structure raises an error + with pytest.raises(ValueError): + updated_params = params.update_params({"a": Identity}) + + +def test_trainables_update(): + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + trainables = {"a": True, "b": True} + params = Parameters(params=param_vals, trainables=trainables) + + updated_trainables = {"a": True, "b": False} + updated_params = params.update_trainables(updated_trainables) + + # Check that bijections have been updated + assert updated_params.trainables == updated_trainables + # Check all else is equal + assert updated_params == params + assert updated_params.bijectors == params.bijectors + assert updated_params.priors == params.priors + + # Check that a key structure raises an error + with pytest.raises(ValueError): + updated_params = params.update_trainables({"a": True}) + + +def test_priors_update(): + param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + priors = {"a": dx.Normal(0.0, 1.0), "b": dx.Normal(0.0, 1.0)} + params = Parameters(params=param_vals, priors=priors) + + updated_priors = {"a": dx.Normal(0.0, 1.0), "b": dx.Gamma(3.0, 3.0)} + updated_params = params.update_priors(updated_priors) + + # Check that bijections have been updated + assert updated_params.priors == updated_priors + # Check all else is equal + assert updated_params == params + assert updated_params.trainables == params.trainables + assert updated_params.bijectors == params.bijectors + + # Check that a key structure raises an error + with pytest.raises(ValueError): + updated_params = params.update_priors({"a": dx.Gamma(3.0, 3.0)}) + + +def param_equality(params, truth, set_priors, set_trainables, set_bijectors): + assert params["params"] == truth["params"] + + if set_trainables: + assert params["trainables"] == truth["trainables"] + else: + assert params["trainables"] == {k: True for k in truth["params"]} + + if set_bijectors: + assert params["bijectors"] == truth["bijectors"] + else: + assert params["bijectors"] == {k: Identity for k in truth["params"]} + + if set_priors: + assert params["priors"] == truth["priors"] + else: + assert params["priors"] == {k: None for k in truth["params"]} + + +@pytest.mark.parametrize("set_priors", [True, False]) +@pytest.mark.parametrize("set_trainables", [True, False]) +@pytest.mark.parametrize("set_bijectors", [True, False]) +def test_unpack(set_priors, set_trainables, set_bijectors): + init_param_vals = {"a": jnp.array([1.0]), "b": jnp.array([2.0])} + params, truth = build_params( + init_param_vals, + set_priors, + set_trainables, + set_bijectors, + ) + contents = params.unpack() + + param_equality(contents, truth, set_priors, set_trainables, set_bijectors) + assert isinstance(contents["params"], dict) + assert isinstance(contents["trainables"], dict) + assert isinstance(contents["bijectors"], dict) + assert isinstance(contents["priors"], dict) + + +@pytest.mark.parametrize("set_priors", [True, False]) +@pytest.mark.parametrize("set_trainables", [True, False]) +@pytest.mark.parametrize("set_bijectors", [True, False]) +def test_combine(set_priors, set_trainables, set_bijectors): + p1, truth1 = build_params( + {"a": jnp.array([1.0])}, set_priors, set_trainables, set_bijectors + ) + p2, truth2 = build_params( + {"b": jnp.array([2.0])}, set_priors, set_trainables, set_bijectors + ) + + p = p1.combine(p2, left_key="x", right_key="y") + assert isinstance(p, Parameters) + assert p.params == {"x": truth1["params"], "y": truth2["params"]} + assert list(p.keys()) == ["x", "y"] + + if set_trainables: + assert p.trainables == { + "x": truth1["trainables"], + "y": truth2["trainables"], + } + else: + assert p.trainables == {"x": {"a": True}, "y": {"b": True}} + + if set_bijectors: + assert p.bijectors == { + "x": truth1["bijectors"], + "y": truth2["bijectors"], + } + else: + assert p.bijectors == {"x": {"a": Identity}, "y": {"b": Identity}} + + if set_priors: + assert p.priors == {"x": truth1["priors"], "y": truth2["priors"]} + else: + assert p.priors == {"x": {"a": None}, "y": {"b": None}} + + +@pytest.mark.parametrize("prior", [dx.Normal(0, 1), dx.Gamma(2.0, 2.0), None]) +@pytest.mark.parametrize("trainable", [True, False]) +@pytest.mark.parametrize("bijector", [Softplus, Identity]) +def test_add_parameter(prior, trainable, bijector): + p = Parameters({"a": jnp.array([1.0])}) + p.add_parameter( + key="b", + value=jnp.array([2.0]), + prior=prior, + trainability=trainable, + bijector=bijector, + ) + + assert "b" in p.keys() + assert p["b"] == jnp.array([2.0]) + assert p.trainables["b"] == trainable + assert p.bijectors["b"] == bijector + assert p.priors["b"] == prior + + # Test adding a parameter with a parameter object + p = Parameters({"a": jnp.array([1.0])}) + p2 = Parameters( + params={"c": jnp.array([2.0])}, + bijectors={"c": bijector}, + priors={"c": prior}, + trainables={"c": trainable}, + ) + p.add_parameter( + key="b", + parameter=p2, + ) + + assert "b" in p.keys() + assert p["b"] == p2.params + assert p.trainables["b"] == p2.trainables + assert p.bijectors["b"] == p2.bijectors + assert p.priors["b"] == p2.priors + + # Check that trying to overwrite a parameter raises an error + with pytest.raises(ValueError): + p.add_parameter(key="b", parameter=p2) diff --git a/tests/test_pytree.py b/tests/test_pytree.py deleted file mode 100644 index 8efc699..0000000 --- a/tests/test_pytree.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Adapted from Distrax._src.utils.test_jittable.""" - -import pytest - -import jax -import jax.numpy as jnp -import numpy as np - -from typing import Any -from jaxtyping import Float, Array -from jaxutils import pytree - - -class DummyJittable(pytree.PyTree): - def __init__(self, values: Float[Array, "N"]): - self.name = "dummy" # Non-JAX property, cannot be traced. - self.data = {"params": values} # Tree property, must be traced recursively. - - -def test_jittable() -> None: - @jax.jit - def get_params(obj): - return obj.data["params"] - - obj = DummyJittable(jnp.ones((5,))) - np.testing.assert_array_equal(get_params(obj), obj.data["params"]) - - -def test_vmappable() -> None: - def do_sum(obj): - return obj.data["params"].sum() - - obj = DummyJittable(jnp.array([[1, 2, 3], [4, 5, 6]])) - - np.testing.assert_array_equal(do_sum(obj), obj.data["params"].sum()) - - np.testing.assert_array_equal( - jax.vmap(do_sum, in_axes=0)(obj), obj.data["params"].sum(axis=1) - ) - - np.testing.assert_array_equal( - jax.vmap(do_sum, in_axes=1)(obj), obj.data["params"].sum(axis=0) - ) - - -def test_traceable() -> None: - @jax.jit - def inner_fn(obj): - obj.data["params"] *= 3 # Modification after passing to jitted fn. - return obj.data["params"].sum() - - def loss_fn(params): - obj = DummyJittable(params) - obj.data["params"] *= 2 # Modification before passing to jitted fn. - return inner_fn(obj) - - params = np.ones((5,)) - # Both modifications will be traced if data tree is correctly traversed. - grad_expected = params * 2 * 3 - grad = jax.grad(loss_fn)(params) - np.testing.assert_array_equal(grad, grad_expected) - - params = jnp.ones((5,)) - # Both modifications will be traced if data tree is correctly traversed. - grad_expected = params * 2 * 3 - grad = jax.grad(loss_fn)(params) - np.testing.assert_array_equal(grad, grad_expected) - - -def test_different_jittables_to_compiled_function() -> None: - @jax.jit - def add_one_to_params(obj): - obj.data["params"] = obj.data["params"] + 1 - return obj - - add_one_to_params(DummyJittable(np.zeros((5,)))) - add_one_to_params(DummyJittable(np.ones((5,)))) - - add_one_to_params(DummyJittable(jnp.zeros((5,)))) - add_one_to_params(DummyJittable(jnp.ones((5,)))) - - -def test_modifying_object_data_does_not_leak_tracers() -> None: - @jax.jit - def add_one_to_params(obj): - obj.data["params"] = obj.data["params"] + 1 - return obj - - dummy = DummyJittable(jnp.ones((5,))) - dummy_out = add_one_to_params(dummy) - dummy_out.data["params"] -= 1 - - -def test_metadata_modification_statements_are_removed_by_compilation() -> None: - @jax.jit - def add_char_to_name(obj): - obj.name += "_x" - return obj - - dummy = DummyJittable(jnp.ones((5,))) - dummy_out = add_char_to_name(dummy) - dummy_out = add_char_to_name(dummy) # `name` change has been compiled out. - dummy_out.name += "y" - assert dummy_out.name == "dummy_xy" - - -@pytest.mark.parametrize("x", [1, 1.0, True, None]) -def test_is_jax_type(x: Any) -> None: - assert pytree.is_jax_type(x) == False diff --git a/tests/test_vscan.py b/tests/test_vscan.py new file mode 100644 index 0000000..cf1a9a2 --- /dev/null +++ b/tests/test_vscan.py @@ -0,0 +1,30 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxutils.scan import vscan + +import jax.numpy as jnp + +# TODO: Thorough checks on vscan. +def test_vscan(): + def body(c, x): + a, b = x + return c, a + b + + xs = (jnp.arange(10), jnp.arange(10)) + c, ys = vscan(body, 0, xs) + + assert c == 0 + assert jnp.all(ys == jnp.arange(10) * 2) diff --git a/versioneer.py b/versioneer.py index 18e34c2..ccc663b 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.28 """The Versioneer - like a rocketeer, but for versions. @@ -348,11 +347,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -365,8 +366,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py) + ) except NameError: pass return root @@ -384,9 +387,9 @@ def get_config_from_root(root): section = None if pyproject_toml.exists() and have_tomllib: try: - with open(pyproject_toml, 'rb') as fobj: + with open(pyproject_toml, "rb") as fobj: pp = tomllib.load(fobj) - section = pp['tool']['versioneer'] + section = pp["tool"]["versioneer"] except (tomllib.TOMLDecodeError, KeyError): pass if not section: @@ -398,7 +401,7 @@ def get_config_from_root(root): section = parser["versioneer"] cfg = VersioneerConfig() - cfg.VCS = section['VCS'] + cfg.VCS = section["VCS"] cfg.style = section.get("style", "") cfg.versionfile_source = section.get("versionfile_source") cfg.versionfile_build = section.get("versionfile_build") @@ -421,15 +424,16 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" HANDLERS.setdefault(vcs, {})[method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -445,10 +449,14 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError: e = sys.exc_info()[1] @@ -471,7 +479,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, process.returncode -LONG_VERSION_PY['git'] = r''' +LONG_VERSION_PY[ + "git" +] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1187,7 +1197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1196,7 +1206,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1204,24 +1214,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1243,8 +1260,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1252,10 +1268,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1270,8 +1295,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -1311,17 +1335,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1330,10 +1353,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1407,15 +1432,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1444,11 +1475,13 @@ def versions_from_file(filename): contents = f.read() except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1457,8 +1490,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1490,8 +1522,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1520,8 +1551,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1682,11 +1712,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1710,9 +1742,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1735,8 +1771,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1790,9 +1827,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1845,6 +1886,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in setuptools @@ -1866,8 +1908,8 @@ def run(self): # but the build_py command is not expected to copy any files. # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py = cmds['build_py'] + if "build_py" in cmds: + _build_py = cmds["build_py"] else: from setuptools.command.build_py import build_py as _build_py @@ -1884,14 +1926,14 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py - if 'build_ext' in cmds: - _build_ext = cmds['build_ext'] + if "build_ext" in cmds: + _build_ext = cmds["build_ext"] else: from setuptools.command.build_ext import build_ext as _build_ext @@ -1911,19 +1953,22 @@ def run(self): # it with an updated value if not cfg.versionfile_build: return - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) if not os.path.exists(target_versionfile): - print(f"Warning: {target_versionfile} does not exist, skipping " - "version update. This can happen if you are running build_ext " - "without first running build_py.") + print( + f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py." + ) return print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1944,17 +1989,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.setuptools_buildexe import py2exe as _py2exe except ImportError: @@ -1973,18 +2022,22 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # sdist farms its file list building out to egg_info - if 'egg_info' in cmds: - _egg_info = cmds['egg_info'] + if "egg_info" in cmds: + _egg_info = cmds["egg_info"] else: from setuptools.command.egg_info import egg_info as _egg_info @@ -1997,7 +2050,7 @@ def find_sources(self): # Modify the filelist and normalize it root = get_root() cfg = get_config_from_root(root) - self.filelist.append('versioneer.py') + self.filelist.append("versioneer.py") if cfg.versionfile_source: # There are rare cases where versionfile_source might not be # included by default, so we must be explicit @@ -2010,18 +2063,21 @@ def find_sources(self): # We will instead replicate their final normalization (to unicode, # and POSIX-style paths) from setuptools import unicode_utils - normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') - for f in self.filelist.files] - manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') - with open(manifest_filename, 'w') as fobj: - fobj.write('\n'.join(normalized)) + normalized = [ + unicode_utils.filesys_decode(f).replace(os.sep, "/") + for f in self.filelist.files + ] + + manifest_filename = os.path.join(self.egg_info, "SOURCES.txt") + with open(manifest_filename, "w") as fobj: + fobj.write("\n".join(normalized)) - cmds['egg_info'] = cmd_egg_info + cmds["egg_info"] = cmd_egg_info # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist = cmds['sdist'] + if "sdist" in cmds: + _sdist = cmds["sdist"] else: from setuptools.command.sdist import sdist as _sdist @@ -2043,8 +2099,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -2104,11 +2162,9 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -2117,15 +2173,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: