Skip to content

Commit

Permalink
pass transforms around instead of making duplicates (facebookresearch…
Browse files Browse the repository at this point in the history
…#416)

Summary:
Pull Request resolved: facebookresearch#416

Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects.

Differential Revision: D65155103
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 1, 2024
1 parent 36bff86 commit 6ffadda
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
4 changes: 2 additions & 2 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy:
stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1)
outcome_types = config.getlist(name, "outcome_types", element_type=str)

generator = GeneratorWrapper.from_config(name, config)
generator = GeneratorWrapper.from_config(name, config, transforms)

model = ModelWrapper.from_config(name, config)
model = ModelWrapper.from_config(name, config, transforms)

acqf_cls = config.getobj(name, "acqf", fallback=None)
if acqf_cls is not None and hasattr(generator, "acqf"):
Expand Down
23 changes: 15 additions & 8 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def from_config(
cls,
name: str,
config: Config,
) -> "GeneratorWrapper":
transforms: Optional[ChainedInputTransform] = None,
) -> Any:
"""Returns a generator wrapped by GeneratorWrapper that will transform its inputs
and outputs automatically with the same API as the underlying generator.
Expand All @@ -264,10 +265,11 @@ def from_config(
class's name will be ParameterWrapped<Generator.__name__>.
"""
gen_cls = config.getobj(name, "generator", fallback=SobolGenerator)
transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

# We need transformed values from config but we don't want to edit config
transformed_config = transform_options(config)
transformed_config = transform_options(config, transforms)

gen = gen_cls.from_config(transformed_config)

Expand Down Expand Up @@ -465,7 +467,8 @@ def from_config(
cls,
name: str,
config: Config,
) -> "ModelWrapper":
transforms: Optional[ChainedInputTransform] = None,
) -> Optional[Any]:
"""Returns a model wrapped by ModelWrapper that will transform its inputs
and outputs automatically with the same API as the underlying model.
Expand All @@ -482,17 +485,20 @@ def from_config(
if model_cls is None:
return None

transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

# Need transformed values
transformed_config = transform_options(config)
transformed_config = transform_options(config, transforms)

model = model_cls.from_config(transformed_config)

return cls(model, transforms)


def transform_options(config: Config) -> Config:
def transform_options(
config: Config, transforms: Optional[ChainedInputTransform] = None
) -> Config:
"""
Return a copy of the config with the options transformed.
Expand All @@ -502,7 +508,8 @@ def transform_options(config: Config) -> Config:
Returns:
Config: A copy of the original config with relevant options transformed.
"""
transforms = ParameterTransforms.from_config(config)
if transforms is None:
transforms = ParameterTransforms.from_config(config)

configClone = deepcopy(config)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def test_model_init_equivalent(self):

def test_transforms_in_strategy(self):
for _strat in self.strat.strat_list:
# Check if the same transform is passed around everywhere
self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms))
if _strat.model is not None:
self.assertTrue(
id(_strat.generator.transforms) == id(_strat.model.transforms)
)

# Check all the transform bits are the same
for strat_transform, gen_transform in zip(
_strat.transforms.items(), _strat.generator.transforms.items()
):
Expand Down

0 comments on commit 6ffadda

Please sign in to comment.