Skip to content

Commit

Permalink
fixed multiple wandb logging when distributed + added teacher checkpo…
Browse files Browse the repository at this point in the history
…int saving every p% iteration
  • Loading branch information
clemsgrs committed Sep 18, 2024
1 parent eff7616 commit c09982b
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 41 deletions.
4 changes: 1 addition & 3 deletions dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ train:
batch_size_per_gpu: 64
dataset_path: ImageNet:split=TRAIN
output_dir: .
saveckp_freq: 20
seed: 0
num_workers: 8
cache_dataset: true
centering: "centering" # or "sinkhorn_knopp"
save_frequency: 0.1 # save every 10% of an epoch
save_every: 5
tune:
tune_every:
Expand Down Expand Up @@ -132,8 +132,6 @@ crops:
- 0.32
global_crops_size: 224
local_crops_size: 96
evaluation:
eval_period_iterations: 12500
wandb:
enable: false
project: ''
Expand Down
9 changes: 5 additions & 4 deletions dinov2/configs/train/vit_base_14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ibot:
separate_head: true
train:
batch_size_per_gpu: 128
dataset_path: PathologyFoundation:root=/root/data
dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries
centering: sinkhorn_knopp
num_workers: 8
tune:
Expand All @@ -27,18 +27,19 @@ teacher:
momentum_teacher: 0.994
optim:
epochs: 100
max_iter: 50000
weight_decay_end: 0.2
base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg
warmup_epochs: 16
layerwise_decay: 1.0
crops:
local_crops_size: 98
wandb:
enable: true
enable: false
project: 'dinov2'
username: 'vlfm'
exp_name: 'profiling'
tags: ['${wandb.exp_name}', 'patch', '${student.arch}']
exp_name: 'dinov2'
tags: ['${wandb.exp_name}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
7 changes: 4 additions & 3 deletions dinov2/configs/train/vit_large_14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ibot:
head_n_prototypes: 131072
train:
batch_size_per_gpu: 128
dataset_path: PathologyFoundation:root=/root/data
dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries
centering: sinkhorn_knopp
num_workers: 8
tune:
Expand All @@ -29,6 +29,7 @@ teacher:
momentum_teacher: 0.994
optim:
epochs: 100
max_iter: 50000
weight_decay_end: 0.2
base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg
warmup_epochs: 80
Expand All @@ -39,8 +40,8 @@ wandb:
enable: true
project: 'dinov2'
username: 'vlfm'
exp_name: 'profiling'
tags: ['${wandb.exp_name}', 'patch', '${student.arch}']
exp_name: 'dinov2'
tags: ['${wandb.exp_name}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
9 changes: 5 additions & 4 deletions dinov2/configs/train/vit_small_14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ibot:
separate_head: true
train:
batch_size_per_gpu: 128
dataset_path: PathologyFoundation:root=/root/data
dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries
centering: sinkhorn_knopp
num_workers: 8
tune:
Expand All @@ -27,18 +27,19 @@ teacher:
momentum_teacher: 0.994
optim:
epochs: 100
max_iter: 50000
weight_decay_end: 0.2
base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg
warmup_epochs: 16
layerwise_decay: 1.0
crops:
local_crops_size: 98
wandb:
enable: true
enable: false
project: 'dinov2'
username: 'vlfm'
exp_name: 'profiling'
tags: ['${wandb.exp_name}', 'patch', '${student.arch}']
exp_name: 'dinov2'
tags: ['${wandb.exp_name}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
7 changes: 4 additions & 3 deletions dinov2/configs/train/vit_tiny_14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ibot:
separate_head: true
train:
batch_size_per_gpu: 32
dataset_path: PathologyFoundation:root=/root/data
dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries
centering: sinkhorn_knopp
num_workers: 8
tune:
Expand All @@ -27,6 +27,7 @@ teacher:
momentum_teacher: 0.994
optim:
epochs: 100
max_iter: 50000
weight_decay_end: 0.2
base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg
warmup_epochs: 16
Expand All @@ -37,8 +38,8 @@ wandb:
enable: false
project: 'dinov2'
username: 'vlfm'
exp_name: 'profiling'
tags: ['${wandb.exp_name}', 'patch', '${student.arch}']
exp_name: 'dinov2'
tags: ['${wandb.exp_name}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
20 changes: 11 additions & 9 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,14 @@ def update_log_dict(
log_dict.update({f"{name}": value})


def do_test(cfg, model, iteration):
def save_checkpoint(cfg, model, iteration):
new_state_dict = model.teacher.state_dict()

if distributed.is_main_process():
iterstring = str(iteration)
eval_dir = Path(cfg.train.output_dir, "eval", iterstring)
eval_dir.mkdir(exist_ok=True, parents=True)
checkpoint_dir = Path(cfg.train.output_dir, "checkpoints", "teacher")
checkpoint_dir.mkdir(exist_ok=True, parents=True)
# save teacher checkpoint
teacher_ckp_path = Path(eval_dir, "teacher_checkpoint.pth")
teacher_ckp_path = Path(checkpoint_dir, f"teacher_{iteration}.pth")
torch.save({"teacher": new_state_dict}, teacher_ckp_path)


Expand Down Expand Up @@ -326,6 +325,7 @@ def do_train(cfg, model, resume=False):

total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size()
OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size
save_every = int(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH)
if cfg.optim.max_iter is not None:
max_iter = cfg.optim.max_iter
else:
Expand Down Expand Up @@ -506,21 +506,23 @@ def do_train(cfg, model, resume=False):
)
break

# save snapshot and log to wandb
# log to wandb

if distributed.is_main_process() and cfg.wandb.enable and iteration % OFFICIAL_EPOCH_LENGTH == 0:
wandb.log(log_dict, step=epoch)

# checkpointing and testing

if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
do_test(cfg, model, f"training_{iteration}")
if cfg.train.save_frequency > 0 and (iteration + 1) % save_every == 0:
save_checkpoint(cfg, model, iteration + 1)
torch.cuda.synchronize()

periodic_checkpointer.step(iteration, run_distributed=run_distributed)

iteration = iteration + 1

# gather stats from all processes

metric_logger.synchronize_between_processes()
train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return train_stats
Expand All @@ -542,7 +544,7 @@ def main(args):
.get("iteration", -1)
+ 1
)
return do_test(cfg, model, f"manual_{iteration}")
return save_checkpoint(cfg, model, iteration)

do_train(cfg, model, resume=not args.no_resume)

Expand Down
34 changes: 19 additions & 15 deletions dinov2/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import math
import logging
import os
import datetime
import torch

from pathlib import Path
from omegaconf import OmegaConf
Expand Down Expand Up @@ -50,20 +46,28 @@ def get_cfg_from_args(args):


def default_setup(args, cfg):
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = utils.initialize_wandb(cfg, key=key)
wandb_run.define_metric("epoch", summary="max")
run_id = wandb_run.id
distributed.enable(overwrite=True)

output_dir = Path(cfg.train.output_dir, run_id)
if distributed.is_main_process():
output_dir.mkdir(exist_ok=True, parents=True)
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = utils.initialize_wandb(cfg, key=key)
wandb_run.define_metric("epoch", summary="max")
run_id = wandb_run.id
else:
run_id = ""

if distributed.is_enabled():
obj = [run_id]
torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{distributed.get_local_rank()}"))
run_id = obj[0]

output_dir = Path(cfg.train.output_dir, run_id)
output_dir.mkdir(exist_ok=True, parents=True)
cfg.train.output_dir = str(output_dir)

distributed.enable(overwrite=True)
seed = getattr(args, "seed", 0)
rank = distributed.get_global_rank()

Expand Down

0 comments on commit c09982b

Please sign in to comment.