diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index 484a5f874..cd12854cb 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -108,6 +108,7 @@ teacher: warmup_teacher_temp_epochs: 30 optim: epochs: 100 + max_iter: weight_decay: 0.04 weight_decay_end: 0.4 base_lr: 0.004 # learning rate for a batch size of 1024 diff --git a/dinov2/eval/utils.py b/dinov2/eval/utils.py index 2e6138136..eb5f8ba56 100644 --- a/dinov2/eval/utils.py +++ b/dinov2/eval/utils.py @@ -68,7 +68,7 @@ def evaluate( metric_logger = MetricLogger(delimiter=" ", verbose=verbose) header = "Test" - for samples, targets, *_ in metric_logger.log_every(data_loader, 10, device, header): + for samples, targets, *_ in metric_logger.log_every(data_loader, device, 10, header): # given model went through ModelWithNormalize, outputs are already normalized outputs = model(samples.to(device)) targets = targets.to(device) diff --git a/dinov2/logging/helpers.py b/dinov2/logging/helpers.py index b729684b7..727756108 100644 --- a/dinov2/logging/helpers.py +++ b/dinov2/logging/helpers.py @@ -53,13 +53,14 @@ def synchronize_between_processes(self): def add_meter(self, name, meter): self.meters[name] = meter - def dump_in_output_file(self, iteration, iter_time, data_time): + def dump_in_output_file(self, iteration, iter_time, data_time, cpu_time): if self.output_file is None or not distributed.is_main_process(): return dict_to_dump = dict( iteration=iteration, iter_time=iter_time, data_time=data_time, + cpu_time=cpu_time, ) dict_to_dump.update({k: v.median for k, v in self.meters.items()}) with open(self.output_file, "a") as f: @@ -67,15 +68,17 @@ def dump_in_output_file(self, iteration, iter_time, data_time): pass def log_every( - self, iterable, print_freq, gpu_id, header=None, n_iterations=None, start_iteration=0, print_log: bool = True + self, iterable, gpu_id, log_freq=None, header=None, n_iterations=None, start_iteration=0, print_log: bool = True ): i = start_iteration if not header: header = "" start_time = time.time() end = time.time() + cpu_end = time.process_time() iter_time = SmoothedValue(fmt="{avg:.6f}") data_time = SmoothedValue(fmt="{avg:.6f}") + cpu_time = SmoothedValue(fmt="{avg:.6f}") if n_iterations is None: n_iterations = len(iterable) @@ -87,7 +90,7 @@ def log_every( ncols=80, unit_scale=1, initial=start_iteration, - total=n_iterations, + total=n_iterations - 1, leave=self.verbose, file=sys.stdout, disable=not (gpu_id in [-1, 0]), @@ -96,11 +99,15 @@ def log_every( for obj in tqdm_iterable: data_time.update(time.time() - end) yield obj + cpu_time.update(time.process_time() - cpu_end) iter_time.update(time.time() - end) - if (i % print_freq == 0 or i == n_iterations - 1) and print_log: - self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + if ((log_freq is not None and i % log_freq == 0) or i == n_iterations - 1) and print_log: + self.dump_in_output_file( + iteration=i, iter_time=iter_time.avg, data_time=data_time.avg, cpu_time=cpu_time.avg + ) i += 1 end = time.time() + cpu_end = time.process_time() if i >= n_iterations: break total_time = time.time() - start_time diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 4e5a9ec1e..9a0805e12 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -7,6 +7,8 @@ import logging import math import os +import time +import json import wandb import tqdm import datetime @@ -23,7 +25,7 @@ from dinov2.data.transforms import make_classification_eval_transform import dinov2.distributed as distributed from dinov2.fsdp import FSDPCheckpointer -from dinov2.logging import MetricLogger +from dinov2.logging import MetricLogger, SmoothedValue from dinov2.utils.config import setup, write_config from dinov2.utils.utils import CosineScheduler, initialize_wandb, load_weights from dinov2.models import build_model_from_cfg @@ -32,7 +34,6 @@ from dinov2.eval.metrics import AccuracyAveraging from dinov2.eval.utils import EarlyStoppingDINO - from dinov2.train.ssl_meta_arch import SSLMetaArch @@ -332,7 +333,10 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size() OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size - max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + if cfg.optim.max_iter is not None: + max_iter = cfg.optim.max_iter + else: + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH periodic_checkpointer = PeriodicCheckpointer( checkpointer, @@ -383,14 +387,19 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): iteration = start_iter logger.info("Starting training from iteration {}".format(start_iter)) - metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metrics_file = os.path.join( + cfg.train.output_dir, f"training_metrics_{cfg.student.arch}_{cfg.training.batch_size_per_gpu}.json" + ) metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + log_freq = 10 # log_freq has to be smaller than the window_size used with instantiating SmoothedValue (here and in MetricLogger) header = "Train" + forward_backward_time = SmoothedValue(fmt="{avg:.6f}") + for data in metric_logger.log_every( data_loader, - 10, gpu_id, + log_freq, header, max_iter, start_iter, @@ -411,7 +420,18 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): # compute losses optimizer.zero_grad(set_to_none=True) + forward_backward_start = time.time() loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + forward_backward_time.update(time.time() - forward_backward_start) + + if metrics_file is not None and distributed.is_main_process(): + if (log_freq is not None and iteration % log_freq == 0) or iteration == max_iter - 1: + dict_to_dump = dict( + iteration=iteration, + forward_backward_time=forward_backward_time.avg, + ) + with open(metrics_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") # clip gradients