Skip to content

Commit

Permalink
Updated optimize_objective (input_constructors.py) to allow specifi…
Browse files Browse the repository at this point in the history
…cation of acquisition function as kwarg

Summary: Support an optional acquisition function instance as a keyword argument, useful for multi-fidelity acquisition functions.

Differential Revision: D62380369
  • Loading branch information
ltiao authored and facebook-github-bot committed Sep 9, 2024
1 parent 16853b4 commit 04dcb46
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@ def optimize_objective(
model: Model,
bounds: Tensor,
q: int,
acq_function: Optional[AcquisitionFunction] = None,
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
linear_constraints: Optional[tuple[Tensor, Tensor]] = None,
Expand Down Expand Up @@ -1574,18 +1575,21 @@ def optimize_objective(
if optimizer_options is None:
optimizer_options = {}

if objective is not None:
sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler
acq_function = qSimpleRegret(
model=model,
objective=objective,
posterior_transform=posterior_transform,
sampler=sampler_cls(sample_shape=torch.Size([mc_samples]), seed=seed_inner),
)
else:
acq_function = PosteriorMean(
model=model, posterior_transform=posterior_transform
)
if acq_function is None:
if objective is None:
acq_function = PosteriorMean(
model=model, posterior_transform=posterior_transform
)
else:
sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler
acq_function = qSimpleRegret(
model=model,
objective=objective,
posterior_transform=posterior_transform,
sampler=sampler_cls(
sample_shape=torch.Size([mc_samples]), seed=seed_inner
),
)

if fixed_features:
acq_function = FixedFeatureAcquisitionFunction(
Expand Down

0 comments on commit 04dcb46

Please sign in to comment.