Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM callback step/epoch skipping configuration #284

Merged
merged 15 commits into from
Sep 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,15 @@ def on_exception(


class SAM(Callback):
def __init__(self, rho: float = 0.05, adaptive: bool = False) -> None:
def __init__(
self,
rho: float = 0.05,
adaptive: bool = False,
skip_step_count: int | None = None,
skip_epoch_count: int | float | None = None,
logging: bool = False,
log_level: Literal["DEBUG", "INFO", "WARNING"] = "WARNING",
) -> None:
"""
Set up the ``SAM (Sharpness Aware Minimization)`` callback.
https://arxiv.org/abs/2010.01412
Expand All @@ -734,6 +742,27 @@ def __init__(self, rho: float = 0.05, adaptive: bool = False) -> None:
adaptive : bool
A boolean flag indicating whether to adaptively normalize weights.
Defaults to False.
skip_step_count : int | None, default None
Specifies an integer number of steps to skip before SAM is actually
in effect. By default is set to None, which starts SAM from the
first steps. Mutually exclusive with ``skip_epoch_count``.
skip_epoch_count : int | float | None, default None
Specifies the number of epochs to skip before SAM is in effect.
If an integer is passed, this corresponds to the exact epoch
count to wait before SAM is used. If a float between [0,1]
is passed, this corresponds to the fraction of the ``trainer.max_epochs``
to wait before SAM triggers. The default setting, None, will not
wait any epochs before invoking SAM. Mutually exclusive with
``skip_step_count``.
logging : bool, default False
If set to True, logs the behavior of SAM to console. This is useful
for debugging.
log_level: Literal["INFO", "DEBUG", "WARNING"], default "WARNING"
Sets the logging level if logging is specified. By default the
level is set to warnings, which will not report when SAM is not running
but still warn the user when the gradient norm is smaller than
machine epsilon. Set the level to "INFO" or lower if you wish to
check when SAM is *not* running.

Examples
--------
Expand All @@ -747,6 +776,46 @@ def __init__(self, rho: float = 0.05, adaptive: bool = False) -> None:
super().__init__()
self.rho = rho
self.adaptive = adaptive
if skip_epoch_count and skip_epoch_count:
raise ValueError(
"`skip_epoch_count` and `skip_step_count` are mutually exclusive for SAM."
)
self.skip_step_count = skip_step_count
if skip_epoch_count and not skip_epoch_count.is_integer():
assert (
0 < skip_epoch_count < 1.0
), "Decimal `skip_epoch_count` passed not within [0,1]."
self.skip_epoch_count = skip_epoch_count
if logging:
self.logger = getLogger("matsciml.callbacks.SAM")
self.logger.setLevel(log_level)
else:
self.logger = None

def on_fit_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
# determine the maximum number of epochs
self.max_epochs = trainer.max_epochs
# if it's been specified explicitly, just use that
if trainer.max_steps > -1:
self.max_steps = trainer.max_steps
else:
# work out the total number of expected steps
if not trainer.train_dataloader:
trainer.fit_loop.setup_data()
train_len = len(trainer.train_dataloader)
self.max_steps = train_len * self.max_epochs
# if a fractional epoch skip is specified, convert it to
# an integer count for easier comparison
if self.skip_epoch_count and not self.skip_epoch_count.is_integer():
self.skip_epoch_count = int(self.max_epochs * self.skip_epoch_count)
# add floating point epsilon for later use
self.epsilon = torch.tensor(
[torch.finfo(pl_module.dtype).eps],
dtype=pl_module.dtype,
device=pl_module.device,
)

@staticmethod
def _get_params(optimizer: Optimizer) -> Iterator[torch.Tensor]:
Expand All @@ -773,6 +842,24 @@ def on_train_batch_start(
) -> None:
self.batch = batch
self.batch_idx = batch_idx
# add flag to determine if we should run SAM on this step
current_step = trainer.global_step
current_epoch = trainer.current_epoch
# by default we start SAM, and toggle off if conditions met
start_sam = True
if self.skip_epoch_count and current_epoch < self.skip_epoch_count:
start_sam = False
if self.logger:
self.logger.info(
"Required number of epochs not met; not running SAM yet."
)
if self.skip_step_count and current_step < self.skip_step_count:
start_sam = False
if self.logger:
self.logger.info(
"Required number of steps not met; not running SAM yet."
)
self.start_sam = start_sam

def extract_optimizer_specific_loss(self, task, optimizer, loss):
optimizer_names = copy(task.optimizer_names)
Expand Down Expand Up @@ -814,21 +901,25 @@ def on_before_optimizer_step(
task: BaseTaskModule,
optimizer: Optimizer,
) -> None:
optimizer_is_used = self.is_optimizer_used(task, optimizer)
if optimizer_is_used:
with torch.no_grad():
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
# this is for the multitask case where there is more than on optimizer
if not isinstance(task.optimizers(), Optimizer):
loss = self.extract_optimizer_specific_loss(task, optimizer, loss)
loss = self._get_loss(loss)
if loss is not None:
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
with torch.no_grad():
self._second_step(optimizer, org_weights)
# check if SAM should have started yet before going through loop
if self.start_sam:
optimizer_is_used = self.is_optimizer_used(task, optimizer)
if optimizer_is_used:
with torch.no_grad():
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
# this is for the multitask case where there is more than on optimizer
if not isinstance(task.optimizers(), Optimizer):
loss = self.extract_optimizer_specific_loss(
task, optimizer, loss
)
loss = self._get_loss(loss)
if loss is not None:
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
with torch.no_grad():
self._second_step(optimizer, org_weights)

def _norm_weights(self, p: torch.Tensor) -> torch.Tensor:
return torch.abs(p) if self.adaptive else torch.ones_like(p)
Expand All @@ -847,7 +938,13 @@ def _first_step(self, optimizer: Optimizer) -> Dict[torch.Tensor, torch.Tensor]:
"""
org_weights dictionary stores original weights and perturbed weights
"""
scale = self.rho / (self._grad_norm(optimizer) + 1e-5)
# take the larger value of the two; hopefully not epsilon!
grad_norm = torch.maximum(self._grad_norm(optimizer), self.epsilon)
if grad_norm == self.epsilon and self.logger:
self.logger.warning(
f"Gradient norm smaller than machine epsilon at batch number {self.batch_idx}."
)
scale = self.rho / grad_norm
org_weights: Dict[torch.Tensor, torch.Tensor] = {}
for p in self._get_params(optimizer):
if p.grad is None:
Expand Down
Loading