Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Weights and Baises Integration #383

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8==3.9.2 flake8-bugbear flake8-comprehensions isort==5.8.0
python -m pip install flake8==3.9.2 flake8-bugbear flake8-comprehensions isort==5.8.0 wandb
python -m pip install black==21.6b0
flake8 --version
- name: Lint
Expand Down
65 changes: 31 additions & 34 deletions nanodet/trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def training_step(self, batch, batch_idx):
preds, loss, loss_states = self.model.forward_train(batch)

# log train losses
log_dict = {}
if self.global_step % self.cfg.log.interval == 0:
lr = self.optimizers().param_groups[0]["lr"]
log_msg = "Train|Epoch{}/{}|Iter{}({})| lr:{:.2e}| ".format(
Expand All @@ -86,19 +87,16 @@ def training_step(self, batch, batch_idx):
batch_idx,
lr,
)
self.scalar_summary("Train_loss/lr", "Train", lr, self.global_step)
for loss_name in loss_states:
log_msg += "{}:{:.4f}| ".format(
loss_name, loss_states[loss_name].mean().item()
)
self.scalar_summary(
"Train_loss/" + loss_name,
"Train",
loss_states[loss_name].mean().item(),
self.global_step,
)
self.logger.info(log_msg)

log_dict[loss_name] = loss_states[loss_name].mean().item()
log_dict["lr"] = lr
for logger in self.logger:
logger.info(log_msg)
logger.log_metrics(log_dict, self.global_step, prefix="Train/")

return loss

def training_epoch_end(self, outputs: List[Any]) -> None:
Expand All @@ -125,11 +123,12 @@ def validation_step(self, batch, batch_idx):
log_msg += "{}:{:.4f}| ".format(
loss_name, loss_states[loss_name].mean().item()
)
self.logger.info(log_msg)

for logger in self.logger:
logger.info(log_msg)
dets = self.model.head.post_process(preds, batch)
return dets


def validation_epoch_end(self, validation_step_outputs):
"""
Called at the end of the validation epoch with the
Expand Down Expand Up @@ -173,10 +172,14 @@ def validation_epoch_end(self, validation_step_outputs):
warnings.warn(
"Warning! Save_key is not in eval results! Only save model last!"
)
self.logger.log_metrics(eval_results, self.current_epoch + 1)

for logger in self.logger:
logger.log_metrics(eval_results, self.current_epoch + 1, "Val_metrics/")
logger.log_val_results(all_results, self.cfg)
else:
self.logger.info("Skip val on rank {}".format(self.local_rank))

for logger in self.logger:
logger.info("Skip val on rank {}".format(self.local_rank))

def test_step(self, batch, batch_idx):
dets = self.predict(batch, batch_idx)
return dets
Expand Down Expand Up @@ -204,7 +207,8 @@ def test_epoch_end(self, test_step_outputs):
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))
else:
self.logger.info("Skip test on rank {}".format(self.local_rank))
for logger in self.logger:
logger.info("Skip test on rank {}".format(self.local_rank))

def configure_optimizers(self):
"""
Expand Down Expand Up @@ -285,25 +289,14 @@ def get_progress_bar_dict(self):
items.pop("loss", None)
return items

def scalar_summary(self, tag, phase, value, step):
"""
Write Tensorboard scalar summary log.
Args:
tag: Name for the tag
phase: 'Train' or 'Val'
value: Value to record
step: Step value to record

"""
if self.local_rank < 1:
self.logger.experiment.add_scalars(tag, {phase: value}, step)

def info(self, string):
self.logger.info(string)
for logger in self.logger:
logger.info(string)

@rank_zero_only
def save_model_state(self, path):
self.logger.info("Saving model to {}".format(path))
for logger in self.logger:
logger.info("Saving model to {}".format(path))
state_dict = (
self.weight_averager.state_dict()
if self.weight_averager
Expand All @@ -318,7 +311,8 @@ def on_train_start(self) -> None:

def on_pretrain_routine_end(self) -> None:
if "weight_averager" in self.cfg.model:
self.logger.info("Weight Averaging is enabled")
for logger in self.logger:
logger.info("Weight Averaging is enabled")
if self.weight_averager and self.weight_averager.has_inited():
self.weight_averager.to(self.weight_averager.device)
return
Expand Down Expand Up @@ -346,14 +340,17 @@ def on_test_epoch_start(self) -> None:
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
if self.weight_averager:
avg_params = convert_avg_params(checkpointed_state)
info = ''
if len(avg_params) != len(self.model.state_dict()):
self.logger.info(
info = (
"Weight averaging is enabled but average state does not"
"match the model"
)
)
else:
self.weight_averager = build_weight_averager(
self.cfg.model.weight_averager, device=self.device
)
self.weight_averager.load_state_dict(avg_params)
self.logger.info("Loaded average state from checkpoint.")
info = "Loaded average state from checkpoint."
for logger in self.logger:
logger.info(info)
2 changes: 1 addition & 1 deletion nanodet/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from .config import cfg, load_config
from .flops_counter import get_model_complexity_info
from .logger import AverageMeter, Logger, MovingAverage, NanoDetLightningLogger
from .logger import AverageMeter, Logger, MovingAverage, NanoDetLightningLogger, NanoDetWandbLogger
from .misc import images_to_levels, multi_apply, unmap
from .path import collect_files, mkdir
from .rank_filter import rank_filter
Expand Down
109 changes: 105 additions & 4 deletions nanodet/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import get_filesystem
from termcolor import colored
from pytorch_lightning.loggers import WandbLogger
from nanodet.data.dataset import build_dataset
import wandb

from .path import mkdir

Expand Down Expand Up @@ -153,7 +156,6 @@ def experiment(self):
"the dependencies to use torch.utils.tensorboard "
"(applicable to PyTorch 1.1 or higher)"
) from None

self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -209,10 +211,10 @@ def log_hyperparams(self, params):
self.logger.info(f"hyperparams: {params}")

@rank_zero_only
def log_metrics(self, metrics, step):
self.logger.info(f"Val_metrics: {metrics}")
def log_metrics(self, metrics, step, prefix="Val_metrics/"):
self.logger.info(f"{prefix}: {metrics}")
for k, v in metrics.items():
self.experiment.add_scalars("Val_metrics/" + k, {"Val": v}, step)
self.experiment.add_scalars(prefix + k, {"Val": v}, step)

@rank_zero_only
def save(self):
Expand All @@ -223,3 +225,102 @@ def finalize(self, status):
self.experiment.flush()
self.experiment.close()
self.save()

@rank_zero_only
def log_val_results(self, results, cfg):
'''
A methoed to process and log insights/metrics/visualizations from
validation results
'''
pass



class NanoDetWandbLogger(WandbLogger):
def __init__(self,save_dir="./", num_eval_samples=16, **kwargs):
super().__init__(save_dir=save_dir, **kwargs)
self._num_eval_samples = num_eval_samples
self._id_to_name = None

@rank_zero_only
def info(self, string):
pass

@rank_zero_only
def log_metrics(self, metrics, step, prefix=""):
self._prefix = prefix
super().log_metrics(metrics, step)

@rank_zero_only
def log_val_results(self, results, cfg):
'''
A methoed to process and log insights/metrics/visualizations from
validation results
'''
self._log_eval_table(results, cfg)

@rank_zero_only
def dump_cfg(self, cfg_node):
pass

def _log_eval_sample(self, sample_id, sample_path, preds, classes):
"""
Creates one validation image sample
Args:
sample_id (int):
sample_path (Path):
preds (Dict[int,List[float]]):
classes (List[str]):
"""

wandb_classes = wandb.Classes([{'id': id, 'name': name} for id, name in enumerate(classes)])
class_id_to_label = {k: v for k, v in enumerate(classes)}
box_data = []
avg_conf_per_class = [0] * len(classes)
pred_class_count = {}
for cls in preds:
for *xyxy, conf in preds[cls]:
if conf >= 0.25:
box_data.append(
{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": cls,
"box_caption": f"{classes[cls]} {conf:.3f}",
"scores": {"class_score": conf},
"domain": "pixel"})
avg_conf_per_class[cls] += conf

if cls in pred_class_count:
pred_class_count[cls] += 1
else:
pred_class_count[cls] = 1

for pred_class in pred_class_count.keys():
avg_conf_per_class[pred_class] = avg_conf_per_class[pred_class] / pred_class_count[pred_class]

boxes = {"predictions": {"box_data": box_data, "class_labels": class_id_to_label}}

wandb_img = wandb.Image(sample_path, boxes=boxes, classes=wandb_classes)
return [sample_id, wandb_img, *avg_conf_per_class]



@rank_zero_only
def _log_eval_table(self, results, cfg):
max_len = min(len(results.keys()), self._num_eval_samples)
eval_table_rows = []
# create id to file name mappings
if max_len > 0 and self._id_to_name is None:
val_dataset = build_dataset(cfg.data.val, "test")
self._id_to_name = {data['img_info']['id']: data['img_info']['file_name'] for data in val_dataset}

for res_id in list(results.keys())[:max_len]:
res = results[res_id]
file_path = os.path.join(cfg.data.val.img_path, self._id_to_name[res_id])
eval_table_rows.append(self._log_eval_sample(
res_id, file_path, res, cfg.class_names
))
if eval_table_rows:
self.log_table(key="eval_samples",
columns=["id", "prediction", *cfg.class_names],
data=eval_table_rows
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ torch>=1.7
torchmetrics
torchvision
tqdm
wandb
2 changes: 1 addition & 1 deletion tests/test_trainer/test_lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DummyTrainer(nn.Module):
global_step = 0
local_rank = 0
use_ddp = False
logger = NanoDetLightningLogger(tempfile.TemporaryDirectory().name)
logger = [NanoDetLightningLogger(tempfile.TemporaryDirectory().name)]

def save_checkpoint(self, *args, **kwargs):
pass
Expand Down
13 changes: 8 additions & 5 deletions tests/test_utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@

from torch.utils.tensorboard import SummaryWriter

from nanodet.util import NanoDetLightningLogger, cfg, load_config
from nanodet.util import NanoDetLightningLogger, NanoDetWandbLogger, cfg, load_config


def test_logger():
tmp_dir = tempfile.TemporaryDirectory()
logger = NanoDetLightningLogger(tmp_dir.name)

wandb_logger = NanoDetWandbLogger(save_dir=tmp_dir.name, project="ci_test", anonymous="must")
writer = logger.experiment
assert isinstance(writer, SummaryWriter)

logger.info("test")

logger.log_hyperparams({"lr": 1})

wandb_logger.log_hyperparams({"lr": 1})

logger.log_metrics({"mAP": 30.1}, 1)

wandb_logger.log_hyperparams({"mAP": 30.1})

load_config(cfg, "./config/legacy_v0.x_configs/nanodet-m.yml")
logger.dump_cfg(cfg)

logger.finalize(None)
wandb_logger.finalize(None)
10 changes: 8 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nanodet.trainer.task import TrainingTask
from nanodet.util import (
NanoDetLightningLogger,
NanoDetWandbLogger,
cfg,
convert_old_model,
load_config,
Expand All @@ -41,6 +42,9 @@ def parse_args():
"--local_rank", default=-1, type=int, help="node rank for distributed training"
)
parser.add_argument("--seed", type=int, default=None, help="random seed")
parser.add_argument('--use-wandb', action='store_true', help="use wandb logger")
parser.add_argument(
"--eval-samples", type=int, default=16, help="number of evaluation samples to log")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -112,7 +116,9 @@ def main(args):
)

accelerator = None if len(cfg.device.gpu_ids) <= 1 else "ddp"

loggers = [logger]
if args.use_wandb:
loggers.append(NanoDetWandbLogger(save_dir="./", num_eval_samples=args.eval_samples, name="test", project="NnN"))
trainer = pl.Trainer(
default_root_dir=cfg.save_dir,
max_epochs=cfg.schedule.total_epochs,
Expand All @@ -123,7 +129,7 @@ def main(args):
num_sanity_val_steps=0,
resume_from_checkpoint=model_resume_path,
callbacks=[ProgressBar(refresh_rate=0)], # disable tqdm bar
logger=logger,
logger=loggers,
benchmark=True,
gradient_clip_val=cfg.get("grad_clip", 0.0),
)
Expand Down