diff --git a/README.md b/README.md index 52e12d65..401e133d 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,9 @@ With Mammoth, nothing is set in stone. You can easily add new models, datasets, Join our Discord Server for all your Mammoth-related questions → ![Discord Shield](https://discordapp.com/api/guilds/1164956257392799860/widget.png?style=shield) -## **NEW**: WIKI +## Documentation -We have created a [WIKI](https://aimagelab.github.io/mammoth/)! Check it out for more information on how to use Mammoth. +### Check out the official [DOCUMENTATION](https://aimagelab.github.io/mammoth/) for more information on how to use Mammoth!

Sequential MNIST diff --git a/datasets/seq_cifar100.py b/datasets/seq_cifar100.py index 8f71959b..451f9838 100644 --- a/datasets/seq_cifar100.py +++ b/datasets/seq_cifar100.py @@ -151,8 +151,10 @@ def get_batch_size(self): return 32 @staticmethod - def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler: - scheduler = ContinualDataset.get_scheduler(model, args) + def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler: + scheduler = ContinualDataset.get_scheduler(model, args, reload_optim) if scheduler is None: + if reload_optim: + model.opt = model.get_optimizer() scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False) return scheduler diff --git a/datasets/utils/continual_dataset.py b/datasets/utils/continual_dataset.py index 0ffc3915..7289ac89 100644 --- a/datasets/utils/continual_dataset.py +++ b/datasets/utils/continual_dataset.py @@ -157,10 +157,15 @@ def get_denormalization_transform() -> nn.Module: raise NotImplementedError @staticmethod - def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler._LRScheduler: - """Returns the scheduler to be used for the current dataset.""" + def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler._LRScheduler: + """ + Returns the scheduler to be used for the current dataset. + If `reload_optim` is True, the optimizer is reloaded from the model. This should be done at least ONCE every task + to ensure that the learning rate is reset to the initial value. + """ if args.lr_scheduler is not None: - model.opt = model.get_optimizer() + if reload_optim or not hasattr(model, 'opt'): + model.opt = model.get_optimizer() # check if lr_scheduler is in torch.optim.lr_scheduler supported_scheds = {sched_name.lower(): sched_name for sched_name in dir(scheds) if sched_name.lower() in ContinualDataset.AVAIL_SCHEDS} sched = None diff --git a/docs/datasets/index.rst b/docs/datasets/index.rst index c34db1c8..412126f7 100644 --- a/docs/datasets/index.rst +++ b/docs/datasets/index.rst @@ -38,6 +38,7 @@ each dataset **must statically define** all the necessary information to run a c - **get_denormalization_transform** static method (``callable``): returns the transform to apply on the tensors to revert the normalization. You can use the `DeNormalize` function defined in `datasets/transforms/denormalization.py`. + - **get_scheduler** static method (``callable``): returns the learning rate scheduler to use during train. *By default*, it also initializes the optimizer. This prevents errors due to the learning rate being continouosly reduced task after task. This behavior can be changed setting the argument ``reload_optim=False``. See :ref:`continual_dataset` for more details or **SequentialCIFAR10** in :ref:`seq_cifar10` for an example. diff --git a/models/utils/lider_model.py b/models/utils/lider_model.py index bcfe9179..77b55b49 100644 --- a/models/utils/lider_model.py +++ b/models/utils/lider_model.py @@ -160,9 +160,6 @@ def init_net(self, dataset): self.net.lip_coeffs = torch.autograd.Variable(torch.randn(len(teacher_feats), dtype=torch.float), requires_grad=True).to(self.device) self.net.lip_coeffs.data = budget_lip - self.opt = self.get_optimizer() - if hasattr(self, 'scheduler'): - self.scheduler = self.dataset.get_scheduler() self.net.train(was_training) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 00000000..80af6160 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,121 @@ +import os +import sys + +import torch +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.main import main +import pytest + + +def test_der_cifar100_defaultscheduler(): + N_TASKS = 10 + sys.argv = ['mammoth', + '--model', + 'der', + '--dataset', + 'seq-cifar100', + '--buffer_size', + '500', + '--alpha', + '0.3', + '--lr', + '0.03', + '--n_epochs', + '50', + '--non_verbose', + '1', + '--num_workers', + '0', + '--debug_mode', + '1', + '--savecheck', + '1', + '--seed', + '0'] + + log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_defaultscheduler.log') + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(log_path, 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # read output file and search for the string 'Saving checkpoint into' + with open(log_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + ckpt_name = [line for line in lines if 'Saving checkpoint into' in line] + assert any(ckpt_name), f'Checkpoint not saved in {log_path}' + + ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)] + + + for ckpt_path in ckpt_paths: + assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found' + + ckpt = torch.load(ckpt_path) + opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler'] + assert opt['initial_lr'] == 0.03, f'Learning rate not updated correctly in {ckpt_path}' + assert opt['lr']==opt['initial_lr']*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}' + assert list(sched['milestones'].keys()) == [35, 45], f'Milestones not updated correctly in {ckpt_path}' + assert sched['base_lrs']==[0.03], f'Base learning rate not updated correctly in {ckpt_path}' + + +def test_der_cifar100_customscheduler(): + N_TASKS = 10 + sys.argv = ['mammoth', + '--model', + 'der', + '--dataset', + 'seq-cifar100', + '--buffer_size', + '500', + '--alpha', + '0.3', + '--lr', + '0.1', + '--n_epochs', + '10', + '--non_verbose', + '1', + '--num_workers', + '0', + '--debug_mode', + '1', + '--savecheck', + '1', + '--lr_scheduler', + 'multisteplr', + '--lr_milestones', + '2','4','6','8', + '--seed', + '0'] + + log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_customscheduler.der.cifar100.log') + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(log_path, 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # read output file and search for the string 'Saving checkpoint into' + with open(log_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + ckpt_name = [line for line in lines if 'Saving checkpoint into' in line] + assert any(ckpt_name), f'Checkpoint not saved in {log_path}' + + ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)] + + + for ckpt_path in ckpt_paths: + assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found' + + ckpt = torch.load(ckpt_path) + opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler'] + assert opt['initial_lr'] == 0.1, f'Learning rate not updated correctly in {ckpt_path}' + assert opt['lr']==opt['initial_lr']*0.1*0.1*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}' + assert list(sched['milestones'].keys()) == [2,4,6,8], f'Milestones not updated correctly in {ckpt_path}' + assert sched['base_lrs']==[0.1], f'Base learning rate not updated correctly in {ckpt_path}' \ No newline at end of file diff --git a/utils/stats.py b/utils/stats.py index d3d36a68..924e5629 100644 --- a/utils/stats.py +++ b/utils/stats.py @@ -131,11 +131,13 @@ def update_stats(self, cpu_res, gpu_res): self._it += 1 alpha = 1 / self._it - self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res) - self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)} + if self.initial_cpu_res is not None: + self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res) + self.max_cpu_res = max(self.max_cpu_res, cpu_res) - self.max_cpu_res = max(self.max_cpu_res, cpu_res) - self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)} + if self.initial_gpu_res is not None: + self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)} + self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)} if self.logger is not None: self.logger.log_system_stats(cpu_res, gpu_res) diff --git a/utils/status.py b/utils/status.py index a799895c..72f5cf26 100644 --- a/utils/status.py +++ b/utils/status.py @@ -52,23 +52,22 @@ def reset(self) -> None: self.old_time = time() self.running_sum = 0 self.current_task_iter = 0 - self.last_task = 0 - def prog(self, i: int, max_iter: int, epoch: Union[int, str], + def prog(self, current_epoch_iter: int, max_epoch_iter: int, epoch: Union[int, str], task_number: int, loss: float) -> None: """ Prints out the progress bar on the stderr file. Args: - i: the current iteration - max_iter: the maximum number of iteration. If None, the progress bar is not printed. + current_epoch_iter: the current iteration of the epoch + max_epoch_iter: the maximum number of iteration for the task. If None, the progress bar is not printed. epoch: the epoch task_number: the task index loss: the current value of the loss function """ max_width = shutil.get_terminal_size((80, 20)).columns if not self.verbose: - if i == 0: + if current_epoch_iter == 0: if self.joint: padded_print('[ {} ] Joint | epoch {}\n'.format( datetime.now().strftime("%m-%d | %H:%M"), @@ -84,19 +83,19 @@ def prog(self, i: int, max_iter: int, epoch: Union[int, str], return timediff = time() - self.old_time - self.running_sum = self.running_sum + timediff + 1e-8 + self.running_sum += timediff + 1e-8 # Print the progress bar every update_every iterations - if (i and i % self.update_every == 0) or (max_iter is not None and i == max_iter - 1): - progress = min(float((i + 1) / max_iter), 1) if max_iter else 0 - progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) if max_iter else '~N/A~' + if (current_epoch_iter and current_epoch_iter % self.update_every == 0) or (max_epoch_iter is not None and current_epoch_iter == max_epoch_iter - 1): + progress = min(float((current_epoch_iter + 1) / max_epoch_iter), 1) if max_epoch_iter else 0 + progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) if max_epoch_iter else '~N/A~' if self.joint: padded_print('\r[ {} ] Joint | epoch {} | iter {}: |{}| {} ep/h | loss: {} | Time: {} ms/it'.format( datetime.now().strftime("%m-%d | %H:%M"), epoch, self.current_task_iter + 1, progress_bar, - round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A', + round(3600 / (max_epoch_iter * timediff), 2) if max_epoch_iter else 'N/A', round(loss, 8), round(1000 * timediff / self.update_every, 2) ), max_width=max_width, file=sys.stderr, end='', flush=True) @@ -107,18 +106,13 @@ def prog(self, i: int, max_iter: int, epoch: Union[int, str], epoch, self.current_task_iter + 1, progress_bar, - round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A', + round(3600 / (max_epoch_iter * timediff), 2) if max_epoch_iter else 'N/A', round(loss, 8), round(1000 * timediff / self.update_every, 2) ), max_width=max_width, file=sys.stderr, end='', flush=True) self.current_task_iter += 1 - - # def __del__(self): - # max_width = shutil.get_terminal_size((80, 20)).columns - # # if self.verbose: - # # print('\n', file=sys.stderr, flush=True) - # padded_print('\tLast task took: {} s'.format(round(self.running_sum, 2)), max_width=max_width, file=sys.stderr, flush=True) + self.old_time = time() def progress_bar(i: int, max_iter: int, epoch: Union[int, str], diff --git a/utils/training.py b/utils/training.py index 46815172..640efd6c 100644 --- a/utils/training.py +++ b/utils/training.py @@ -257,7 +257,7 @@ def train(model: ContinualModel, dataset: ContinualDataset, if dataset.SETTING == 'class-il': results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1] - scheduler = dataset.get_scheduler(model, args) if not hasattr(model, 'scheduler') else model.scheduler + scheduler = dataset.get_scheduler(model, args, reload_optim=True) if not hasattr(model, 'scheduler') else model.scheduler epoch = 0 best_ea_metric = None