diff --git a/bayeux/_src/optimize/optimistix.py b/bayeux/_src/optimize/optimistix.py new file mode 100644 index 0000000..8b34dac --- /dev/null +++ b/bayeux/_src/optimize/optimistix.py @@ -0,0 +1,110 @@ +# Copyright 2023 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""optimistix specific code.""" +from bayeux._src.optimize import shared +import optimistix + + +class _OptimistixOptimizer(shared.Optimizer): + """Base class for optimistix optimizers.""" + + def get_kwargs(self, **kwargs): + kwargs = self.default_kwargs() | kwargs + solver = getattr(optimistix, self.optimizer) + minimise_kwargs = shared.get_optimizer_kwargs( + optimistix.minimise, kwargs, ignore_required={"y0", "solver", "fn"}) + for k in minimise_kwargs: + if k in kwargs: + minimise_kwargs[k] = kwargs[k] + extra_parameters = shared.get_extra_kwargs(kwargs) + _ = extra_parameters.pop("num_iters") + return {solver: shared.get_optimizer_kwargs(solver, kwargs), + optimistix.minimise: minimise_kwargs, + "extra_parameters": extra_parameters} + + def default_kwargs(self) -> dict[str, float]: + return {"rtol": 1e-5, "atol": 1e-5} + + def _prep_args(self, seed, kwargs): + fun, initial_state, apply_transform = super()._prep_args(seed, kwargs) + def f(x, _): + return fun(x) + return f, initial_state, apply_transform + + def __call__(self, seed, **kwargs): + kwargs = self.get_kwargs(**kwargs) + fun, initial_state, apply_transform = self._prep_args(seed, kwargs) + + solver_fn = getattr(optimistix, self.optimizer) + def run(x0): + solver = solver_fn(**kwargs[solver_fn]) + return optimistix.minimise( + fn=fun, + solver=solver, + y0=x0, + **kwargs[optimistix.minimise]).value + chain_method = kwargs["extra_parameters"]["chain_method"] + mapped_run = self._map_optimizer(chain_method, run) + ret = mapped_run(initial_state) + if apply_transform: + return shared.OptimizerResults( + params=self.transform_fn(ret), state=None, loss=None) + else: + return shared.OptimizerResults(ret, state=None, loss=None) + + +class BFGS(_OptimistixOptimizer): + name = "optimistix_bfgs" + optimizer = "BFGS" + + +class Chord(_OptimistixOptimizer): + name = "optimistix_chord" + optimizer = "Chord" + + +class Dogleg(_OptimistixOptimizer): + name = "optimistix_dogleg" + optimizer = "Dogleg" + + +class GaussNewton(_OptimistixOptimizer): + name = "optimistix_gauss_newton" + optimizer = "GaussNewton" + + +class IndirectLevenbergMarquardt(_OptimistixOptimizer): + name = "optimistix_indirect_levenberg_marquardt" + optimizer = "IndirectLevenbergMarquardt" + + +class LevenbergMarquardt(_OptimistixOptimizer): + name = "optimistix_levenberg_marquardt" + optimizer = "LevenbergMarquardt" + + +class NelderMead(_OptimistixOptimizer): + name = "optimistix_nelder_mead" + optimizer = "NelderMead" + + +class Newton(_OptimistixOptimizer): + name = "optimistix_newton" + optimizer = "Newton" + + +class NonlinearCG(_OptimistixOptimizer): + name = "optimistix_nonlinear_cg" + optimizer = "NonlinearCG" diff --git a/bayeux/optimize/__init__.py b/bayeux/optimize/__init__.py index 620ac6b..3601da1 100644 --- a/bayeux/optimize/__init__.py +++ b/bayeux/optimize/__init__.py @@ -26,6 +26,28 @@ from bayeux._src.optimize.jaxopt import NonlinearCG __all__.extend(["BFGS", "GradientDescent", "LBFGS", "NonlinearCG"]) +if importlib.util.find_spec("optimistix") is not None: + from bayeux._src.optimize.optimistix import BFGS as optimistix_BFGS + from bayeux._src.optimize.optimistix import Chord + from bayeux._src.optimize.optimistix import Dogleg + from bayeux._src.optimize.optimistix import GaussNewton + from bayeux._src.optimize.optimistix import IndirectLevenbergMarquardt + from bayeux._src.optimize.optimistix import LevenbergMarquardt + from bayeux._src.optimize.optimistix import NelderMead + from bayeux._src.optimize.optimistix import Newton + from bayeux._src.optimize.optimistix import NonlinearCG as optimistix_NonlinearCG + + __all__.extend([ + "optimistix_BFGS", + "Chord", + "Dogleg", + "GaussNewton", + "IndirectLevenbergMarquardt", + "LevenbergMarquardt", + "NelderMead", + "Newton", + "optimistix_NonlinearCG"]) + if importlib.util.find_spec("optax") is not None: from bayeux._src.optimize.optax import AdaBelief from bayeux._src.optimize.optax import Adafactor diff --git a/bayeux/tests/optimize_test.py b/bayeux/tests/optimize_test.py index 0d7d00a..8df48cd 100644 --- a/bayeux/tests/optimize_test.py +++ b/bayeux/tests/optimize_test.py @@ -67,14 +67,28 @@ def test_optimizers(method, linear_model): # pylint: disable=redefined-outer-na else: num_iters = 1_000 + if method.startswith("optimistix"): + num_iters = 10_000 # should stop automatically before then + atol = 0.2 + else: + atol = 1e-2 + assert optimizer.debug(seed=seed, verbosity=0) num_particles = 6 params = optimizer( - seed=seed, num_particles=num_particles, num_iters=num_iters).params + seed=seed, + num_particles=num_particles, + num_iters=num_iters, + atol=atol, + max_steps=num_iters, + throw=False).params expected = np.repeat(solution[..., np.newaxis], num_particles, axis=-1).T - if method != "optax_adafactor": - np.testing.assert_allclose(expected, params.w, atol=1e-2) + if method not in { + "optax_adafactor", + "optimistix_chord", + "optimistix_nelder_mead"}: + np.testing.assert_allclose(expected, params.w, atol=atol) def test_initial_state(): diff --git a/pyproject.toml b/pyproject.toml index 872b090..1c311b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "oryx>=0.2.5", "arviz", "optax", + "optimistix", "blackjax", "numpyro", "jaxopt",