Skip to content

Commit

Permalink
Merge branch 'develop' into uri/update_contribution_guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Jan 15, 2024
2 parents 50bdae2 + 82f929d commit 4176eaf
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
9 changes: 6 additions & 3 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import pickle
import tempfile
from typing import Callable, Tuple, Union
from typing import Callable, Mapping, Tuple, Union

import numpy.testing as npt
import pytest
Expand All @@ -36,14 +36,15 @@
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.data import Dataset
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
from trieste.types import State, Tag, TensorType

# Optimizer parameters for testing against the branin function.
# We use a copy of these for a quicker test against a simple quadratic function
Expand Down Expand Up @@ -212,7 +213,9 @@ def _test_ask_tell_optimization_finds_minima(

# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
new_data_point: Union[Mapping[Tag, Dataset], Dataset] = batch_observer(
new_point
)
else:
new_data_point = observer(new_point)

Expand Down
28 changes: 28 additions & 0 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
from collections.abc import Mapping
from typing import Callable, Optional
from unittest.mock import ANY, MagicMock

import gpflow
import numpy as np
Expand Down Expand Up @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions(
)


def test_multi_trust_region_box_acquire_filters() -> None:
# Create some dummy models and datasets
models: Mapping[Tag, ANY] = {"global_tag": MagicMock()}
datasets: Mapping[Tag, ANY] = {
LocalizedTag("tag1", 1): MagicMock(),
LocalizedTag("tag1", 2): MagicMock(),
LocalizedTag("tag2", 1): MagicMock(),
LocalizedTag("tag2", 2): MagicMock(),
"global_tag": MagicMock(),
}

search_space = Box([0.0], [1.0])
mock_base_rule = MagicMock(spec=EfficientGlobalOptimization)
mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64)

# Create a BatchTrustRegionBox instance with the mock base_rule.
subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)]
rule: BatchTrustRegionBox[ProbabilisticModel] = BatchTrustRegionBox(subspaces, mock_base_rule)

rule.acquire(search_space, models, datasets)(None)

# Only the global tags should be passed to the base_rule acquire call.
mock_base_rule.acquire.assert_called_once_with(
ANY, models, {"global_tag": datasets["global_tag"]}
)


def test_multi_trust_region_box_state_deepcopy() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
dataset = Dataset(
Expand Down
18 changes: 15 additions & 3 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,8 @@ def acquire(
# Otherwise, run the base rule as is (i.e as a batch), once with all models and datasets.
# Note: this should only trigger on the first call to `acquire`, as after that we will
# have a list of rules in `self._rules`.
if self._rules is None and (
_num_local_models > 0 or not isinstance(self._rule, EfficientGlobalOptimization)
if self._rules is None and not (
_num_local_models == 0 and isinstance(self._rule, EfficientGlobalOptimization)
):
self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)]

Expand Down Expand Up @@ -1282,7 +1282,19 @@ def state_func(
_points.append(rule.acquire(subspace, _models, _datasets))
points = tf.stack(_points, axis=1)
else:
points = self._rule.acquire(acquisition_space, models, datasets)
# Filter out local datasets as this is a rule (currently only EGO) with normal
# acquisition functions that don't expect local datasets.
# Note: no need to filter out local models, as setups with local models
# are handled above (i.e. we run the base rule sequentially for each subspace).
if datasets is not None:
_datasets = {
tag: dataset
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
}
else:
_datasets = None
points = self._rule.acquire(acquisition_space, models, _datasets)

# We may modify the regions in filter_datasets later, so return a copy.
state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space))
Expand Down
2 changes: 1 addition & 1 deletion trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse
def mk_batch_observer(
objective_or_observer: Union[Callable[[TensorType], TensorType], Observer],
default_key: Tag = OBJECTIVE,
) -> Observer:
) -> MultiObserver:
"""
Create an observer that returns the data from ``objective`` or an existing ``observer``
separately for each query point in a batch.
Expand Down

0 comments on commit 4176eaf

Please sign in to comment.