Skip to content

Commit

Permalink
Add input constructor for qMultiFidelityHypervolumeKnowledgeGradient (p…
Browse files Browse the repository at this point in the history
…ytorch#2524)

Summary:
Pull Request resolved: pytorch#2524

Adds new input constructors for qMultiFidelityHypervolumeKnowledgeGradient.

Reviewed By: Balandat

Differential Revision: D62459735
  • Loading branch information
ltiao authored and facebook-github-bot committed Sep 12, 2024
1 parent 127fc08 commit 9e8b786
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 84 deletions.
99 changes: 84 additions & 15 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
_get_hv_value_function,
qHypervolumeKnowledgeGradient,
qMultiFidelityHypervolumeKnowledgeGradient,
)
from botorch.acquisition.multi_objective.logei import (
qLogExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -1274,21 +1275,6 @@ def construct_inputs_qKG(
return inputs_qkg


def _get_ref_point(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
) -> Tensor:

if objective is None:
ref_point = objective_thresholds
elif isinstance(objective, RiskMeasureMCObjective):
ref_point = objective.preprocessing_function(objective_thresholds)
else:
ref_point = objective(objective_thresholds)

return ref_point


@acqf_input_constructor(qHypervolumeKnowledgeGradient)
def construct_inputs_qHVKG(
model: Model,
Expand Down Expand Up @@ -1381,6 +1367,74 @@ def construct_inputs_qMFKG(
}


@acqf_input_constructor(qMultiFidelityHypervolumeKnowledgeGradient)
def construct_inputs_qMFHVKG(
model: Model,
training_data: MaybeDict[SupervisedDataset],
bounds: list[tuple[float, float]],
target_fidelities: dict[int, Union[int, float]],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
fidelity_weights: Optional[dict[int, float]] = None,
cost_intercept: float = 1.0,
num_trace_observations: int = 0,
num_fantasies: int = 8,
num_pareto: int = 10,
**optimize_objective_kwargs: TOptimizeObjectiveKwargs,
) -> dict[str, Any]:
r"""Construct kwargs for `qMultiFidelityHypervolumeKnowledgeGradient` constructor."""

inputs_mf = construct_inputs_mf_base(
target_fidelities=target_fidelities,
fidelity_weights=fidelity_weights,
cost_intercept=cost_intercept,
num_trace_observations=num_trace_observations,
)

if num_trace_observations > 0:
raise NotImplementedError(
"Trace observations are not currently supported "
"by `qMultiFidelityHypervolumeKnowledgeGradient`."
)

del inputs_mf["expand"]

X = _get_dataset_field(training_data, "X", first_only=True)
_bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device)

ref_point = _get_ref_point(
objective_thresholds=objective_thresholds, objective=objective
)

acq_function = _get_hv_value_function(
model=model,
ref_point=ref_point,
use_posterior_mean=True,
objective=objective,
)

_, current_value = optimize_objective(
model=model,
bounds=_bounds.t(),
q=num_pareto,
acq_function=acq_function,
fixed_features=target_fidelities,
**optimize_objective_kwargs,
)

return {
"model": model,
"objective": objective,
"ref_point": ref_point,
"num_fantasies": num_fantasies,
"num_pareto": num_pareto,
"current_value": current_value.detach().cpu().max(),
"target_fidelities": target_fidelities,
**inputs_mf,
}


@acqf_input_constructor(qMultiFidelityMaxValueEntropy)
def construct_inputs_qMFMES(
model: Model,
Expand Down Expand Up @@ -1806,3 +1860,18 @@ def construct_inputs_NIPV(
"posterior_transform": posterior_transform,
}
return inputs


def _get_ref_point(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
) -> Tensor:

if objective is None:
ref_point = objective_thresholds
elif isinstance(objective, RiskMeasureMCObjective):
ref_point = objective.preprocessing_function(objective_thresholds)
else:
ref_point = objective(objective_thresholds)

return ref_point
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def __init__(
)
self.project = project
if kwargs.get("expand") is not None:
raise NotImplementedError("Trace observations are not currently supported.")
raise NotImplementedError(
"Trace observations are not currently supported "
"by `qMultiFidelityHypervolumeKnowledgeGradient`."
)
self.expand = lambda X: X
self.valfunc_cls = valfunc_cls
self.valfunc_argfac = valfunc_argfac
Expand Down
Loading

0 comments on commit 9e8b786

Please sign in to comment.