diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index cd12854cb..6ce2f39cd 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -62,11 +62,11 @@ train: batch_size_per_gpu: 64 dataset_path: ImageNet:split=TRAIN output_dir: . - saveckp_freq: 20 seed: 0 num_workers: 8 cache_dataset: true centering: "centering" # or "sinkhorn_knopp" + save_frequency: 0.1 # save every 10% of an epoch save_every: 5 tune: tune_every: @@ -132,8 +132,6 @@ crops: - 0.32 global_crops_size: 224 local_crops_size: 96 -evaluation: - eval_period_iterations: 12500 wandb: enable: false project: '' diff --git a/dinov2/configs/train/vit_base_14.yaml b/dinov2/configs/train/vit_base_14.yaml index 518899aa5..ed978b242 100644 --- a/dinov2/configs/train/vit_base_14.yaml +++ b/dinov2/configs/train/vit_base_14.yaml @@ -4,7 +4,7 @@ ibot: separate_head: true train: batch_size_per_gpu: 128 - dataset_path: PathologyFoundation:root=/root/data + dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries centering: sinkhorn_knopp num_workers: 8 tune: @@ -27,6 +27,7 @@ teacher: momentum_teacher: 0.994 optim: epochs: 100 + max_iter: 50000 weight_decay_end: 0.2 base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg warmup_epochs: 16 @@ -34,11 +35,11 @@ optim: crops: local_crops_size: 98 wandb: - enable: true + enable: false project: 'dinov2' username: 'vlfm' - exp_name: 'profiling' - tags: ['${wandb.exp_name}', 'patch', '${student.arch}'] + exp_name: 'dinov2' + tags: ['${wandb.exp_name}', '${student.arch}'] dir: '/home/user' group: resume_id: diff --git a/dinov2/configs/train/vit_large_14.yaml b/dinov2/configs/train/vit_large_14.yaml index 204087a38..a78bb9875 100644 --- a/dinov2/configs/train/vit_large_14.yaml +++ b/dinov2/configs/train/vit_large_14.yaml @@ -6,7 +6,7 @@ ibot: head_n_prototypes: 131072 train: batch_size_per_gpu: 128 - dataset_path: PathologyFoundation:root=/root/data + dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries centering: sinkhorn_knopp num_workers: 8 tune: @@ -29,6 +29,7 @@ teacher: momentum_teacher: 0.994 optim: epochs: 100 + max_iter: 50000 weight_decay_end: 0.2 base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg warmup_epochs: 80 @@ -39,8 +40,8 @@ wandb: enable: true project: 'dinov2' username: 'vlfm' - exp_name: 'profiling' - tags: ['${wandb.exp_name}', 'patch', '${student.arch}'] + exp_name: 'dinov2' + tags: ['${wandb.exp_name}', '${student.arch}'] dir: '/home/user' group: resume_id: diff --git a/dinov2/configs/train/vit_small_14.yaml b/dinov2/configs/train/vit_small_14.yaml index 6c0e75989..5b0070346 100644 --- a/dinov2/configs/train/vit_small_14.yaml +++ b/dinov2/configs/train/vit_small_14.yaml @@ -4,7 +4,7 @@ ibot: separate_head: true train: batch_size_per_gpu: 128 - dataset_path: PathologyFoundation:root=/root/data + dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries centering: sinkhorn_knopp num_workers: 8 tune: @@ -27,6 +27,7 @@ teacher: momentum_teacher: 0.994 optim: epochs: 100 + max_iter: 50000 weight_decay_end: 0.2 base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg warmup_epochs: 16 @@ -34,11 +35,11 @@ optim: crops: local_crops_size: 98 wandb: - enable: true + enable: false project: 'dinov2' username: 'vlfm' - exp_name: 'profiling' - tags: ['${wandb.exp_name}', 'patch', '${student.arch}'] + exp_name: 'dinov2' + tags: ['${wandb.exp_name}', '${student.arch}'] dir: '/home/user' group: resume_id: diff --git a/dinov2/configs/train/vit_tiny_14.yaml b/dinov2/configs/train/vit_tiny_14.yaml index 23535120b..8b6325497 100644 --- a/dinov2/configs/train/vit_tiny_14.yaml +++ b/dinov2/configs/train/vit_tiny_14.yaml @@ -4,7 +4,7 @@ ibot: separate_head: true train: batch_size_per_gpu: 32 - dataset_path: PathologyFoundation:root=/root/data + dataset_path: PathologyFoundation:root=/pathology-fm-data/tarballs:extra=/pathology-fm-data/entries centering: sinkhorn_knopp num_workers: 8 tune: @@ -27,6 +27,7 @@ teacher: momentum_teacher: 0.994 optim: epochs: 100 + max_iter: 50000 weight_decay_end: 0.2 base_lr: 2.0e-03 # learning rate for a batch size of 1024, will get scaled in apply_scaling_rules_to_cfg warmup_epochs: 16 @@ -37,8 +38,8 @@ wandb: enable: false project: 'dinov2' username: 'vlfm' - exp_name: 'profiling' - tags: ['${wandb.exp_name}', 'patch', '${student.arch}'] + exp_name: 'dinov2' + tags: ['${wandb.exp_name}', '${student.arch}'] dir: '/home/user' group: resume_id: diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 46ae7e000..9c00d96c1 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -136,15 +136,14 @@ def update_log_dict( log_dict.update({f"{name}": value}) -def do_test(cfg, model, iteration): +def save_checkpoint(cfg, model, iteration): new_state_dict = model.teacher.state_dict() if distributed.is_main_process(): - iterstring = str(iteration) - eval_dir = Path(cfg.train.output_dir, "eval", iterstring) - eval_dir.mkdir(exist_ok=True, parents=True) + checkpoint_dir = Path(cfg.train.output_dir, "checkpoints", "teacher") + checkpoint_dir.mkdir(exist_ok=True, parents=True) # save teacher checkpoint - teacher_ckp_path = Path(eval_dir, "teacher_checkpoint.pth") + teacher_ckp_path = Path(checkpoint_dir, f"teacher_{iteration}.pth") torch.save({"teacher": new_state_dict}, teacher_ckp_path) @@ -326,6 +325,7 @@ def do_train(cfg, model, resume=False): total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size() OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size + save_every = int(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH) if cfg.optim.max_iter is not None: max_iter = cfg.optim.max_iter else: @@ -506,14 +506,15 @@ def do_train(cfg, model, resume=False): ) break - # save snapshot and log to wandb + # log to wandb + if distributed.is_main_process() and cfg.wandb.enable and iteration % OFFICIAL_EPOCH_LENGTH == 0: wandb.log(log_dict, step=epoch) # checkpointing and testing - if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: - do_test(cfg, model, f"training_{iteration}") + if cfg.train.save_frequency > 0 and (iteration + 1) % save_every == 0: + save_checkpoint(cfg, model, iteration + 1) torch.cuda.synchronize() periodic_checkpointer.step(iteration, run_distributed=run_distributed) @@ -521,6 +522,7 @@ def do_train(cfg, model, resume=False): iteration = iteration + 1 # gather stats from all processes + metric_logger.synchronize_between_processes() train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} return train_stats @@ -542,7 +544,7 @@ def main(args): .get("iteration", -1) + 1 ) - return do_test(cfg, model, f"manual_{iteration}") + return save_checkpoint(cfg, model, iteration) do_train(cfg, model, resume=not args.no_resume) diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index 411ea1633..69cc124b9 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -1,12 +1,8 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - import math import logging import os import datetime +import torch from pathlib import Path from omegaconf import OmegaConf @@ -50,20 +46,28 @@ def get_cfg_from_args(args): def default_setup(args, cfg): - run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") - # set up wandb - if cfg.wandb.enable: - key = os.environ.get("WANDB_API_KEY") - wandb_run = utils.initialize_wandb(cfg, key=key) - wandb_run.define_metric("epoch", summary="max") - run_id = wandb_run.id + distributed.enable(overwrite=True) - output_dir = Path(cfg.train.output_dir, run_id) if distributed.is_main_process(): - output_dir.mkdir(exist_ok=True, parents=True) + run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") + # set up wandb + if cfg.wandb.enable: + key = os.environ.get("WANDB_API_KEY") + wandb_run = utils.initialize_wandb(cfg, key=key) + wandb_run.define_metric("epoch", summary="max") + run_id = wandb_run.id + else: + run_id = "" + + if distributed.is_enabled(): + obj = [run_id] + torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{distributed.get_local_rank()}")) + run_id = obj[0] + + output_dir = Path(cfg.train.output_dir, run_id) + output_dir.mkdir(exist_ok=True, parents=True) cfg.train.output_dir = str(output_dir) - distributed.enable(overwrite=True) seed = getattr(args, "seed", 0) rank = distributed.get_global_rank()