diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 91af1fb92..e144a36e3 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -251,7 +251,8 @@ def from_config( cls, name: str, config: Config, - ) -> "GeneratorWrapper": + transforms: Optional[ChainedInputTransform] = None, + ) -> Any: """Returns a generator wrapped by GeneratorWrapper that will transform its inputs and outputs automatically with the same API as the underlying generator. @@ -264,10 +265,11 @@ def from_config( class's name will be ParameterWrapped. """ gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -465,7 +467,8 @@ def from_config( cls, name: str, config: Config, - ) -> "ModelWrapper": + transforms: Optional[ChainedInputTransform] = None, + ) -> Optional[Any]: """Returns a model wrapped by ModelWrapper that will transform its inputs and outputs automatically with the same API as the underlying model. @@ -482,17 +485,20 @@ def from_config( if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. @@ -502,7 +508,8 @@ def transform_options(config: Config) -> Config: Returns: Config: A copy of the original config with relevant options transformed. """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d7326b01..40caf71a2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -101,6 +101,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):