From fb39cf392269eb8d628dc4187b48a5b5e89a849f Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 30 Oct 2024 15:16:04 -0700 Subject: [PATCH] pass transforms around instead of making duplicates (#416) Summary: Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects. Differential Revision: D65155103 --- aepsych/strategy.py | 4 ++-- aepsych/transforms/parameters.py | 19 +++++++++++++------ tests/test_transforms.py | 8 ++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index f3dc8f62b..d8f462d3f 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 685bb919d..e55aff07b 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -162,12 +162,14 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ): gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -281,27 +283,32 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ): # We don't always have models model_cls = config.getobj(name, "model", fallback=None) if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. The config """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3514ce62f..134190117 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -100,6 +100,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):