diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 1287c70393..b73c1794a8 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -16,6 +16,7 @@ import copy from collections.abc import Mapping from typing import Callable, Optional +from unittest.mock import ANY, MagicMock import gpflow import numpy as np @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions( ) +def test_multi_trust_region_box_acquire_filters() -> None: + # Create some dummy models and datasets + models: Mapping[Tag, ANY] = {"global_tag": MagicMock()} + datasets: Mapping[Tag, ANY] = { + LocalizedTag("tag1", 1): MagicMock(), + LocalizedTag("tag1", 2): MagicMock(), + LocalizedTag("tag2", 1): MagicMock(), + LocalizedTag("tag2", 2): MagicMock(), + "global_tag": MagicMock(), + } + + search_space = Box([0.0], [1.0]) + mock_base_rule = MagicMock(spec=EfficientGlobalOptimization) + mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64) + + # Create a BatchTrustRegionBox instance with the mock base_rule. + subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)] + rule = BatchTrustRegionBox(subspaces, mock_base_rule) # type: ignore[var-annotated] + + rule.acquire(search_space, models, datasets)(None) + + # Only the global tags should be passed to the base_rule acquire call. + mock_base_rule.acquire.assert_called_once_with( + ANY, models, {"global_tag": datasets["global_tag"]} + ) + + def test_multi_trust_region_box_state_deepcopy() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index 1ed6353303..fa7b695746 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -1282,7 +1282,19 @@ def state_func( _points.append(rule.acquire(subspace, _models, _datasets)) points = tf.stack(_points, axis=1) else: - points = self._rule.acquire(acquisition_space, models, datasets) + # Filter out local datasets as this is the EGO rule with normal acquisition + # functions that don't expect local datasets. + # Note: no need to filter out local models, as setups with local models + # are handled above (i.e. we run the base rule sequentially for each subspace). + if datasets is not None: + _datasets = { + tag: dataset + for tag, dataset in datasets.items() + if not LocalizedTag.from_tag(tag).is_local + } + else: + _datasets = None + points = self._rule.acquire(acquisition_space, models, _datasets) # We may modify the regions in filter_datasets later, so return a copy. state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space))