Skip to content

Commit

Permalink
training ncaltech101
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgehrig18 committed Sep 2, 2024
1 parent edd6041 commit 3c9c717
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 8 deletions.
36 changes: 36 additions & 0 deletions config/dagr-l-ncaltech.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
path: "/data/storage/daniel/aegnn"
output_directory: "/data/storage/daniel/aegnn/logs"
pooling_dim_at_output: 5x7

task: detection
dataset: ncaltech101

# network
radius: 0.01
time_window_us: 1000000
max_neighbors: 16
n_nodes: 50000

batch_size: 64

activation: relu
edge_attr_dim: 2
aggr: sum
kernel_size: 5
pooling_aggr: max

base_width: 0.5
after_pool_width: 1
net_stem_width: 1
yolo_stem_width: 1
num_scales: 1

# learning
weight_decay: 0.00001
clip: 0.1

aug_trans: 0.1
aug_p_flip: 0
aug_zoom: 1
l_r: 0.001
tot_num_epochs: 801
182 changes: 182 additions & 0 deletions scripts/train_ncaltech101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# avoid matlab error on server
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

import torch
import tqdm
import wandb
from pathlib import Path
import argparse

from torch_geometric.data import DataLoader

from dagr.utils.logging import Checkpointer, set_up_logging_directory, log_hparams
from dagr.utils.buffers import DetectionBuffer
from dagr.utils.args import FLAGS
from dagr.utils.learning_rate_scheduler import LRSchedule

from dagr.data.augment import Augmentations
from dagr.utils.buffers import format_data
from dagr.data.ncaltech101_data import NCaltech101

from dagr.model.networks.dagr import DAGR
from dagr.model.networks.ema import ModelEMA

def gradients_broken(model):
valid_gradients = True
for name, param in model.named_parameters():
if param.grad is not None:
# valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
valid_gradients = not (torch.isnan(param.grad).any())
if not valid_gradients:
break
return not valid_gradients

def fix_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
param.grad = torch.nan_to_num(param.grad, nan=0.0)


def train(loader: DataLoader,
model: torch.nn.Module,
ema: ModelEMA,
scheduler: torch.optim.lr_scheduler.LambdaLR,
optimizer: torch.optim.Optimizer,
args: argparse.ArgumentParser,
run_name=""):

model.train()

for i, data in enumerate(tqdm.tqdm(loader, desc=f"Training {run_name}")):
data = data.cuda(non_blocking=True)
data = format_data(data)

optimizer.zero_grad(set_to_none=True)

model_outputs = model(data)

loss_dict = {k: v for k, v in model_outputs.items() if "loss" in k}
loss = loss_dict.pop("total_loss")

loss.backward()

torch.nn.utils.clip_grad_value_(model.parameters(), args.clip)

fix_gradients(model)

optimizer.step()
scheduler.step()

ema.update(model)

training_logs = {f"training/loss/{k}": v for k, v in loss_dict.items()}
wandb.log({"training/loss": loss.item(), "training/lr": scheduler.get_last_lr()[-1], **training_logs})

def run_test(loader: DataLoader,
model: torch.nn.Module,
dry_run_steps: int=-1,
dataset="gen1"):

model.eval()

mapcalc = DetectionBuffer(height=loader.dataset.height, width=loader.dataset.width, classes=loader.dataset.classes)

for i, data in enumerate(tqdm.tqdm(loader)):
data = data.cuda()
data = format_data(data)

detections, targets = model(data)
if i % 10 == 0:
torch.cuda.empty_cache()

mapcalc.update(detections, targets, dataset, data.height[0], data.width[0])

if dry_run_steps > 0 and i == dry_run_steps:
break

torch.cuda.empty_cache()

return mapcalc

if __name__ == '__main__':
import torch_geometric
import random
import numpy as np

seed = 42
torch_geometric.seed.seed_everything(seed)
torch.random.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

args = FLAGS()

output_directory = set_up_logging_directory(args.dataset, args.task, args.output_directory, exp_name=args.exp_name)
log_hparams(args)

augmentations = Augmentations(args)

print("init datasets")
dataset_path = args.dataset_directory / args.dataset

train_dataset = NCaltech101(dataset_path, "training", augmentations.transform_training, num_events=args.n_nodes)
test_dataset = NCaltech101(dataset_path, "validation", augmentations.transform_testing, num_events=args.n_nodes)


train_loader = DataLoader(train_dataset, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
num_iters_per_epoch = len(train_loader)

sampler = np.random.permutation(np.arange(len(test_dataset)))
test_loader = DataLoader(test_dataset, sampler=sampler, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)

print("init net")
# load a dummy sample to get height, width
model = DAGR(args, height=test_dataset.height, width=test_dataset.width)

num_params = sum([np.prod(p.size()) for p in model.parameters()])
print(f"Training with {num_params} number of parameters.")

model = model.cuda()
ema = ModelEMA(model)

nominal_batch_size = 64
lr = args.l_r * np.sqrt(args.batch_size) / np.sqrt(nominal_batch_size)
optimizer = torch.optim.AdamW(list(model.parameters()), lr=lr, weight_decay=args.weight_decay)

lr_func = LRSchedule(warmup_epochs=.3,
num_iters_per_epoch=num_iters_per_epoch,
tot_num_epochs=args.tot_num_epochs)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_func)

checkpointer = Checkpointer(output_directory=output_directory,
model=model, optimizer=optimizer,
scheduler=lr_scheduler, ema=ema,
args=args)

start_epoch = checkpointer.restore_if_existing(output_directory, resume_from_best=False)

start_epoch = 0
if "resume_checkpoint" in args:
start_epoch = checkpointer.restore_checkpoint(args.resume_checkpoint, best=False)
print(f"Resume from checkpoint at epoch {start_epoch}")

with torch.no_grad():
mapcalc = run_test(test_loader, ema.ema, dry_run_steps=2, dataset=args.dataset)
mapcalc.compute()

print("starting to train")
for epoch in range(start_epoch, args.tot_num_epochs):
train(train_loader, model, ema, lr_scheduler, optimizer, args, run_name=wandb.run.name)
checkpointer.checkpoint(epoch, name=f"last_model")

if epoch % 3 > 0:
continue

with torch.no_grad():
mapcalc = run_test(test_loader, ema.ema, dataset=args.dataset)
metrics = mapcalc.compute()
checkpointer.process(metrics, epoch)

6 changes: 3 additions & 3 deletions src/dagr/model/networks/dagr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dagr.model.networks.net import Net
from dagr.model.layers.spline_conv import SplineConvToDense
from dagr.model.layers.conv import ConvBlock
from dagr.model.utils import shallow_copy, init_subnetwork, voxel_size_to_params, postprocess_network_output, convert_to_evaluation_format, init_grid_and_stride
from dagr.model.utils import shallow_copy, init_subnetwork, voxel_size_to_params, postprocess_network_output, convert_to_evaluation_format, init_grid_and_stride, convert_to_training_format


class DAGR(YOLOX):
Expand Down Expand Up @@ -75,10 +75,10 @@ def forward(self, x: Data, reset=True, return_targets=True, filtering=True):
self.head.output_sizes = self.backbone.get_output_sizes()

if self.training:
targets = self.convert_to_training_format(x.bbox, x.bbox_batch, x.num_graphs)
targets = convert_to_training_format(x.bbox, x.bbox_batch, x.num_graphs)

if self.backbone.use_image:
targets0 = self.convert_to_training_format(x.bbox0, x.bbox0_batch, x.num_graphs)
targets0 = convert_to_training_format(x.bbox0, x.bbox0_batch, x.num_graphs)
targets = (targets, targets0)

# gt_target inputs need to be [l cx cy w h] in pixels
Expand Down
4 changes: 3 additions & 1 deletion src/dagr/model/networks/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def compute_pooling_at_each_layer(pooling_dim_at_output, num_layers):
poolings = torch.stack(poolings)
return poolings


class Net(torch.nn.Module):
def __init__(self, args, height, width):
super().__init__()
Expand All @@ -51,7 +52,7 @@ def __init__(self, args, height, width):
self.use_image = args.use_image
self.num_scales = args.num_scales

self.num_classes = dict(dsec=2).get(args.dataset, 2)
self.num_classes = dict(dsec=2, ncaltech101=100).get(args.dataset, 2)

self.events_to_graph = EV_TGN(args)

Expand All @@ -66,6 +67,7 @@ def __init__(self, args, height, width):
poolings = compute_pooling_at_each_layer(args.pooling_dim_at_output, num_layers=4)
max_vals_for_cartesian = 2*poolings[:,:2].max(-1).values
self.strides = torch.ceil(poolings[-2:,1] * height).numpy().astype("int32").tolist()
self.strides = self.strides[-self.num_scales:]

effective_radius = 2*float(int(args.radius * width + 2) / width)
self.edge_attrs = Cartesian(norm=True, cat=False, max_value=effective_radius)
Expand Down
1 change: 1 addition & 0 deletions src/dagr/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def FLAGS():
# learning params
parser.add_argument('--aug_trans', default=argparse.SUPPRESS, type=float)
parser.add_argument('--aug_zoom', default=argparse.SUPPRESS, type=float)
parser.add_argument('--exp_name', default=argparse.SUPPRESS, type=str)
parser.add_argument('--l_r', default=argparse.SUPPRESS, type=float)
parser.add_argument('--no_eval', action="store_true")
parser.add_argument('--tot_num_epochs', default=argparse.SUPPRESS, type=int)
Expand Down
94 changes: 90 additions & 4 deletions src/dagr/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,103 @@
from torch_geometric.data import Data


def set_up_logging_directory(dataset, task, output_directory):
class Checkpointer:
def __init__(self, output_directory: Optional[Path] = None, args=None, optimizer=None, scheduler=None, ema=None, model=None):
self.optimizer = optimizer
self.scheduler = scheduler
self.ema = ema
self.model = model

self.mAP_max = 0
self.output_directory = output_directory
self.args = args

def restore_if_existing(self, folder, resume_from_best=False):
checkpoint = self.search_for_checkpoint(folder, best=resume_from_best)
if checkpoint is not None:
print(f"Found existing checkpoint at {checkpoint}, resuming...")
self.restore_checkpoint(folder, best=resume_from_best)

def mAP_from_checkpoint_name(self, checkpoint_name: Path):
return float(str(checkpoint_name).split("_")[-1].split(".pth")[0])

def search_for_checkpoint(self, resume_checkpoint: Path, best=False):
checkpoints = list(resume_checkpoint.glob("*.pth"))
if len(checkpoints) == 0:
return None

if not best:
if resume_checkpoint / "last_model.pth" in checkpoints:
return resume_checkpoint / "last_model.pth"

# remove "last_model.pth" from checkpoints
if resume_checkpoint / "last_model.pth" in checkpoints:
checkpoints.remove(resume_checkpoint / "last_model.pth")

checkpoints = sorted(checkpoints, key=lambda x: self.mAP_from_checkpoint_name(x.name))
return checkpoints[-1]


def restore_if_not_none(self, target, source):
if target is not None:
target.load_state_dict(source)

def restore_checkpoint(self, checkpoint_directory, best=False):
path = self.search_for_checkpoint(checkpoint_directory, best)
assert path is not None, "No checkpoint found in {}".format(checkpoint_directory)
print("Restoring checkpoint from {}".format(path))
checkpoint = torch.load(path)

checkpoint['model'] = self.fix_checkpoint(checkpoint['model'])
checkpoint['ema'] = self.fix_checkpoint(checkpoint['ema'])

if self.ema is not None:
self.ema.ema.load_state_dict(checkpoint.get('ema', checkpoint['model']))
self.ema.updates = checkpoint.get('ema_updates', 0)
self.restore_if_not_none(self.model, checkpoint['model'])
self.restore_if_not_none(self.optimizer, checkpoint['optimizer'])
self.restore_if_not_none(self.scheduler, checkpoint['scheduler'])
return checkpoint['epoch']

def fix_checkpoint(self, state_dict):
return state_dict

def checkpoint(self, epoch: int, name: str=""):
self.output_directory.mkdir(exist_ok=True, parents=True)

checkpoint = {
"ema": self.ema.ema.state_dict(),
"ema_updates": self.ema.updates,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"epoch": epoch,
"args": self.args
}

torch.save(checkpoint, self.output_directory / f"{name}.pth")

def process(self, data: Dict[str, float], epoch: int):
mAP = data['mAP']
data = {f"validation/metric/{k}": v for k, v in data.items()}
data['epoch'] = epoch
wandb.log(data)

if mAP > self.mAP_max:
self.checkpoint(epoch, name=f"best_model_mAP_{mAP}")
self.mAP_max = mAP


def set_up_logging_directory(dataset, task, output_directory, exp_name="temp"):
project = f"low_latency-{dataset}-{task}"

output_directory = output_directory / dataset / task
output_directory.mkdir(parents=True, exist_ok=True)
wandb.init(project=project, entity="rpg", save_code=True, dir=str(output_directory))
wandb.init(project=project, id=exp_name, entity="danielgehrig18", save_code=True, dir=str(output_directory))

name = wandb.run.name
name = wandb.run.id
output_directory = output_directory / name
output_directory.mkdir(parents=True, exist_ok=True)
os.system(f"cp -r {os.path.join(os.path.dirname(__file__), '../../low_latency_object_detection')} {str(output_directory)}")

return output_directory

Expand Down

0 comments on commit 3c9c717

Please sign in to comment.