Skip to content

Commit

Permalink
improved config setup & readability of train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
clemsgrs committed Aug 20, 2024
1 parent a118e31 commit 082523b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 54 deletions.
4 changes: 2 additions & 2 deletions dinov2/configs/train/vit_tiny_14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dino:
ibot:
separate_head: true
train:
batch_size_per_gpu: 128
batch_size_per_gpu: 32
dataset_path: PathologyFoundation:root=/root/data
centering: sinkhorn_knopp
num_workers: 8
Expand Down Expand Up @@ -34,7 +34,7 @@ optim:
crops:
local_crops_size: 98
wandb:
enable: true
enable: false
project: 'dinov2'
username: 'vlfm'
exp_name: 'profiling'
Expand Down
2 changes: 1 addition & 1 deletion dinov2/eval/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def eval_knn(
persistent_workers=persistent_workers,
verbose=verbose,
)
num_classes = query_labels.max() + 1
num_classes = len(torch.unique(query_labels))
metric_collection = build_metric(num_classes=num_classes, average_type=accuracy_averaging)

device = torch.cuda.current_device()
Expand Down
59 changes: 11 additions & 48 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import argparse
import logging
import math
Expand All @@ -11,7 +6,6 @@
import json
import wandb
import tqdm
import datetime
from functools import partial
from typing import Optional
from pathlib import Path
Expand All @@ -27,7 +21,7 @@
from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger, SmoothedValue
from dinov2.utils.config import setup, write_config
from dinov2.utils.utils import CosineScheduler, initialize_wandb, load_weights
from dinov2.utils.utils import CosineScheduler, load_weights
from dinov2.models import build_model_from_cfg
from dinov2.eval.knn import eval_knn_with_model
from dinov2.eval.setup import get_autocast_dtype
Expand Down Expand Up @@ -64,7 +58,7 @@ def get_args_parser(add_help: bool = True):
parser.add_argument(
"--output-dir",
"--output_dir",
default="",
default="output",
type=str,
help="Output directory to save logs and checkpoints",
)
Expand Down Expand Up @@ -161,7 +155,6 @@ def do_tune(
query_dataset,
test_dataset,
output_dir,
gpu_id,
verbose: bool = True,
):
# in DINOv2, they have on SSLMetaArch class
Expand Down Expand Up @@ -189,8 +182,8 @@ def do_tune(

student = student.to(torch.device("cuda"))
teacher = teacher.to(torch.device("cuda"))
# student = student.to(torch.device(f"cuda:{gpu_id}"))
# teacher = teacher.to(torch.device(f"cuda:{gpu_id}"))
# student = student.to(torch.device(f"cuda:{distributed.get_global_rank()}"))
# teacher = teacher.to(torch.device(f"cuda:{distributed.get_global_rank()}"))
if verbose:
tqdm.tqdm.write(f"Loading epoch {epoch} weights...")
student_weights = model.student.state_dict()
Expand All @@ -217,7 +210,7 @@ def do_tune(
temperature=cfg.tune.knn.temperature,
autocast_dtype=autocast_dtype,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
gpu_id=gpu_id,
gpu_id=distributed.get_global_rank(),
gather_on_cpu=cfg.tune.knn.gather_on_cpu,
batch_size=cfg.tune.knn.batch_size,
num_workers=0,
Expand All @@ -237,7 +230,7 @@ def do_tune(
temperature=cfg.tune.knn.temperature,
autocast_dtype=autocast_dtype,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
gpu_id=gpu_id,
gpu_id=distributed.get_global_rank(),
gather_on_cpu=cfg.tune.knn.gather_on_cpu,
batch_size=cfg.tune.knn.batch_size,
num_workers=0,
Expand All @@ -263,7 +256,7 @@ def do_tune(
return results


def do_train(cfg, model, gpu_id, run_distributed, resume=False):
def do_train(cfg, model, resume=False):
model.train()
inputs_dtype = torch.half
fp16_scaler = model.fp16_scaler # for mixed precision training
Expand Down Expand Up @@ -396,7 +389,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):

for data in metric_logger.log_every(
data_loader,
gpu_id,
distributed.get_global_rank(),
log_freq,
header,
max_iter,
Expand Down Expand Up @@ -493,7 +486,6 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):
query_dataset,
test_dataset,
results_save_dir,
gpu_id,
verbose=False,
)

Expand All @@ -502,7 +494,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):
for name, value in metrics_dict.items():
update_log_dict(log_dict, f"tune/{model_name}.{name}", value, step="epoch")

early_stopper(epoch, tune_results, periodic_checkpointer, run_distributed, iteration)
early_stopper(epoch, tune_results, periodic_checkpointer, distributed.is_enabled(), iteration)
if early_stopper.early_stop and cfg.tune.early_stopping.enable:
stop = True

Expand All @@ -523,7 +515,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):
do_test(cfg, model, f"training_{iteration}")
torch.cuda.synchronize()

periodic_checkpointer.step(iteration, run_distributed=run_distributed)
periodic_checkpointer.step(iteration, run_distributed=distributed.is_enabled())

iteration = iteration + 1

Expand All @@ -535,35 +527,6 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):

def main(args):
cfg = setup(args)

run_distributed = torch.cuda.device_count() > 1
if run_distributed:
gpu_id = int(os.environ["LOCAL_RANK"])
else:
gpu_id = -1

if distributed.is_main_process():
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = initialize_wandb(cfg, key=key)
wandb_run.define_metric("epoch", summary="max")
run_id = wandb_run.id
else:
run_id = ""

if run_distributed:
obj = [run_id]
torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{gpu_id}"))
run_id = obj[0]

output_dir = Path(cfg.train.output_dir, run_id)
if distributed.is_main_process():
output_dir.mkdir(exist_ok=True, parents=True)
cfg.train.output_dir = str(output_dir)

if distributed.is_main_process():
write_config(cfg, cfg.train.output_dir)

Expand All @@ -580,7 +543,7 @@ def main(args):
)
return do_test(cfg, model, f"manual_{iteration}")

do_train(cfg, model, gpu_id, run_distributed, resume=not args.no_resume)
do_train(cfg, model, resume=not args.no_resume)


if __name__ == "__main__":
Expand Down
22 changes: 19 additions & 3 deletions dinov2/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import math
import logging
import os
import datetime

from pathlib import Path
from omegaconf import OmegaConf

import dinov2.distributed as distributed
Expand Down Expand Up @@ -47,25 +49,39 @@ def get_cfg_from_args(args):
return cfg


def default_setup(args):
def default_setup(args, cfg):
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = utils.initialize_wandb(cfg, key=key)
wandb_run.define_metric("epoch", summary="max")
run_id = wandb_run.id

output_dir = Path(cfg.train.output_dir, run_id)
if distributed.is_main_process():
output_dir.mkdir(exist_ok=True, parents=True)
cfg.train.output_dir = str(output_dir)

distributed.enable(overwrite=True)
seed = getattr(args, "seed", 0)
rank = distributed.get_global_rank()

global logger
setup_logging(output=args.output_dir, level=logging.INFO)
setup_logging(output=cfg.train.output_dir, level=logging.INFO)
logger = logging.getLogger("dinov2")

utils.fix_random_seeds(seed + rank)
logger.info("git:\n {}\n".format(utils.get_sha()))
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
return cfg


def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg_from_args(args)
default_setup(args)
cfg = default_setup(args, cfg)
apply_scaling_rules_to_cfg(cfg)
return cfg

0 comments on commit 082523b

Please sign in to comment.