Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter-out local datasets when calling base-rule #805

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import pickle
import tempfile
from typing import Callable, Tuple, Union
from typing import Callable, Mapping, Tuple, Union

import numpy.testing as npt
import pytest
Expand All @@ -36,14 +36,15 @@
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.data import Dataset
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
from trieste.types import State, Tag, TensorType

# Optimizer parameters for testing against the branin function.
# We use a copy of these for a quicker test against a simple quadratic function
Expand Down Expand Up @@ -212,7 +213,9 @@ def _test_ask_tell_optimization_finds_minima(

# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
new_data_point: Union[Mapping[Tag, Dataset], Dataset] = batch_observer(
new_point
)
else:
new_data_point = observer(new_point)

Expand Down
28 changes: 28 additions & 0 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
from collections.abc import Mapping
from typing import Callable, Optional
from unittest.mock import ANY, MagicMock

import gpflow
import numpy as np
Expand Down Expand Up @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions(
)


def test_multi_trust_region_box_acquire_filters() -> None:
# Create some dummy models and datasets
models: Mapping[Tag, ANY] = {"global_tag": MagicMock()}
datasets: Mapping[Tag, ANY] = {
LocalizedTag("tag1", 1): MagicMock(),
LocalizedTag("tag1", 2): MagicMock(),
LocalizedTag("tag2", 1): MagicMock(),
LocalizedTag("tag2", 2): MagicMock(),
"global_tag": MagicMock(),
}

search_space = Box([0.0], [1.0])
mock_base_rule = MagicMock(spec=EfficientGlobalOptimization)
mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64)

# Create a BatchTrustRegionBox instance with the mock base_rule.
subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)]
rule = BatchTrustRegionBox(subspaces, mock_base_rule) # type: ignore[var-annotated]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(what's the mypy error here?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is:
tests/unit/acquisition/test_rule.py:1819: error: Need type annotation for "rule" [var-annotated]

It wants the type of the rule spelled out, as it seems it can't figure that out iteself (via the base_rule). This is a standard error in most instantiations of TR rules as they have a base rule (I think). This can be avoided by using the following (for this instance). However, for unit tests I chose to be not that verbose; we instance the rules in many places.

rule: BatchTrustRegionBox[ProbabilisticModel] = BatchTrustRegionBox(subspaces, mock_base_rule)


rule.acquire(search_space, models, datasets)(None)

# Only the global tags should be passed to the base_rule acquire call.
mock_base_rule.acquire.assert_called_once_with(
ANY, models, {"global_tag": datasets["global_tag"]}
)


def test_multi_trust_region_box_state_deepcopy() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
dataset = Dataset(
Expand Down
14 changes: 13 additions & 1 deletion trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,19 @@ def state_func(
_points.append(rule.acquire(subspace, _models, _datasets))
points = tf.stack(_points, axis=1)
else:
points = self._rule.acquire(acquisition_space, models, datasets)
# Filter out local datasets as this is the EGO rule with normal acquisition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that self._rules implies (EGO + no local models) isn't obvious here. Maybe rename ._rules to
._subspace_rules and change the initialisation (above) to:

        if self._rules is None and not (
            _num_local_models == 0 and isinstance(self._rule, EfficientGlobalOptimization)
        ):

?

Copy link
Collaborator Author

@khurram-ghani khurram-ghani Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about renaming to _subspace_rules, but am happy to apply De Morgan's law to the condition expression.

Also, I'll change the comment slightly to be not so definitive about being EGO. That is the case for now, but may support more rules in the future.

# functions that don't expect local datasets.
# Note: no need to filter out local models, as setups with local models
# are handled above (i.e. we run the base rule sequentially for each subspace).
if datasets is not None:
_datasets = {
tag: dataset
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
}
else:
_datasets = None
points = self._rule.acquire(acquisition_space, models, _datasets)

# We may modify the regions in filter_datasets later, so return a copy.
state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space))
Expand Down
2 changes: 1 addition & 1 deletion trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse
def mk_batch_observer(
objective_or_observer: Union[Callable[[TensorType], TensorType], Observer],
default_key: Tag = OBJECTIVE,
) -> Observer:
) -> MultiObserver:
"""
Create an observer that returns the data from ``objective`` or an existing ``observer``
separately for each query point in a batch.
Expand Down
Loading