Open
Description
We should be able to set the modules that are subject to activation checkpointing via the config
Currently, it is hardcoded:
def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module):
return isinstance(submodule, GPT2Block)
see: https://github.com/Modalities/modalities/blob/main/src/modalities/activation_checkpointing.py#L15