Skip to content

Commit

Permalink
Only pass global datasets to EGO parallel rule
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Jan 10, 2024
1 parent 7bdb0b3 commit 683c38d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
28 changes: 28 additions & 0 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
from collections.abc import Mapping
from typing import Callable, Optional
from unittest.mock import ANY, MagicMock

import gpflow
import numpy as np
Expand Down Expand Up @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions(
)


def test_multi_trust_region_box_acquire_filters() -> None:
# Create some dummy models and datasets
models: Mapping[Tag, ANY] = {"global_tag": MagicMock()}
datasets: Mapping[Tag, ANY] = {
LocalizedTag("tag1", 1): MagicMock(),
LocalizedTag("tag1", 2): MagicMock(),
LocalizedTag("tag2", 1): MagicMock(),
LocalizedTag("tag2", 2): MagicMock(),
"global_tag": MagicMock(),
}

search_space = Box([0.0], [1.0])
mock_base_rule = MagicMock(spec=EfficientGlobalOptimization)
mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64)

# Create a BatchTrustRegionBox instance with the mock base_rule.
subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)]
rule = BatchTrustRegionBox(subspaces, mock_base_rule) # type: ignore[var-annotated]

rule.acquire(search_space, models, datasets)(None)

# Only the global tags should be passed to the base_rule acquire call.
mock_base_rule.acquire.assert_called_once_with(
ANY, models, {"global_tag": datasets["global_tag"]}
)


def test_multi_trust_region_box_state_deepcopy() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
dataset = Dataset(
Expand Down
14 changes: 13 additions & 1 deletion trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,19 @@ def state_func(
_points.append(rule.acquire(subspace, _models, _datasets))
points = tf.stack(_points, axis=1)
else:
points = self._rule.acquire(acquisition_space, models, datasets)
# Filter out local datasets as this is the EGO rule with normal acquisition
# functions that don't expect local datasets.
# Note: no need to filter out local models, as setups with local models
# are handled above (i.e. we run the base rule sequentially for each subspace).
if datasets is not None:
_datasets = {
tag: dataset
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
}
else:
_datasets = None
points = self._rule.acquire(acquisition_space, models, _datasets)

# We may modify the regions in filter_datasets later, so return a copy.
state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space))
Expand Down

0 comments on commit 683c38d

Please sign in to comment.