diff --git a/.gitignore b/.gitignore index 79c564d..cf67214 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ *.so *.egg-info/ *.pyc -data/ build/ libs/ -.idea/ \ No newline at end of file +.idea/ diff --git a/config/dagr-l-ncaltech.yaml b/config/dagr-l-ncaltech.yaml new file mode 100644 index 0000000..b2ab882 --- /dev/null +++ b/config/dagr-l-ncaltech.yaml @@ -0,0 +1,36 @@ +path: "/data/storage/daniel/aegnn" +output_directory: "/data/storage/daniel/aegnn/logs" +pooling_dim_at_output: 5x7 + +task: detection +dataset: ncaltech101 + +# network +radius: 0.01 +time_window_us: 1000000 +max_neighbors: 16 +n_nodes: 50000 + +batch_size: 64 + +activation: relu +edge_attr_dim: 2 +aggr: sum +kernel_size: 5 +pooling_aggr: max + +base_width: 0.5 +after_pool_width: 1 +net_stem_width: 1 +yolo_stem_width: 1 +num_scales: 1 + +# learning +weight_decay: 0.00001 +clip: 0.1 + +aug_trans: 0.1 +aug_p_flip: 0 +aug_zoom: 1 +l_r: 0.001 +tot_num_epochs: 801 \ No newline at end of file diff --git a/readme.md b/readme.md index c5c2174..ac934fc 100644 --- a/readme.md +++ b/readme.md @@ -147,3 +147,36 @@ python scripts/visualize_detections.py --detections_folder $LOG_DIR/$WANDB_DIR \ ``` This will start a visualization window showing the detections over a given sequence. If you want to save the detections to a video, use the `--write_to_output` flag, which will create a video in the folder `$LOG_DIR/$WANDB_DIR/visualization}`. + +## Training +To train on N-Caltech101, download the files with + +```bash +wget https://download.ifi.uzh.ch/rpg/dagr/data/ncaltech101.zip -P $DAGR_DIR/data/ +cd $DAGR_DIR/data/ +unzip ncaltech101.zip +rm -rf ncaltech101.zip +``` + +Then run training with + +```bash + +python scripts/train_ncaltech101.py --config config/dagr-l-ncaltech.yaml \ + --exp_name ncaltech_l \ + --dataset_directory $DAGR_DIR/data/ \ + --output_directory $DAGR_DIR/logs/ +``` +To train on DSEC, make a symlink to the data directory via +```bash +ln -s $DSEC_ROOT $DAGR_DIR/data/dsec +``` +Then run training with +```bash + +python scripts/train_dsec.py --config config/dagr-s-dsec.yaml \ + --exp_name dsec_s_50 \ + --dataset_directory $DAGR_DIR/data/ \ + --output_directory $DAGR_DIR/logs/ \ + --use_image --img_net resnet50 --batch_size 32 +``` diff --git a/scripts/train_dsec.py b/scripts/train_dsec.py new file mode 100644 index 0000000..301d208 --- /dev/null +++ b/scripts/train_dsec.py @@ -0,0 +1,184 @@ +# avoid matlab error on server +import os +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + +import torch +import tqdm +import wandb +from pathlib import Path +import argparse + +from torch_geometric.data import DataLoader + +from dagr.utils.logging import Checkpointer, set_up_logging_directory, log_hparams +from dagr.utils.buffers import DetectionBuffer +from dagr.utils.args import FLAGS +from dagr.utils.learning_rate_scheduler import LRSchedule + +from dagr.data.augment import Augmentations +from dagr.utils.buffers import format_data +from dagr.data.dsec_data import DSEC + +from dagr.model.networks.dagr import DAGR +from dagr.model.networks.ema import ModelEMA + + +def gradients_broken(model): + valid_gradients = True + for name, param in model.named_parameters(): + if param.grad is not None: + # valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()) + valid_gradients = not (torch.isnan(param.grad).any()) + if not valid_gradients: + break + return not valid_gradients + +def fix_gradients(model): + for name, param in model.named_parameters(): + if param.grad is not None: + param.grad = torch.nan_to_num(param.grad, nan=0.0) + + +def train(loader: DataLoader, + model: torch.nn.Module, + ema: ModelEMA, + scheduler: torch.optim.lr_scheduler.LambdaLR, + optimizer: torch.optim.Optimizer, + args: argparse.ArgumentParser, + run_name=""): + + model.train() + + for i, data in enumerate(tqdm.tqdm(loader, desc=f"Training {run_name}")): + data = data.cuda(non_blocking=True) + data = format_data(data) + + optimizer.zero_grad(set_to_none=True) + + model_outputs = model(data) + + loss_dict = {k: v for k, v in model_outputs.items() if "loss" in k} + loss = loss_dict.pop("total_loss") + + loss.backward() + + torch.nn.utils.clip_grad_value_(model.parameters(), args.clip) + + fix_gradients(model) + + optimizer.step() + scheduler.step() + + ema.update(model) + + training_logs = {f"training/loss/{k}": v for k, v in loss_dict.items()} + wandb.log({"training/loss": loss.item(), "training/lr": scheduler.get_last_lr()[-1], **training_logs}) + +def run_test(loader: DataLoader, + model: torch.nn.Module, + dry_run_steps: int=-1, + dataset="gen1"): + + model.eval() + + mapcalc = DetectionBuffer(height=loader.dataset.height, width=loader.dataset.width, classes=loader.dataset.classes) + + for i, data in enumerate(tqdm.tqdm(loader)): + data = data.cuda() + data = format_data(data) + + detections, targets = model(data) + if i % 10 == 0: + torch.cuda.empty_cache() + + mapcalc.update(detections, targets, dataset, data.height[0], data.width[0]) + + if dry_run_steps > 0 and i == dry_run_steps: + break + + torch.cuda.empty_cache() + + return mapcalc + +if __name__ == '__main__': + import torch_geometric + import random + import numpy as np + + seed = 42 + torch_geometric.seed.seed_everything(seed) + torch.random.manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + args = FLAGS() + + output_directory = set_up_logging_directory(args.dataset, args.task, args.output_directory, exp_name=args.exp_name) + log_hparams(args) + + augmentations = Augmentations(args) + + print("init datasets") + dataset_path = args.dataset_directory / args.dataset + + train_dataset = DSEC(root=dataset_path, split="train", transform=augmentations.transform_training, debug=False, + min_bbox_diag=15, min_bbox_height=10) + test_dataset = DSEC(root=dataset_path, split="val", transform=augmentations.transform_testing, debug=False, + min_bbox_diag=15, min_bbox_height=10) + + train_loader = DataLoader(train_dataset, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) + num_iters_per_epoch = len(train_loader) + + sampler = np.random.permutation(np.arange(len(test_dataset))) + test_loader = DataLoader(test_dataset, sampler=sampler, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True) + + print("init net") + # load a dummy sample to get height, width + model = DAGR(args, height=test_dataset.height, width=test_dataset.width) + + num_params = sum([np.prod(p.size()) for p in model.parameters()]) + print(f"Training with {num_params} number of parameters.") + + model = model.cuda() + ema = ModelEMA(model) + + nominal_batch_size = 64 + lr = args.l_r * np.sqrt(args.batch_size) / np.sqrt(nominal_batch_size) + optimizer = torch.optim.AdamW(list(model.parameters()), lr=lr, weight_decay=args.weight_decay) + + lr_func = LRSchedule(warmup_epochs=.3, + num_iters_per_epoch=num_iters_per_epoch, + tot_num_epochs=args.tot_num_epochs) + + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_func) + + checkpointer = Checkpointer(output_directory=output_directory, + model=model, optimizer=optimizer, + scheduler=lr_scheduler, ema=ema, + args=args) + + start_epoch = checkpointer.restore_if_existing(output_directory, resume_from_best=False) + + start_epoch = 0 + if "resume_checkpoint" in args: + start_epoch = checkpointer.restore_checkpoint(args.resume_checkpoint, best=False) + print(f"Resume from checkpoint at epoch {start_epoch}") + + with torch.no_grad(): + mapcalc = run_test(test_loader, ema.ema, dry_run_steps=2, dataset=args.dataset) + mapcalc.compute() + + print("starting to train") + for epoch in range(start_epoch, args.tot_num_epochs): + train(train_loader, model, ema, lr_scheduler, optimizer, args, run_name=wandb.run.name) + checkpointer.checkpoint(epoch, name=f"last_model") + + if epoch % 3 > 0: + continue + + with torch.no_grad(): + mapcalc = run_test(test_loader, ema.ema, dataset=args.dataset) + metrics = mapcalc.compute() + checkpointer.process(metrics, epoch) + diff --git a/scripts/train_ncaltech101.py b/scripts/train_ncaltech101.py new file mode 100644 index 0000000..2fb6505 --- /dev/null +++ b/scripts/train_ncaltech101.py @@ -0,0 +1,182 @@ +# avoid matlab error on server +import os +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + +import torch +import tqdm +import wandb +from pathlib import Path +import argparse + +from torch_geometric.data import DataLoader + +from dagr.utils.logging import Checkpointer, set_up_logging_directory, log_hparams +from dagr.utils.buffers import DetectionBuffer +from dagr.utils.args import FLAGS +from dagr.utils.learning_rate_scheduler import LRSchedule + +from dagr.data.augment import Augmentations +from dagr.utils.buffers import format_data +from dagr.data.ncaltech101_data import NCaltech101 + +from dagr.model.networks.dagr import DAGR +from dagr.model.networks.ema import ModelEMA + +def gradients_broken(model): + valid_gradients = True + for name, param in model.named_parameters(): + if param.grad is not None: + # valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()) + valid_gradients = not (torch.isnan(param.grad).any()) + if not valid_gradients: + break + return not valid_gradients + +def fix_gradients(model): + for name, param in model.named_parameters(): + if param.grad is not None: + param.grad = torch.nan_to_num(param.grad, nan=0.0) + + +def train(loader: DataLoader, + model: torch.nn.Module, + ema: ModelEMA, + scheduler: torch.optim.lr_scheduler.LambdaLR, + optimizer: torch.optim.Optimizer, + args: argparse.ArgumentParser, + run_name=""): + + model.train() + + for i, data in enumerate(tqdm.tqdm(loader, desc=f"Training {run_name}")): + data = data.cuda(non_blocking=True) + data = format_data(data) + + optimizer.zero_grad(set_to_none=True) + + model_outputs = model(data) + + loss_dict = {k: v for k, v in model_outputs.items() if "loss" in k} + loss = loss_dict.pop("total_loss") + + loss.backward() + + torch.nn.utils.clip_grad_value_(model.parameters(), args.clip) + + fix_gradients(model) + + optimizer.step() + scheduler.step() + + ema.update(model) + + training_logs = {f"training/loss/{k}": v for k, v in loss_dict.items()} + wandb.log({"training/loss": loss.item(), "training/lr": scheduler.get_last_lr()[-1], **training_logs}) + +def run_test(loader: DataLoader, + model: torch.nn.Module, + dry_run_steps: int=-1, + dataset="gen1"): + + model.eval() + + mapcalc = DetectionBuffer(height=loader.dataset.height, width=loader.dataset.width, classes=loader.dataset.classes) + + for i, data in enumerate(tqdm.tqdm(loader)): + data = data.cuda() + data = format_data(data) + + detections, targets = model(data) + if i % 10 == 0: + torch.cuda.empty_cache() + + mapcalc.update(detections, targets, dataset, data.height[0], data.width[0]) + + if dry_run_steps > 0 and i == dry_run_steps: + break + + torch.cuda.empty_cache() + + return mapcalc + +if __name__ == '__main__': + import torch_geometric + import random + import numpy as np + + seed = 42 + torch_geometric.seed.seed_everything(seed) + torch.random.manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + args = FLAGS() + + output_directory = set_up_logging_directory(args.dataset, args.task, args.output_directory, exp_name=args.exp_name) + log_hparams(args) + + augmentations = Augmentations(args) + + print("init datasets") + dataset_path = args.dataset_directory / args.dataset + + train_dataset = NCaltech101(dataset_path, "training", augmentations.transform_training, num_events=args.n_nodes) + test_dataset = NCaltech101(dataset_path, "validation", augmentations.transform_testing, num_events=args.n_nodes) + + + train_loader = DataLoader(train_dataset, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) + num_iters_per_epoch = len(train_loader) + + sampler = np.random.permutation(np.arange(len(test_dataset))) + test_loader = DataLoader(test_dataset, sampler=sampler, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True) + + print("init net") + # load a dummy sample to get height, width + model = DAGR(args, height=test_dataset.height, width=test_dataset.width) + + num_params = sum([np.prod(p.size()) for p in model.parameters()]) + print(f"Training with {num_params} number of parameters.") + + model = model.cuda() + ema = ModelEMA(model) + + nominal_batch_size = 64 + lr = args.l_r * np.sqrt(args.batch_size) / np.sqrt(nominal_batch_size) + optimizer = torch.optim.AdamW(list(model.parameters()), lr=lr, weight_decay=args.weight_decay) + + lr_func = LRSchedule(warmup_epochs=.3, + num_iters_per_epoch=num_iters_per_epoch, + tot_num_epochs=args.tot_num_epochs) + + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_func) + + checkpointer = Checkpointer(output_directory=output_directory, + model=model, optimizer=optimizer, + scheduler=lr_scheduler, ema=ema, + args=args) + + start_epoch = checkpointer.restore_if_existing(output_directory, resume_from_best=False) + + start_epoch = 0 + if "resume_checkpoint" in args: + start_epoch = checkpointer.restore_checkpoint(args.resume_checkpoint, best=False) + print(f"Resume from checkpoint at epoch {start_epoch}") + + with torch.no_grad(): + mapcalc = run_test(test_loader, ema.ema, dry_run_steps=2, dataset=args.dataset) + mapcalc.compute() + + print("starting to train") + for epoch in range(start_epoch, args.tot_num_epochs): + train(train_loader, model, ema, lr_scheduler, optimizer, args, run_name=wandb.run.name) + checkpointer.checkpoint(epoch, name=f"last_model") + + if epoch % 3 > 0: + continue + + with torch.no_grad(): + mapcalc = run_test(test_loader, ema.ema, dataset=args.dataset) + metrics = mapcalc.compute() + checkpointer.process(metrics, epoch) + diff --git a/src/dagr/data/augment.py b/src/dagr/data/augment.py index 96630aa..ab56a1e 100644 --- a/src/dagr/data/augment.py +++ b/src/dagr/data/augment.py @@ -58,7 +58,10 @@ def _crop_image(image, left, right): return image def _resize_image(image, height, width, bg=None): + + image = image[0].permute(1, 2, 0).numpy() new_image = cv2.resize(image, (width, height), interpolation=cv2.INTER_NEAREST) + px = (new_image.shape[1] - image.shape[1])//2 py = (new_image.shape[0] - image.shape[0])//2 @@ -68,6 +71,8 @@ def _resize_image(image, height, width, bg=None): assert bg is not None bg[-py:-py+new_image.shape[0], -px:-px+new_image.shape[1]] = new_image + bg = torch.from_numpy(bg).permute(2, 0, 1)[None] + return bg def _crop_bbox(bbox: torch.Tensor, left: torch.Tensor, right: torch.Tensor): @@ -93,7 +98,10 @@ def __call__(self, data: Data): data.pos[:,0] = data.width - 1 - data.pos[:,0] if hasattr(data, "image"): - data.image = np.ascontiguousarray(data.image[:,::-1]) + image = data.image[0].permute(1,2,0).numpy() + image = np.ascontiguousarray(image[:,::-1]) + image = torch.from_numpy(image).permute(2, 0, 1)[None] + data.image = image if hasattr(data, "bbox"): data.bbox[:, 0] = data.width - 1 - (data.bbox[:, 0] + data.bbox[:, 2]) @@ -256,9 +264,11 @@ def __call__(self, data: Data): data.pos = data.pos + move_px if hasattr(data, "image"): - image = self.pad(data.image, self.image.copy()) - data.image = image[self.size[1]-move_px[1]:self.size[1]-move_px[1]+data.height, \ - self.size[0]-move_px[0]:self.size[0]-move_px[0]+data.width] + image = data.image[0].permute(1, 2, 0).numpy() + image = self.pad(image, self.image.copy()) + image = image[self.size[1]-move_px[1]:self.size[1]-move_px[1]+data.height, \ + self.size[0]-move_px[0]:self.size[0]-move_px[0]+data.width] + data.image = torch.from_numpy(image).permute(2, 0, 1)[None] if hasattr(data, "bbox"): data.bbox[:,:2] += move_px diff --git a/src/dagr/data/dsec_data.py b/src/dagr/data/dsec_data.py index 60233bd..ef7a933 100644 --- a/src/dagr/data/dsec_data.py +++ b/src/dagr/data/dsec_data.py @@ -95,7 +95,7 @@ def __init__(self, self.class_remapping = compute_class_mapping(self.classes, self.dataset.classes, self.MAPPING) if transform is not None and hasattr(transform, "transforms"): - init_transforms(transform.transforms, self.height, self.dataset.width) + init_transforms(transform.transforms, self.height, self.width) self.transform = transform self.no_eval = no_eval diff --git a/src/dagr/data/ncaltech101_data.py b/src/dagr/data/ncaltech101_data.py new file mode 100644 index 0000000..95c0bdd --- /dev/null +++ b/src/dagr/data/ncaltech101_data.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import hdf5plugin +import h5py + +from pathlib import Path +from typing import Optional, Callable +from torch.utils.data import Dataset +from torch_geometric.data import Data +from dagr.data.augment import init_transforms +from dagr.data.utils import to_data + + +class NCaltech101(Dataset): + + def __init__(self, root: Path, split, transform=Optional[Callable[[Data,], Data]], num_events: int=50000): + super().__init__() + self.load_dir = root / split + self.classes = sorted([d.name for d in self.load_dir.glob("*")]) + self.num_classes = len(self.classes) + self.files = sorted(list(self.load_dir.rglob("*.h5"))) + self.height = 180 + self.width = 240 + if transform is not None and hasattr(transform, "transforms"): + init_transforms(transform.transforms, self.height, self.width) + self.transform = transform + self.time_window = 1000000 + self.num_events = num_events + + def __len__(self): + return len(self.files) + + def preprocess(self, data): + data.t -= (data.t[-1] - self.time_window + 1) + return data + + def load_events(self, f_path): + return _load_events(f_path, self.num_events) + + def __getitem__(self, idx): + f_path = self.files[idx] + target = self.classes.index(str(f_path.parent.name)) + + events = self.load_events(f_path) + data = to_data(**events, bbox=self.load_bboxes(f_path, target), + t0=events['t'][0], t1=events['t'], width=self.width, height=self.height, + time_window=self.time_window) + + data = self.preprocess(data) + + data = self.transform(data) if self.transform is not None else data + + if not hasattr(data, "t"): + data.t = data.pos[:, -1:] + data.pos = data.pos[:, :2].type(torch.int16) + + return data + + def load_bboxes(self, raw_file: Path, class_id): + rel_path = str(raw_file.relative_to(self.load_dir)) + rel_path = rel_path.replace("image_", "annotation_").replace(".h5", ".bin") + annotation_file = self.load_dir / "../annotations" / rel_path + with annotation_file.open() as fh: + annotations = np.fromfile(fh, dtype=np.int16) + annotations = np.array(annotations[2:10]) + + return np.array([ + annotations[0], annotations[1], # upper-left corner + annotations[2] - annotations[0], # width + annotations[5] - annotations[1], # height + class_id, + 1 + ]).astype("float32").reshape((1,-1)) + +def _load_events(f_path, num_events): + with h5py.File(str(f_path)) as fh: + fh = fh['events'] + x = fh["x"][-num_events:] + y = fh["y"][-num_events:] + t = fh["t"][-num_events:] + p = fh["p"][-num_events:] + return dict(x=x, y=y, t=t, p=p) diff --git a/src/dagr/model/layers/components.py b/src/dagr/model/layers/components.py index 99a6547..3559ccc 100644 --- a/src/dagr/model/layers/components.py +++ b/src/dagr/model/layers/components.py @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs): def forward(self, data): if data.edge_index.shape[1] > 0: - return T.Cartesian.forward(self, data) + return T.Cartesian.__call__(self, data) else: data.edge_attr = torch.zeros((0, 3), dtype=data.x.dtype, device=data.x.device) - return data \ No newline at end of file + return data diff --git a/src/dagr/model/networks/dagr.py b/src/dagr/model/networks/dagr.py index bd36057..45f8eaf 100644 --- a/src/dagr/model/networks/dagr.py +++ b/src/dagr/model/networks/dagr.py @@ -8,7 +8,7 @@ from dagr.model.networks.net import Net from dagr.model.layers.spline_conv import SplineConvToDense from dagr.model.layers.conv import ConvBlock -from dagr.model.utils import shallow_copy, init_subnetwork, voxel_size_to_params, postprocess_network_output, convert_to_evaluation_format, init_grid_and_stride +from dagr.model.utils import shallow_copy, init_subnetwork, voxel_size_to_params, postprocess_network_output, convert_to_evaluation_format, init_grid_and_stride, convert_to_training_format class DAGR(YOLOX): @@ -24,6 +24,7 @@ def __init__(self, args, height, width): in_channels=backbone.out_channels, in_channels_cnn=backbone.out_channels_cnn, strides=backbone.strides, + pretrain_cnn=args.pretrain_cnn, args=args) super().__init__(backbone=backbone, head=head) @@ -75,10 +76,10 @@ def forward(self, x: Data, reset=True, return_targets=True, filtering=True): self.head.output_sizes = self.backbone.get_output_sizes() if self.training: - targets = self.convert_to_training_format(x.bbox, x.bbox_batch, x.num_graphs) + targets = convert_to_training_format(x.bbox, x.bbox_batch, x.num_graphs) if self.backbone.use_image: - targets0 = self.convert_to_training_format(x.bbox0, x.bbox0_batch, x.num_graphs) + targets0 = convert_to_training_format(x.bbox0, x.bbox0_batch, x.num_graphs) targets = (targets, targets0) # gt_target inputs need to be [l cx cy w h] in pixels @@ -130,10 +131,12 @@ def __init__( in_channels_cnn=[256, 512, 1024], act="silu", depthwise=False, + pretrain_cnn=False, args=None ): YOLOXHead.__init__(self, num_classes, args.yolo_stem_width, strides, in_channels, act, depthwise) + self.pretrain_cnn = pretrain_cnn self.num_scales = args.num_scales self.use_image = args.use_image self.batch_size = args.batch_size @@ -236,7 +239,7 @@ def forward(self, xin: Data, labels=None, imgs=None): # if we are only training the image detectors (pretraining), # we only need to minimize the loss at detections from the image branch. if self.use_image: - return self.get_losses( + losses_image = self.get_losses( imgs, image_out['x_shifts'], image_out['y_shifts'], @@ -246,6 +249,26 @@ def forward(self, xin: Data, labels=None, imgs=None): image_out['origin_preds'], dtype=image_out['x_shifts'][0].dtype, ) + + if not self.pretrain_cnn: + losses_events = self.get_losses( + imgs, + hybrid_out['x_shifts'], + hybrid_out['y_shifts'], + hybrid_out['expanded_strides'], + labels, + torch.cat(hybrid_out['outputs'], 1), + hybrid_out['origin_preds'], + dtype=xin[0].x.dtype, + ) + + losses_image = list(losses_image) + losses_events = list(losses_events) + + for i in range(5): + losses_image[i] = losses_image[i] + losses_events[i] + + return losses_image else: return self.get_losses( imgs, diff --git a/src/dagr/model/networks/net.py b/src/dagr/model/networks/net.py index 4d0b3d1..bb9965e 100644 --- a/src/dagr/model/networks/net.py +++ b/src/dagr/model/networks/net.py @@ -27,6 +27,7 @@ def compute_pooling_at_each_layer(pooling_dim_at_output, num_layers): poolings = torch.stack(poolings) return poolings + class Net(torch.nn.Module): def __init__(self, args, height, width): super().__init__() @@ -51,7 +52,7 @@ def __init__(self, args, height, width): self.use_image = args.use_image self.num_scales = args.num_scales - self.num_classes = dict(dsec=2).get(args.dataset, 2) + self.num_classes = dict(dsec=2, ncaltech101=100).get(args.dataset, 2) self.events_to_graph = EV_TGN(args) @@ -66,6 +67,7 @@ def __init__(self, args, height, width): poolings = compute_pooling_at_each_layer(args.pooling_dim_at_output, num_layers=4) max_vals_for_cartesian = 2*poolings[:,:2].max(-1).values self.strides = torch.ceil(poolings[-2:,1] * height).numpy().astype("int32").tolist() + self.strides = self.strides[-self.num_scales:] effective_radius = 2*float(int(args.radius * width + 2) / width) self.edge_attrs = Cartesian(norm=True, cat=False, max_value=effective_radius) diff --git a/src/dagr/utils/args.py b/src/dagr/utils/args.py index fdb7773..468d727 100644 --- a/src/dagr/utils/args.py +++ b/src/dagr/utils/args.py @@ -15,6 +15,7 @@ def BASE_FLAGS(): parser.add_argument("--config", type=Path, default="../config/detection.yaml") parser.add_argument("--use_image", action="store_true") parser.add_argument("--no_events", action="store_true") + parser.add_argument("--pretrain_cnn", action="store_true") parser.add_argument("--keep_temporal_ordering", action="store_true") # task params @@ -56,6 +57,7 @@ def FLAGS(): # learning params parser.add_argument('--aug_trans', default=argparse.SUPPRESS, type=float) parser.add_argument('--aug_zoom', default=argparse.SUPPRESS, type=float) + parser.add_argument('--exp_name', default=argparse.SUPPRESS, type=str) parser.add_argument('--l_r', default=argparse.SUPPRESS, type=float) parser.add_argument('--no_eval', action="store_true") parser.add_argument('--tot_num_epochs', default=argparse.SUPPRESS, type=int) diff --git a/src/dagr/utils/logging.py b/src/dagr/utils/logging.py index 5d30bd7..cc365d9 100644 --- a/src/dagr/utils/logging.py +++ b/src/dagr/utils/logging.py @@ -11,17 +11,103 @@ from torch_geometric.data import Data -def set_up_logging_directory(dataset, task, output_directory): +class Checkpointer: + def __init__(self, output_directory: Optional[Path] = None, args=None, optimizer=None, scheduler=None, ema=None, model=None): + self.optimizer = optimizer + self.scheduler = scheduler + self.ema = ema + self.model = model + + self.mAP_max = 0 + self.output_directory = output_directory + self.args = args + + def restore_if_existing(self, folder, resume_from_best=False): + checkpoint = self.search_for_checkpoint(folder, best=resume_from_best) + if checkpoint is not None: + print(f"Found existing checkpoint at {checkpoint}, resuming...") + self.restore_checkpoint(folder, best=resume_from_best) + + def mAP_from_checkpoint_name(self, checkpoint_name: Path): + return float(str(checkpoint_name).split("_")[-1].split(".pth")[0]) + + def search_for_checkpoint(self, resume_checkpoint: Path, best=False): + checkpoints = list(resume_checkpoint.glob("*.pth")) + if len(checkpoints) == 0: + return None + + if not best: + if resume_checkpoint / "last_model.pth" in checkpoints: + return resume_checkpoint / "last_model.pth" + + # remove "last_model.pth" from checkpoints + if resume_checkpoint / "last_model.pth" in checkpoints: + checkpoints.remove(resume_checkpoint / "last_model.pth") + + checkpoints = sorted(checkpoints, key=lambda x: self.mAP_from_checkpoint_name(x.name)) + return checkpoints[-1] + + + def restore_if_not_none(self, target, source): + if target is not None: + target.load_state_dict(source) + + def restore_checkpoint(self, checkpoint_directory, best=False): + path = self.search_for_checkpoint(checkpoint_directory, best) + assert path is not None, "No checkpoint found in {}".format(checkpoint_directory) + print("Restoring checkpoint from {}".format(path)) + checkpoint = torch.load(path) + + checkpoint['model'] = self.fix_checkpoint(checkpoint['model']) + checkpoint['ema'] = self.fix_checkpoint(checkpoint['ema']) + + if self.ema is not None: + self.ema.ema.load_state_dict(checkpoint.get('ema', checkpoint['model'])) + self.ema.updates = checkpoint.get('ema_updates', 0) + self.restore_if_not_none(self.model, checkpoint['model']) + self.restore_if_not_none(self.optimizer, checkpoint['optimizer']) + self.restore_if_not_none(self.scheduler, checkpoint['scheduler']) + return checkpoint['epoch'] + + def fix_checkpoint(self, state_dict): + return state_dict + + def checkpoint(self, epoch: int, name: str=""): + self.output_directory.mkdir(exist_ok=True, parents=True) + + checkpoint = { + "ema": self.ema.ema.state_dict(), + "ema_updates": self.ema.updates, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "epoch": epoch, + "args": self.args + } + + torch.save(checkpoint, self.output_directory / f"{name}.pth") + + def process(self, data: Dict[str, float], epoch: int): + mAP = data['mAP'] + data = {f"validation/metric/{k}": v for k, v in data.items()} + data['epoch'] = epoch + wandb.log(data) + + if mAP > self.mAP_max: + self.checkpoint(epoch, name=f"best_model_mAP_{mAP}") + self.mAP_max = mAP + + +def set_up_logging_directory(dataset, task, output_directory, exp_name="temp"): project = f"low_latency-{dataset}-{task}" output_directory = output_directory / dataset / task output_directory.mkdir(parents=True, exist_ok=True) - wandb.init(project=project, entity="rpg", save_code=True, dir=str(output_directory)) + wandb.init(project=project, id=exp_name, entity="danielgehrig18", save_code=True, dir=str(output_directory)) - name = wandb.run.name + name = wandb.run.id output_directory = output_directory / name output_directory.mkdir(parents=True, exist_ok=True) - os.system(f"cp -r {os.path.join(os.path.dirname(__file__), '../../low_latency_object_detection')} {str(output_directory)}") return output_directory