Skip to content

Commit

Permalink
Add support for custom initial_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 31, 2024
1 parent c918ed2 commit a281406
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions trieste/acquisition/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from __future__ import annotations

from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
import itertools
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast

import greenlet as gr
import numpy as np
Expand Down Expand Up @@ -168,12 +169,63 @@ def optimize_discrete(
return _get_max_discrete_points(points, target_func)


InitialPointSampler = Callable[[SearchSpace, int, int], TensorType]
"""
Type alias for a function that generates initial samples for an optimizer.
Takes a search space, the number of samples to generate, and an offset, and returns
that many samples (at the offset).
"""


def sample_from_space(space: SearchSpace, num_samples: int, offset: int) -> TensorType:
"""Default initial point sampler that samples randomly from the space."""
return space.sample(num_samples)


def sample_from_sequence(sequence: Sequence[TensorType]) -> InitialPointSampler:
"""
Initial point sampler that returns points from a prebuilt sequence (e.g. a list or Tensor).
Raises a :exc:`ValueError` if there are insufficiently many points to return.
"""

def _sampler(space: SearchSpace, num_samples: int, offset: int) -> TensorType:
if len(sequence) < offset + num_samples:
raise ValueError(
f"Insufficient samples ({offset+num_samples} required, {len(sequence)} available"
)
samples = sequence[offset : offset + num_samples]
if not isinstance(slice, tf.Tensor):
samples = tf.concat(samples, axis=0)
return samples

return _sampler


def sample_from_iterator(iterator: Iterator[TensorType]) -> InitialPointSampler:
"""
Initial point sampler that returns points from a prebuilt iterator (e.g. a generator).
Raises a :exc:`ValueError` if the iterator is exhausted with insufficiently many points.
"""

def _sampler(space: SearchSpace, num_samples: int, offset: int) -> TensorType:
slice = itertools.islice(iterator, num_samples)
samples = tf.concat(list(slice), axis=0)
if len(samples) < num_samples:
raise RuntimeError(
f"Insufficient samples ({num_samples} requested, {len(samples)} available"
)
return samples

return _sampler


def generate_continuous_optimizer(
num_initial_samples: int = NUM_SAMPLES_MIN,
num_optimization_runs: int = 10,
num_recovery_runs: int = 10,
optimizer_args: Optional[dict[str, Any]] = None,
split_initial_samples: Optional[int] = 100_000,
initial_sampler: Callable[[SearchSpace, int, int], TensorType] = sample_from_space,
) -> AcquisitionOptimizer[Box | CollectionSearchSpace]:
"""
Generate a gradient-based optimizer for :class:'Box' and :class:'CollectionSearchSpace'
Expand Down Expand Up @@ -205,6 +257,10 @@ def generate_continuous_optimizer(
can be passed. Note that method, jac and bounds cannot/should not be changed.
:param split_initial_samples: Maximum number of initial samples to process at a time.
Decreasing this can reduce memory usage at the start, at the slight cost of performance.
:param initial_sampler: Function for generating initial samples. This should accept
a search space, the number of samples to generate and an offset. It may be called multiple
times (with different offsets) if `split_initial_samples` is specified. By default,
samples are generated by calling `space.sample(num_samples)`.
:return: The acquisition optimizer.
"""
if num_initial_samples <= 0:
Expand Down Expand Up @@ -271,7 +327,7 @@ def optimize_continuous(
samples = max(min(samples_left, split_initial_samples // V), 1)
samples_left -= samples

candidates = space.sample(samples)
candidates = initial_sampler(space, samples, num_initial_samples - samples_left)
if tf.rank(candidates) == 3:
# If samples is a tensor of rank 3, then it is a batch of samples. In this case
# the vectorization of the target function must be a multiple of the length of the
Expand Down

0 comments on commit a281406

Please sign in to comment.