From 8eed347c8bf0bf5f23a7d3937e6aa2e2997159bc Mon Sep 17 00:00:00 2001 From: Michael Welter Date: Sat, 7 Dec 2024 10:45:02 +0100 Subject: [PATCH] Improve train code: * Move general stuff from pose estimator training to train.py * Autoformat * Add smoketest back in --- run.sh | 21 +-- scripts/train_poseestimator.py | 137 +-------------- test/test_train.py | 144 ++++++++++++++- trackertraincode/train.py | 313 ++++++++++++++++++++++++--------- 4 files changed, 382 insertions(+), 233 deletions(-) diff --git a/run.sh b/run.sh index 3e0017a..2e20a5f 100644 --- a/run.sh +++ b/run.sh @@ -1,20 +1,9 @@ #!/bin/bash -# python scripts/train_poseestimator.py --lr 1.e-3 --epochs 500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \ -# --save-plot train.pdf \ -# --with-swa \ -# --with-nll-loss \ -# --roi-override original \ -# --no-onnx \ -# --backbone mobilenetv1 \ -# --outdir model_files/ - - - -#--rampup_nll_losses \ - -python scripts/train_poseestimator_lightning.py --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \ - --epochs 10 \ +python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \ --with-swa \ --with-nll-loss \ - --rampup-nll-losses \ No newline at end of file + --backbone hybrid_vit \ + --rampup-nll-losses + +# --outdir model_files/current/run0/ \ No newline at end of file diff --git a/scripts/train_poseestimator.py b/scripts/train_poseestimator.py index e91dddb..24b53cc 100644 --- a/scripts/train_poseestimator.py +++ b/scripts/train_poseestimator.py @@ -18,10 +18,8 @@ import tqdm import pytorch_lightning as pl -from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint -# from pytorch_lightning.loggers import Logger, -from pytorch_lightning.utilities import rank_zero_only import torch.optim as optim import torch import torch.nn as nn @@ -32,8 +30,6 @@ import trackertraincode.train as train import trackertraincode.pipelines -from trackertraincode.neuralnets.io import complement_lightning_checkpoint -from scripts.export_model import convert_posemodel_onnx from trackertraincode.datasets.batch import Batch from trackertraincode.pipelines import Tag @@ -161,11 +157,6 @@ def create_optimizer(net, args: MyArgs): return optimizer, scheduler -class SaveBestSpec(NamedTuple): - weights: List[float] - names: List[str] - - def setup_losses(args: MyArgs, net): C = train.Criterion cregularize = [ @@ -259,9 +250,7 @@ def wrapped(step): ), } - savebest = SaveBestSpec([1.0, 1.0, 1.0], ["rot", "xy", "sz"]) - - return train_criterions, test_criterions, savebest + return train_criterions, test_criterions def create_net(args: MyArgs): @@ -281,7 +270,7 @@ def __init__(self, args: MyArgs): super().__init__() self._args = args self._model = create_net(args) - train_criterions, test_criterions, savebest = setup_losses(args, self._model) + train_criterions, test_criterions = setup_losses(args, self._model) self._train_criterions = train_criterions self._test_criterions = test_criterions @@ -315,120 +304,6 @@ def model(self): return self._model -class SwaCallback(Callback): - def __init__(self, start_epoch): - super().__init__() - self._swa_model: optim.swa_utils.AveragedModel | None = None - self._start_epoch = start_epoch - - @property - def swa_model(self): - return self._swa_model.module - - def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - assert isinstance(pl_module, LitModel) - self._swa_model = optim.swa_utils.AveragedModel(pl_module.model, device="cpu", use_buffers=True) - - def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - assert isinstance(pl_module, LitModel) - if trainer.current_epoch > self._start_epoch: - self._swa_model.update_parameters(pl_module.model) - - def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - assert self._swa_model is not None - swa_filename = join(trainer.default_root_dir, f"swa.ckpt") - models.save_model(self._swa_model.module, swa_filename) - - -class MetricsGraphing(Callback): - def __init__(self): - super().__init__() - self._visu: train.TrainHistoryPlotter | None = None - self._metrics_accumulator = defaultdict(list) - - def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - assert self._visu is None - self._visu = train.TrainHistoryPlotter(save_filename=join(trainer.default_root_dir, "train.pdf")) - - def on_train_batch_end( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int - ): - mt_losses: dict[str, torch.Tensor] = outputs["mt_losses"] - for k, v in mt_losses.items(): - self._visu.add_train_point(trainer.current_epoch, batch_idx, k, v.numpy()) - self._visu.add_train_point(trainer.current_epoch, batch_idx, "loss", outputs["loss"].detach().cpu().numpy()) - - def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - if trainer.lr_scheduler_configs: # scheduler is not None: - scheduler = next( - iter(trainer.lr_scheduler_configs) - ).scheduler # Pick the first scheduler (and there should only be one) - last_lr = next(iter(scheduler.get_last_lr())) # LR from the first parameter group - self._visu.add_test_point(trainer.current_epoch, "lr", last_lr) - - self._visu.summarize_train_values() - - def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._metrics_accumulator = defaultdict(list) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[train.LossVal], - batch: Any, - batch_idx: int, - dataloader_idx: int = 0, - ) -> None: - for val in outputs: - self._metrics_accumulator[val.name].append(val.val) - - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - if self._visu is None: - return - for k, v in self._metrics_accumulator.items(): - self._visu.add_test_point(trainer.current_epoch - 1, k, torch.cat(v).mean().cpu().numpy()) - if trainer.current_epoch > 0: - self._visu.update_graph() - - -class SimpleProgressBar(Callback): - """Creates progress bars for total training time and progress of per epoch.""" - - def __init__(self, batchsize: int): - super().__init__() - self._bar: tqdm.tqdm | None = None - self._epoch_bar: tqdm.tqdm | None = None - self._batchsize = batchsize - - def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._bar = tqdm.tqdm(total=trainer.max_epochs, desc='Training', position=0) - self._epoch_bar = tqdm.tqdm(total=trainer.num_training_batches * self._batchsize, desc="Epoch", position=1) - - def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._bar.close() - self._epoch_bar.close() - self._bar = None - self._epoch_bar = None - - def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._epoch_bar.reset(self._epoch_bar.total) - - def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._bar.update(1) - - def on_train_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: Mapping[str, Any], - batch: list[Batch] | Batch, - batch_idx: int, - ) -> None: - n = sum(b.meta.batchsize for b in batch) if isinstance(batch, list) else batch.meta.batchsize - self._epoch_bar.update(n) - - def main(): np.seterr(all="raise") cv2.setNumThreads(1) @@ -499,13 +374,13 @@ def main(): save_weights_only=False, ) - progress_cb = SimpleProgressBar(args.batchsize) + progress_cb = train.SimpleProgressBar(args.batchsize) - callbacks = [MetricsGraphing(), checkpoint_cb, progress_cb] + callbacks = [train.MetricsGraphing(), checkpoint_cb, progress_cb] swa_callback = None if args.swa: - swa_callback = SwaCallback(start_epoch=args.epochs * 2 // 3) + swa_callback = train.SwaCallback(start_epoch=args.epochs * 2 // 3) callbacks.append(swa_callback) # TODO: inf norm? diff --git a/test/test_train.py b/test/test_train.py index 2674c5f..1f42f7b 100644 --- a/test/test_train.py +++ b/test/test_train.py @@ -1,22 +1,154 @@ from torch.utils.data import Dataset, DataLoader import time import torch +from torch import nn import numpy as np +import os import functools -from typing import List -from trackertraincode.datasets.batch import Batch, Metadata +from typing import List, Any +import itertools +from pytorch_lightning.callbacks import ModelCheckpoint +import pytorch_lightning as pl +import matplotlib +import matplotlib.pyplot +import time +from trackertraincode.datasets.batch import Batch, Metadata import trackertraincode.train as train + def test_plotter(): plotter = train.TrainHistoryPlotter() - names = [ 'foo', 'bar', 'baz', 'lr' ] + names = ['foo', 'bar', 'baz', 'lr'] for e in range(4): for t in range(5): for name in names[:-2]: - plotter.add_train_point(e, t, name, 10. + e + np.random.normal(0., 1.,(1,))) + plotter.add_train_point(e, t, name, 10.0 + e + np.random.normal(0.0, 1.0, (1,))) for name in names[1:]: - plotter.add_test_point(e, name, 9. + e + np.random.normal()) + plotter.add_test_point(e, name, 9.0 + e + np.random.normal()) plotter.summarize_train_values() plotter.update_graph() - plotter.close() \ No newline at end of file + plotter.close() + + +class MseLoss(object): + def __call__(self, pred, batch): + return torch.nn.functional.mse_loss(pred['test_head_out'], batch['y'], reduction='none') + + +class L1Loss(object): + def __call__(self, pred, batch): + return torch.nn.functional.l1_loss(pred['test_head_out'], batch['y'], reduction='none') + + +class CosineDataset(Dataset): + def __init__(self, n): + self.n = n + + def __len__(self): + return self.n + + def __getitem__(self, i): + x = torch.rand((1,)) + y = torch.cos(x) + return Batch(Metadata(0, batchsize=0), {'image': x, 'y': y}) + + +class MockupModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential(torch.nn.Linear(1, 128), torch.nn.ReLU(), torch.nn.Linear(128, 1)) + + def forward(self, x: torch.Tensor): + return {'test_head_out': self.layers(x)} + + def get_config(self): + return {} + + +class LitModel(pl.LightningModule): + def __init__(self): + super().__init__() + self._model = MockupModel() + self._train_criterions = self.__setup_train_criterions() + self._test_criterion = train.Criterion('test_head_out_c1', MseLoss(), 1.0) + + def __setup_train_criterions(self): + c1 = train.Criterion('c1', MseLoss(), 0.42) + c2 = train.Criterion('c2', L1Loss(), 0.7) + return train.CriterionGroup([c1, c2], 'test_head_out_') + + def training_step(self, batch: Batch, batch_idx): + loss_sum, all_lossvals = train.default_compute_loss( + self._model, [batch], self.current_epoch, self._train_criterions + ) + loss_val_by_name = { + name: val + for name, (val, _) in train.concatenated_lossvals_by_name( + itertools.chain.from_iterable(all_lossvals) + ).items() + } + self.log("loss", loss_sum, on_epoch=True, prog_bar=True, batch_size=batch.meta.batchsize) + return {"loss": loss_sum, "mt_losses": loss_val_by_name} + + def validation_step(self, batch: Batch, batch_idx: int) -> torch.Tensor | dict[str, Any] | None: + images = batch["image"] + pred = self._model(images) + values = self._test_criterion.evaluate(pred, batch, batch_idx) + val_loss = torch.cat([(lv.val * lv.weight) for lv in values]).sum() + self.log("val_loss", val_loss, on_epoch=True, batch_size=batch.meta.batchsize) + return values + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.model.parameters(), lr=1.0e-4) + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1.0e-4, total_steps=50) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + @property + def model(self): + return self._model + + +def test_train_smoketest(tmp_path): + batchsize = 32 + epochs = 50 + train_loader = DataLoader(CosineDataset(20), batch_size=batchsize, collate_fn=Batch.collate) + test_loader = DataLoader(CosineDataset(8), batch_size=batchsize, collate_fn=Batch.collate) + model = LitModel() + model_out_dir = os.path.join(tmp_path, 'models') + + checkpoint_cb = ModelCheckpoint( + save_top_k=1, + save_last=True, + monitor="val_loss", + enable_version_counter=False, + filename="best", + dirpath=model_out_dir, + save_weights_only=False, + ) + + progress_cb = train.SimpleProgressBar(batchsize) + visu_cb = train.MetricsGraphing() + callbacks = [visu_cb, checkpoint_cb, progress_cb, train.SwaCallback(start_epoch=epochs // 2)] + + trainer = pl.Trainer( + fast_dev_run=False, + gradient_clip_val=1.0, + gradient_clip_algorithm="norm", + default_root_dir=model_out_dir, + # limit_train_batches=((10 * 1024) // batchsize), + callbacks=callbacks, + enable_checkpointing=True, + max_epochs=epochs, + log_every_n_steps=1, + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=test_loader) + + visu_cb.close() + + assert os.path.isfile(tmp_path / 'models' / 'swa.ckpt') + assert os.path.isfile(tmp_path / 'models' / 'best.ckpt') + assert os.path.isfile(tmp_path / 'models' / 'train.pdf') diff --git a/trackertraincode/train.py b/trackertraincode/train.py index 5d150c3..65431c4 100644 --- a/trackertraincode/train.py +++ b/trackertraincode/train.py @@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict import numpy as np from os.path import join, isdir -from typing import Any, Dict, List, Sequence, Tuple, Union, Optional, Callable, NamedTuple +from typing import Any, Dict, List, Sequence, Tuple, Union, Optional, Callable, NamedTuple, Mapping, Protocol import tqdm import multiprocessing import queue @@ -14,21 +14,25 @@ import os import math +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback from torch import Tensor import torch import torch.nn as nn from torch.optim.lr_scheduler import LambdaLR, CyclicLR +from torch.optim.swa_utils import AveragedModel +from trackertraincode.neuralnets.io import save_model from trackertraincode.datasets.batch import Batch class LossVal(NamedTuple): - val : Tensor - weight : float - name : str + val: Tensor + weight: float + name: str -def concatenated_lossvals_by_name(vals : list[LossVal]): +def concatenated_lossvals_by_name(vals: list[LossVal]): '''Sorts by name and concatenates. Assumes that names can occur multiple times. Then corresponding weights and @@ -43,20 +47,18 @@ def concatenated_lossvals_by_name(vals : list[LossVal]): for v in vals: value_lists[v.name].append(v.val) weight_lists[v.name].append(v.weight) - return { - k:(torch.concat(value_lists[k]),torch.concat(weight_lists[k])) for k in value_lists - } + return {k: (torch.concat(value_lists[k]), torch.concat(weight_lists[k])) for k in value_lists} class Criterion(NamedTuple): - name : str - f : Callable[[Batch,Batch],Tensor] - w : Union[float,Callable[[int],float]] + name: str + f: Callable[[Batch, Batch], Tensor] + w: Union[float, Callable[[int], float]] def evaluate(self, pred, batch, step) -> List[LossVal]: - val = self.f(pred,batch) + val = self.f(pred, batch) w = self._eval_weight(step) - return [ LossVal(val, w, self.name) ] + return [LossVal(val, w, self.name)] def _eval_weight(self, step): if isinstance(self.w, float): @@ -66,9 +68,9 @@ def _eval_weight(self, step): class CriterionGroup(NamedTuple): - criterions : List[Union['CriterionGroup',Criterion]] - name : str = '' - w : Union[float,Callable[[int],float]] = 1.0 + criterions: List[Union['CriterionGroup', Criterion]] + name: str = '' + w: Union[float, Callable[[int], float]] = 1.0 def _eval_weight(self, step): if isinstance(self.w, float): @@ -79,39 +81,39 @@ def _eval_weight(self, step): def evaluate(self, pred, batch, step) -> List[LossVal]: w = self._eval_weight(step) lossvals = sum((c.evaluate(pred, batch, step) for c in self.criterions), start=[]) - lossvals = [ LossVal(v.val,v.weight*w,self.name+v.name) for v in lossvals ] + lossvals = [LossVal(v.val, v.weight * w, self.name + v.name) for v in lossvals] return lossvals @dataclasses.dataclass class History: - train : List[Any] = dataclasses.field(default_factory=list) - test : List[Any] = dataclasses.field(default_factory=list) - current_train_buffer : List[Any] = dataclasses.field(default_factory=list) - logplot : bool = True + train: List[Any] = dataclasses.field(default_factory=list) + test: List[Any] = dataclasses.field(default_factory=list) + current_train_buffer: List[Any] = dataclasses.field(default_factory=list) + logplot: bool = True class TrainHistoryPlotter(object): - def __init__(self, save_filename = None): + def __init__(self, save_filename=None): self.histories = defaultdict(History) self.queue = multiprocessing.Queue(maxsize=100) - self.plotting = multiprocessing.Process(None,self.run_plotting,args=(self.queue, save_filename)) + self.plotting = multiprocessing.Process(None, self.run_plotting, args=(self.queue, save_filename)) self.plotting.start() @staticmethod - def ensure_axes_are_ready(fig : pyplot.Figure, axes, last_rows, histories): + def ensure_axes_are_ready(fig: pyplot.Figure, axes, last_rows, histories): num_rows = len(histories) if num_rows != last_rows: if num_rows > 5: - r, c = (num_rows+1)//2, 2 + r, c = (num_rows + 1) // 2, 2 else: r, c = num_rows, 1 fig.clear() - fig.set_figheight(3*r) + fig.set_figheight(3 * r) axes = fig.subplots(r, c) if c > 1: axes = axes.ravel() - if num_rows==1: + if num_rows == 1: axes = [axes] else: for ax in axes: @@ -144,8 +146,8 @@ def update_actual_graphs(histories, fig, axes, num_rows): fig, axes, num_rows = TrainHistoryPlotter.ensure_axes_are_ready(fig, axes, num_rows, histories) for ax, (name, history) in zip(axes, histories.items()): if name == 'lr': - t, lr = np.array(history.test).T - ax.plot(t, lr, label = 'lr', marker='o', color='k') + t, lr = np.array(history.test).T + ax.plot(t, lr, label='lr', marker='o', color='k') ax.set(yscale='log') ax.grid(axis='y', which='both') ax.legend() @@ -160,7 +162,7 @@ def update_actual_graphs(histories, fig, axes, num_rows): t, x, xerr = np.array(history.train).T ax.errorbar(t, x, yerr=xerr, label=name, color='r') if history.test: - ax.plot(*np.array(history.test).T, label='test '+name, marker='x', color='b') + ax.plot(*np.array(history.test).T, label='test ' + name, marker='x', color='b') # FIXME: Hack with `startswith('nll')` if history.logplot and (not name.startswith('nll')) and (not name == 'loss'): ax.set(yscale='log') @@ -191,10 +193,10 @@ def add_train_point(self, epoch, step, name, value): self.histories[name].current_train_buffer.append((epoch, value)) def add_test_point(self, epoch, name, value): - self.histories[name].test.append((epoch,value)) + self.histories[name].test.append((epoch, value)) @staticmethod - def summarize_single_train_history(k, h : History): + def summarize_single_train_history(k, h: History): if not h.current_train_buffer: return epochs, values = zip(*h.current_train_buffer) @@ -206,11 +208,12 @@ def summarize_single_train_history(k, h : History): h.train.append((np.average(epochs), np.average(values), np.std(values))) except FloatingPointError: with np.printoptions(precision=4, suppress=True, threshold=20000): - print (f"Floating point error at {k} in epochs {np.average(epochs)} with values:\n {str(values)} of which there are {len(values)}\n") + print( + f"Floating point error at {k} in epochs {np.average(epochs)} with values:\n {str(values)} of which there are {len(values)}\n" + ) h.train.append((np.average(epochs), np.nan, np.nan)) h.current_train_buffer = [] - def summarize_train_values(self): for k, h in self.histories.items(): TrainHistoryPlotter.summarize_single_train_history(k, h) @@ -223,6 +226,7 @@ def update_graph(self): def close(self): self.queue.put(None) + self.plotting.join() class ConsoleTrainOutput(object): @@ -233,14 +237,14 @@ def add_train_point(self, epoch, step, name, value): self.histories[name].current_train_buffer.append((epoch, value)) def add_test_point(self, epoch, name, value): - self.histories[name].test.append((epoch,value)) + self.histories[name].test.append((epoch, value)) def summarize_train_values(self): for k, h in self.histories.items(): TrainHistoryPlotter.summarize_single_train_history(k, h) def update_graph(self): - print ("Losses:") + print("Losses:") for name, h in self.histories.items(): if h.train: epoch, mean, std = h.train[-1] @@ -252,29 +256,30 @@ def update_graph(self): test_str = f'{val:.4f}' else: test_str = '----' - print (f"{name}: Train: {train_str}, Test: {test_str}") + print(f"{name}: Train: {train_str}, Test: {test_str}") h.test = [] h.train = [] def close(self): pass + class DebugData(NamedTuple): - parameters : dict[str,Tensor] - batches : list[Batch] - preds : dict[str,Tensor] - lossvals : list[list[LossVal]] + parameters: dict[str, Tensor] + batches: list[Batch] + preds: dict[str, Tensor] + lossvals: list[list[LossVal]] def is_bad(self): '''Checks data for badness. - + Currently NANs and input value range. - + Return: True if so. ''' - #TODO: decouple for name of input tensor - for k,v in self.parameters.items(): + # TODO: decouple for name of input tensor + for k, v in self.parameters.items(): if torch.any(torch.isnan(v)): print(f"{k} is NAN") return True @@ -284,10 +289,12 @@ def is_bad(self): print(f"{k} is NAN") return True inputs = b['image'] - if torch.amin(inputs)<-2. or torch.amax(inputs)>2.: - print(f"Input image {inputs.shape} exceeds value limits with {torch.amin(inputs)} to {torch.amax(inputs)}") + if torch.amin(inputs) < -2.0 or torch.amax(inputs) > 2.0: + print( + f"Input image {inputs.shape} exceeds value limits with {torch.amin(inputs)} to {torch.amax(inputs)}" + ) return True - for k,v in self.preds.items(): + for k, v in self.preds.items(): if torch.any(torch.isnan(v)): print(f"{k} is NAN") return True @@ -298,29 +305,34 @@ def is_bad(self): return True return False -class DebugCallback(): + +class DebugCallback: '''For dumping a history of stuff when problems are detected.''' + def __init__(self): self.history_length = 3 - self.debug_data : List[DebugData] = [] + self.debug_data: List[DebugData] = [] self.filename = '/tmp/notgood.pkl' - - def observe(self, net_pre_update : nn.Module, batches : list[Batch], preds : dict[str,Tensor], lossvals : list[list[LossVal]]): + + def observe( + self, net_pre_update: nn.Module, batches: list[Batch], preds: dict[str, Tensor], lossvals: list[list[LossVal]] + ): '''Record and check. Args: batches: Actually sub-batches lossvals: One list of loss terms per sub-batch ''' dd = DebugData( - {k:v.detach().to('cpu', non_blocking=True,copy=True) for k,v in net_pre_update.state_dict().items()}, - [b.to('cpu', non_blocking=True,copy=True) for b in batches ], - {k:v.detach().to('cpu', non_blocking=True,copy=True) for k,v in preds.items()}, - lossvals + {k: v.detach().to('cpu', non_blocking=True, copy=True) for k, v in net_pre_update.state_dict().items()}, + [b.to('cpu', non_blocking=True, copy=True) for b in batches], + {k: v.detach().to('cpu', non_blocking=True, copy=True) for k, v in preds.items()}, + lossvals, ) if len(self.debug_data) >= self.history_length: self.debug_data.pop(0) self.debug_data.append(dd) - torch.cuda.current_stream().synchronize() + if torch.cuda.is_available(): + torch.cuda.current_stream().synchronize() if dd.is_bad(): with open(self.filename, 'wb') as f: pickle.dump(self.debug_data, f) @@ -330,7 +342,17 @@ def observe(self, net_pre_update : nn.Module, batches : list[Batch], preds : dic # g_debug = DebugCallback() -def default_compute_loss(net, batch : List[Batch], current_epoch : int, loss : dict[Any, Criterion | CriterionGroup] | Criterion | CriterionGroup): +def default_compute_loss( + net, + batch: List[Batch], + current_epoch: int, + loss: dict[Any, Criterion | CriterionGroup] | Criterion | CriterionGroup, +): + """ + Return: + Loss sum for backprop + LossVals - one nested list per batch item. Tensors transfered to the cpu. + """ # global g_debug inputs = torch.concat([b['image'] for b in batch], dim=0) @@ -338,26 +360,30 @@ def default_compute_loss(net, batch : List[Batch], current_epoch : int, loss : d preds = net(inputs) lossvals_by_name = defaultdict(list) - all_lossvals : list[list[LossVal]] = [] + all_lossvals: list[list[LossVal]] = [] # Iterate over different datasets / loss configurations offset = 0 for subset in batch: - frames_in_subset, = subset.meta.prefixshape - subpreds = { k:v[offset:offset+frames_in_subset,...] for k,v in preds.items() } + (frames_in_subset,) = subset.meta.prefixshape + subpreds = {k: v[offset : offset + frames_in_subset, ...] for k, v in preds.items()} # Get loss function and evaluate - loss_func_of_subset : Union[Criterion,CriterionGroup] = loss[subset.meta.tag] if isinstance(loss, dict) else loss - multi_task_terms : List[LossVal] = loss_func_of_subset.evaluate(subpreds, subset, current_epoch) + loss_func_of_subset: Union[Criterion, CriterionGroup] = ( + loss[subset.meta.tag] if isinstance(loss, dict) else loss + ) + multi_task_terms: List[LossVal] = loss_func_of_subset.evaluate(subpreds, subset, current_epoch) # Support loss weighting by datasets if 'dataset_weight' in subset: dataset_weight = subset['dataset_weight'] assert dataset_weight.size(0) == subset.meta.batchsize - multi_task_terms = [ v._replace(weight=v.weight*dataset_weight) for v in multi_task_terms ] + multi_task_terms = [v._replace(weight=v.weight * dataset_weight) for v in multi_task_terms] else: # Else, make the weight member a tensor the same shape as the loss values - multi_task_terms = [ v._replace(weight=v.val.new_full(size=v.val.shape,fill_value=v.weight)) for v in multi_task_terms ] + multi_task_terms = [ + v._replace(weight=v.val.new_full(size=v.val.shape, fill_value=v.weight)) for v in multi_task_terms + ] all_lossvals.append(multi_task_terms) del multi_task_terms, loss_func_of_subset @@ -365,55 +391,182 @@ def default_compute_loss(net, batch : List[Batch], current_epoch : int, loss : d offset += frames_in_subset batchsize = sum(subset.meta.batchsize for subset in batch) - # Concatenate the loss values over the sub-batches. + # Concatenate the loss values over the sub-batches. lossvals_by_name = concatenated_lossvals_by_name(itertools.chain.from_iterable(all_lossvals)) # Compute weighted average, dividing by the batch size which is equivalent to substituting missing losses by 0. - loss_sum = torch.concat([ (values*weights) for values,weights in lossvals_by_name.values() ]).sum() / batchsize + loss_sum = torch.concat([(values * weights) for values, weights in lossvals_by_name.values()]).sum() / batchsize # Transfer to CPU for loss_list in all_lossvals: for i, v in enumerate(loss_list): - loss_list[i] = v._replace(val = v.val.detach().to('cpu', non_blocking=True)) - - torch.cuda.current_stream().synchronize() + loss_list[i] = v._replace(val=v.val.detach().to('cpu', non_blocking=True)) + if torch.cuda.is_available(): + torch.cuda.current_stream().synchronize() return loss_sum, all_lossvals +class LightningModelWrapper(Protocol): + @property + def model() -> nn.Module: ... + + +class SwaCallback(Callback): + def __init__(self, start_epoch): + super().__init__() + self._swa_model: AveragedModel | None = None + self._start_epoch = start_epoch + + @property + def swa_model(self): + return self._swa_model.module + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._swa_model = AveragedModel(pl_module.model, device="cpu", use_buffers=True) + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if trainer.current_epoch > self._start_epoch: + self._swa_model.update_parameters(pl_module.model) + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + assert self._swa_model is not None + swa_filename = join(trainer.default_root_dir, f"swa.ckpt") + save_model(self._swa_model.module, swa_filename) + + +class MetricsGraphing(Callback): + def __init__(self): + super().__init__() + self._visu: TrainHistoryPlotter | None = None + self._metrics_accumulator = defaultdict(list) + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + assert self._visu is None + self._visu = TrainHistoryPlotter(save_filename=join(trainer.default_root_dir, "train.pdf")) + + def on_train_batch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int + ): + mt_losses: dict[str, torch.Tensor] = outputs["mt_losses"] + for k, v in mt_losses.items(): + self._visu.add_train_point(trainer.current_epoch, batch_idx, k, v.numpy()) + self._visu.add_train_point(trainer.current_epoch, batch_idx, "loss", outputs["loss"].detach().cpu().numpy()) + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if trainer.lr_scheduler_configs: # scheduler is not None: + scheduler = next( + iter(trainer.lr_scheduler_configs) + ).scheduler # Pick the first scheduler (and there should only be one) + last_lr = next(iter(scheduler.get_last_lr())) # LR from the first parameter group + self._visu.add_test_point(trainer.current_epoch, "lr", last_lr) + + self._visu.summarize_train_values() + + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._metrics_accumulator = defaultdict(list) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[LossVal], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + for val in outputs: + self._metrics_accumulator[val.name].append(val.val) + + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if self._visu is None: + return + for k, v in self._metrics_accumulator.items(): + self._visu.add_test_point(trainer.current_epoch - 1, k, torch.cat(v).mean().cpu().numpy()) + if trainer.current_epoch > 0: + self._visu.update_graph() + + def close(self): + self._visu.close() + + +class SimpleProgressBar(Callback): + """Creates progress bars for total training time and progress of per epoch.""" + + def __init__(self, batchsize: int): + super().__init__() + self._bar: tqdm.tqdm | None = None + self._epoch_bar: tqdm.tqdm | None = None + self._batchsize = batchsize + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._bar = tqdm.tqdm(total=trainer.max_epochs, desc='Training', position=0) + self._epoch_bar = tqdm.tqdm(total=trainer.num_training_batches * self._batchsize, desc="Epoch", position=1) + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._bar.close() + self._epoch_bar.close() + self._bar = None + self._epoch_bar = None + + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_bar.reset(self._epoch_bar.total) + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._bar.update(1) + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Mapping[str, Any], + batch: list[Batch] | Batch, + batch_idx: int, + ) -> None: + n = sum(b.meta.batchsize for b in batch) if isinstance(batch, list) else batch.meta.batchsize + self._epoch_bar.update(n) + + ########################################## ## Schedules ########################################## + def TriangularSchedule(optimizer, min_lr, lr, num_steps, *args, **kwargs): - num_steps_up = min(max(1,num_steps*3//10), 33) + num_steps_up = min(max(1, num_steps * 3 // 10), 33) num_steps_down = num_steps - num_steps_up - return CyclicLR(optimizer, min_lr, lr, num_steps_up, num_steps_down, *args, mode='triangular', cycle_momentum=False, **kwargs) + return CyclicLR( + optimizer, min_lr, lr, num_steps_up, num_steps_down, *args, mode='triangular', cycle_momentum=False, **kwargs + ) def LinearUpThenSteps(optimizer, num_up, gamma, steps): steps = [0] + steps + def lr_func(i): if i < num_up: - return ((i+1)/num_up) + return (i + 1) / num_up else: - step_index = [j for j,step in enumerate(steps) if i>step][-1] + step_index = [j for j, step in enumerate(steps) if i > step][-1] return gamma**step_index + return LambdaLR(optimizer, lr_func) def ExponentialUpThenSteps(optimizer, num_up, gamma, steps): steps = [0] + steps + def lr_func(i): - eps = 1.e-2 + eps = 1.0e-2 scale = math.log(eps) if i < num_up: - f = ((i+1)/num_up) - #return torch.sigmoid((f - 0.5) * 15.) + f = (i + 1) / num_up + # return torch.sigmoid((f - 0.5) * 15.) # a * exp(f / l) | f=1 == 1. # a * exp(f / l) | f=0 ~= eps # => a = eps # => ln(1./eps) = 1./l - return eps * math.exp(-scale*f) + return eps * math.exp(-scale * f) else: - step_index = [j for j,step in enumerate(steps) if i>step][-1] + step_index = [j for j, step in enumerate(steps) if i > step][-1] return gamma**step_index - return LambdaLR(optimizer, lr_func) \ No newline at end of file + + return LambdaLR(optimizer, lr_func)