From dd322b31a712461070a07c79fb169a78f0d099bd Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Fri, 1 Nov 2024 13:55:57 -0700 Subject: [PATCH] pass transforms around instead of making duplicates (#416) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/416 Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects. Differential Revision: D65155103 --- aepsych/strategy.py | 4 ++-- aepsych/transforms/parameters.py | 23 +++++++++++++++-------- tests/test_transforms.py | 8 ++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 91af1fb92..e144a36e3 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -251,7 +251,8 @@ def from_config( cls, name: str, config: Config, - ) -> "GeneratorWrapper": + transforms: Optional[ChainedInputTransform] = None, + ) -> Any: """Returns a generator wrapped by GeneratorWrapper that will transform its inputs and outputs automatically with the same API as the underlying generator. @@ -264,10 +265,11 @@ def from_config( class's name will be ParameterWrapped. """ gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -465,7 +467,8 @@ def from_config( cls, name: str, config: Config, - ) -> "ModelWrapper": + transforms: Optional[ChainedInputTransform] = None, + ) -> Optional[Any]: """Returns a model wrapped by ModelWrapper that will transform its inputs and outputs automatically with the same API as the underlying model. @@ -482,17 +485,20 @@ def from_config( if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. @@ -502,7 +508,8 @@ def transform_options(config: Config) -> Config: Returns: Config: A copy of the original config with relevant options transformed. """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d7326b01..40caf71a2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -101,6 +101,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):