Skip to content

Commit

Permalink
Fix batch observer return type to by MultiObserver
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Jan 10, 2024
1 parent b4ff3f6 commit 7bdb0b3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import pickle
import tempfile
from typing import Callable, Tuple, Union
from typing import Callable, Mapping, Tuple, Union

import numpy.testing as npt
import pytest
Expand All @@ -36,14 +36,15 @@
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.data import Dataset
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
from trieste.types import State, Tag, TensorType

# Optimizer parameters for testing against the branin function.
# We use a copy of these for a quicker test against a simple quadratic function
Expand Down Expand Up @@ -212,7 +213,9 @@ def _test_ask_tell_optimization_finds_minima(

# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
new_data_point: Union[Mapping[Tag, Dataset], Dataset] = batch_observer(
new_point
)
else:
new_data_point = observer(new_point)

Expand Down
2 changes: 1 addition & 1 deletion trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse
def mk_batch_observer(
objective_or_observer: Union[Callable[[TensorType], TensorType], Observer],
default_key: Tag = OBJECTIVE,
) -> Observer:
) -> MultiObserver:
"""
Create an observer that returns the data from ``objective`` or an existing ``observer``
separately for each query point in a batch.
Expand Down

0 comments on commit 7bdb0b3

Please sign in to comment.