Skip to content

Commit

Permalink
Improve train code:
Browse files Browse the repository at this point in the history
* Move general stuff from pose estimator training to train.py
* Autoformat
* Add smoketest back in
  • Loading branch information
DaWelter committed Dec 7, 2024
1 parent 320b39c commit 8eed347
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 233 deletions.
21 changes: 5 additions & 16 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
#!/bin/bash

# python scripts/train_poseestimator.py --lr 1.e-3 --epochs 500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
# --save-plot train.pdf \
# --with-swa \
# --with-nll-loss \
# --roi-override original \
# --no-onnx \
# --backbone mobilenetv1 \
# --outdir model_files/



#--rampup_nll_losses \

python scripts/train_poseestimator_lightning.py --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
--epochs 10 \
python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
--with-swa \
--with-nll-loss \
--rampup-nll-losses
--backbone hybrid_vit \
--rampup-nll-losses

# --outdir model_files/current/run0/
137 changes: 6 additions & 131 deletions scripts/train_poseestimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
import tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint

# from pytorch_lightning.loggers import Logger,
from pytorch_lightning.utilities import rank_zero_only
import torch.optim as optim
import torch
import torch.nn as nn
Expand All @@ -32,8 +30,6 @@
import trackertraincode.train as train
import trackertraincode.pipelines

from trackertraincode.neuralnets.io import complement_lightning_checkpoint
from scripts.export_model import convert_posemodel_onnx
from trackertraincode.datasets.batch import Batch
from trackertraincode.pipelines import Tag

Expand Down Expand Up @@ -161,11 +157,6 @@ def create_optimizer(net, args: MyArgs):
return optimizer, scheduler


class SaveBestSpec(NamedTuple):
weights: List[float]
names: List[str]


def setup_losses(args: MyArgs, net):
C = train.Criterion
cregularize = [
Expand Down Expand Up @@ -259,9 +250,7 @@ def wrapped(step):
),
}

savebest = SaveBestSpec([1.0, 1.0, 1.0], ["rot", "xy", "sz"])

return train_criterions, test_criterions, savebest
return train_criterions, test_criterions


def create_net(args: MyArgs):
Expand All @@ -281,7 +270,7 @@ def __init__(self, args: MyArgs):
super().__init__()
self._args = args
self._model = create_net(args)
train_criterions, test_criterions, savebest = setup_losses(args, self._model)
train_criterions, test_criterions = setup_losses(args, self._model)
self._train_criterions = train_criterions
self._test_criterions = test_criterions

Expand Down Expand Up @@ -315,120 +304,6 @@ def model(self):
return self._model


class SwaCallback(Callback):
def __init__(self, start_epoch):
super().__init__()
self._swa_model: optim.swa_utils.AveragedModel | None = None
self._start_epoch = start_epoch

@property
def swa_model(self):
return self._swa_model.module

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert isinstance(pl_module, LitModel)
self._swa_model = optim.swa_utils.AveragedModel(pl_module.model, device="cpu", use_buffers=True)

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert isinstance(pl_module, LitModel)
if trainer.current_epoch > self._start_epoch:
self._swa_model.update_parameters(pl_module.model)

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert self._swa_model is not None
swa_filename = join(trainer.default_root_dir, f"swa.ckpt")
models.save_model(self._swa_model.module, swa_filename)


class MetricsGraphing(Callback):
def __init__(self):
super().__init__()
self._visu: train.TrainHistoryPlotter | None = None
self._metrics_accumulator = defaultdict(list)

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert self._visu is None
self._visu = train.TrainHistoryPlotter(save_filename=join(trainer.default_root_dir, "train.pdf"))

def on_train_batch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int
):
mt_losses: dict[str, torch.Tensor] = outputs["mt_losses"]
for k, v in mt_losses.items():
self._visu.add_train_point(trainer.current_epoch, batch_idx, k, v.numpy())
self._visu.add_train_point(trainer.current_epoch, batch_idx, "loss", outputs["loss"].detach().cpu().numpy())

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if trainer.lr_scheduler_configs: # scheduler is not None:
scheduler = next(
iter(trainer.lr_scheduler_configs)
).scheduler # Pick the first scheduler (and there should only be one)
last_lr = next(iter(scheduler.get_last_lr())) # LR from the first parameter group
self._visu.add_test_point(trainer.current_epoch, "lr", last_lr)

self._visu.summarize_train_values()

def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._metrics_accumulator = defaultdict(list)

def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list[train.LossVal],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
for val in outputs:
self._metrics_accumulator[val.name].append(val.val)

def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if self._visu is None:
return
for k, v in self._metrics_accumulator.items():
self._visu.add_test_point(trainer.current_epoch - 1, k, torch.cat(v).mean().cpu().numpy())
if trainer.current_epoch > 0:
self._visu.update_graph()


class SimpleProgressBar(Callback):
"""Creates progress bars for total training time and progress of per epoch."""

def __init__(self, batchsize: int):
super().__init__()
self._bar: tqdm.tqdm | None = None
self._epoch_bar: tqdm.tqdm | None = None
self._batchsize = batchsize

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._bar = tqdm.tqdm(total=trainer.max_epochs, desc='Training', position=0)
self._epoch_bar = tqdm.tqdm(total=trainer.num_training_batches * self._batchsize, desc="Epoch", position=1)

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._bar.close()
self._epoch_bar.close()
self._bar = None
self._epoch_bar = None

def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._epoch_bar.reset(self._epoch_bar.total)

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._bar.update(1)

def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Mapping[str, Any],
batch: list[Batch] | Batch,
batch_idx: int,
) -> None:
n = sum(b.meta.batchsize for b in batch) if isinstance(batch, list) else batch.meta.batchsize
self._epoch_bar.update(n)


def main():
np.seterr(all="raise")
cv2.setNumThreads(1)
Expand Down Expand Up @@ -499,13 +374,13 @@ def main():
save_weights_only=False,
)

progress_cb = SimpleProgressBar(args.batchsize)
progress_cb = train.SimpleProgressBar(args.batchsize)

callbacks = [MetricsGraphing(), checkpoint_cb, progress_cb]
callbacks = [train.MetricsGraphing(), checkpoint_cb, progress_cb]

swa_callback = None
if args.swa:
swa_callback = SwaCallback(start_epoch=args.epochs * 2 // 3)
swa_callback = train.SwaCallback(start_epoch=args.epochs * 2 // 3)
callbacks.append(swa_callback)

# TODO: inf norm?
Expand Down
144 changes: 138 additions & 6 deletions test/test_train.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,154 @@
from torch.utils.data import Dataset, DataLoader
import time
import torch
from torch import nn
import numpy as np
import os
import functools
from typing import List
from trackertraincode.datasets.batch import Batch, Metadata
from typing import List, Any
import itertools
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import matplotlib
import matplotlib.pyplot
import time

from trackertraincode.datasets.batch import Batch, Metadata
import trackertraincode.train as train


def test_plotter():
plotter = train.TrainHistoryPlotter()
names = [ 'foo', 'bar', 'baz', 'lr' ]
names = ['foo', 'bar', 'baz', 'lr']
for e in range(4):
for t in range(5):
for name in names[:-2]:
plotter.add_train_point(e, t, name, 10. + e + np.random.normal(0., 1.,(1,)))
plotter.add_train_point(e, t, name, 10.0 + e + np.random.normal(0.0, 1.0, (1,)))
for name in names[1:]:
plotter.add_test_point(e, name, 9. + e + np.random.normal())
plotter.add_test_point(e, name, 9.0 + e + np.random.normal())
plotter.summarize_train_values()
plotter.update_graph()
plotter.close()
plotter.close()


class MseLoss(object):
def __call__(self, pred, batch):
return torch.nn.functional.mse_loss(pred['test_head_out'], batch['y'], reduction='none')


class L1Loss(object):
def __call__(self, pred, batch):
return torch.nn.functional.l1_loss(pred['test_head_out'], batch['y'], reduction='none')


class CosineDataset(Dataset):
def __init__(self, n):
self.n = n

def __len__(self):
return self.n

def __getitem__(self, i):
x = torch.rand((1,))
y = torch.cos(x)
return Batch(Metadata(0, batchsize=0), {'image': x, 'y': y})


class MockupModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(torch.nn.Linear(1, 128), torch.nn.ReLU(), torch.nn.Linear(128, 1))

def forward(self, x: torch.Tensor):
return {'test_head_out': self.layers(x)}

def get_config(self):
return {}


class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self._model = MockupModel()
self._train_criterions = self.__setup_train_criterions()
self._test_criterion = train.Criterion('test_head_out_c1', MseLoss(), 1.0)

def __setup_train_criterions(self):
c1 = train.Criterion('c1', MseLoss(), 0.42)
c2 = train.Criterion('c2', L1Loss(), 0.7)
return train.CriterionGroup([c1, c2], 'test_head_out_')

def training_step(self, batch: Batch, batch_idx):
loss_sum, all_lossvals = train.default_compute_loss(
self._model, [batch], self.current_epoch, self._train_criterions
)
loss_val_by_name = {
name: val
for name, (val, _) in train.concatenated_lossvals_by_name(
itertools.chain.from_iterable(all_lossvals)
).items()
}
self.log("loss", loss_sum, on_epoch=True, prog_bar=True, batch_size=batch.meta.batchsize)
return {"loss": loss_sum, "mt_losses": loss_val_by_name}

def validation_step(self, batch: Batch, batch_idx: int) -> torch.Tensor | dict[str, Any] | None:
images = batch["image"]
pred = self._model(images)
values = self._test_criterion.evaluate(pred, batch, batch_idx)
val_loss = torch.cat([(lv.val * lv.weight) for lv in values]).sum()
self.log("val_loss", val_loss, on_epoch=True, batch_size=batch.meta.batchsize)
return values

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1.0e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1.0e-4, total_steps=50)
return {"optimizer": optimizer, "lr_scheduler": scheduler}

@property
def model(self):
return self._model


def test_train_smoketest(tmp_path):
batchsize = 32
epochs = 50
train_loader = DataLoader(CosineDataset(20), batch_size=batchsize, collate_fn=Batch.collate)
test_loader = DataLoader(CosineDataset(8), batch_size=batchsize, collate_fn=Batch.collate)
model = LitModel()
model_out_dir = os.path.join(tmp_path, 'models')

checkpoint_cb = ModelCheckpoint(
save_top_k=1,
save_last=True,
monitor="val_loss",
enable_version_counter=False,
filename="best",
dirpath=model_out_dir,
save_weights_only=False,
)

progress_cb = train.SimpleProgressBar(batchsize)
visu_cb = train.MetricsGraphing()
callbacks = [visu_cb, checkpoint_cb, progress_cb, train.SwaCallback(start_epoch=epochs // 2)]

trainer = pl.Trainer(
fast_dev_run=False,
gradient_clip_val=1.0,
gradient_clip_algorithm="norm",
default_root_dir=model_out_dir,
# limit_train_batches=((10 * 1024) // batchsize),
callbacks=callbacks,
enable_checkpointing=True,
max_epochs=epochs,
log_every_n_steps=1,
logger=False,
enable_progress_bar=False,
)

trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=test_loader)

visu_cb.close()

assert os.path.isfile(tmp_path / 'models' / 'swa.ckpt')
assert os.path.isfile(tmp_path / 'models' / 'best.ckpt')
assert os.path.isfile(tmp_path / 'models' / 'train.pdf')
Loading

0 comments on commit 8eed347

Please sign in to comment.