diff --git a/aepsych/__init__.py b/aepsych/__init__.py index 34427e6ea..47960f050 100644 --- a/aepsych/__init__.py +++ b/aepsych/__init__.py @@ -11,7 +11,16 @@ from gpytorch.likelihoods import BernoulliLikelihood, GaussianLikelihood -from . import acquisition, config, factory, generators, models, strategy, utils +from . import ( + acquisition, + config, + factory, + generators, + models, + strategy, + transforms, + utils, +) from .config import Config from .likelihoods import BernoulliObjectiveLikelihood from .models import GPClassificationModel @@ -26,6 +35,7 @@ "factory", "models", "strategy", + "transforms", "utils", "generators", # classes diff --git a/aepsych/benchmark/benchmark.py b/aepsych/benchmark/benchmark.py index 834551e4b..917690b7e 100644 --- a/aepsych/benchmark/benchmark.py +++ b/aepsych/benchmark/benchmark.py @@ -77,6 +77,7 @@ def make_benchmark_list(self, **bench_config) -> List[Dict[str, float]]: List[dict[str, float]]: List of dictionaries, each of which can be passed to aepsych.config.Config. """ + # This could be a generator but then we couldn't # know how many params we have, tqdm wouldn't work, etc, # so we materialize the full list. @@ -154,6 +155,9 @@ def run_experiment( np.random.seed(seed) config_dict["common"]["lb"] = str(problem.lb.tolist()) config_dict["common"]["ub"] = str(problem.ub.tolist()) + config_dict["common"]["parnames"] = str( + [f"par{i}" for i in range(len(problem.ub.tolist()))] + ) config_dict["problem"] = problem.metadata materialized_config = self.materialize_config(config_dict) diff --git a/aepsych/config.py b/aepsych/config.py index 212b6e3d8..c10302648 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -6,12 +6,23 @@ # LICENSE file in the root directory of this source tree. import abc import ast -import re import configparser import json +import re import warnings from types import ModuleType -from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, TypeVar +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Dict, + List, + Mapping, + Optional, + Sequence, + TypeVar, +) import botorch import gpytorch import numpy as np @@ -21,6 +32,7 @@ _T = TypeVar("_T") + class Config(configparser.ConfigParser): # names in these packages can be referred to by string name @@ -75,7 +87,6 @@ def _get( fallback=configparser._UNSET, **kwargs, ): - """ Override configparser to: 1. Return from common if a section doesn't exist. This comes @@ -107,8 +118,8 @@ def _get( ) # Convert config into a dictionary (eliminate duplicates from defaulted 'common' section.) - def to_dict(self, deduplicate: bool = True) -> dict: - _dict = {} + def to_dict(self, deduplicate: bool = True) -> Dict[str, Any]: + _dict: Dict[str, Any] = {} for section in self: _dict[section] = {} for setting in self[section]: @@ -160,8 +171,10 @@ def update( warnings.warn( "ub and lb have been defined in common section, ignoring parameter specific blocks, be very careful!" ) - elif "parnames" in self["common"]: # it's possible to pass no parnames - par_names = self.getlist("common", "parnames", element_type=str, fallback = []) + elif "parnames" in self["common"]: # it's possible to pass no parnames + par_names = self.getlist( + "common", "parnames", element_type=str, fallback=[] + ) lb = [None] * len(par_names) ub = [None] * len(par_names) for i, par_name in enumerate(par_names): @@ -174,14 +187,15 @@ def update( self["common"]["lb"] = f"[{', '.join(lb)}]" self["common"]["ub"] = f"[{', '.join(ub)}]" - # Deprecation warning for "experiment" section if "experiment" in self: for i in self["experiment"]: self["common"][i] = self["experiment"][i] del self["experiment"] - def _str_to_list(self, v: str, element_type: _T = float) -> List[_T]: + def _str_to_list( + self, v: str, element_type: Callable[[_T], _T] = float + ) -> List[_T]: v = re.sub(r"\n ", ",", v) v = re.sub(r"(? None: # Checking if param_type is set if "par_type" not in param_block: - raise ValueError(f"Parameter {param_name} is missing the param_type setting.") + raise ValueError( + f"Parameter {param_name} is missing the param_type setting." + ) # Each parameter type has a different set of required settings - if param_block['par_type'] == "continuous": + if param_block["par_type"] == "continuous": # Check if bounds exist if "lower_bound" not in param_block: - raise ValueError(f"Parameter {param_name} is missing the lower_bound setting.") + raise ValueError( + f"Parameter {param_name} is missing the lower_bound setting." + ) if "upper_bound" not in param_block: - raise ValueError(f"Parameter {param_name} is missing the upper_bound setting.") + raise ValueError( + f"Parameter {param_name} is missing the upper_bound setting." + ) else: - raise ValueError(f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}.") - + raise ValueError( + f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." + ) def __repr__(self) -> str: return f"Config at {hex(id(self))}: \n {str(self)}" diff --git a/aepsych/config.pyi b/aepsych/config.pyi index adb078bcf..c57a8a9fc 100644 --- a/aepsych/config.pyi +++ b/aepsych/config.pyi @@ -7,10 +7,24 @@ import abc import configparser -from typing import Any, ClassVar, Dict, List, Mapping, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + TypeVar, + Union, +) import numpy as np import torch +from botorch.models.transforms.input import ( + ChainedInputTransform, + ReversibleInputTransform, +) _T = TypeVar("_T") _ET = TypeVar("_ET") @@ -50,7 +64,7 @@ class Config(configparser.ConfigParser): raw: bool = ..., vars: Optional[Mapping[str, str]] = ..., fallback: _T = ..., - element_type: _ET = ..., + element_type: Callable[[_ET], _ET] = ..., ) -> Union[_T, List[_ET]]: ... def getarray( self, @@ -61,10 +75,29 @@ class Config(configparser.ConfigParser): vars: Optional[Mapping[str, str]] = ..., fallback: _T = ..., ) -> Union[np.ndarray, _T]: ... + def getboolean( + self, + section: str, + option: str, + *, + raw: bool = ..., + vars: Mapping[str, str] | None = ..., + fallback: _T = ..., + ) -> bool | _T: ... + def getfloat( + self, + section: str, + option: str, + *, + raw: bool = ..., + vars: Mapping[str, str] | None = ..., + fallback: _T = ..., + ) -> float | _T: ... @classmethod def register_module(cls: _T, module): ... def jsonifyMetadata(self) -> str: ... def jsonifyAll(self) -> str: ... + def to_dict(self, deduplicate: bool = ...) -> Dict[str, Any]: ... class ConfigurableMixin(abc.ABC): @classmethod diff --git a/aepsych/generators/base.py b/aepsych/generators/base.py index 77fa94c75..480425d09 100644 --- a/aepsych/generators/base.py +++ b/aepsych/generators/base.py @@ -4,19 +4,19 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import abc -from inspect import signature -from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional import re +from inspect import signature +from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar import torch from aepsych.config import Config from aepsych.models.base import AEPsychMixin from botorch.acquisition import ( AcquisitionFunction, - NoisyExpectedImprovement, - qNoisyExpectedImprovement, LogNoisyExpectedImprovement, + NoisyExpectedImprovement, qLogNoisyExpectedImprovement, + qNoisyExpectedImprovement, ) @@ -43,6 +43,9 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]): stimuli_per_trial = 1 max_asks: Optional[int] = None + acqf: AcquisitionFunction + acqf_kwargs: Dict[str, Any] + def __init__( self, ) -> None: @@ -81,7 +84,7 @@ def _get_acqf_options(cls, acqf: AcquisitionFunction, config: Config) -> Dict[st elif re.search( r"^\[.*\]$", v, flags=re.DOTALL ): # use regex to check if the value is a list - extra_acqf_args[k] = config._str_to_list(v) # type: ignore + extra_acqf_args[k] = config._str_to_list(v) # type: ignore else: # otherwise try a float try: diff --git a/aepsych/models/base.py b/aepsych/models/base.py index 0baee322b..feef49456 100644 --- a/aepsych/models/base.py +++ b/aepsych/models/base.py @@ -64,7 +64,7 @@ def bounds(self) -> torch.Tensor: def dim(self) -> int: pass - def posterior(self, x: torch.Tensor) -> GPyTorchPosterior: + def posterior(self, X: torch.Tensor) -> GPyTorchPosterior: pass def predict(self, x: torch.Tensor, **kwargs) -> torch.Tensor: @@ -103,7 +103,9 @@ def update( ) -> None: pass - def p_below_threshold(self, x, f_thresh) -> torch.Tensor: + def p_below_threshold( + self, x: torch.Tensor, f_thresh: torch.Tensor + ) -> torch.Tensor: pass @@ -374,11 +376,11 @@ def _fit_mll( ) return res - def p_below_threshold(self, x: torch.Tensor, f_thresh: torch.Tensor) -> torch.Tensor: + def p_below_threshold(self, x: torch.Tensor, f_thresh: torch.Tensor) -> torch.Tensor: f, var = self.predict(x) f_thresh = f_thresh.reshape(-1, 1) f = f.reshape(1, -1) var = var.reshape(1, -1) - + z = (f_thresh - f) / var.sqrt() - return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent \ No newline at end of file + return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent diff --git a/aepsych/plotting.py b/aepsych/plotting.py index 3fe9fc47a..73461dff3 100644 --- a/aepsych/plotting.py +++ b/aepsych/plotting.py @@ -12,6 +12,7 @@ from matplotlib.axes import Axes import numpy as np +import torch from aepsych.strategy import Strategy from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol from scipy.stats import norm @@ -156,7 +157,7 @@ def _plot_strat_1d( assert x is not None and y is not None, "No data to plot!" if strat.model is not None: - grid = strat.model.dim_grid(gridsize=gridsize) + grid = strat.model.dim_grid(gridsize=gridsize).cpu() samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach()) phimean = samps.mean(0) else: @@ -178,10 +179,17 @@ def _plot_strat_1d( if target_level is not None: from aepsych.utils import interpolate_monotonic + lb = strat.transforms.untransform(strat.lb)[0] + ub = strat.transforms.untransform(strat.ub)[0] + threshold_samps = [ interpolate_monotonic( - grid, s, target_level, strat.lb[0], strat.ub[0] - ).cpu().numpy() + x=grid.squeeze(), + y=s, + z=target_level, + min_x=lb, + max_x=ub, + ) for s in samps ] thresh_med = np.mean(threshold_samps) @@ -201,13 +209,17 @@ def _plot_strat_1d( true_f = true_testfun(grid) ax.plot(grid, true_f.squeeze(), label="True function") if target_level is not None: - true_thresh = interpolate_monotonic( - grid, - true_f.squeeze(), - target_level, - strat.lb[0], - strat.ub[0], - ).cpu().numpy() + true_thresh = ( + interpolate_monotonic( + grid, + true_f.squeeze(), + target_level, + strat.lb[0], + strat.ub[0], + ) + .cpu() + .numpy() + ) ax.plot( true_thresh, @@ -266,25 +278,28 @@ def _plot_strat_2d( else: raise RuntimeError("Cannot plot without a model!") - extent = np.r_[strat.lb[0], strat.ub[0], strat.lb[1], strat.ub[1]] + lb = strat.transforms.untransform(strat.lb) + ub = strat.transforms.untransform(strat.ub) + + extent = np.r_[lb[0], ub[0], lb[1], ub[1]] colormap = ax.imshow( phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5 ) if flipx: - extent = np.r_[strat.lb[0], strat.ub[0], strat.ub[1], strat.lb[1]] + extent = np.r_[lb[0], ub[0], ub[1], lb[1]] colormap = ax.imshow( phimean, aspect="auto", origin="upper", extent=extent, alpha=0.5 ) else: - extent = np.r_[strat.lb[0], strat.ub[0], strat.lb[1], strat.ub[1]] + extent = np.r_[lb[0], ub[0], lb[1], ub[1]] colormap = ax.imshow( phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5 ) # hacky relabel to be in logspace if logx: - locs: np.ndarray = np.arange(strat.lb[0], strat.ub[0]) + locs: np.ndarray = np.arange(lb[0], ub[0]) ax.set_xticks(ticks=locs) ax.set_xticklabels(2.0**locs) @@ -292,8 +307,8 @@ def _plot_strat_2d( ax.plot(x[y == 1, 0], x[y == 1, 1], "bo", alpha=0.7, label=yes_label) if target_level is not None: # plot threshold - mono_grid = np.linspace(strat.lb[1], strat.ub[1], num=gridsize) - context_grid = np.linspace(strat.lb[0], strat.ub[0], num=gridsize) + mono_grid = np.linspace(lb[1], ub[1], num=gridsize) + context_grid = np.linspace(lb[0], ub[0], num=gridsize) thresh_75, lower, upper = get_lse_interval( model=strat.model, mono_grid=mono_grid, @@ -310,14 +325,19 @@ def _plot_strat_2d( label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)", ) ax.fill_between( - context_grid, lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.3, hatch="///", edgecolor="gray" + context_grid, + lower.cpu().numpy(), + upper.cpu().numpy(), + alpha=0.3, + hatch="///", + edgecolor="gray", ) if true_testfun is not None: true_f = true_testfun(grid).reshape(gridsize, gridsize) true_thresh = get_lse_contour( - true_f, mono_grid, level=target_level, lb=strat.lb[-1], ub=strat.ub[-1] - ).cpu().numpy() + true_f, mono_grid, level=target_level, lb=lb[-1], ub=ub[-1] + ) ax.plot(context_grid, true_thresh, label="Ground truth threshold") ax.set_xlabel(xlabel) @@ -379,7 +399,7 @@ def plot_strat_3d( if not isinstance(contour_levels_list, Sized): raise TypeError("contour_levels_list must be Sized (e.g., a list or an array).") - + # slice_vals is either a list of values or an integer number of values to slice on if isinstance(slice_vals, int): slices = np.linspace(strat.lb[slice_dim], strat.ub[slice_dim], slice_vals) @@ -388,14 +408,13 @@ def plot_strat_3d( raise TypeError("slice_vals must be either an integer or a list of values") else: slices = np.array(slice_vals) - - # make mypy happy, note that this can't be more specific + + # make mypy happy, note that this can't be more specific # because of https://github.com/numpy/numpy/issues/24738 - axs: np.ndarray + axs: np.ndarray[Any, Any] _, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) # type: ignore assert len(slices) > 1, "Must have at least 2 slices" - for _i, dim_val in enumerate(slices): img = plot_slice( axs[_i], diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 81d5c071c..f3dc8f62b 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -22,6 +22,7 @@ from aepsych.utils import _process_bounds, make_scaled_sobol from aepsych.utils_logging import getLogger from botorch.exceptions.errors import ModelFittingError +from botorch.models.transforms.input import ChainedInputTransform logger = getLogger() @@ -56,7 +57,7 @@ def __init__( lb: Union[np.ndarray, torch.Tensor], ub: Union[np.ndarray, torch.Tensor], stimuli_per_trial: int, - outcome_types: Sequence[Type[str]], + outcome_types: List[str], dim: Optional[int] = None, min_total_tells: int = 0, min_asks: int = 0, @@ -68,6 +69,7 @@ def __init__( min_post_range: Optional[float] = None, name: str = "", run_indefinitely: bool = False, + transforms: ChainedInputTransform = ChainedInputTransform(**{}), ) -> None: """Initialize the strategy object. @@ -93,6 +95,11 @@ def __init__( name (str): The name of the strategy. Defaults to the empty string. run_indefinitely (bool): If true, the strategy will run indefinitely until finish() is explicitly called. Other stopping criteria will be ignored. Defaults to False. + transforms (ReversibleInputTransform, optional): Transforms + to apply parameters. This is immediately applied to lb/ub, thus lb/ub + should be defined in raw parameter space for initialization. However, + if the lb/ub attribute are access from an initialized Strategy object, + it will be returned in transformed space. """ self.is_finished = False @@ -129,6 +136,11 @@ def __init__( self.max_asks = max_asks or generator.max_asks self.keep_most_recent = keep_most_recent + self.transforms = transforms + if self.transforms is not None: + self.lb = self.transforms.transform(self.lb) + self.ub = self.transforms.transform(self.ub) + self.min_post_range = min_post_range if self.min_post_range is not None: assert model is not None, "min_post_range must be None if model is None!" @@ -136,6 +148,12 @@ def __init__( lb=self.lb, ub=self.ub, size=self._n_eval_points ) + # this grid needs to be in untransformed space because it goes through a + # transform wrapped model + if self.transforms is not None: + self.eval_grid = self.transforms.untransform(self.eval_grid) + + # similar to ub/lb/grid, x is in raw parameter space self.x: Optional[torch.Tensor] = None self.y: Optional[torch.Tensor] = None self.n: int = 0 @@ -339,7 +357,7 @@ def fit(self) -> None: if self.can_fit: if self.keep_most_recent is not None: try: - + self.model.fit( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore @@ -359,10 +377,10 @@ def fit(self) -> None: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) def update(self) -> None: - + if self.can_fit: if self.keep_most_recent is not None: - try: + try: self.model.update( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore @@ -387,17 +405,14 @@ def from_config(cls, config: Config, name: str) -> Strategy: ub = config.gettensor(name, "ub") dim = config.getint(name, "dim", fallback=None) + transforms = ParameterTransforms.from_config(config) + stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - generator = gen_cls.from_config(config) + generator = GeneratorWrapper.from_config(name, config) - model_cls = config.getobj(name, "model", fallback=None) - if model_cls is not None: - model = model_cls.from_config(config) - else: - model = None + model = ModelWrapper.from_config(name, config) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): @@ -440,6 +455,7 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial=stimuli_per_trial, outcome_types=outcome_types, dim=dim, + transforms=transforms, model=model, generator=generator, min_asks=min_asks, diff --git a/aepsych/transforms/__init__.py b/aepsych/transforms/__init__.py new file mode 100644 index 000000000..3ce58f7e8 --- /dev/null +++ b/aepsych/transforms/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .parameters import ( + GeneratorWrapper, + ModelWrapper, + ParameterTransforms, + transform_options, +) + +__all__ = [ + "GeneratorWrapper", + "ModelWrapper", + "ParameterTransforms", + "transform_options", +] diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py new file mode 100644 index 000000000..685bb919d --- /dev/null +++ b/aepsych/transforms/parameters.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import ast +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, List, Literal, Optional, Type + +import numpy as np +import torch +from aepsych.config import Config +from aepsych.generators import SobolGenerator +from aepsych.generators.base import AEPsychGenerator +from aepsych.models.base import AEPsychMixin, ModelProtocol +from botorch.acquisition import AcquisitionFunction +from botorch.models.transforms.input import ChainedInputTransform, Log10 +from botorch.models.transforms.utils import subset_transform +from botorch.posteriors import Posterior +from torch import Tensor + +_TRANSFORMABLE = [ + "lb", + "ub", + "points", + "window", +] + + +class ParameterTransforms(ChainedInputTransform): + """ + Holds set of transformations to be applied to parameters. The ParameterTransform + objects can be used by themselves to transform values or can be passed to Generator + or Model wrappers to consistently transform parameters. ParameterTransforms can + transform values into transformed space and also untransform values from transformed + space back into raw space. + """ + + @classmethod + def from_config(cls, config: Config): + parnames: List[str] = config.getlist("common", "parnames", element_type=str) + transformDict = {} + for i, par in enumerate(parnames): + # This is the order that transforms are potentially applied, order matters + + # Log scale + if config.getboolean(par, "log_scale", fallback=False): + lb = config.getfloat(par, "lower_bound") + if lb < 0.0: + transformDict[f"{par}_Log10Plus"] = Log10Plus( + indices=[i], constant=np.abs(lb) + 1.0 + ) + elif lb < 1.0: + transformDict[f"{par}_Log10Plus"] = Log10Plus( + indices=[i], constant=1.0 + ) + else: + transformDict[f"{par}_Log10"] = Log10(indices=[i]) + + return cls(**transformDict) + + +class ParameterTransformWrapper(ABC): + """ + Abstract base class for parameter transform wrappers. __getattr__ is overridden to + allow base object attributes to be surfaced smoothly. Methods that require the + transforms should be overridden in the wrapper class to apply the transform + operations. + """ + + transforms: ChainedInputTransform + _base_obj: object = None + + def __getattr__(self, name): + return getattr(self._base_obj, name) + + @classmethod + @abstractmethod + def from_config(cls, name: str, config: Config): + pass + + +class GeneratorWrapper(ParameterTransformWrapper): + _base_obj: AEPsychGenerator + + def __init__( + self, + generator: Type | AEPsychGenerator, + transforms: ChainedInputTransform = ChainedInputTransform(**{}), + **kwargs, + ) -> None: + f""" + Wraps a Generator with parameter transforms. This will transform any relevant + generator arguments (e.g., bounds) to be transformed into the transformed space + and ensure all generator outputs to be untransformed into raw space. The wrapper + surfaces critical components of the API of the generator such that the wrapper + can be used much like the raw generator. + + Bounds are returned in the transformed space, this is necessary to handle + parameters that would not have sensible raw parameter space. If bounds are + manually set (e.g., `Wrapper(**kwargs).lb = lb)`, ensure that they are + correctly transformed and in a correctly shaped Tensor. If the bounds are + being set in init (e.g., `Wrapper(Type, lb=lb, ub=ub)`, `lb` and `ub` + should be in the raw parameter space. + + Args: + model (Type | AEPsychGenerator): Generator to wrap, this could either be a + completely initialized generator or just the generator class. An + initialized generator is expected to have been initialized in the + transformed parameter space (i.e., bounds are transformed). If a + generator class is passed, **kwargs will be used to initialize the + generator, note that the bounds are expected to be in raw parameter + space, thus the transforms are applied to it. + transforms (ChainedInputTransform, optional): A set of transforms to apply + to parameters of this generator. If no transforms are passed, it will + default to an identity transform. + """ + # Figure out what we need to do with generator + if isinstance(generator, type): + if "lb" in kwargs: + kwargs["lb"] = transforms.transform(kwargs["lb"].float()) + if "ub" in kwargs: + kwargs["ub"] = transforms.transform(kwargs["ub"].float()) + _base_obj = generator(**kwargs) + else: + _base_obj = generator + + self._base_obj = _base_obj + self.transforms = transforms + + # This lets us emit we're the class we're wrapping + self.__class__ = type( + f"ParameterTransformed{_base_obj.__class__.__name__}", + (self.__class__, _base_obj.__class__), + {}, + ) + + def gen(self, num_points: int, model: Optional[AEPsychMixin] = None) -> Tensor: + x = self._base_obj.gen(num_points, model) + return self.transforms.untransform(x) + + @property + def acqf(self) -> AcquisitionFunction | None: + return self._base_obj.acqf + + @acqf.setter + def acqf(self, value: AcquisitionFunction): + self._base_obj.acqf = value + + @property + def acqf_kwargs(self) -> dict | None: + return self._base_obj.acqf_kwargs + + @acqf_kwargs.setter + def acqf_kwargs(self, value: dict): + self._base_obj.acqf_kwargs = value + + @classmethod + def from_config( + cls, + name: str, + config: Config, + ): + gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) + transforms = ParameterTransforms.from_config(config) + + # We need transformed values from config but we don't want to edit config + transformed_config = transform_options(config) + + gen = gen_cls.from_config(transformed_config) + + return cls(gen, transforms) + + def _get_acqf_options(self, acqf: AcquisitionFunction, config: Config): + return self._base_obj._get_acqf_options(acqf, config) + + +class ModelWrapper(ParameterTransformWrapper): + _base_obj: ModelProtocol + + def __init__( + self, + model: Type | ModelProtocol, + transforms: ChainedInputTransform = ChainedInputTransform(**{}), + **kwargs, + ) -> None: + f""" + Wraps a Model with parameter transforms. This will transform any relevant + model arguments (e.g., bounds) and model data (e.g., training data, x) to be + transformed into the transformed space. The wrapper surfaces the API of the + raw model such that the wrapper can be used like a raw model. + + Bounds are returned in the transformed space, this is necessary to handle + parameters that would not have sensible raw parameter space. If bounds are + manually set (e.g., `Wrapper(**kwargs).lb = lb)`, ensure that they are + correctly transformed and in a correctly shaped Tensor. If the bounds are + being set in init (e.g., `Wrapper(Type, lb=lb, ub=ub)`, `lb` and `ub` + should be in the raw parameter space. + + Args: + model (Type | ModelProtocol): Model to wrap, this could either be a + completely initialized model or just the model class. An initialized + model is expected to have been initialized in the transformed + parameter space (i.e., bounds are transformed). If a model class is + passed, **kwargs will be used to initialize the model. Note that the + bounds in this case are expected to be in raw parameter space, thus the + transforms are applied to it. + transforms (ChainedInputTransform, optional): A set of transforms to apply + to parameters of this model. If no transforms are passed, it will + default to an identity transform. + """ + # Alternative instantiation method for analysis (and not live) + if isinstance(model, type): + if "lb" in kwargs: + kwargs["lb"] = transforms.transform(kwargs["lb"].float()) + if "ub" in kwargs: + kwargs["ub"] = transforms.transform(kwargs["ub"].float()) + _base_obj = model(**kwargs) + else: + _base_obj = model + + self._base_obj = _base_obj + self.transforms = transforms + + # This lets us emit we're the class we're wrapping + self.__class__ = type( + f"ParameterTransformed{_base_obj.__class__.__name__}", + (self.__class__, _base_obj.__class__), + {}, + ) + + def predict(self, x: Tensor, **kwargs) -> Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) + x = self.transforms.transform(x) + return self._base_obj.predict(x, **kwargs) + + def predict_probability(self, x: Tensor, **kwargs) -> Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) + x = self.transforms.transform(x) + return self._base_obj.predict_probability(x, **kwargs) + + def sample(self, x: Tensor, num_samples: int) -> Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) + x = self.transforms.transform(x) + return self._base_obj.sample(x, num_samples) + + def dim_grid(self, gridsize: int = 30) -> Tensor: + grid = self._base_obj.dim_grid(gridsize) + return self.transforms.untransform(grid) + + def posterior(self, X: Tensor, **kwargs) -> Posterior: + # This ensures X is a tensor with the right shape + X = Tensor(X) + return self._base_obj.posterior(X=X, **kwargs) + + def fit(self, train_x: Tensor, train_y: Tensor, **kwargs: Any) -> None: + if len(train_x.shape) == 1: + train_x = train_x.unsqueeze(-1) + train_x = self.transforms.transform(train_x) + self._base_obj.fit(train_x, train_y, **kwargs) + + def update(self, train_x: Tensor, train_y: Tensor, **kwargs: Any) -> None: + if len(train_x.shape) == 1: + train_x = train_x.unsqueeze(-1) + train_x = self.transforms.transform(train_x) + self._base_obj.update(train_x, train_y, **kwargs) + + def p_below_threshold(self, x: Tensor, f_thresh: torch.Tensor) -> torch.Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) + x = self.transforms.transform(x) + return self._base_obj.p_below_threshold(x, f_thresh) + + @classmethod + def from_config( + cls, + name: str, + config: Config, + ): + # We don't always have models + model_cls = config.getobj(name, "model", fallback=None) + if model_cls is None: + return None + + transforms = ParameterTransforms.from_config(config) + + # Need transformed values + transformed_config = transform_options(config) + + model = model_cls.from_config(transformed_config) + + return cls(model, transforms) + + +def transform_options(config: Config) -> Config: + """ + Return a copy of the config with the options transformed. The config + """ + transforms = ParameterTransforms.from_config(config) + + configClone = deepcopy(config) + + # Can't use self.sections() to avoid default section behavior + for section, options in config.to_dict().items(): + for option, value in options.items(): + if option in _TRANSFORMABLE: + value = ast.literal_eval(value) + value = np.array(value, dtype=float) + value = torch.tensor(value).to(torch.float64) + + value = transforms.transform(value) + + def _arr_to_list(iter): + if hasattr(iter, "__iter__"): + iter = list(iter) + iter = [_arr_to_list(element) for element in iter] + return iter + return iter + + # Recursively turn back into str + configClone[section][option] = str(_arr_to_list(value.numpy())) + + return configClone + + +class Log10Plus(Log10): + r"""Base-10 log transform that we add a constant to the values""" + + def __init__( + self, + indices: list[int], + constant: float = 1.0, + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + ) -> None: + r"""Initalize transform + + Args: + indices: The indices of the inputs to log transform. + constant: The constant to add to inputs before log transforming. Default: 1.0 + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + reverse: A boolean indicating whether the forward pass should untransform + the inputs. + """ + super().__init__( + indices=indices, + transform_on_train=transform_on_train, + transform_on_eval=transform_on_eval, + transform_on_fantasize=transform_on_fantasize, + reverse=reverse, + ) + self.register_buffer("constant", torch.tensor(constant, dtype=torch.long)) + + @subset_transform + def _transform(self, X: Tensor) -> Tensor: + r"""Add the constant then log transform the inputs. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of transformed inputs. + """ + X = X + (torch.ones_like(X) * self.constant) + return X.log10() + + @subset_transform + def _untransform(self, X: Tensor) -> Tensor: + r"""Reverse the log transformation then subtract the constant. + + Args: + X: A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of untransformed inputs. + """ + X = 10.0**X + return X - (torch.ones_like(X) * self.constant) diff --git a/aepsych/utils.py b/aepsych/utils.py index 5bce10ce4..420240fd7 100644 --- a/aepsych/utils.py +++ b/aepsych/utils.py @@ -143,6 +143,9 @@ def get_lse_interval( dim=-1 ).reshape(-1, model.dim) + if model.transforms is not None: + xgrid = model.transforms.untransform(xgrid) + samps = model.sample(xgrid, num_samples=n_samps, **kwargs) samps = [s.reshape((gridsize,) * model.dim) for s in samps] diff --git a/configs/parameter_settings_example.ini b/configs/parameter_settings_example.ini index 162126f6e..040060f94 100644 --- a/configs/parameter_settings_example.ini +++ b/configs/parameter_settings_example.ini @@ -6,9 +6,10 @@ target = 0.75 strategy_names = [init_strat, opt_strat] [contPar] -par_type = continuous -lower_bound = 0 -upper_bound = 1 +par_type = continuous # we only support continuous right now +lower_bound = 0 # lower bound for this parameter in raw parameter space +upper_bound = 1 # upper bound for this parameter in raw parameter space +log_scale = True # this parameter will be transformed to log-scale space for the model # Strategy blocks below [init_strat] diff --git a/docs/parameters.md b/docs/parameters.md new file mode 100644 index 000000000..4aa0157d1 --- /dev/null +++ b/docs/parameters.md @@ -0,0 +1,60 @@ +--- +id: parameters +title: Advanced Parameter Configuration +--- + +This page provides an overview of additional controls for parameters, including +parameter transformations. Generally, parameters should be defined in the natural raw +parameter space and AEPsych will handle transforming the parameters into a form usable +by the models. This means that the server will always suggest parameters in response to +an `ask` in raw parameter space and you should always `tell` the server the +results of a trial also in the raw parameter space. This remains true no matter +what parameter types are used and whatever transformations are used. + +