Skip to content

Commit

Permalink
pass transforms around instead of making duplicates (facebookresearch…
Browse files Browse the repository at this point in the history
…#416)

Summary:

Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects.

Differential Revision: D65155103
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 30, 2024
1 parent e7e648b commit b85b9fb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
4 changes: 2 additions & 2 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy:
stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1)
outcome_types = config.getlist(name, "outcome_types", element_type=str)

generator = GeneratorWrapper.from_config(name, config)
generator = GeneratorWrapper.from_config(name, config, transforms)

model = ModelWrapper.from_config(name, config)
model = ModelWrapper.from_config(name, config, transforms)

acqf_cls = config.getobj(name, "acqf", fallback=None)
if acqf_cls is not None and hasattr(generator, "acqf"):
Expand Down
19 changes: 13 additions & 6 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,14 @@ def from_config(
cls,
name: str,
config: Config,
transforms: Optional[ChainedInputTransform] = None,
):
gen_cls = config.getobj(name, "generator", fallback=SobolGenerator)
transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

# We need transformed values from config but we don't want to edit config
transformed_config = transform_options(config)
transformed_config = transform_options(config, transforms)

gen = gen_cls.from_config(transformed_config)

Expand Down Expand Up @@ -281,27 +283,32 @@ def from_config(
cls,
name: str,
config: Config,
transforms: Optional[ChainedInputTransform] = None,
):
# We don't always have models
model_cls = config.getobj(name, "model", fallback=None)
if model_cls is None:
return None

transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

# Need transformed values
transformed_config = transform_options(config)
transformed_config = transform_options(config, transforms)

model = model_cls.from_config(transformed_config)

return cls(model, transforms)


def transform_options(config: Config) -> Config:
def transform_options(
config: Config, transforms: Optional[ChainedInputTransform] = None
) -> Config:
"""
Return a copy of the config with the options transformed. The config
"""
transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

configClone = deepcopy(config)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def test_model_init_equivalent(self):

def test_transforms_in_strategy(self):
for _strat in self.strat.strat_list:
# Check if the same transform is passed around everywhere
self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms))
if _strat.model is not None:
self.assertTrue(
id(_strat.generator.transforms) == id(_strat.model.transforms)
)

# Check all the transform bits are the same
for strat_transform, gen_transform in zip(
_strat.transforms.items(), _strat.generator.transforms.items()
):
Expand Down

0 comments on commit b85b9fb

Please sign in to comment.