From 854761a43c39a8ebcd3c55275bf616a4552f67e1 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 9 Jan 2022 13:15:47 +0100 Subject: [PATCH 1/2] losses, top1, top5 added to iteration hook --- robustness/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robustness/train.py b/robustness/train.py index f955289..364e708 100644 --- a/robustness/train.py +++ b/robustness/train.py @@ -501,7 +501,7 @@ def _model_loop(args, loop_type, loader, model, opt, epoch, adv, writer): # USER-DEFINED HOOK if has_attr(args, 'iteration_hook'): - args.iteration_hook(model, i, loop_type, inp, target) + args.iteration_hook(model, i, loop_type, inp, target, losses, top1, top5) iterator.set_description(desc) iterator.refresh() From bf822ae5466365d209f275ba542ce7a3c89de61b Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 9 Jan 2022 13:22:17 +0100 Subject: [PATCH 2/2] losses, top1, top5 and docstring --- robustness/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robustness/train.py b/robustness/train.py index 364e708..1f2f816 100644 --- a/robustness/train.py +++ b/robustness/train.py @@ -254,7 +254,7 @@ def train_model(args, model, loaders, *, checkpoint=None, dp_device_ids=None, If given, this function is called every training iteration by the training loop (useful for custom logging). The function is given arguments `model, iteration #, loop_type [train/eval], - current_batch_ims, current_batch_labels`. + current_batch_ims, current_batch_labels, losses (AverageMeter), top1 accuracy (AverageMeter), top5 accuracy (AverageMeter)`. epoch hook (function, optional) Similar to iteration_hook but called every epoch instead, and given arguments `model, log_info` where `log_info` is a