Skip to content

Commit

Permalink
wip: working on ensembling
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed May 23, 2022
1 parent ee2a4ed commit e0acc3b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 20 deletions.
7 changes: 5 additions & 2 deletions src/pyepal/pal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _get_max_wt( # pylint:disable=too-many-arguments
"""
max_uncertainty = -np.inf
maxid = 0

uncertainties = []
pooling_method = pooling_method.lower()

for i in range(0, len(unclassified_t)): # pylint:disable=consider-using-enumerate
Expand All @@ -316,11 +316,12 @@ def _get_max_wt( # pylint:disable=too-many-arguments
uncer = rectangle_ups[i, :] - rectangle_lows[i, :]

uncertainty = _pool(uncer, pooling_method)
uncertainties.append(uncertainty)
if uncertainty > max_uncertainty:
max_uncertainty = uncertainty
maxid = i

return maxid, max_uncertainty
return maxid, uncertainties


@jit(nopython=True)
Expand Down Expand Up @@ -355,6 +356,7 @@ def _get_max_wt_all( # pylint:disable=too-many-arguments
"""
max_uncertainty = -np.inf
maxid = 0
uncertainties = []

pooling_method = pooling_method.lower()

Expand All @@ -370,6 +372,7 @@ def _get_max_wt_all( # pylint:disable=too-many-arguments
else:
uncer = rectangle_ups[i, :] - rectangle_lows[i, :]
uncertainty = _pool(uncer, pooling_method)
uncertainties.append(uncertainty)
if uncertainty > max_uncertainty:
max_uncertainty = uncertainty
maxid = i
Expand Down
17 changes: 6 additions & 11 deletions src/pyepal/pal/pal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import warnings
from copy import deepcopy
from typing import Iterable, List, Union
from typing import Iterable, List, Union, Tuple

import numpy as np
from sklearn.metrics import mean_absolute_error
Expand Down Expand Up @@ -135,6 +135,7 @@ def __init__( # pylint:disable=too-many-arguments
# measurement_uncertainty is provided in update_train_set by the user
self.measurement_uncertainty = np.zeros((design_space_size, self.ndim))
self._has_train_set = False
self.pooling_method = pooling_method

def __repr__(self):
return f"pyepal at iteration {self.iteration}. \
Expand Down Expand Up @@ -513,7 +514,6 @@ def run_one_step( # pylint:disable=too-many-arguments
for _ in range(batch_size):
sampled_idx = self.sample(
exclude_idx=samples,
pooling_method=self.pooling_method,
sample_discarded=sample_discarded,
use_coef_var=use_coef_var,
)
Expand Down Expand Up @@ -705,21 +705,16 @@ def augment_design_space( # pylint: disable=invalid-name
def sample(
self,
exclude_idx: Union[np.array, None] = None,
pooling_method: str = "fro",
sample_discarded: bool = False,
use_coef_var: bool = True,
) -> int:
) -> Tuple[int, float]:
"""Runs the sampling step based on the size of the hyperrectangle.
I.e., favoring exploration.
Args:
exclude_idx (Union[np.array, None], optional):
Points in design space to exclude from sampling.
Defaults to None.
pooling_method (str): Method that is used to aggregate
the uncertainty in different objectives into one scalar.
Available options are: "fro" (Frobenius/Euclidean norm), "mean",
"median". Defaults to "fro".
sample_discarded (bool): if true, it will sample from all points
and not only from the unclassified and Pareto optimal ones
use_coef_var (bool): If True, uses the coefficient of variation instead of
Expand Down Expand Up @@ -754,7 +749,7 @@ def sample(
self.rectangle_ups,
self._means,
sampled_mask,
pooling_method,
self.pooling_method,
use_coef_var,
)
else:
Expand All @@ -765,8 +760,8 @@ def sample(
self.pareto_optimal,
self.unclassified,
sampled_mask,
pooling_method,
self.pooling_method,
use_coef_var,
)

return sampled_idx
return sampled_idx, _uncertainty
33 changes: 26 additions & 7 deletions src/pyepal/pal/pal_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,29 @@


class PALEnsemble:
def __init__(self, pal_list):
def __init__(self, pal_list, reuse_models=False):
self.pal_list = pal_list

# we just pick one class where we will update the models
self.head_pal = pal_list[0]
self.reuse_models = reuse_models

@classmethod
def from_class_and_kwarg_lists(pal_class, **kwargs):

# Throw error if there are no kwargs
if not kwargs:
raise ValueError("No kwargs provided")

pal_list = []
iterable_keys = []
for key, value in kwargs.items():
if isinstance(value, list, tuple):
iterable_keys.append(key)

# the problem is here that we would still need to account for the fact that some arguments by themselves are
# iterable, but not the others. The coding will be much easier if we just, for every model, accept its kwargs

if len(iterable_keys) == 0:
raise ValueError(
"No iterable keys found in kwargs. If you do not provide iterable keys, please use a single PAL instance."
Expand All @@ -42,7 +51,6 @@ def from_class_and_kwarg_lists(pal_class, **kwargs):
def run_one_step(
self,
batch_size: int = 1,
pooling_method: str = "fro",
sample_discarded: bool = False,
use_coef_var: bool = True,
replace_mean: bool = True,
Expand All @@ -51,23 +59,34 @@ def run_one_step(
samples = []
uncertainties = []
head_samples, head_uncertainties = self.head_pal.run_one_step(
batch_size, pooling_method, sample_discarded, use_coef_var, replace_mean, replace_std
batch_size, sample_discarded, use_coef_var, replace_mean, replace_std
)
samples.extend(head_samples)
uncertainties.extend(head_uncertainties)

if isinstance(head_samples, int):
head_samples = [head_samples]
if isinstance(head_uncertainties, float):
head_uncertainties = [head_uncertainties]
uncertainties.extend(head_uncertainties)
samples.extend(head_samples)

for pal in self.pal_list[1:]:
this_samples, this_uncertainties = pal.run_one_step(
batch_size,
pooling_method,
sample_discarded,
use_coef_var,
replace_mean,
replace_std,
replace_models=self.head_pal.models,
replacement_models=self.head_pal.models if self.reuse_models else None,
)

this_uncertainties = np.array(this_uncertainties)
this_uncertainties = (
this_uncertainties - this_uncertainties.mean()
) / this_uncertainties.std()
if isinstance(this_samples, int):
this_samples = [this_samples]
if isinstance(this_uncertainties, float):
this_uncertainties = [this_uncertainties]
samples.extend(this_samples)
uncertainties.extend(this_uncertainties)

Expand Down
26 changes: 26 additions & 0 deletions tests/test_pal_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pyepal.pal.pal_ensemble import PALEnsemble
import pytest
import numpy as np


def test_pal_ensemble_init(make_random_dataset):
from pyepal.pal.pal_gpy import PALGPy
from pyepal.models.gpr import build_model

X, y = make_random_dataset
sample_idx = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# with pytest.raises(ValueError):
# # Shouldn't work if there are no kwargs

# ensemble = PALEnsemble.from_class_and_kwarg_lists(PALGPy, [])
m0 = build_model(X, y, 0) # pylint:disable=invalid-name
m1 = build_model(X, y, 1) # pylint:disable=invalid-name
m2 = build_model(X, y, 2) # pylint:disable=invalid-name

palgpy_instance = PALGPy(X, models=[m0, m1, m2], ndim=3, delta=0.01, pooling_method="fro")
palgpy_instance_2 = PALGPy(X, models=[m0, m1, m2], ndim=3, delta=0.01, pooling_method="mean")

pal_ensemble = PALEnsemble([palgpy_instance, palgpy_instance_2])
pal_ensemble.update_train_set(sample_idx, y[sample_idx])
sample, _ = pal_ensemble.run_one_step(1)
assert len(sample) == 1

0 comments on commit e0acc3b

Please sign in to comment.