Skip to content

Commit

Permalink
fix bug where variational GPs wouldn't use correct likelihoods; use b…
Browse files Browse the repository at this point in the history
…otorch inducing point selection; use botorch default mean/covar; update botorch/ax versions (#323)

Summary:
Pull Request resolved: #323

This ensures variational GPs will use correct likelihoods, and it moves some logic out of AEPsych and into botorch. Anecdotally, the botorch priors seem to produce better models, and since we are so resource-strapped on the AEPsych side, we should rely on botorch logic as much as possible.

Reviewed By: tymmsc

Differential Revision: D48891019

fbshipit-source-id: 24b288850a02c18f4515022e248d725ffe10279e
  • Loading branch information
crasanders authored and facebook-github-bot committed Dec 15, 2023
1 parent e79ef71 commit 9e16a29
Show file tree
Hide file tree
Showing 26 changed files with 519 additions and 276 deletions.
1 change: 1 addition & 0 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,5 @@ def from_config(cls, config: Config, name: Optional[str] = None):
Config.register_module(gpytorch.likelihoods)
Config.register_module(gpytorch.kernels)
Config.register_module(botorch.acquisition)
Config.register_module(botorch.acquisition.multi_objective)
Config.registered_names["None"] = None
2 changes: 0 additions & 2 deletions aepsych/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .manual_generator import ManualGenerator
from .monotonic_rejection_generator import MonotonicRejectionGenerator
from .monotonic_thompson_sampler_generator import MonotonicThompsonSamplerGenerator
from .multi_outcome_generator import MultiOutcomeOptimizationGenerator
from .optimize_acqf_generator import AxOptimizeAcqfGenerator, OptimizeAcqfGenerator
from .pairwise_optimize_acqf_generator import PairwiseOptimizeAcqfGenerator
from .pairwise_sobol_generator import PairwiseSobolGenerator
Expand All @@ -33,7 +32,6 @@
"AxOptimizeAcqfGenerator",
"AxSobolGenerator",
"IntensityAwareSemiPGenerator",
"MultiOutcomeOptimizationGenerator",
"AxRandomGenerator",
]

Expand Down
27 changes: 0 additions & 27 deletions aepsych/generators/multi_outcome_generator.py

This file was deleted.

7 changes: 2 additions & 5 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from aepsych.config import Config, ConfigurableMixin
from aepsych.generators.base import AEPsychGenerationStep, AEPsychGenerator
from aepsych.models.base import ModelProtocol
from aepsych.models.surrogate import AEPsychSurrogate
from ax.models.torch.botorch_modular.surrogate import Surrogate
from aepsych.utils_logging import getLogger
from ax.modelbridge import Models
from ax.modelbridge.registry import Cont_X_trans
Expand Down Expand Up @@ -152,14 +152,11 @@ def get_config_options(cls, config: Config, name: str) -> Dict:
acqf_options = cls._get_acqf_options(acqf_cls, config)
gen_options = cls._get_gen_options(config)

max_fit_time = model_options["max_fit_time"]

model_kwargs = {
"surrogate": AEPsychSurrogate(
"surrogate": Surrogate(
botorch_model_class=model_class,
mll_class=model_class.get_mll_class(),
model_options=model_options,
max_fit_time=max_fit_time,
),
"acquisition_class": AEPsychAcquisition,
"botorch_acqf_class": acqf_cls,
Expand Down
10 changes: 8 additions & 2 deletions aepsych/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .exact_gp import ContinuousRegressionGP, ExactGP
from .gp_classification import GPBetaRegressionModel, GPClassificationModel
from .gp_regression import GPRegressionModel
from .model_list import AEPsychModelListGP
from .monotonic_projection_gp import MonotonicProjectionGP
from .monotonic_rejection_gp import MonotonicRejectionGP
from .multitask_regression import IndependentMultitaskGPRModel, MultitaskGPRModel
Expand All @@ -21,8 +22,12 @@
semi_p_posterior_transform,
SemiParametricGPModel,
)
from .variational_gp import BetaRegressionGP, BinaryClassificationGP, OrdinalGP, VariationalGP

from .variational_gp import (
BetaRegressionGP,
BinaryClassificationGP,
OrdinalGP,
VariationalGP,
)


__all__ = [
Expand All @@ -44,6 +49,7 @@
"semi_p_posterior_transform",
"OrdinalGP",
"GPBetaRegressionModel",
"AEPsychModelListGP",
]

Config.register_module(sys.modules[__name__])
Loading

0 comments on commit 9e16a29

Please sign in to comment.