Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiling #1

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ teacher:
warmup_teacher_temp_epochs: 30
optim:
epochs: 100
max_iter:
weight_decay: 0.04
weight_decay_end: 0.4
base_lr: 0.004 # learning rate for a batch size of 1024
Expand Down
2 changes: 1 addition & 1 deletion dinov2/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def evaluate(
metric_logger = MetricLogger(delimiter=" ", verbose=verbose)
header = "Test"

for samples, targets, *_ in metric_logger.log_every(data_loader, 10, device, header):
for samples, targets, *_ in metric_logger.log_every(data_loader, device, 10, header):
# given model went through ModelWithNormalize, outputs are already normalized
outputs = model(samples.to(device))
targets = targets.to(device)
Expand Down
17 changes: 12 additions & 5 deletions dinov2/logging/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,32 @@ def synchronize_between_processes(self):
def add_meter(self, name, meter):
self.meters[name] = meter

def dump_in_output_file(self, iteration, iter_time, data_time):
def dump_in_output_file(self, iteration, iter_time, data_time, cpu_time):
if self.output_file is None or not distributed.is_main_process():
return
dict_to_dump = dict(
iteration=iteration,
iter_time=iter_time,
data_time=data_time,
cpu_time=cpu_time,
)
dict_to_dump.update({k: v.median for k, v in self.meters.items()})
with open(self.output_file, "a") as f:
f.write(json.dumps(dict_to_dump) + "\n")
pass

def log_every(
self, iterable, print_freq, gpu_id, header=None, n_iterations=None, start_iteration=0, print_log: bool = True
self, iterable, gpu_id, log_freq=None, header=None, n_iterations=None, start_iteration=0, print_log: bool = True
):
i = start_iteration
if not header:
header = ""
start_time = time.time()
end = time.time()
cpu_end = time.process_time()
iter_time = SmoothedValue(fmt="{avg:.6f}")
data_time = SmoothedValue(fmt="{avg:.6f}")
cpu_time = SmoothedValue(fmt="{avg:.6f}")

if n_iterations is None:
n_iterations = len(iterable)
Expand All @@ -87,7 +90,7 @@ def log_every(
ncols=80,
unit_scale=1,
initial=start_iteration,
total=n_iterations,
total=n_iterations - 1,
leave=self.verbose,
file=sys.stdout,
disable=not (gpu_id in [-1, 0]),
Expand All @@ -96,11 +99,15 @@ def log_every(
for obj in tqdm_iterable:
data_time.update(time.time() - end)
yield obj
cpu_time.update(time.process_time() - cpu_end)
iter_time.update(time.time() - end)
if (i % print_freq == 0 or i == n_iterations - 1) and print_log:
self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
if ((log_freq is not None and i % log_freq == 0) or i == n_iterations - 1) and print_log:
self.dump_in_output_file(
iteration=i, iter_time=iter_time.avg, data_time=data_time.avg, cpu_time=cpu_time.avg
)
i += 1
end = time.time()
cpu_end = time.process_time()
if i >= n_iterations:
break
total_time = time.time() - start_time
Expand Down
30 changes: 25 additions & 5 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
import math
import os
import time
import json
import wandb
import tqdm
import datetime
Expand All @@ -23,7 +25,7 @@
from dinov2.data.transforms import make_classification_eval_transform
import dinov2.distributed as distributed
from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger
from dinov2.logging import MetricLogger, SmoothedValue
from dinov2.utils.config import setup, write_config
from dinov2.utils.utils import CosineScheduler, initialize_wandb, load_weights
from dinov2.models import build_model_from_cfg
Expand All @@ -32,7 +34,6 @@
from dinov2.eval.metrics import AccuracyAveraging
from dinov2.eval.utils import EarlyStoppingDINO


from dinov2.train.ssl_meta_arch import SSLMetaArch


Expand Down Expand Up @@ -332,7 +333,10 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):

total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size()
OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
if cfg.optim.max_iter is not None:
max_iter = cfg.optim.max_iter
else:
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH

periodic_checkpointer = PeriodicCheckpointer(
checkpointer,
Expand Down Expand Up @@ -383,14 +387,19 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):
iteration = start_iter

logger.info("Starting training from iteration {}".format(start_iter))
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metrics_file = os.path.join(
cfg.train.output_dir, f"training_metrics_{cfg.student.arch}_{cfg.training.batch_size_per_gpu}.json"
)
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
log_freq = 10 # log_freq has to be smaller than the window_size used with instantiating SmoothedValue (here and in MetricLogger)
header = "Train"

forward_backward_time = SmoothedValue(fmt="{avg:.6f}")

for data in metric_logger.log_every(
data_loader,
10,
gpu_id,
log_freq,
header,
max_iter,
start_iter,
Expand All @@ -411,7 +420,18 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False):
# compute losses

optimizer.zero_grad(set_to_none=True)
forward_backward_start = time.time()
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
forward_backward_time.update(time.time() - forward_backward_start)

if metrics_file is not None and distributed.is_main_process():
if (log_freq is not None and iteration % log_freq == 0) or iteration == max_iter - 1:
dict_to_dump = dict(
iteration=iteration,
forward_backward_time=forward_backward_time.avg,
)
with open(metrics_file, "a") as f:
f.write(json.dumps(dict_to_dump) + "\n")

# clip gradients

Expand Down
Loading