diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 15f349ed8..cb4cb5c22 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -186,7 +186,7 @@ def __init__( transforms (ChainedInputTransform, optional): A set of transforms to apply to parameters of this generator. If no transforms are passed, it will default to an identity transform. - **kwargs: Keyword arguments to pass to the model to initialize it if model + **kwargs: Keyword arguments to pass to the model to initialize it if model is a class. """ # Figure out what we need to do with generator @@ -251,6 +251,7 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ) -> Any: """Returns a generator wrapped by GeneratorWrapper that will transform its inputs and outputs automatically with the same API as the underlying generator. @@ -264,10 +265,11 @@ def from_config( class's name will be ParameterWrapped. """ gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -315,7 +317,7 @@ def __init__( transforms (ChainedInputTransform, optional): A set of transforms to apply to parameters of this model. If no transforms are passed, it will default to an identity transform. - **kwargs: Keyword arguments to be passed to the model if the model is a + **kwargs: Keyword arguments to be passed to the model if the model is a class. """ # Alternative instantiation method for analysis (and not live) @@ -482,17 +484,20 @@ def from_config( if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. @@ -502,7 +507,8 @@ def transform_options(config: Config) -> Config: Returns: Config: A copy of the original config with relevant options transformed. """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d7326b01..40caf71a2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -101,6 +101,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):