Skip to content

Commit

Permalink
Merge pull request #42 from aimagelab/dev
Browse files Browse the repository at this point in the history
Fix CIFAR-100 scheduler
  • Loading branch information
loribonna committed Jul 1, 2024
2 parents 88c06ec + ef4b9e4 commit 26eb2eb
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 32 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ With Mammoth, nothing is set in stone. You can easily add new models, datasets,

Join our Discord Server for all your Mammoth-related questions → ![Discord Shield](https://discordapp.com/api/guilds/1164956257392799860/widget.png?style=shield)

## **NEW**: WIKI
## Documentation

We have created a [WIKI](https://aimagelab.github.io/mammoth/)! Check it out for more information on how to use Mammoth.
### Check out the official [DOCUMENTATION](https://aimagelab.github.io/mammoth/) for more information on how to use Mammoth!

<p align="center">
<img width="112" height="112" src="seq_mnist.gif" alt="Sequential MNIST">
Expand Down
6 changes: 4 additions & 2 deletions datasets/seq_cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def get_batch_size(self):
return 32

@staticmethod
def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler:
scheduler = ContinualDataset.get_scheduler(model, args)
def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler:
scheduler = ContinualDataset.get_scheduler(model, args, reload_optim)
if scheduler is None:
if reload_optim:
model.opt = model.get_optimizer()
scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False)
return scheduler
11 changes: 8 additions & 3 deletions datasets/utils/continual_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,15 @@ def get_denormalization_transform() -> nn.Module:
raise NotImplementedError

@staticmethod
def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler._LRScheduler:
"""Returns the scheduler to be used for the current dataset."""
def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler._LRScheduler:
"""
Returns the scheduler to be used for the current dataset.
If `reload_optim` is True, the optimizer is reloaded from the model. This should be done at least ONCE every task
to ensure that the learning rate is reset to the initial value.
"""
if args.lr_scheduler is not None:
model.opt = model.get_optimizer()
if reload_optim or not hasattr(model, 'opt'):
model.opt = model.get_optimizer()
# check if lr_scheduler is in torch.optim.lr_scheduler
supported_scheds = {sched_name.lower(): sched_name for sched_name in dir(scheds) if sched_name.lower() in ContinualDataset.AVAIL_SCHEDS}
sched = None
Expand Down
1 change: 1 addition & 0 deletions docs/datasets/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ each dataset **must statically define** all the necessary information to run a c

- **get_denormalization_transform** static method (``callable``): returns the transform to apply on the tensors to revert the normalization. You can use the `DeNormalize` function defined in `datasets/transforms/denormalization.py`.

- **get_scheduler** static method (``callable``): returns the learning rate scheduler to use during train. *By default*, it also initializes the optimizer. This prevents errors due to the learning rate being continouosly reduced task after task. This behavior can be changed setting the argument ``reload_optim=False``.

See :ref:`continual_dataset` for more details or **SequentialCIFAR10** in :ref:`seq_cifar10` for an example.

Expand Down
3 changes: 0 additions & 3 deletions models/utils/lider_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def init_net(self, dataset):

self.net.lip_coeffs = torch.autograd.Variable(torch.randn(len(teacher_feats), dtype=torch.float), requires_grad=True).to(self.device)
self.net.lip_coeffs.data = budget_lip
self.opt = self.get_optimizer()
if hasattr(self, 'scheduler'):
self.scheduler = self.dataset.get_scheduler()

self.net.train(was_training)

Expand Down
121 changes: 121 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import sys

import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
import pytest


def test_der_cifar100_defaultscheduler():
N_TASKS = 10
sys.argv = ['mammoth',
'--model',
'der',
'--dataset',
'seq-cifar100',
'--buffer_size',
'500',
'--alpha',
'0.3',
'--lr',
'0.03',
'--n_epochs',
'50',
'--non_verbose',
'1',
'--num_workers',
'0',
'--debug_mode',
'1',
'--savecheck',
'1',
'--seed',
'0']

log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_defaultscheduler.log')
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
sys.stdout = open(log_path, 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()

# read output file and search for the string 'Saving checkpoint into'
with open(log_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
ckpt_name = [line for line in lines if 'Saving checkpoint into' in line]
assert any(ckpt_name), f'Checkpoint not saved in {log_path}'

ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip()
ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)]


for ckpt_path in ckpt_paths:
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'

ckpt = torch.load(ckpt_path)
opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler']
assert opt['initial_lr'] == 0.03, f'Learning rate not updated correctly in {ckpt_path}'
assert opt['lr']==opt['initial_lr']*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}'
assert list(sched['milestones'].keys()) == [35, 45], f'Milestones not updated correctly in {ckpt_path}'
assert sched['base_lrs']==[0.03], f'Base learning rate not updated correctly in {ckpt_path}'


def test_der_cifar100_customscheduler():
N_TASKS = 10
sys.argv = ['mammoth',
'--model',
'der',
'--dataset',
'seq-cifar100',
'--buffer_size',
'500',
'--alpha',
'0.3',
'--lr',
'0.1',
'--n_epochs',
'10',
'--non_verbose',
'1',
'--num_workers',
'0',
'--debug_mode',
'1',
'--savecheck',
'1',
'--lr_scheduler',
'multisteplr',
'--lr_milestones',
'2','4','6','8',
'--seed',
'0']

log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_customscheduler.der.cifar100.log')
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
sys.stdout = open(log_path, 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()

# read output file and search for the string 'Saving checkpoint into'
with open(log_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
ckpt_name = [line for line in lines if 'Saving checkpoint into' in line]
assert any(ckpt_name), f'Checkpoint not saved in {log_path}'

ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip()
ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)]


for ckpt_path in ckpt_paths:
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'

ckpt = torch.load(ckpt_path)
opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler']
assert opt['initial_lr'] == 0.1, f'Learning rate not updated correctly in {ckpt_path}'
assert opt['lr']==opt['initial_lr']*0.1*0.1*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}'
assert list(sched['milestones'].keys()) == [2,4,6,8], f'Milestones not updated correctly in {ckpt_path}'
assert sched['base_lrs']==[0.1], f'Base learning rate not updated correctly in {ckpt_path}'
10 changes: 6 additions & 4 deletions utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ def update_stats(self, cpu_res, gpu_res):
self._it += 1

alpha = 1 / self._it
self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res)
self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)}
if self.initial_cpu_res is not None:
self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res)
self.max_cpu_res = max(self.max_cpu_res, cpu_res)

self.max_cpu_res = max(self.max_cpu_res, cpu_res)
self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)}
if self.initial_gpu_res is not None:
self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)}
self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)}

if self.logger is not None:
self.logger.log_system_stats(cpu_res, gpu_res)
Expand Down
28 changes: 11 additions & 17 deletions utils/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,22 @@ def reset(self) -> None:
self.old_time = time()
self.running_sum = 0
self.current_task_iter = 0
self.last_task = 0

def prog(self, i: int, max_iter: int, epoch: Union[int, str],
def prog(self, current_epoch_iter: int, max_epoch_iter: int, epoch: Union[int, str],
task_number: int, loss: float) -> None:
"""
Prints out the progress bar on the stderr file.
Args:
i: the current iteration
max_iter: the maximum number of iteration. If None, the progress bar is not printed.
current_epoch_iter: the current iteration of the epoch
max_epoch_iter: the maximum number of iteration for the task. If None, the progress bar is not printed.
epoch: the epoch
task_number: the task index
loss: the current value of the loss function
"""
max_width = shutil.get_terminal_size((80, 20)).columns
if not self.verbose:
if i == 0:
if current_epoch_iter == 0:
if self.joint:
padded_print('[ {} ] Joint | epoch {}\n'.format(
datetime.now().strftime("%m-%d | %H:%M"),
Expand All @@ -84,19 +83,19 @@ def prog(self, i: int, max_iter: int, epoch: Union[int, str],
return

timediff = time() - self.old_time
self.running_sum = self.running_sum + timediff + 1e-8
self.running_sum += timediff + 1e-8

# Print the progress bar every update_every iterations
if (i and i % self.update_every == 0) or (max_iter is not None and i == max_iter - 1):
progress = min(float((i + 1) / max_iter), 1) if max_iter else 0
progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) if max_iter else '~N/A~'
if (current_epoch_iter and current_epoch_iter % self.update_every == 0) or (max_epoch_iter is not None and current_epoch_iter == max_epoch_iter - 1):
progress = min(float((current_epoch_iter + 1) / max_epoch_iter), 1) if max_epoch_iter else 0
progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) if max_epoch_iter else '~N/A~'
if self.joint:
padded_print('\r[ {} ] Joint | epoch {} | iter {}: |{}| {} ep/h | loss: {} | Time: {} ms/it'.format(
datetime.now().strftime("%m-%d | %H:%M"),
epoch,
self.current_task_iter + 1,
progress_bar,
round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A',
round(3600 / (max_epoch_iter * timediff), 2) if max_epoch_iter else 'N/A',
round(loss, 8),
round(1000 * timediff / self.update_every, 2)
), max_width=max_width, file=sys.stderr, end='', flush=True)
Expand All @@ -107,18 +106,13 @@ def prog(self, i: int, max_iter: int, epoch: Union[int, str],
epoch,
self.current_task_iter + 1,
progress_bar,
round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A',
round(3600 / (max_epoch_iter * timediff), 2) if max_epoch_iter else 'N/A',
round(loss, 8),
round(1000 * timediff / self.update_every, 2)
), max_width=max_width, file=sys.stderr, end='', flush=True)

self.current_task_iter += 1

# def __del__(self):
# max_width = shutil.get_terminal_size((80, 20)).columns
# # if self.verbose:
# # print('\n', file=sys.stderr, flush=True)
# padded_print('\tLast task took: {} s'.format(round(self.running_sum, 2)), max_width=max_width, file=sys.stderr, flush=True)
self.old_time = time()


def progress_bar(i: int, max_iter: int, epoch: Union[int, str],
Expand Down
2 changes: 1 addition & 1 deletion utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def train(model: ContinualModel, dataset: ContinualDataset,
if dataset.SETTING == 'class-il':
results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1]

scheduler = dataset.get_scheduler(model, args) if not hasattr(model, 'scheduler') else model.scheduler
scheduler = dataset.get_scheduler(model, args, reload_optim=True) if not hasattr(model, 'scheduler') else model.scheduler

epoch = 0
best_ea_metric = None
Expand Down

0 comments on commit 26eb2eb

Please sign in to comment.