From 8f1bdda3e9185931a2ed5f5fcf53cc834acddb3a Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Fri, 1 Nov 2024 16:04:07 -0700 Subject: [PATCH] pass transforms around instead of making duplicates (#416) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/416 Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects. Differential Revision: D65155103 --- aepsych/strategy.py | 4 ++-- aepsych/transforms/parameters.py | 29 +++++++++++++++++++++-------- tests/test_transforms.py | 8 ++++++++ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 15f349ed8..4e9b86967 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -186,7 +186,7 @@ def __init__( transforms (ChainedInputTransform, optional): A set of transforms to apply to parameters of this generator. If no transforms are passed, it will default to an identity transform. - **kwargs: Keyword arguments to pass to the model to initialize it if model + **kwargs: Keyword arguments to pass to the model to initialize it if model is a class. """ # Figure out what we need to do with generator @@ -251,6 +251,7 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ) -> Any: """Returns a generator wrapped by GeneratorWrapper that will transform its inputs and outputs automatically with the same API as the underlying generator. @@ -258,16 +259,20 @@ def from_config( Args: name (str): Name of the generator class to look for in the config. config (Config): The config to build the Generator from. + transforms (ParameterTransforms, optional): Parameter transforms to wrap the + generator in. If not set, a new ParameterTransforms instance will be made + from the config. Returns: GeneratorWrapper: A configured Generator wrapped based on the config. The class's name will be ParameterWrapped. """ gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -315,7 +320,7 @@ def __init__( transforms (ChainedInputTransform, optional): A set of transforms to apply to parameters of this model. If no transforms are passed, it will default to an identity transform. - **kwargs: Keyword arguments to be passed to the model if the model is a + **kwargs: Keyword arguments to be passed to the model if the model is a class. """ # Alternative instantiation method for analysis (and not live) @@ -465,6 +470,7 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ParameterTransforms] = None, ) -> Any: """Returns a model wrapped by ModelWrapper that will transform its inputs and outputs automatically with the same API as the underlying model. @@ -472,6 +478,9 @@ def from_config( Args: name (str): Name of the Model class to look for in the config. config (Config): The config to build the Model from. + transforms (ParameterTransforms, optional): Parameter transforms to wrap the + model in. If not set, a new ParameterTransforms instance will be made + from the config. Returns: ModelWrapper: A configured Model wrapped based on the config. The @@ -482,17 +491,20 @@ def from_config( if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. @@ -502,7 +514,8 @@ def transform_options(config: Config) -> Config: Returns: Config: A copy of the original config with relevant options transformed. """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d7326b01..40caf71a2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -101,6 +101,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):