Skip to content

Commit

Permalink
Use flatten_... func and add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Dec 13, 2023
1 parent 1de7e20 commit a75e618
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,8 @@ def update(
def _get_tags(self, tags: Set[Tag]) -> Tuple[Set[Tag], Set[Tag]]:
# Separate tags into local (matching index) and global tags (without matching
# local tag).
local_gtags = set()
global_tags = set()
local_gtags = set() # Set of global part of all local tags.
global_tags = set() # Set of all global tags.
for tag in tags:
ltag = LocalizedTag.from_tag(tag)
if not ltag.is_local:
Expand Down
9 changes: 4 additions & 5 deletions trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
from collections.abc import Callable
from typing import Mapping, Optional, Union, overload

import tensorflow as tf
from check_shapes import check_shapes

from ..data import Dataset
from ..observer import OBJECTIVE, MultiObserver, Observer, SingleObserver
from ..types import Tag, TensorType
from ..utils.misc import LocalizedTag
from ..utils.misc import LocalizedTag, flatten_leading_dims


@overload
Expand Down Expand Up @@ -83,7 +82,7 @@ def _observer(qps: TensorType) -> Mapping[Tag, Dataset]:
# Call objective with rank 2 query points by flattening batch dimension.
# Some objectives might only expect rank 2 query points, so this is safer.
batch_size = qps.shape[1]
flat_qps = tf.reshape(qps, [-1, qps.shape[-1]])
flat_qps, unflatten = flatten_leading_dims(qps)
obs_or_dataset = objective_or_observer(flat_qps)

if not isinstance(obs_or_dataset, (Mapping, Dataset)):
Expand All @@ -98,8 +97,8 @@ def _observer(qps: TensorType) -> Mapping[Tag, Dataset]:
for key, dataset in obs_or_dataset.items():
# Include overall dataset and per batch dataset.
flat_obs = dataset.observations
qps = tf.reshape(flat_qps, [-1, batch_size, flat_qps.shape[-1]])
obs = tf.reshape(flat_obs, [-1, batch_size, flat_obs.shape[-1]])
qps = unflatten(flat_qps)
obs = unflatten(flat_obs)
datasets[key] = dataset
for i in range(batch_size):
datasets[LocalizedTag(key, i)] = Dataset(qps[:, i], obs[:, i])
Expand Down

0 comments on commit a75e618

Please sign in to comment.