Skip to content

Commit

Permalink
feat!: add gradient support for estimators
Browse files Browse the repository at this point in the history
Note: currently only when choosing the jax backend, gradients are
supported.
  • Loading branch information
spflueger committed Feb 9, 2021
1 parent 0b82775 commit 5d6f2bb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
21 changes: 18 additions & 3 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
from typing import Callable, Dict, List, Union

import numpy as np
import sympy
import tensorflow as tf

Expand Down Expand Up @@ -70,14 +69,26 @@ def __call__(self, new_parameters: Dict[str, float]) -> float:
log_lh = tf.reduce_sum(logs)
return -log_lh.numpy()

def gradient(self) -> np.ndarray:
def gradient(self, parameters: Dict[str, float]) -> Dict[str, float]:
raise NotImplementedError("Gradient not implemented.")

@property
def parameters(self) -> List[str]:
return list(self.__model.parameters.keys())


def _get_gradient_function(function, backend):
def not_implemented(parameters: Dict[str, float]) -> List[float]:
raise NotImplementedError("Gradient not implemented.")

if isinstance(backend, str) and backend == "jax":
import jax

return jax.grad(function)

return not_implemented


class SympyUnbinnedNLL(Estimator):
"""Unbinned negative log likelihood estimator.
Expand All @@ -101,6 +112,7 @@ def __init__(
backend: Union[str, tuple, dict] = "numpy",
) -> None:
processed_backend = process_backend_argument(backend)
self.__gradient = _get_gradient_function(self.__call__, backend)

model_expr = model.expression.doit()

Expand All @@ -116,7 +128,7 @@ def find_function_in_backend(name: str) -> Callable:
and name in processed_backend
):
return processed_backend[name]
elif isinstance(processed_backend, (tuple, list)):
if isinstance(processed_backend, (tuple, list)):
for module in processed_backend:
if name in module.__dict__:
return module.__dict__[name]
Expand Down Expand Up @@ -177,3 +189,6 @@ def __update_parameters(self, parameters: Dict[str, float]) -> None:
@property
def parameters(self) -> List[str]:
return list(self.__parameters.keys())

def gradient(self, parameters: Dict[str, float]) -> Dict[str, float]:
return self.__gradient(parameters)
4 changes: 4 additions & 0 deletions src/tensorwaves/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __call__(self, parameters: Dict[str, float]) -> float:
def parameters(self) -> Iterable[str]:
"""Get list of parameter names."""

@abstractmethod
def gradient(self, parameters: Dict[str, float]) -> Dict[str, float]:
"""Calculate gradient for given parameter mapping."""


class Kinematics(ABC):
"""Abstract interface for computation of kinematic variables."""
Expand Down
26 changes: 21 additions & 5 deletions src/tensorwaves/optimizer/minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from copy import deepcopy
from datetime import datetime
from typing import Dict, Optional
from typing import Dict, List, Optional

from iminuit import Minuit
from tqdm import tqdm
Expand All @@ -21,10 +21,15 @@ class Minuit2(Optimizer):
Implements the `~.interfaces.Optimizer` interface.
"""

def __init__(self, callback: Optional[Callback] = None) -> None:
def __init__(
self,
callback: Optional[Callback] = None,
use_analytic_gradient: bool = False,
) -> None:
self.__callback: Callback = CallbackList([])
if callback is not None:
self.__callback = callback
self.__use_gradient = use_analytic_gradient

def optimize(
self, estimator: Estimator, initial_parameters: Dict[str, float]
Expand All @@ -33,11 +38,14 @@ def optimize(
progress_bar = tqdm()
n_function_calls = 0

def update_parameters(pars: list) -> None:
for i, k in enumerate(parameters.keys()):
parameters[k] = pars[i]

def wrapped_function(pars: list) -> float:
nonlocal n_function_calls
n_function_calls += 1
for i, k in enumerate(parameters.keys()):
parameters[k] = pars[i]
update_parameters(pars)
estimator_value = estimator(parameters)
progress_bar.set_postfix({"estimator": estimator_value})
progress_bar.update()
Expand All @@ -54,15 +62,23 @@ def wrapped_function(pars: list) -> float:
self.__callback.on_iteration_end(n_function_calls, logs)
return estimator_value

def wrapped_gradient(pars: list) -> List[float]:
update_parameters(pars)
grad = estimator.gradient(parameters)
return [grad[x] for x in parameters.keys()]

minuit = Minuit(
wrapped_function,
tuple(parameters.values()),
grad=wrapped_gradient if self.__use_gradient else None,
name=tuple(parameters),
)
minuit.errors = tuple(
0.1 * x if x != 0.0 else 0.1 for x in parameters.values()
)
minuit.errordef = Minuit.LIKELIHOOD
minuit.errordef = (
Minuit.LIKELIHOOD
) # that error definition should be defined in the estimator

start_time = time.time()
minuit.migrad()
Expand Down

0 comments on commit 5d6f2bb

Please sign in to comment.