diff --git a/examples/callbacks/SAM.py b/examples/callbacks/SAM.py index dbb215b4..93f1c150 100644 --- a/examples/callbacks/SAM.py +++ b/examples/callbacks/SAM.py @@ -36,6 +36,8 @@ }, ) -# run a quick training loop -trainer = pl.Trainer(fast_dev_run=1000, callbacks=[SAM()]) +# run a quick training loop, skipping the first five steps +trainer = pl.Trainer( + fast_dev_run=100, callbacks=[SAM(skip_step_count=5, logging=True, log_level="INFO")] +) trainer.fit(task, datamodule=dm) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index af7ffa54..3239e830 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -776,12 +776,12 @@ def __init__( super().__init__() self.rho = rho self.adaptive = adaptive - if skip_epoch_count and skip_epoch_count: + if skip_step_count and skip_epoch_count: raise ValueError( "`skip_epoch_count` and `skip_step_count` are mutually exclusive for SAM." ) self.skip_step_count = skip_step_count - if skip_epoch_count and not skip_epoch_count.is_integer(): + if skip_epoch_count and isinstance(skip_epoch_count, float): assert ( 0 < skip_epoch_count < 1.0 ), "Decimal `skip_epoch_count` passed not within [0,1]." @@ -808,8 +808,12 @@ def on_fit_start( self.max_steps = train_len * self.max_epochs # if a fractional epoch skip is specified, convert it to # an integer count for easier comparison - if self.skip_epoch_count and not self.skip_epoch_count.is_integer(): + if self.skip_epoch_count and isinstance(self.skip_epoch_count, float): self.skip_epoch_count = int(self.max_epochs * self.skip_epoch_count) + if self.logger: + self.logger.info( + f"Fractional epoch skip - will start SAM from epoch {self.skip_epoch_count}, max epochs {self.max_epochs}." + ) # add floating point epsilon for later use self.epsilon = torch.tensor( [torch.finfo(pl_module.dtype).eps],