From 010af92af6ded063b7f3615be83bdf63209df151 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Thu, 31 Oct 2024 22:25:51 -0700 Subject: [PATCH] pass transforms around instead of making duplicates (#416) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/416 Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects. Differential Revision: D65155103 --- aepsych/strategy.py | 4 ++-- aepsych/transforms/parameters.py | 19 +++++++++++++------ tests/test_transforms.py | 8 ++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index b2a4f9bf6..59e1680d7 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -249,6 +249,7 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ) -> "GeneratorWrapper": """Returns a generator wrapped by GeneratorWrapper that will transform its inputs and outputs automatically with the same API as the underlying generator. @@ -262,10 +263,11 @@ def from_config( class's name will be ParameterWrapped. """ gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -463,6 +465,7 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ) -> "ModelWrapper": """Returns a model wrapped by ModelWrapper that will transform its inputs and outputs automatically with the same API as the underlying model. @@ -480,17 +483,20 @@ def from_config( if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. @@ -500,7 +506,8 @@ def transform_options(config: Config) -> Config: Returns: Config: A copy of the original config with relevant options transformed. """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3514ce62f..134190117 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -100,6 +100,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):