diff --git a/README.md b/README.md deleted file mode 100644 index b85b709..0000000 --- a/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# dagr -Code for the paper "Low Latency Automotive Vision with Event Cameras", published in Nature diff --git a/config/dagr-l-dsec.yaml b/config/dagr-l-dsec.yaml new file mode 100644 index 0000000..a032e2d --- /dev/null +++ b/config/dagr-l-dsec.yaml @@ -0,0 +1,73 @@ +dataset_directory: "/data/storage/daniel/aegnn/" +output_directory: "/data/storage/daniel/aegnn/logs" + +task: detection +dataset: dsec + +# network +radius: 0.01 +time_window_us: 1000000 +max_neighbors: 16 +n_nodes: 50000 + +batch_size: 64 + +activation: relu +edge_attr_dim: 2 +aggr: sum +kernel_size: 5 +pooling_aggr: max + +base_width: 0.5 +after_pool_width: 1 +net_stem_width: 1 +yolo_stem_width: 1 +num_scales: 2 + +# learning +weight_decay: 0.00001 +clip: 0.1 + +pooling_dim_at_output: 5x7 + +aug_trans: 0.1 +aug_zoom: 1.5 +aug_p_flip: 0.5 + +img_net: resnet18 + +l_r: 0.0002 +tot_num_epochs: 801 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/dagr-m-dsec.yaml b/config/dagr-m-dsec.yaml new file mode 100644 index 0000000..cfca265 --- /dev/null +++ b/config/dagr-m-dsec.yaml @@ -0,0 +1,73 @@ +dataset_directory: "/data/storage/daniel/aegnn/" +output_directory: "/data/storage/daniel/aegnn/logs" + +task: detection +dataset: dsec + +# network +radius: 0.01 +time_window_us: 1000000 +max_neighbors: 16 +n_nodes: 50000 + +batch_size: 64 + +activation: relu +edge_attr_dim: 2 +aggr: sum +kernel_size: 5 +pooling_aggr: max + +base_width: 0.5 +after_pool_width: 1 +net_stem_width: 0.75 +yolo_stem_width: 0.75 +num_scales: 2 + +# learning +weight_decay: 0.00001 +clip: 0.1 + +pooling_dim_at_output: 5x7 + +aug_trans: 0.1 +aug_zoom: 1.5 +aug_p_flip: 0.5 + +img_net: resnet18 + +l_r: 0.0002 +tot_num_epochs: 801 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/dagr-n-dsec.yaml b/config/dagr-n-dsec.yaml new file mode 100644 index 0000000..6a53e74 --- /dev/null +++ b/config/dagr-n-dsec.yaml @@ -0,0 +1,73 @@ +dataset_directory: "/data/storage/daniel/aegnn/" +output_directory: "/data/storage/daniel/aegnn/logs" + +task: detection +dataset: dsec + +# network +radius: 0.01 +time_window_us: 1000000 +max_neighbors: 16 +n_nodes: 50000 + +batch_size: 64 + +activation: relu +edge_attr_dim: 2 +aggr: sum +kernel_size: 5 +pooling_aggr: max + +base_width: 0.5 +after_pool_width: 1 +net_stem_width: 0.25 +yolo_stem_width: 0.25 +num_scales: 2 + +# learning +weight_decay: 0.00001 +clip: 0.1 + +pooling_dim_at_output: 5x7 + +aug_trans: 0.1 +aug_zoom: 1.5 +aug_p_flip: 0.5 + +img_net: resnet18 + +l_r: 0.0002 +tot_num_epochs: 801 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/dagr-s-dsec.yaml b/config/dagr-s-dsec.yaml new file mode 100644 index 0000000..b5c80f5 --- /dev/null +++ b/config/dagr-s-dsec.yaml @@ -0,0 +1,72 @@ +dataset_directory: "/data/storage/daniel/aegnn/" +output_directory: "/data/storage/daniel/aegnn/logs" + +task: detection +dataset: dsec + +# network +radius: 0.01 +time_window_us: 1000000 +max_neighbors: 16 +n_nodes: 50000 + +batch_size: 64 + +activation: relu +edge_attr_dim: 2 +aggr: sum +kernel_size: 5 +pooling_aggr: max + +base_width: 0.5 +after_pool_width: 1 +net_stem_width: 0.5 +yolo_stem_width: 0.5 +num_scales: 2 + +# learning +weight_decay: 0.00001 +clip: 0.1 + +pooling_dim_at_output: 5x7 + +aug_trans: 0.1 +aug_zoom: 1.5 +aug_p_flip: 0.5 + +img_net: resnet18 + +l_r: 0.0002 +tot_num_epochs: 801 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/download_and_install_dependencies.sh b/download_and_install_dependencies.sh new file mode 100644 index 0000000..619b34a --- /dev/null +++ b/download_and_install_dependencies.sh @@ -0,0 +1,25 @@ +#! /usr/bin/env bash +DAGR_DIR=$(pwd) + +# Download detectron2 for its fast mAP calculation function +mkdir $DAGR_DIR/libs +cd $DAGR_DIR/libs +git clone --no-checkout git@github.com:facebookresearch/detectron2.git +cd $DAGR_DIR/libs/detectron2/ +git checkout 32bd159d7263683e39bf4e87e5c4ac88bad2fd73 + +# Download YOLOX +cd $DAGR_DIR/libs +git clone --no-checkout git@github.com:Megvii-BaseDetection/YOLOX.git +cd $DAGR_DIR/libs/YOLOX +git checkout 618fd8c08b2bc5fac9ffbb19a3b7e039ea0d5b9a + +# Download dsec-det +cd $DAGR_DIR/libs +git clone git@github.com:uzh-rpg/dsec-det.git +cd $DAGR_DIR/libs/dsec-det +git checkout f3ea48b0eebef93b2052396fd23b3d40e6ff0363 + +pip install -e $DAGR_DIR/libs/dsec-det +pip install -e $DAGR_DIR/libs/detectron2 +pip install -e $DAGR_DIR/libs/YOLOX \ No newline at end of file diff --git a/download_example_data.sh b/download_example_data.sh new file mode 100644 index 0000000..ae6519e --- /dev/null +++ b/download_example_data.sh @@ -0,0 +1,10 @@ +#! /usr/bin/env bash +DAGR_DIR=$(pwd) +DATA_DIR=$DAGR_DIR/data + +mkdir $DATA_DIR +wget https://download.ifi.uzh.ch/rpg/dagr/data/dagr_s_50.pth -O $DATA_DIR/dagr_s_50.pth + +wget https://download.ifi.uzh.ch/rpg/dagr/data/DSEC_fragment.zip -O $DATA_DIR/DSEC_fragment.zip +unzip $DATA_DIR/DSEC_fragment.zip -d $DATA_DIR +rm -rf $DATA_DIR/DSEC_fragment.zip \ No newline at end of file diff --git a/install_env.sh b/install_env.sh new file mode 100644 index 0000000..08c91e6 --- /dev/null +++ b/install_env.sh @@ -0,0 +1,12 @@ +#! /usr/bin/env bash + +TORCH=$(python -c "import torch; print(torch.__version__)") +CUDA=$(python -c "import torch; print(torch.version.cuda)") +URL=https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html + +pip install --no-cache-dir torch-scatter -f $URL; +pip install --no-cache-dir torch-cluster -f $URL; +pip install --no-cache-dir torch-spline-conv -f $URL; +pip install --no-cache-dir torch-sparse -f $URL; +pip install torch-geometric; +pip install wandb numba hdf5plugin plotly matplotlib pycocotools opencv-python scikit-video pandas ruamel.yaml diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..f572caa --- /dev/null +++ b/readme.md @@ -0,0 +1,149 @@ +# Low Latency Automotive Vision with Event Cameras + + + +This repository contains code from our 2024 Nature paper +**_Low Latency Automotive Vision with Event Cameras_** by Daniel Gehrig and Davide Scaramuzza. +If you use our code or refer to this project, please cite it using + +``` +@inproceedings{Gehrig24nature, + author = {Gehrig, Daniel and Scaramuzza, Davide}, + title = {Low Latency Automotive Vision with Event Cameras}, + booktitle = {Nature}, + year = {2024} +} +``` + +## Installation +First, download the github repository and its dependencies +```bash +WORK_DIR=/path/to/work/directory/ +cd $WORK_DIR +git clone git@github.com:uzh-rpg/low_latency_object_detection.git # TODO update url +DAGR_DIR=$WORK_DIR/low_latency_object_detection # TODO update url + +cd $DAGR_DIR +git checkout opensource # TODO remove + +``` +Then start by installing the main libraries. Make sure Anaconda (or better Mamba), PyTorch, and CUDA is installed. +```bash +cd $DAGR_DIR +conda create -y -n dagr python=3.8 +conda activate dagr +conda install -y pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch +``` +Then install the pytorch-geometric libraries. This may take a while. +```bash +bash install_env.sh +``` +The above bash file will figure out the CUDA and Torch version, and install the appropriate pytorch-geometric packages. +Then, download and install additional dependencies locally +```bash +bash download_and_install_dependencies.sh +``` +Finally, install the dagr package +```bash +pip install -e . +``` + +## Run Example +After installing, you can download a data fragment, and checkpoint with +```bash +bash download_example_data.sh +``` +This will download a checkpoint and data fragment of DSEC-Detection on which you can test the code. +Once downloaded, run the following command +```bash +LOG_DIR=/path/to/log +DEVICE=1 +CUDA_VISIBLE_DEVICES=$DEVICE python scripts/run_test_interframe.py --config config/dagr-s-dsec.yaml \ + --use_image \ + --img_net resnet50 \ + --checkpoint data/dagr_s_50.pth \ + --batch_size 8 \ + --dataset_directory data/DSEC_fragment \ + --output_directory $LOG_DIR +``` +note the wandb directory as `$WANDB_DIR` and then visualize the detections with +```bash +python scripts/visualize_detections.py --detections_folder $LOG_DIR/$WANDB_DIR \ + --dataset_directory data/DSEC_fragment/test \ + --vis_time_step_us 1000 \ + --event_time_window_us 5000 \ + --sequence zurich_city_13_b +``` + +## Test on DSEC +Start by downloading the DSEC dataset and the additional labelled data introduced in this work. +To do so, follow [these instructions](https://github.com/uzh-rpg/dsec-det?tab=readme-ov-file#download-dsec). They are based on the scripts +of [dsec-det](https://github.com/uzh-rpg/dsec-det), which can be found in `libs/dsec-det/scripts`. +To continue, complete sections [Download DSEC](https://github.com/uzh-rpg/dsec-det?tab=readme-ov-file#download-dsec) until [Test Alignment](https://github.com/uzh-rpg/dsec-det?tab=readme-ov-file#test-alignment). +If you already downloaded DSEC, make sure `$DSEC_ROOT` points to it, and instead start at section [Download DSEC-extra +](https://github.com/uzh-rpg/dsec-det?tab=readme-ov-file#download-dsec-extra). + +After downloading all the data, change back to $DAGR_DIR, and start by downsampling the events +```bash +cd $DAGR_DIR +bash scripts/downsample_all_events.sh $DSEC_ROOT +``` + +### Running Evaluation +This repository implements three scripts for running evaluation of the model on DSEC-Det. +The first, evaluates the detection performance of the model after seeing one image, and the subsequent 50 milliseconds of events. +To run it, specify a device, and logging directory with type +```bash +LOG_DIR=/path/to/log +DEVICE=1 +CUDA_VISIBLE_DEVICES=$DEVICE python scripts/run_test.py --config config/dagr-s-dsec.yaml \ + --use_image \ + --img_net resnet50 \ + --checkpoint data/dagr_s_50.pth \ + --batch_size 8 \ + --dataset_directory $DSEC_ROOT \ + --output_directory $LOG_DIR +``` +Then, to evaluate the number of FLOPS generated in asynchronous mode, run +```bash +LOG_DIR=/path/to/log +DEVICE=1 +CUDA_VISIBLE_DEVICES=$DEVICE python scripts/count_flops.py --config config/eagr-s-dsec.yaml \ + --use_image \ + --img_net resnet50 \ + --checkpoint data/dagr_s_50.pth \ + --batch_size 8 \ + --dataset_directory $DSEC_ROOT \ + --output_directory $LOG_DIR +``` +Finally, to evaluate the interframe detection performance of our method run +```bash +LOG_DIR=/path/to/log +DEVICE=1 +CUDA_VISIBLE_DEVICES=$DEVICE python scripts/run_test_interframe.py --config config/eagr-s-dsec.yaml \ + --use_image \ + --img_net resnet50 \ + --checkpoint data/dagr_s_50.pth \ + --batch_size 8 \ + --dataset_directory $DSEC_ROOT \ + --output_directory $LOG_DIR \ + --num_interframe_steps 10 +``` +This last script will write the high-rate detections from our method into the folder `$LOG_DIR/$WANDB_DIR`, +where `$WANDB_DIR` is the automatically generated folder created by wandb. +To visualize the detections, use the following script: +```bash +python scripts/visualize_detections.py --detections_folder $LOG_DIR/$WANDB_DIR \ + --dataset_directory $DSEC_ROOT/test/ \ + --vis_time_step_us 1000 \ + --event_time_window_us 5000 \ + --sequence zurich_city_13_b + +``` +This will start a visualization window showing the detections over a given sequence. If you want to save the detections +to a video, use the `--write_to_output` flag, which will create a video in the folder `$LOG_DIR/$WANDB_DIR/visualization}`. \ No newline at end of file diff --git a/scripts/count_flops.py b/scripts/count_flops.py new file mode 100644 index 0000000..dd9ae04 --- /dev/null +++ b/scripts/count_flops.py @@ -0,0 +1,75 @@ +import os +import tqdm +import torch +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + +from torch_geometric.data import DataLoader + +from dagr.utils.args import FLOPS_FLAGS +from dagr.utils.buffers import DictBuffer, format_data + +from dagr.data.augment import Augmentations +from dagr.data.dsec_data import DSEC + +from dagr.model.networks.dagr import DAGR + +from dagr.asynchronous.evaluate_flops import evaluate_flops + + +if __name__ == '__main__': + import torch_geometric + seed = 42 + torch_geometric.seed.seed_everything(seed) + args = FLOPS_FLAGS() + assert "checkpoint" in args + + project = f"flops-{args.dataset}-{args.task}" + pbar = tqdm.tqdm(total=4) + + pbar.set_description("Loading dataset") + dataset_path = args.dataset_directory / args.dataset + print("init datasets") + dataset = DSEC(args.dataset_directory, "test", Augmentations.transform_testing, debug=True, min_bbox_diag=15, min_bbox_height=10) + loader = DataLoader(dataset, follow_batch=['bbox', "bbox0"], batch_size=args.batch_size, shuffle=False, num_workers=16) + pbar.update(1) + + pbar.set_description("Initializing net") + model = DAGR(args, height=dataset.height, width=dataset.width) + model = model.cuda() + model.eval() + pbar.update(1) + + assert "checkpoint" in args + checkpoint = torch.load(args.checkpoint) + model.load_state_dict(checkpoint['ema']) + pbar.update(1) + + model.cache_luts(radius=args.radius, height=dataset.height, width=dataset.width) + + pbar.set_description("Computing FLOPS") + buffer = DictBuffer() + args.output_directory.mkdir(parents=True, exist_ok=True) + pbar_flops = tqdm.tqdm(total=len(loader.dataset), desc="Computing FLOPS") + for i, data in enumerate(loader): + data = data.cuda(non_blocking=True) + data = format_data(data) + + flops_evaluation = evaluate_flops(model, data, + check_consistency=args.check_consistency, + return_all_samples=True, dense=args.dense) + if flops_evaluation is None: + continue + + buffer.update(flops_evaluation['flops_per_layer']) + buffer.save(args.output_directory / "flops_per_layer.pth") + tot_flops = sum(buffer.compute().values()) + + pbar_flops.set_description(f"Total FLOPS {tot_flops}") + pbar_flops.update(1) + + print(sum(buffer.compute().values())) + pbar.update(1) + + + + diff --git a/scripts/downsample_all_events.sh b/scripts/downsample_all_events.sh new file mode 100644 index 0000000..1209463 --- /dev/null +++ b/scripts/downsample_all_events.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +DSEC_ROOT=$1 +for split in train test; do + for sequence in $DSEC_ROOT/$split/*/; do + infile=$sequence/events/left/events.h5 + outfile=$sequence/events/left/events_2x.h5 + python scripts/downsample_events.py --input_path $infile --output_path $outfile + done +done \ No newline at end of file diff --git a/scripts/downsample_events.py b/scripts/downsample_events.py new file mode 100644 index 0000000..8eed44e --- /dev/null +++ b/scripts/downsample_events.py @@ -0,0 +1,165 @@ +import argparse +import tqdm +import hdf5plugin +import h5py +import weakref +import numba + +import numpy as np + +from pathlib import Path + +from dsec_det.io import extract_from_h5_by_index, get_num_events + + +def _compression_opts(): + compression_level = 1 # {0, ..., 9} + shuffle = 2 # {0: none, 1: byte, 2: bit} + # From https://github.com/Blosc/c-blosc/blob/7435f28dd08606bd51ab42b49b0e654547becac4/blosc/blosc.h#L66-L71 + # define BLOSC_BLOSCLZ 0 + # define BLOSC_LZ4 1 + # define BLOSC_LZ4HC 2 + # define BLOSC_SNAPPY 3 + # define BLOSC_ZLIB 4 + # define BLOSC_ZSTD 5 + compressor_type = 5 + compression_opts = (0, 0, 0, 0, compression_level, shuffle, compressor_type) + return compression_opts + + +H5_BLOSC_COMPRESSION_FLAGS = dict( + compression=32001, + compression_opts=_compression_opts(), # Blosc + chunks=True +) + +def create_ms_to_idx(t_us): + t_ms = t_us // 1000 + x, counts = np.unique(t_ms, return_counts=True) + ms_to_idx = np.zeros(shape=(t_ms[-1] + 2,), dtype="uint64") + ms_to_idx[x + 1] = counts + ms_to_idx = ms_to_idx[:-1].cumsum() + return ms_to_idx + +class H5Writer: + def __init__(self, outfile): + assert not outfile.exists() + + self.h5f = h5py.File(outfile, 'a') + self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) + + self.t_offset = None + self.num_events = 0 + + # create hdf5 datasets + shape = (2 ** 16,) + maxshape = (None,) + + self.h5f.create_dataset(f'events/x', shape=shape, dtype='u2', maxshape=maxshape, **H5_BLOSC_COMPRESSION_FLAGS) + self.h5f.create_dataset(f'events/y', shape=shape, dtype='u2', maxshape=maxshape, **H5_BLOSC_COMPRESSION_FLAGS) + self.h5f.create_dataset(f'events/p', shape=shape, dtype='u1', maxshape=maxshape, **H5_BLOSC_COMPRESSION_FLAGS) + self.h5f.create_dataset(f'events/t', shape=shape, dtype='u4', maxshape=maxshape, **H5_BLOSC_COMPRESSION_FLAGS) + + def create_ms_to_idx(self): + t_us = self.h5f['events/t'][()] + self.h5f.create_dataset(f'ms_to_idx', data=create_ms_to_idx(t_us), dtype='u8', **H5_BLOSC_COMPRESSION_FLAGS) + + @staticmethod + def close_callback(h5f: h5py.File): + h5f.close() + + def add_data(self, events): + if self.t_offset is None: + self.t_offset = events['t'][0] + self.h5f.create_dataset(f't_offset', data=self.t_offset, dtype='i8') + + events['t'] -= self.t_offset + size = len(events['t']) + self.num_events += size + + self.h5f[f'events/x'].resize(self.num_events, axis=0) + self.h5f[f'events/y'].resize(self.num_events, axis=0) + self.h5f[f'events/p'].resize(self.num_events, axis=0) + self.h5f[f'events/t'].resize(self.num_events, axis=0) + + self.h5f[f'events/x'][self.num_events-size:self.num_events] = events['x'] + self.h5f[f'events/y'][self.num_events-size:self.num_events] = events['y'] + self.h5f[f'events/p'][self.num_events-size:self.num_events] = events['p'] + self.h5f[f'events/t'][self.num_events-size:self.num_events] = events['t'] + + +def downsample_events(events, input_height, input_width, output_height, output_width, change_map=None): + # this subsamples events if they were generated with cv2.INTER_AREA + if change_map is None: + change_map = np.zeros((output_height, output_width), dtype="float32") + + fx = int(input_width / output_width) + fy = int(input_height / output_height) + + mask = np.zeros(shape=(len(events['t']),), dtype="bool") + mask, change_map = _filter_events_resize(events['x'], events['y'], events['p'], mask, change_map, fx, fy) + + events = {k: v[mask] for k, v in events.items()} + events['x'] = (events['x'] / fx).astype("uint16") + events['y'] = (events['y'] / fy).astype("uint16") + + return events, change_map + + +@numba.jit(nopython=True, cache=True) +def _filter_events_resize(x, y, p, mask, change_map, fx, fy): + # iterates through x,y,p of events, and increments cells of size fx x fy by 1/(fx*fy) + # if one of these cells reaches +-1, then reset the cell, and pass through that event. + # for memory reasons, this only returns the True/False for every event, indicating if + # the event was skipped or passed through. + for i in range(len(x)): + x_l = x[i] // fx + y_l = y[i] // fy + change_map[y_l, x_l] += p[i] * 1.0 / (fx * fy) + + if np.abs(change_map[y_l, x_l]) >= 1: + mask[i] = True + change_map[y_l, x_l] -= p[i] + + return mask, change_map + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("""Downsample events""") + parser.add_argument("--input_path", type=Path, required=True, help="Path to input events.h5. ") + parser.add_argument("--output_path", type=Path, required=True, help="Path where output events.h5 will be written.") + parser.add_argument("--input_height", type=int, default=480, help="Height of the input events resolution.") + parser.add_argument("--input_width", type=int, default=640, help="Width of the input events resolution") + parser.add_argument("--output_height", type=int, default=240, help="Height of the output events resolution.") + parser.add_argument("--output_width", type=int, default=320, help="Width of the output events resolution.") + args = parser.parse_args() + + num_events = get_num_events(args.input_path) + num_events_per_chunk = 100000 + num_iterations = num_events // num_events_per_chunk + + writer = H5Writer(args.output_path) + + change_map = None + pbar = tqdm.tqdm(total=num_iterations+1) + for i in range(num_iterations): + events = extract_from_h5_by_index(args.input_path, i * num_events_per_chunk, (i+1) * num_events_per_chunk) + events['p'] = 2 * events['p'].astype("int8") - 1 + downsampled_events, change_map = downsample_events(events, change_map=change_map, input_height=args.input_height, input_width=args.input_width, + output_height=args.output_height, output_width=args.output_width) + writer.add_data(downsampled_events) + pbar.update(1) + + events = extract_from_h5_by_index(args.input_path, num_iterations * num_events_per_chunk, num_events) + downsampled_events, change_map = downsample_events(events, change_map=change_map, input_height=args.input_height, + input_width=args.input_width, + output_height=args.output_height, output_width=args.output_width) + writer.add_data(downsampled_events) + pbar.update(1) + + writer.create_ms_to_idx() + + + + diff --git a/scripts/run_test.py b/scripts/run_test.py new file mode 100644 index 0000000..1562c74 --- /dev/null +++ b/scripts/run_test.py @@ -0,0 +1,66 @@ +# avoid matlab error on server +import os +import torch +import wandb +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + +from torch_geometric.data import DataLoader +from dagr.utils.args import FLAGS + +from dagr.data.dsec_data import DSEC +from dagr.data.augment import Augmentations + +from dagr.model.networks.dagr import DAGR +from dagr.model.networks.ema import ModelEMA + +from dagr.utils.logging import set_up_logging_directory, log_hparams +from dagr.utils.testing import run_test_with_visualization + + +if __name__ == '__main__': + import torch_geometric + import random + import numpy as np + + seed = 42 + torch_geometric.seed.seed_everything(seed) + torch.random.manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + args = FLAGS() + + output_directory = set_up_logging_directory(args.dataset, args.task, args.output_directory) + + project = f"low_latency-{args.dataset}-{args.task}" + print(f"PROJECT: {project}") + log_hparams(args) + + print("init datasets") + dataset_path = args.dataset_directory.parent / args.dataset + + test_dataset = DSEC(args.dataset_directory, "test", Augmentations.transform_testing, debug=False, min_bbox_diag=15, min_bbox_height=10) + + num_iters_per_epoch = 1 + + sampler = np.random.permutation(np.arange(len(test_dataset))) + test_loader = DataLoader(test_dataset, sampler=sampler, follow_batch=['bbox', 'bbox0'], batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True) + + print("init net") + # load a dummy sample to get height, width + model = DAGR(args, height=test_dataset.height, width=test_dataset.width) + model = model.cuda() + ema = ModelEMA(model) + + assert "checkpoint" in args + checkpoint = torch.load(args.checkpoint) + ema.ema.load_state_dict(checkpoint['ema']) + ema.ema.cache_luts(radius=args.radius, height=test_dataset.height, width=test_dataset.width) + + with torch.no_grad(): + metrics = run_test_with_visualization(test_loader, ema.ema, dataset=args.dataset) + log_data = {f"testing/metric/{k}": v for k, v in metrics.items()} + wandb.log(log_data) + print(metrics['mAP']) + diff --git a/scripts/run_test_interframe.py b/scripts/run_test_interframe.py new file mode 100644 index 0000000..81e85b6 --- /dev/null +++ b/scripts/run_test_interframe.py @@ -0,0 +1,88 @@ +import torch +import tqdm +import wandb +import os +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + +from torch_geometric.data import DataLoader +from pprint import pprint + +from dagr.utils.logging import set_up_logging_directory, log_hparams +from dagr.utils.args import FLAGS +from dagr.utils.testing import run_test_with_visualization + +from dagr.data.augment import Augmentations +from dagr.data.dsec_data import DSEC + +from dagr.model.networks.dagr import DAGR +from dagr.model.networks.ema import ModelEMA + + +def to_npy(detections): + n_boxes = len(detections['boxes']) + dtype = np.dtype([('t', ' 0 + assert args.event_time_window_us > 0 + + if args.write_to_output: + output_path = args.detections_folder / "visualization" + output_path.mkdir(parents=True, exist_ok=True) + + detections_file = args.detections_folder / f"detections_{args.sequence}.npy" + detections = np.load(detections_file) + detection_timestamps = np.unique(detections['t']) + + dsec_directory = DSECDirectory(args.dataset_directory / args.sequence) + + t0, t1 = load_start_and_end_time(dsec_directory) + + vis_timestamps = np.arange(t0, t1, step=args.vis_time_step_us) + step_index_to_image_index = compute_index(dsec_directory.images.timestamps, vis_timestamps) + step_index_to_boxes_index = compute_index(detection_timestamps, vis_timestamps) + scale = 2 + + for step, t in enumerate(vis_timestamps): + + # find most recent image + image_index = step_index_to_image_index[step] + image = load_image_with_index(dsec_directory, image_index) + + # find events within time window [image_timestamps, t] + events = load_events_in_timewindow(dsec_directory, t-args.event_time_window_us, t) + + # find most recent bounding boxes + boxes_index = step_index_to_boxes_index[step] + boxes_timestamp = detection_timestamps[boxes_index] + boxes = detections[detections['t'] == boxes_timestamp] + + # draw them on one image + scale = 2 + image = draw_events_on_image(image, events['x'], events['y'], events['p']) + image = draw_bbox_on_img(image, scale*boxes['x'], scale*boxes['y'], scale*boxes['w'], scale*boxes["h"], + boxes["class_id"], boxes['class_confidence'], conf=0.3, nms=0.65) + + if args.write_to_output: + cv2.imwrite(str(output_path / ("%06d.png" % step)), image) + else: + cv2.imshow("DSEC Det: Visualization", image) + cv2.waitKey(3) + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..47063ca --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from distutils.core import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='dagr', + packages=['dagr'], + package_dir={'':'src'}, + ext_modules=[ + CUDAExtension(name='asy_tools', + sources=['src/dagr/asynchronous/asy_tools/main.cu']), + CUDAExtension(name="ev_graph_cuda", + sources=['src/dagr/graph/ev_graph.cu']) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/src/dagr/asynchronous/__init__.py b/src/dagr/asynchronous/__init__.py new file mode 100644 index 0000000..776a5a5 --- /dev/null +++ b/src/dagr/asynchronous/__init__.py @@ -0,0 +1,118 @@ +import logging + +import torch.nn +import torch_geometric +import inspect + +from torch.nn import ModuleList + +from .conv import make_conv_asynchronous +from .batch_norm import make_batch_norm_asynchronous +from .linear import make_linear_asynchronous +from .max_pool import make_max_pool_asynchronous +from .cartesian import make_cartesian_asynchronous + +from .flops import compute_flops_from_module + +from dagr.model.layers.spline_conv import MySplineConv +from dagr.model.layers.pooling import Pooling +from dagr.model.layers.components import BatchNormData, Cartesian, Linear + + + +from torch_geometric.data import Data, Batch +from typing import List + + +def is_data_or_data_list(ann): + return ann is Data or ann is Batch or ann is List[Data] + +def make_model_synchronous(module: torch.nn.Module): + module.forward = module.sync_forward + module.asy_flops_log = [] + + for key, nn in module.named_modules(): + if hasattr(nn, "sync_forward"): + nn.forward = nn.sync_forward + nn.asy_flops_log = [] + + return module + +def make_model_asynchronous(module, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for graph convolutional layers. + By overwriting parts of the module asynchronous processing can be enabled without the need of re-learning + and moving its weights and configuration. So, a convolutional layer can be converted by, for example: + + ``` + module = GCNConv(1, 2) + module = make_conv_asynchronous(module) + ``` + + :param module: convolutional module to transform. + :param grid_size: grid size (grid starting at 0, spanning to `grid_size`), >= `size` for pooling operations, + e.g. the image size. + :param r: update radius around new events. + :param edge_attributes: function for computing edge attributes (default = None), assumed to be the same over + all convolutional layers. + :param log_flops: log flops of asynchronous update. + """ + assert isinstance(module, torch.nn.Module), "module must be a `torch.nn.Module`" + model_forward = module.forward + module.sync_forward = module.forward + + module.asy_flops_log = [] if log_flops else None + + # Make all layers asynchronous that have an implemented asynchronous function. Otherwise use + # the synchronous forward function. + for key, nn in module._modules.items(): + nn_class_name = nn.__class__.__name__ + logging.debug(f"Making layer {key} of type {nn_class_name} asynchronous") + + if isinstance(nn, MySplineConv): + module._modules[key] = make_conv_asynchronous(nn, log_flops=log_flops) + + elif isinstance(nn, Pooling): + module._modules[key] = make_max_pool_asynchronous(nn, log_flops=log_flops) + + elif isinstance(nn, BatchNormData): + module._modules[key] = make_batch_norm_asynchronous(nn, log_flops=log_flops) + + elif isinstance(nn, Cartesian): + module._modules[key] = make_cartesian_asynchronous(nn, log_flops=log_flops) + + elif isinstance(nn, Linear): + module._modules[key] = make_linear_asynchronous(nn, log_flops=log_flops) + + elif isinstance(nn, ModuleList): + module._modules[key] = make_model_asynchronous(nn, log_flops=log_flops) + + else: + sign = inspect.signature(nn.forward) + first_arg = list(sign.parameters.values())[0] + + if not is_data_or_data_list(first_arg.annotation): + continue + + module._modules[key] = make_model_asynchronous(nn, log_flops=log_flops) + logging.debug(f"Asynchronous module for {nn_class_name} is being made asynchronous recursively.") + + def async_forward(data: torch_geometric.data.Data, *args, **kwargs): + out = model_forward(data, *args, **kwargs) + + if module.asy_flops_log is not None: + flops_count = [compute_flops_from_module(layer) for layer in module._modules.values()] + module.asy_flops_log.append(sum(flops_count)) + logging.debug(f"Model's modules update with overall {sum(flops_count)} flops") + + return out + + module.forward = async_forward + return module + + +__all__ = [ + "make_conv_asynchronous", + "make_linear_asynchronous", + "make_max_pool_asynchronous", + "make_model_asynchronous" +] diff --git a/src/dagr/asynchronous/asy_tools/main.cu b/src/dagr/asynchronous/asy_tools/main.cu new file mode 100644 index 0000000..b2fa80d --- /dev/null +++ b/src/dagr/asynchronous/asy_tools/main.cu @@ -0,0 +1,244 @@ +#include + +#include +#include +#include + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_DEVICE(x, y) AT_ASSERTM(x.device().index() == y.device().index(), #x " and " #y " must be in same CUDA device") + + +template +__global__ void masked_isdiff_kernel( + int64_t* __restrict__ indices, + const scalar_t* __restrict__ x_old, + const scalar_t* __restrict__ x_new, + int K, int C, float atol, float rtol +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K) + return; + + // find out how many events to write, and what is the offset + int64_t temp = indices[lin_idx]; + indices[lin_idx] = -1; + int offset = temp*C; + for (int i=0; i atol + rtol * other) { + indices[lin_idx] = temp; + break; + } + } +} + +template +__global__ void masked_inplace_BN_kernel( + const int64_t* __restrict__ indices, + const scalar_t* __restrict__ x, + scalar_t* __restrict__ x_out, + const scalar_t* __restrict__ running_mean, + const scalar_t* __restrict__ running_var, + const scalar_t* __restrict__ weight, + const scalar_t* __restrict__ bias, + int K, int C, float eps +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K*C) + return; + + int i = lin_idx / C; + int c = lin_idx % C; + + int x_lin_idx = C * indices[i] + c; + x_out[x_lin_idx] = (x[x_lin_idx] - running_mean[c]) / (sqrt(running_var[c] + eps)) * weight[c] + bias[c]; +} + +void masked_inplace_BN( + const torch::Tensor& indices, + const torch::Tensor& x, + torch::Tensor& x_out, + const torch::Tensor& running_mean, + const torch::Tensor& running_var, + const torch::Tensor& weight, + const torch::Tensor& bias, + float eps + ) +{ + unsigned K = indices.size(0); + unsigned C = x.size(1); + + unsigned threads = 256; + dim3 blocks((K*C + threads - 1) / threads, 1); + + masked_inplace_BN_kernel<<>>( + indices.data(), + x.data(), + x_out.data(), + running_mean.data(), + running_var.data(), + weight.data(), + bias.data(), K, C, eps + ); +} + +torch::Tensor masked_isdiff( + const torch::Tensor& indices, // N -> num events + const torch::Tensor& x_old, // K -> num active pixels + const torch::Tensor& x_new, // K -> num active pixels + float atol, float rtol + ) +{ + CHECK_INPUT(indices); + CHECK_INPUT(x_old); + CHECK_INPUT(x_new); + + CHECK_DEVICE(indices, x_old); + CHECK_DEVICE(indices, x_new); + + unsigned K = indices.size(0); + unsigned C = x_old.size(1); + + unsigned threads = 256; + dim3 blocks((K + threads - 1) / threads, 1); + + masked_isdiff_kernel<<>>( + indices.data(), + x_old.data(), + x_new.data(), + K, C, atol, rtol + ); + + return indices.index({indices > -1}); +} + + +template +__global__ void masked_lin_kernel( + int64_t* __restrict__ indices, + const scalar_t* __restrict__ x_in, + scalar_t* __restrict__ x_out, + const scalar_t* __restrict__ weight, + const scalar_t* __restrict__ bias, + int K, int Cin, int Cout, bool add +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K*Cout) + return; + + int i = lin_idx / Cout; + int cout = lin_idx % Cout; + + int x_out_lin_idx = Cout * indices[i] + cout; + int x_int_lin_idx = Cin * indices[i]; + + if (!add) + x_out[x_out_lin_idx] = 0; + + for (int cin=0; cin +__global__ void masked_lin_no_bias_kernel( + int64_t* __restrict__ indices, + const scalar_t* __restrict__ x_in, + scalar_t* __restrict__ x_out, + const scalar_t* __restrict__ weight, + int K, int Cin, int Cout, bool add +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K*Cout) + return; + + int i = lin_idx / Cout; + int cout = lin_idx % Cout; + + int x_out_lin_idx = Cout * indices[i] + cout; + int x_int_lin_idx = Cin * indices[i]; + + if (!add) + x_out[x_out_lin_idx] = 0; + + for (int cin=0; cin<<>>( + indices.data(), + x_in.data(), + x_out.data(), + weight.data(), + K, Cin, Cout, add); +} + + +void masked_lin( + const torch::Tensor& indices, + const torch::Tensor& x_in, + torch::Tensor& x_out, + const torch::Tensor& weight, + const torch::Tensor& bias, + bool add + ) +{ + unsigned K = indices.size(0); + unsigned Cin = weight.size(1); + unsigned Cout = weight.size(0); + + unsigned threads = 256; + dim3 blocks((K*Cout + threads - 1) / threads, 1); + + masked_lin_kernel<<>>( + indices.data(), + x_in.data(), + x_out.data(), + weight.data(), + bias.data(), K, Cin, Cout, add); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("masked_lin", &masked_lin, "Find edges from a queue of events."); + m.def("masked_lin_no_bias", &masked_lin_no_bias, "Find edges from a queue of events."); + m.def("masked_isdiff", &masked_isdiff, "Find edges from a queue of events."); + m.def("masked_inplace_BN", &masked_inplace_BN, "Find edges from a queue of events."); +} diff --git a/src/dagr/asynchronous/base/__init__.py b/src/dagr/asynchronous/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dagr/asynchronous/base/base.py b/src/dagr/asynchronous/base/base.py new file mode 100644 index 0000000..9f0a618 --- /dev/null +++ b/src/dagr/asynchronous/base/base.py @@ -0,0 +1,28 @@ +from contextlib import contextmanager +import logging + + +def add_async_graph(module, log_flops: bool = False): + module.asy_graph = None + module.asy_flops_log = [] if log_flops else None + return module + + +def make_asynchronous(module, initialization_func, processing_func): + module.sync_forward = module.forward + def async_forward(*args, **kwargs): + with async_context(module, initialization_func, processing_func) as func: + output = func(module, *args, **kwargs) + return output + module.forward = async_forward + return module + + +@contextmanager +def async_context(module, initialization_func, processing_func): + if module.asy_graph is None: + logging.debug(f"Graph initialization of module {module}") + yield initialization_func + else: + logging.debug(f"Calling processing of module {module}") + yield processing_func diff --git a/src/dagr/asynchronous/base/utils.py b/src/dagr/asynchronous/base/utils.py new file mode 100644 index 0000000..c768e52 --- /dev/null +++ b/src/dagr/asynchronous/base/utils.py @@ -0,0 +1,55 @@ +import torch + +from typing import Tuple +import asy_tools + + +def _efficient_cat(data_list): + data_list = [d for d in data_list if len(d) > 0] + if len(data_list) == 1: + return data_list[0] + return torch.cat(data_list) + +def _efficient_cat_unique(data_list): + # first only keep elements that have len > 0 + data_list_filt = [data for data in data_list if data.shape[0] > 0] + if len(data_list_filt) == 1: + return data_list_filt[0] + elif len(data_list_filt) == 0: + return data_list[0] + else: + return torch.cat(data_list_filt).unique() + +def _to_hom(x, ones=None): + if ones is None or len(ones) < len(x): + ones = torch.ones_like(x[:,-1:]) + else: + ones = ones[:len(x)] + return torch.cat([x, ones], dim=-1) + +def _from_hom(x): + return x[:,:-1] / (x[:,-1:] + 1e-9) + +def graph_new_nodes(old_data, new_data): + return torch.arange(old_data.x.shape[0], new_data.x.shape[0], device=new_data.x.device, dtype=torch.long) + +def graph_changed_nodes(old_data, new_data) -> Tuple[torch.Tensor, torch.Tensor]: + len_x_old = old_data.x.shape[0] + len_pos_old = old_data.pos.shape[0] + x_new = new_data.x[:len_x_old] if len_x_old < new_data.x.shape[0] else new_data.x + pos_new = new_data.pos[:len_pos_old] if len_pos_old < new_data.pos.shape[0] else new_data.pos + + diff_idx = asy_tools.masked_isdiff(new_data.diff_idx, x_new, old_data.x, 1e-8, 1e-5) if new_data.diff_idx.numel() > 0 else new_data.diff_idx + diff_pos_idx = asy_tools.masked_isdiff(new_data.diff_pos_idx, pos_new, old_data.pos, 1e-8, 1e-5) if new_data.diff_pos_idx.numel() > 0 else new_data.diff_pos_idx + + return diff_idx, diff_pos_idx + +def torch_isin(query, database): + if hasattr(torch, "isin"): + return torch.isin(query, database) + else: + return (query.view(1, -1) == database.view(-1, 1)).any(0) + +def __remove_duplicate_from_A(a, b): + a_in_b = (a.view(2,1,-1) == b.view(2,-1,1)).all(0).any(0) + return a[:,~a_in_b] \ No newline at end of file diff --git a/src/dagr/asynchronous/batch_norm.py b/src/dagr/asynchronous/batch_norm.py new file mode 100644 index 0000000..e3274eb --- /dev/null +++ b/src/dagr/asynchronous/batch_norm.py @@ -0,0 +1,77 @@ +import torch +import asy_tools +from torch_geometric.nn.norm import BatchNorm +import torch.nn.functional as F +from .base.base import make_asynchronous, add_async_graph +from .base.utils import graph_changed_nodes, graph_new_nodes + + +def __sync_forward(m, x): + return F.batch_norm(x, m.running_mean, m.running_var, m.weight, m.bias, False, m.momentum, m.eps) + + +def __graph_initialization(module: BatchNorm, data) -> torch.Tensor: + module.asy_graph = data.clone() + module.graph_out = data.clone() + module.graph_out.x = __sync_forward(module.module, data.x) + + # flops are not counted since BN can be fused with previous conv operator. + if module.asy_flops_log is not None: + flops = 0 + module.asy_flops_log.append(flops) + + return module.graph_out.clone() + +def __graph_processing(module: BatchNorm, data) -> torch.Tensor: + """Batch norms only execute simple normalization operation, which already is very efficient. The overhead + for looking for diff nodes would be much larger than computing the dense update. + + However, a new node slightly changes the feature distribution and therefore all activations, when calling + the dense implementation. Therefore, we approximate the distribution with the initial distribution as + num_new_events << num_initial_events. + """ + if len(module.asy_graph.x) < len(data.x): + diff_idx = graph_new_nodes(module.asy_graph, data) + module.graph_out.x = torch.cat([module.graph_out.x, torch.zeros_like(data.x[:len(diff_idx)])]) + else: + diff_idx, _ = graph_changed_nodes(module.asy_graph, data) + + if data.diff_idx.numel()>0: + asy_tools.masked_inplace_BN(data.diff_idx, data.x, + module.graph_out.x, + module.module.running_mean, + module.module.running_var, + module.module.weight, + module.module.bias, + module.module.eps) + + # If required, compute the flops of the asynchronous update operation. + if module.asy_flops_log is not None: + flops = 0 + module.asy_flops_log.append(flops) + + data.x = module.graph_out.x + + return data + + +def __check_support(module): + return True + + +def make_batch_norm_asynchronous(module: BatchNorm, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for batch norm (1d) layers. + By overwriting parts of the module asynchronous processing can be enabled without the need of re-learning + and moving its weights and configuration. So, a layer can be converted by, for example: + + ``` + module = BatchNorm(4) + module = make_batch_norm_asynchronous(module) + ``` + + :param module: batch norm module to transform. + :param log_flops: log flops of asynchronous update. + """ + assert __check_support(module) + module = add_async_graph(module, log_flops=log_flops) + return make_asynchronous(module, __graph_initialization, __graph_processing) diff --git a/src/dagr/asynchronous/cartesian.py b/src/dagr/asynchronous/cartesian.py new file mode 100644 index 0000000..3857b52 --- /dev/null +++ b/src/dagr/asynchronous/cartesian.py @@ -0,0 +1,76 @@ +import torch + +from torch_geometric.nn.norm import BatchNorm +from .base.base import make_asynchronous, add_async_graph + +def __edge_attr(pos, edge_index, norm, max): + (row, col), pos = edge_index, pos + + cart = pos[row] - pos[col] + cart = cart.view(-1, 1) if cart.dim() == 1 else cart + + if norm and cart.numel() > 0: + max_value = cart.abs().max() if max is None else max + cart = cart / (2 * max_value) + 0.5 + + return cart + + +def __graph_initialization(module: BatchNorm, data) -> torch.Tensor: + module.asy_graph = data.clone() + module.graph_out = data.clone() + module.graph_out.edge_attr = __edge_attr(data.pos, data.edge_index, module.norm, module.max) + + # flops are not counted since BN can be fused with previous conv operator. + if module.asy_flops_log is not None: + flops = 2 * len(module.graph_out.edge_attr) + module.asy_flops_log.append(flops) + + return module.graph_out.clone() + + +def __graph_processing(module: BatchNorm, data) -> torch.Tensor: + """Batch norms only execute simple normalization operation, which already is very efficient. The overhead + for looking for diff nodes would be much larger than computing the dense update. + + However, a new node slightly changes the feature distribution and therefore all activations, when calling + the dense implementation. Therefore, we approximate the distribution with the initial distribution as + num_new_events << num_initial_events. + """ + module.graph_out.pos = torch.cat([module.asy_graph.pos, data.pos]) + module.graph_out.x = torch.cat([module.asy_graph.x, data.x]) + module.graph_out.edge_attr = __edge_attr(module.graph_out.pos, data.edge_index, module.norm, module.max) + module.graph_out.edge_index = data.edge_index + + # flops are not counted since BN can be fused with previous conv operator. + if module.asy_flops_log is not None: + flops = 2 * len(module.graph_out.edge_attr) + module.asy_flops_log.append(flops) + + if hasattr(data, "diff_idx"): + module.graph_out.diff_idx = data.diff_idx + module.graph_out.diff_pos_idx = data.diff_pos_idx + + return module.graph_out + + +def __check_support(module): + return True + + +def make_cartesian_asynchronous(module: BatchNorm, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for cartesian layers. + By overwriting parts of the module asynchronous processing can be enabled without the need of re-learning + and moving its weights and configuration. So, a layer can be converted by, for example: + + ``` + module = Cartesian() + module = make_cartesian_asynchronous(module) + ``` + + :param module: cartesian module to transform. + :param log_flops: log flops of asynchronous update. + """ + assert __check_support(module) + module = add_async_graph(module, log_flops=log_flops) + return make_asynchronous(module, __graph_initialization, __graph_processing) \ No newline at end of file diff --git a/src/dagr/asynchronous/conv.py b/src/dagr/asynchronous/conv.py new file mode 100644 index 0000000..a5ee521 --- /dev/null +++ b/src/dagr/asynchronous/conv.py @@ -0,0 +1,266 @@ +import asy_tools +import torch +import torch_geometric.nn.conv + +from .base.base import make_asynchronous, add_async_graph +from .base.utils import graph_new_nodes, graph_changed_nodes, _efficient_cat_unique, torch_isin +from .flops import compute_flops_conv, compute_flops_cat +from torch_scatter import scatter_sum + + +def __conv(x, edge_index, edge_attr, mask, nn): + if edge_index.numel() > 0: + x_j = x[edge_index[0, :], :] + phi = nn.message(x_j, edge_attr=edge_attr[:, :nn.dim]) + y = nn.aggregate(phi, index=edge_index[1, :], ptr=None, dim_size=x.size()[0]) + else: + y = torch.zeros(size=(x.shape[0], nn.out_channels), dtype=x.dtype, device=x.device) + + if hasattr(nn, "root_weight") and nn.root_weight: + nn.lin_act = nn.lin(x) + y[mask] += nn.lin_act[mask] + + if hasattr(nn, "bias") and nn.bias is not None: + y[mask] += nn.bias + + return y + +def __graph_initialization(module, data, *args, **kwargs): + module.asy_graph = data.clone() + module.graph_out = data.clone() + + # Concat old and updated feature for output feature vector. + if hasattr(module.asy_graph, "active_clusters"): + mask = module.asy_graph.active_clusters + num_updated_elements = len(mask) + else: + mask = slice(None) + num_updated_elements = len(data.x) + + module.graph_out.x = __conv(data.x, data.edge_index, data.edge_attr, mask, module) + + # If required, compute the flops of the asynchronous update operation. Therefore, sum the flops for each node + # update, as they highly depend on the number of neighbors of this node. + if module.asy_flops_log is not None: + flops = compute_flops_conv(module, num_times_apply_bias_and_root=num_updated_elements, num_edges=data.edge_index.shape[1]) + module.asy_flops_log.append(flops) + + if hasattr(module, "to_dense"): + mask = module.graph_out.active_clusters + batch = module.graph_out.batch if module.graph_out.batch is None else module.graph_out.batch[mask] + if batch is None: + batch = torch.zeros(len(module.graph_out.pos[mask]), dtype=torch.long, device=data.x.device) + return module.to_dense(module.graph_out.x[mask], + module.graph_out.pos[mask], + module.graph_out.pooling, + batch) + + return module.graph_out.clone() + +def __edges_with_src_node(node_idx, edge_index, edge_attr=None, node_idx_type="src", return_changed_edges=False, return_mask=False): + if node_idx.numel() == 0: + outputs = [torch.empty(size=(2,0), dtype=torch.long, device=node_idx.device)] + if edge_attr is not None: + outputs.append(torch.empty(size=(0,3), dtype=edge_attr.dtype, device=edge_attr.device)) + if return_mask: + outputs.append(torch.empty(size=(0,), dtype=torch.bool, device=node_idx.device)) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + + if node_idx_type == "src": + mask = torch_isin(edge_index[0], node_idx) + elif node_idx_type == "dst": + mask = torch_isin(edge_index[1], node_idx) + elif node_idx_type == "both": + mask = torch_isin(edge_index[0], node_idx) | torch_isin(edge_index[1], node_idx) + else: + raise ValueError + + output = [edge_index[:,mask]] + if edge_attr is not None: + output.append(edge_attr[mask]) + if return_changed_edges: + output.append(mask.nonzero().ravel()) + if return_mask: + output.append(mask.nonzero().ravel()) + if len(output) == 1: + output = output[0] + return output + +def find_only_x(idx_new_comp, idx_diff, pos_idx_diff, edge): + return idx_new_comp[torch_isin(idx_new_comp, idx_diff) & ~torch_isin(idx_new_comp, pos_idx_diff) & ~torch_isin(idx_new_comp, edge)] + +def __graph_processing(module, data, *args, **kwargs): + """Asynchronous graph update for graph convolutional layer. + + After the initialization of the graph, only the nodes (and their receptive field) have to updated which either + have changed (different features) or have been added. Therefore, for updating the graph we have to first + compute the set of "diff" and "new" nodes to then do the convolutional message passing on this subgraph, + and add the resulting residuals to the graph. + + :param x: graph nodes features. + """ + num_edges_image_feat = 0 + num_edges = 0 + num_times_apply_bias_and_root = 0 + new_nodes = len(data.x) > len(module.asy_graph.x) + + # first update the input graph + if new_nodes: + idx_new = graph_new_nodes(module.asy_graph, data) + + module.asy_graph.x = torch.cat([module.asy_graph.x, data.x[idx_new]]) + idx_new_comp = idx_new + + # when new edges are added through added events, make sure to add them, otherwise only update the edge attributes + module.asy_graph.edge_index = torch.cat([module.asy_graph.edge_index, data.edge_index], dim=-1) + module.asy_graph.edge_attr = torch.cat([module.asy_graph.edge_attr, data.edge_attr], dim=0) + + zero_row = torch.zeros(len(idx_new), module.out_channels, device=data.x.device) + module.graph_out.x = torch.cat([module.graph_out.x, zero_row]) + + data.diff_idx = idx_new_comp + pos_idx_diff = torch.zeros(size=(0,), dtype=torch.long, device=data.x.device) + + if idx_new_comp.numel() > 0: + edge_index_new, edge_attr_new = data.edge_index, data.edge_attr + num_edges += edge_index_new.shape[1] + else: + idx_diff, pos_idx_diff = graph_changed_nodes(module.asy_graph, data) + idx_new_comp = _efficient_cat_unique([pos_idx_diff, idx_diff, data.edge_index[1].unique()]) + data.diff_idx = idx_new_comp + + if idx_new_comp.numel() > 0: + # find out dests of idx new, idx diff and pos_idx_diff + edge_index_update_message, mask = __edges_with_src_node(idx_new_comp, module.asy_graph.edge_index, return_mask=True) + edge_attr_update_message = module.asy_graph.edge_attr[mask] + num_edges += edge_index_update_message.shape[1] + if hasattr(module.asy_graph, "active_clusters") and hasattr(data, "_changed_attr"): + module.asy_graph.edge_attr[data._changed_attr_indices] = data._changed_attr + edge_attr_update_message_new = module.asy_graph.edge_attr[mask] + else: + edge_attr_update_message_new = edge_attr_update_message + + # when new edges are added through added events, make sure to add them, otherwise only update the edge attributes + if data.edge_index.numel() > 0: + module.asy_graph.edge_index = torch.cat([module.asy_graph.edge_index, data.edge_index], dim=-1) + module.asy_graph.edge_attr = torch.cat([module.asy_graph.edge_attr, data.edge_attr], dim=0) + + if idx_new_comp.numel() > 0 and edge_index_update_message.numel() > 0: + # first compute update to y + x_old = module.asy_graph.x[edge_index_update_message[0], :] + phi_old = module.message(x_old, edge_attr=edge_attr_update_message) + + # new messages + x_new = data.x[edge_index_update_message[0], :] + phi_new = module.message(x_new, edge_attr=edge_attr_update_message_new) + scatter_sum(phi_new-phi_old, index=edge_index_update_message[1],out=module.graph_out.x, dim=0, dim_size=len(module.graph_out.x)) + + data.diff_idx = _efficient_cat_unique([data.diff_idx, edge_index_update_message[1]]) + num_edges += edge_index_update_message.shape[1] + + only_x = find_only_x(idx_new_comp, idx_diff, pos_idx_diff, data.edge_index[1]) + if only_x is not None and len(only_x) > 0: + idx_new_comp = idx_new_comp[~torch_isin(idx_new_comp, only_x)] + generalized_lin(module, data.x - module.asy_graph.x, module.graph_out.x, only_x) + num_times_apply_bias_and_root += len(only_x) + + if idx_new_comp.numel() > 0: + + # edge and attrs for newly computed + edge_index_new, edge_attr_new = __edges_with_src_node(idx_new_comp, edge_index=module.asy_graph.edge_index, + edge_attr=module.asy_graph.edge_attr, + node_idx_type="dst") + + edge_index_pos, _ = __edges_with_src_node(pos_idx_diff, edge_index=module.asy_graph.edge_index, + edge_attr=module.asy_graph.edge_attr, + node_idx_type="dst") + num_edges_image_feat = edge_index_pos.shape[1] + + num_edges += edge_index_new.shape[1] + module.graph_out.x[idx_new_comp] = 0 + + if idx_new_comp.numel() > 0: + if edge_index_new.shape[1] > 0: + num_edges += edge_index_new.shape[1] + # next compute all messages for computing new index + x_j = data.x[edge_index_new[0, :], :] + phi = module.message(x_j, edge_attr=edge_attr_new[:,:module.dim]) + scatter_sum(phi, out=module.graph_out.x, index=edge_index_new[1], dim=0, dim_size=len(module.graph_out.x)) + + num_times_apply_bias_and_root += len(idx_new_comp) + generalized_lin(module, data.x, module.graph_out.x, idx_new_comp) + + data.x = module.graph_out.x + data.diff_pos_idx = pos_idx_diff + + # If required, compute the flops of the asynchronous update operation. Therefore, sum the flops for each node + # update, as they highly depend on the number of neighbors of this node. + if module.asy_flops_log is not None: + cat = hasattr(data, "skipped") and data.skipped + data.skipped = False + flops = compute_flops_conv(module, num_times_apply_bias_and_root=len(idx_new_comp), num_edges=num_edges, + concatenation=cat, num_image_channels=getattr(data, "num_image_channels", -1)) + + if cat: + flops += compute_flops_cat(module, num_edges=num_edges_image_feat, + num_times_apply_bias_and_root=num_times_apply_bias_and_root, num_image_channels=getattr(data, "num_image_channels", -1)) + + + module.asy_flops_log.append(flops) + + if hasattr(module, "to_dense"): + if pos_idx_diff.numel() > 0 or idx_new_comp.numel() > 0: + mask = data.active_clusters + batch = data.batch if data.batch is None else data.batch[mask] + if batch is None: + batch = torch.zeros(len(module.graph_out.pos[mask]), dtype=torch.long, device=data.x.device) + + return module.to_dense(data.x[mask], + data.pos[mask], + data.pooling, + batch) + else: + return module.dense[:1] + + return data + +def generalized_lin(module, input, output, idx): + uses_bias = hasattr(module, "bias") and module.bias is not None + uses_weight = hasattr(module, "root_weight") and module.root_weight + if not uses_weight: + return + + if uses_bias: + asy_tools.masked_lin(idx, input, output, module.lin.weight.data, module.bias.data, True) + else: + asy_tools.masked_lin_no_bias(idx, input, output, module.lin.weight.data, True) + +def __check_support(module) -> bool: + if isinstance(module, torch_geometric.nn.conv.GCNConv): + if module.normalize is True: + raise NotImplementedError("GCNConvs with normalization are not yet supported!") + return True + + +def make_conv_asynchronous(module, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for graph convolutional layers. + By overwriting parts of the module asynchronous processing can be enabled without the need of re-learning + and moving its weights and configuration. So, a convolutional layer can be converted by, for example: + + ``` + module = GCNConv(1, 2) + module = make_conv_asynchronous(module) + ``` + + :param module: convolutional module to transform. + :param r: update radius around new events. + :param edge_attributes: function for computing edge attributes (default = None). + :param is_initial: layer initial layer of sequential or deeper (default = False). + :param log_flops: log flops of asynchronous update. + """ + assert __check_support(module) + + module = add_async_graph(module, log_flops=log_flops) + return make_asynchronous(module, __graph_initialization, __graph_processing) diff --git a/src/dagr/asynchronous/evaluate_flops.py b/src/dagr/asynchronous/evaluate_flops.py new file mode 100644 index 0000000..4a2b413 --- /dev/null +++ b/src/dagr/asynchronous/evaluate_flops.py @@ -0,0 +1,261 @@ +import torch + +from torch_geometric.data import Batch, Data +from typing import List, Tuple +from collections import OrderedDict + +from . import make_model_asynchronous, make_model_synchronous + + +def split_data(data: Data, index: int)->Tuple[Data, Data]: + kwargs = dict(time_window=data.time_window, width=data.width, height=data.height) + + if hasattr(data, "image"): + kwargs['image'] = data.image + + data1 = Data(pos=data.pos[:index], x=data.x[:index], **kwargs) + data2 = Data(pos=data.pos[index:], x=data.x[index:], **kwargs) + + if hasattr(data, "pos_denorm"): + data1.pos_denorm = data.pos_denorm[:index] + data2.pos_denorm = data.pos_denorm[index:] + + return data1, data2 + +def forward_hook(inst, inp, out): + inp = inp[0] + + if type(inp) is list: + inp = inp[0].clone() + elif type(inp) is tuple or type(inp) is dict: + return + else: + inp = inp.clone() + + if type(out) is list: + out = out[0].clone() + elif type(out) is tuple or type(out) is dict: + return + else: + out = out.clone() + + if not hasattr(inst, "activations"): + inst.activations = [] + + if type(inp) is torch.Tensor: + inp = inp if len(inp.shape) == 2 else inp[0] + inp = Data(x=inp) + if type(out) is torch.Tensor: + out = out if len(out.shape) == 2 else out[0] + out = Data(x=out) + + if hasattr(inp, "active_clusters") and not hasattr(out, "active_clusters"): + out.active_clusters = inp.active_clusters + elif hasattr(out, "active_clusters") and not hasattr(inp, "active_clusters"): + inp.active_clusters = out.active_clusters + + inp = _mask_if_possible(inp) + out = _mask_if_possible(out) + + inst.activations.append((inp, out)) + +def _mask_if_possible(data): + mask = slice(None, None, None) + if hasattr(data,"active_clusters") and len(data.x) > data.active_clusters.max(): + mask = data.active_clusters + masked = Data() + if hasattr(data, "x"): + masked.x = data.x[mask] + if hasattr(data, "pos") and data.pos is not None: + masked.pos = data.pos[mask] + if hasattr(data, "edge_index"): + masked.edge_index = data.edge_index + masked.edge_attr = data.edge_attr + return masked + +def denorm(data): + denorm = torch.tensor([int(data.width), int(data.height), int(data.time_window)], device=data.pos.device) + data.pos_denorm = (denorm.view(1,-1) * data.pos + 1e-3).int() + data.batch = data.batch.int() + return data + +def evaluate_flops(model: torch.nn.Module, batch: Data, dense=False, + check_consistency=False, + return_all_samples=False) -> OrderedDict: + + flops_per_layer_batch = [] + + # for loop over batch + for i, data in enumerate(batch.to_data_list()): + events_initial, events_new = split_data(data, -1) + + events_initial = Batch.from_data_list([events_initial]) + events_new = Batch.from_data_list([events_new]) + data = Batch.from_data_list([data]) + + # prepare data for fast inference + data = denorm(data) + events_new = denorm(events_new) + events_initial = denorm(events_initial) + + # make a deep copy asynchronous version + handles = [] + if check_consistency: + for m in model.modules(): + handle = m.register_forward_hook(forward_hook) + handles.append(handle) + + with torch.no_grad(): + model.forward(data, reset=True, return_targets=False) + + model = make_model_asynchronous(model, log_flops=True) + + try: + with torch.no_grad(): + model.forward(events_initial, reset=True, return_targets=False) + model.forward(events_new, reset=False, return_targets=False) + + except Exception as e: + print(f"Crashed at index {i} with message {e}") + raise e + + index = 0 if dense else 1 + flops_per_layer = OrderedDict( + [ + (name, module.asy_flops_log[index]) for name, module in model.named_modules() \ + if hasattr(module, "asy_flops_log") and module.asy_flops_log is not None and len( + module.asy_flops_log) > 0 + ] + ) + + flops_per_layer = _filter_non_leaf_nodes(flops_per_layer) + flops_per_layer = _merge_to_level_flops(flops_per_layer, level=3) + + if not check_consistency: + flops_per_layer_batch.append(flops_per_layer) + + model = make_model_synchronous(model) + + if check_consistency: + # tests if outputs from 0th and 2nd run are equal + max_mistake_x_layer, max_mistake_pos_layer, global_summary = test_and_compare_activations(model, runs=[0,2]) + if max_mistake_x_layer[0] > 1e-3 or max_mistake_pos_layer[1] > 1e-3: + print(global_summary) + print(f"AssertionError(Failed at index {i}.)") + else: + flops_per_layer_batch.append(flops_per_layer) + print(global_summary) + + for handle in handles: + handle.remove() + for m in model.modules(): + if hasattr(m, "activations"): + del m.activations + + if len(flops_per_layer_batch) == 0: + return None + + # global average + flops_per_layer = _merge_list_flops(flops_per_layer_batch) + + output = {"flops_per_layer": flops_per_layer, "total_flops": sum(flops_per_layer.values())} + if return_all_samples: + output['flops_per_layer_batch'] = flops_per_layer_batch + + return output + +def _filter_non_leaf_nodes(flops_per_layer: OrderedDict)->OrderedDict: + filter_keys = [] + for q_name in flops_per_layer: + for name in flops_per_layer: + if q_name in name and q_name != name: + filter_keys.append(q_name) + break + for f in filter_keys: + flops_per_layer.pop(f) + return flops_per_layer + +def _merge_to_level_flops(flops_per_layer: OrderedDict, level=2)->OrderedDict: + known_flops = [] + known_keys = [] + for name, flops in flops_per_layer.items(): + layers = name.split(".") + layers_up_to_level = ".".join(layers[:level]) + if layers_up_to_level not in known_keys: + known_keys.append(layers_up_to_level) + known_flops.append(0) + index = known_keys.index(layers_up_to_level) + known_flops[index] += flops + + return OrderedDict(zip(known_keys, known_flops)) + +def _merge_list_flops(flops_per_layer_batch: List[OrderedDict])->OrderedDict: + return OrderedDict([(key, sum([f[key] for f in flops_per_layer_batch]) / len(flops_per_layer_batch)) for key in flops_per_layer_batch[0]]) + +def _summary(est, gt, prefix): + if len(est) != len(gt): + return "\tCannot compare since x do not have same length\n", None + max_diff, max_rel_diff, ind, max_ind = max_abs_diff(gt, est, threshold=1e-6) + + summary = f"\t{prefix} MAX DIFF: {max_diff} MAX REL DIFF: {max_rel_diff}\n" + if ind.numel() > 0: + summary += f"\t{prefix} IND: {max_ind.cpu().numpy().ravel().tolist()}\n" + return summary, max_diff + + +def max_rel_diff(x, y, threshold=None): + return error_above_threshold((x-y).abs() / (x.abs()+1e-6), threshold) + +def error_above_threshold(error, mag, threshold): + if threshold is None: + return error.max() + else: + error_ravel = error.ravel() + arg = error_ravel.argmax() + return error_ravel[arg], error_ravel[arg] / mag.ravel()[arg], (error > threshold).nonzero()[:,0].unique(), error.max(-1).values.argmax() + +def max_abs_diff(x, y, threshold=None, alpha=0): + error = (x-y).abs()-x.abs()*alpha + return error_above_threshold(error, x.abs(), threshold) + +def _print_summary_for_one(target, estimate, prefix=""): + max_diff_pos = None + if type(target) is torch.Tensor: + summary, max_diff_x = _summary(target, estimate, prefix) + else: + summary = "" + if target.pos is not None and estimate.pos is not None: + sub_summary, max_diff_pos = _summary(target.pos[:,:2], estimate.pos[:,:2], f"{prefix} POS") + summary += sub_summary + + sub_summary, max_diff_x = _summary(target.x, estimate.x, prefix=f"{prefix} X") + summary += sub_summary + + return summary, max_diff_x, max_diff_pos + +def print_summary_of_module(activations, runs=[0,2]): + target, estimate = [activations[i][1] for i in runs] + return _print_summary_for_one(target, estimate, "OUT") + +def test_and_compare_activations(model, runs=[0,2]): + num_mistakes = [] + global_summary = "" + for name, module in model.named_modules(): + if not hasattr(module, "activations"): + continue + else: + if len(module.activations) <= max(runs): + continue + + summary, max_diff_x, max_diff_pos = print_summary_of_module(module.activations, runs) + if max_diff_x is not None and max_diff_pos is not None: + num_mistakes.append([max_diff_x, max_diff_pos, name]) + global_summary += f"Inspecting {name}\n{summary}\n\n" + + max_mistake_x_layer = max(num_mistakes, key=lambda x: x[0]) + max_mistake_pos_layer = max(num_mistakes, key=lambda x: x[1]) + global_summary += f"Maximum mistakes: \n" \ + f"\t{max_mistake_x_layer}\n" \ + f"\t{max_mistake_pos_layer}" + + return max_mistake_x_layer, max_mistake_pos_layer, global_summary diff --git a/src/dagr/asynchronous/flops/__init__.py b/src/dagr/asynchronous/flops/__init__.py new file mode 100644 index 0000000..4ec3e1f --- /dev/null +++ b/src/dagr/asynchronous/flops/__init__.py @@ -0,0 +1,36 @@ +import logging +from torch.nn import ModuleList + +from .conv import compute_flops_conv, compute_flops_cat + + +def compute_flops_from_module(module) -> int: + """Compute flops from a GNN module (after the forward pass). + + Generally, there are two cases. Either the module is an asynchronous module, then it should + have an `flops_log`, which contains the flops used for the last forward pass. Otherwise, the + layer's flops are computed from to the synchronous, dense update. + + :param module: module to infer the flops from. + """ + module_name = module.__class__.__name__ + + if hasattr(module, "asy_flops_log") and module.asy_flops_log is not None: + assert type(module.asy_flops_log) == list, "asyc. flops log must be a list" + if type(module) is ModuleList: + flops = sum([compute_flops_from_module(layer) for layer in module._modules.values()]) + else: + assert len(module.asy_flops_log) > 0, f"asynchronous flops log is empty for module {module.__class__.__name__}" + flops = module.asy_flops_log[-1] + else: + logging.debug(f"Module {module_name} is not asynchronous, using flops = 0") + return 0 + + logging.debug(f"Module {module_name} adds {flops} flops") + return flops + + +__all__ = [ + "compute_flops_conv", + "compute_flops_from_module" +] diff --git a/src/dagr/asynchronous/flops/conv.py b/src/dagr/asynchronous/flops/conv.py new file mode 100644 index 0000000..7d6e915 --- /dev/null +++ b/src/dagr/asynchronous/flops/conv.py @@ -0,0 +1,37 @@ +import torch + + +def compute_flops_conv(module: torch.nn.Module, num_times_apply_bias_and_root: int, num_edges: int, concatenation=False, num_image_channels=-1) -> int: + # Iterate over every different and every new node, and add the number of flops introduced + # by the node to the overall flops count of the layer. + ni = num_edges + + m_in = module.in_channels + + if concatenation: + m_in -= num_image_channels + + m_out = module.out_channels + + flops = ni * (2*m_in-1) * m_out + + if hasattr(module, "root_weight") and module.root_weight: + flops += num_times_apply_bias_and_root * module.lin.weight.shape[0] * (2*module.lin.weight.shape[1]-1) + + if hasattr(module, "bias") and module.bias is not None: + flops += num_times_apply_bias_and_root * module.lin.weight.shape[0] + + return flops + + +def compute_flops_cat(module, num_edges, num_times_apply_bias_and_root, num_image_channels): + ni = num_edges + m_in = num_image_channels + m_out = module.out_channels + + flops = ni * (2 * m_in - 1) * m_out + + if hasattr(module, "root_weight") and module.root_weight: + flops += num_times_apply_bias_and_root * module.lin.weight.shape[0] * (2*m_in-1) + + return flops \ No newline at end of file diff --git a/src/dagr/asynchronous/linear.py b/src/dagr/asynchronous/linear.py new file mode 100644 index 0000000..cd1e9f0 --- /dev/null +++ b/src/dagr/asynchronous/linear.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +import torch_geometric +import asy_tools + +from torch.nn import Linear +import torch.nn.functional as F +from .base.base import make_asynchronous, add_async_graph +from .base.utils import graph_new_nodes, graph_changed_nodes + + +def __graph_initialization(module: Linear, data) -> torch.Tensor: + mask = data.active_clusters if hasattr(data, "active_clusters") else slice(None, None, None) + x = data.x[mask] + weight = module.mlp.weight + bias = module.mlp.bias + + y = torch.zeros(size=(len(data.x), weight.shape[0]), dtype=torch.float32, device=data.pos.device) + y[mask] = F.linear(x, weight, bias) + + module.asy_graph = data.clone() + module.graph_out = torch_geometric.data.Data(x=y, pos=data.pos) + if hasattr(data, "active_clusters"): + module.graph_out.active_clusters = data.active_clusters + + if module.asy_flops_log is not None: + flops = int(np.prod(x.size()) * y.size()[-1]) + module.asy_flops_log.append(flops) + + return module.graph_out.clone() + +def __graph_processing(module: Linear, data) -> torch.Tensor: + if len(module.asy_graph.x) < len(data.x): + diff_idx = graph_new_nodes(module.asy_graph, data) + diff_pos_idx = diff_idx.clone() + module.graph_out.x = torch.cat([module.graph_out.x, torch.zeros_like(module.graph_out.x[:len(diff_idx)])]) + else: + diff_idx, diff_pos_idx = graph_changed_nodes(module.asy_graph, data) + + weight = module.mlp.weight + bias = module.mlp.bias + + # Update the graph with the new values (only there where it has changed). + if diff_idx.numel() > 0: + if bias is not None: + asy_tools.masked_lin(diff_idx, data.x, module.graph_out.x, weight.data, bias.data, False) + else: + asy_tools.masked_lin_no_bias(diff_idx, data.x, module.graph_out.x, weight.data, False) + + # If required, compute the flops of the asynchronous update operation. + if module.asy_flops_log is not None: + cin = weight.shape[1] + cat = hasattr(data, "skipped") and data.skipped + data.skipped = False + + if cat: + cin -= data.num_image_channels + + flops = diff_idx.numel() * int(weight.shape[0] * (2*cin-1)) + flops += diff_idx.numel() * weight.shape[0] + module.asy_flops_log.append(flops) + + data.diff_idx = diff_idx + data.diff_pos_idx = diff_pos_idx + data.x = module.graph_out.x + + return data + +def __check_support(module: Linear): + return True + + +def make_linear_asynchronous(module: Linear, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for linear layers. + By overwriting parts of the module asynchronous processing can be enabled without the need of re-learning + and moving its weights and configuration. So, a linear layer can be converted by, for example: + + ``` + module = Linear(4, 2) + module = make_linear_asynchronous(module) + ``` + + :param module: linear module to transform. + :param log_flops: log flops of asynchronous update. + """ + assert __check_support(module) + module = add_async_graph(module, log_flops=log_flops) + return make_asynchronous(module, __graph_initialization, __graph_processing) diff --git a/src/dagr/asynchronous/max_pool.py b/src/dagr/asynchronous/max_pool.py new file mode 100644 index 0000000..7053ff2 --- /dev/null +++ b/src/dagr/asynchronous/max_pool.py @@ -0,0 +1,273 @@ +import logging +import torch + +from torch_geometric.data import Data +from torch_scatter import scatter_max, scatter_sum + +from .base.base import add_async_graph, make_asynchronous +from .base.utils import graph_changed_nodes, graph_new_nodes, _efficient_cat_unique, torch_isin, _efficient_cat +from .conv import __edges_with_src_node +from .base.utils import _to_hom, _from_hom, __remove_duplicate_from_A + + +def pool_edge(cluster, edge_index, self_loop): + edge_index = cluster[edge_index] + if self_loop: + edge_index = edge_index.unique(dim=-1) + else: + edge_index = edge_index[:,edge_index[0]!=edge_index[1]].unique(dim=-1) + + if len(edge_index) > 0: + return edge_index + return torch.zeros((2,0), dtype=torch.long, device=cluster.device) + + +def compute_attrs(transform, edge_index, pos): + return (pos[edge_index[0]] - pos[edge_index[1]]) / (2 * transform.max) + 0.5 + + +def __dense_process(module, data: Data, *args, **kwargs) -> Data: + # compute the cache to compute the output graph. This contains + # 1. the cluster assignment for each input feature -> dim num_input_nodes + # 2. the sum of positions for each feature in each cluster -> max_num_clusters + # 3. the count of positions for each feature -> max_num_clusters + # 4. which input nodes went to the computation of which output_node -> max_num_clusters x num_output + cluster_index = __get_global_cluster_index(module, pos=data.pos[:,:module.dim]) + x, pos = data.x, data.pos + edge_index = pool_edge(cluster_index, data.edge_index, module.self_loop) + + if hasattr(module.asy_graph, "active_clusters"): + active_cluster_index = cluster_index[module.asy_graph.active_clusters] + new_cluster_index = torch.full_like(cluster_index, fill_value=-1) + new_cluster_index[module.asy_graph.active_clusters] = active_cluster_index + cluster_index = new_cluster_index + x = x[module.asy_graph.active_clusters] + pos = pos[module.asy_graph.active_clusters] + else: + active_cluster_index = cluster_index + + pos_hom = scatter_sum(_to_hom(pos[:,:module.dim]), active_cluster_index, dim=0, dim_size=module.num_grid_cells) + output_pos = _from_hom(pos_hom) + + module.wh_inv = 1/ torch.Tensor([data.width[0], data.height[0]]).to(output_pos.device).view(1,-1) + output_pos[:,:2] = module.round_to_pixel(output_pos[:,:2], wh_inv=module.wh_inv) + + active_clusters = torch.unique(active_cluster_index) + + cache = Data(cluster_index=cluster_index, pos_hom=pos_hom) + + if module.aggr == 'max': + output_x = torch.full(size=(module.num_grid_cells, x.shape[1]), fill_value=-torch.inf, device=x.device) + _, output_argmax = scatter_max(x, active_cluster_index, dim=0, out=output_x, dim_size=module.num_grid_cells) + cache.output_argmax = output_argmax + else: + x_hom = _to_hom(x) + cache.output_x_hom = scatter_sum(x_hom, active_cluster_index, dim=0, dim_size=module.num_grid_cells) + output_x = _from_hom(cache.output_x_hom) + + module.ones = torch.ones_like(output_x[:,:1]) + + # construct output. This contains: + # the output graph -> has num_unique_clusters nodes + if module.keep_temporal_ordering: + t = pos[:, -1] if pos.shape[-1] > 2 else data.t_max[active_cluster_index] + output_t = torch.full(size=(module.num_grid_cells,), fill_value=-torch.inf, device=x.device) + t_max, _ = scatter_max(t, active_cluster_index, dim=0, out=output_t, dim_size=module.num_grid_cells) + if edge_index.shape[1] > 0: + t_src, t_dst = t_max[edge_index] + edge_index = edge_index[:, t_dst > t_src] + + output_graph = Data(x=output_x, + pos=output_pos, + edge_index=edge_index, + active_clusters=active_clusters, + width=data.width, + height=data.height) + + if module.keep_temporal_ordering: + output_graph.t_max = output_t + + if module.transform is not None: + output_graph = module.transform(output_graph) + + return output_graph, cache + +def __graph_initialization(module, data: Data, *args, **kwargs) -> Data: + """Graph initialization for asynchronous update. + + Both the input as well as the output graph have to be stored, in order to avoid repeated computation. The + input graph is used for spotting changed or new nodes (as for other asyn. layers), while the output graph + is compared to the set of diff & new nodes, in order to be updated. Depending on the type of pooling (max, mean, + average, etc) not only the output voxel feature have to be stored but also aggregations over all nodes in + one output voxel such as the sum or count. + + Next to the features the node positions are averaged over all nodes in the voxel, as well. To do so, + position aggregations (count, sum) are stored and updated, too. + """ + module.asy_graph = data.clone() + module.graph_out, module.cache = __dense_process(module, data) + module.graph_out.pooling = module.voxel_size + + logging.debug(f"Resulting in coarse graph {module.graph_out}") + + # Compute number of floating point operations (no cat, flatten, etc.). + if module.asy_flops_log is not None: + unique_clusters = len(module.graph_out.active_clusters) + flops = 6 * unique_clusters # pos and scatter with index + flops += module.graph_out.x.shape[1] * unique_clusters + module.graph_out.edge_index.numel() # every edge has to be re-assigned + module.asy_flops_log.append(flops) + + return module.graph_out.clone() + +#@profile +def __graph_process(module, data, *args, **kwargs) -> Data: + new_nodes = len(data.x) > len(module.asy_graph.x) + + if new_nodes: + new_idx = graph_new_nodes(module.asy_graph, data) + + module.asy_graph.x = torch.cat([module.asy_graph.x, data.x[new_idx]]) + module.asy_graph.pos = torch.cat([module.asy_graph.pos, data.pos[new_idx]]) + + new_cluster_idx = __get_global_cluster_index(module, data.pos[new_idx, :module.dim]) + + # add to active clusters + if new_idx.numel() > 0: + module.graph_out.active_clusters = torch.cat([new_cluster_idx, module.graph_out.active_clusters]).sort().values.unique() + + module.cache.cluster_index = torch.cat([module.cache.cluster_index, new_cluster_idx]) + diff_pos_idx = new_idx + new_pos_hom = _to_hom(data.pos[new_idx, :module.dim], module.ones) + recomp_pos_new = new_cluster_idx + recomp_x_new = new_cluster_idx + if recomp_x_new.numel() > 0: + recomp_x_new = recomp_x_new#.clone() + + num_diff_x = 0#len(diff_idx) + num_new = len(new_idx) + scatter_sum(new_pos_hom, new_cluster_idx, out=module.cache.pos_hom, dim=0) + + if recomp_x_new.numel() > 0: + if module.aggr == "max": + mask = torch.cat([module.cache.output_argmax[recomp_x_new].ravel(), new_idx]).unique() + else: + mask = torch_isin(module.cache.cluster_index, recomp_x_new) + + else: + num_new = 0 + + diff_idx, diff_pos_idx = graph_changed_nodes(module.asy_graph, data) + num_diff_x = len(diff_idx) + + recomp_x_new = None + recomp_pos_new = None + + if diff_pos_idx.numel()> 0: + inactive = torch_isin(diff_pos_idx, module.asy_graph.active_clusters) + old_pos = module.asy_graph.pos[diff_pos_idx[inactive], :module.dim] + module.asy_graph.pos[diff_pos_idx] = data.pos[diff_pos_idx] + + old_pos_hom = _to_hom(old_pos, module.ones) + old_cluster_idx_pos = __get_global_cluster_index(module, old_pos) + new_pos_hom = _to_hom(data.pos[diff_pos_idx, :module.dim], module.ones) + all_pos = torch.cat([-old_pos_hom, new_pos_hom]) + new_cluster_idx_pos = __get_global_cluster_index(module, data.pos[diff_pos_idx, :module.dim]) + module.cache.cluster_index[diff_pos_idx] = new_cluster_idx_pos + + recomp_x_new = new_cluster_idx_pos + recomp_pos_new = _efficient_cat([old_cluster_idx_pos, new_cluster_idx_pos]) + # todo stupid bug, shallow copy could occur + if recomp_pos_new.numel()>0 and recomp_pos_new.data_ptr() == recomp_x_new.data_ptr(): + recomp_pos_new = recomp_pos_new#.clone() + scatter_sum(all_pos, recomp_pos_new, out=module.cache.pos_hom, dim=0) + + if diff_idx.numel() > 0: + cluster_idx_x = __get_global_cluster_index(module, module.asy_graph.pos[diff_idx, :module.dim]) + recomp_x_new = cluster_idx_x if recomp_x_new is None else _efficient_cat_unique([recomp_x_new, cluster_idx_x]) + + if recomp_x_new is not None and recomp_x_new.numel() > 0: + mask = torch_isin(module.cache.cluster_index, recomp_x_new) + if module.aggr == "max": + module.graph_out.x[recomp_x_new] = -torch.inf + + if recomp_x_new is not None and recomp_x_new.numel() > 0: + if module.aggr == "max": + scatter_max(data.x[mask], module.cache.cluster_index[mask], out=module.graph_out.x, dim=0) + else: + delta_x_hom = _to_hom(data.x[mask], module.ones) # + valid = ~torch.isinf(module.asy_graph.x[mask][:,0]) + delta_x_hom[valid] -= _to_hom(module.asy_graph.x[mask][valid], module.ones) + scatter_sum(delta_x_hom, module.cache.cluster_index[mask], out=module.cache.output_x_hom, dim=0) + module.graph_out.x[recomp_x_new] = _from_hom(module.cache.output_x_hom[recomp_x_new]) + + # find the edges which are associated with changed positions since these need their attrs updated + # however, here we can only look at the x,y values. If only the third attr changes, then we don't need to do anything + if recomp_pos_new is not None and recomp_pos_new.numel() > 0: + # update pos with the updated positions + new_pos = _from_hom(module.cache.pos_hom[recomp_pos_new]) + new_pos[:,:2] = module.round_to_pixel(new_pos[:,:2], wh_inv=module.wh_inv) + module.graph_out.pos[recomp_pos_new,:module.dim] = new_pos + update_edge_index, changed_edges = __edges_with_src_node(recomp_pos_new, module.graph_out.edge_index, node_idx_type="both", return_changed_edges=True) + if module.transform is not None: + module.graph_out._changed_attr = compute_attrs(module.transform, update_edge_index, module.graph_out.pos) + module.graph_out._changed_attr_indices = changed_edges + + # also handle edges which come from new connections at the input. These first need to be pooled + # then check if they are actually new. + if data.edge_index.numel() > 0: + coarse_edge_index = pool_edge(module.cache.cluster_index, data.edge_index, module.self_loop) + module.graph_out.edge_index = __remove_duplicate_from_A(coarse_edge_index, module.graph_out.edge_index) + else: + module.graph_out.edge_index = data.edge_index#torch.empty((2, 0), dtype=torch.long, device=data.x.device) + + if module.transform is not None: + if module.graph_out.edge_index.numel() > 0: + module.graph_out.edge_attr = compute_attrs(module.transform, module.graph_out.edge_index, module.graph_out.pos) + else: + module.graph_out.edge_attr = data.edge_attr + + module.graph_out.diff_idx = recomp_x_new.unique() if recomp_x_new is not None else diff_idx + module.graph_out.diff_pos_idx = recomp_pos_new.unique() if recomp_pos_new is not None else diff_pos_idx + + if module.asy_flops_log is not None: + num_recomp_x = 0 if recomp_x_new is None else len(recomp_x_new) + num_recomp_pos = 0 if recomp_pos_new is None else len(recomp_pos_new) + flops = 0 + flops += num_recomp_x * module.graph_out.x.shape[1] # perform max + flops += num_recomp_pos # recompute pos + flops += 4 * len(diff_pos_idx) # subtract and add pos twice + flops += len(diff_pos_idx) + num_diff_x # get cluster center for each index + flops += num_new * 2 # add twice, also compute cluster center + module.asy_flops_log.append(flops) + + return module.graph_out + +def __get_global_cluster_index(module, pos) -> torch.LongTensor: + n_pos_dim = 2#pos.shape[1] + voxel_size = module.voxel_size[:n_pos_dim]#, device=pos.device) + pos_vertex = (pos[:,:2] / voxel_size).long() + x_v, y_v = pos_vertex.t() + grid_size = (1 / voxel_size + 1e-3).long() + cluster_idx = x_v + grid_size[0] * y_v + return cluster_idx + + +def make_max_pool_asynchronous(module, log_flops: bool = False): + """Module converter from synchronous to asynchronous & sparse processing for graph max pooling layer. + By overwriting parts of the module asynchronous processing can be enabled without the need re-creating the + object. So, a max pooling layer can be converted by, for example: + + ``` + module = MaxPool([4, 4]) + module = make_max_pool_asynchronous(module) + ``` + + :param module: standard max pooling module. + :param grid_size: grid size (grid starting at 0, spanning to `grid_size`), >= `size`. + :param r: update radius around new events. + :param log_flops: log flops of asynchronous update. + """ + + module = add_async_graph(module, log_flops=log_flops) + module = make_asynchronous(module, __graph_initialization, __graph_process) + return module diff --git a/src/dagr/data/augment.py b/src/dagr/data/augment.py new file mode 100644 index 0000000..96630aa --- /dev/null +++ b/src/dagr/data/augment.py @@ -0,0 +1,289 @@ +import torch + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data +from typing import List + +import cv2 +import numpy as np +import numba +import torch_geometric.transforms as T + + +@numba.njit +def _add_event(x, y, xlim, ylim, p, i, count, pos, mask, threshold=1): + count[ylim, xlim] += float(p * (1 - abs(x - xlim)) * (1 - abs(y - ylim))) + pol = 1 if count[ylim, xlim] > 0 else -1 + + if pol * count[ylim, xlim] > threshold: + count[ylim, xlim] -= pol * threshold + + mask[i] = True + pos[i, 0] = xlim + pos[i, 1] = ylim + + +@numba.njit +def _subsample(pos: np.ndarray, polarity: np.ndarray, mask: np.ndarray, count: np.ndarray, threshold=1): + for i in range(len(pos)): + x, y = pos[i] + x0, x1 = int(x), int(x+1) + y0, y1 = int(y), int(y+1) + + _add_event(x, y, x0, y0, polarity[i,0], i=i, count=count, pos=pos, mask=mask, threshold=threshold) + _add_event(x, y, x1, y0, polarity[i,0], i=i, count=count, pos=pos, mask=mask, threshold=threshold) + _add_event(x, y, x0, y1, polarity[i,0], i=i, count=count, pos=pos, mask=mask, threshold=threshold) + _add_event(x, y, x1, y1, polarity[i,0], i=i, count=count, pos=pos, mask=mask, threshold=threshold) + + +def _crop_events(data, left, right, not_crop_idx=None): + if not_crop_idx is None: + not_crop_idx = torch.all((data.pos >= left) & (data.pos <= right), dim=1) + + data.x = data.x[not_crop_idx] + data.pos = data.pos[not_crop_idx] + + if hasattr(data, "t"): + data.t = data.t[not_crop_idx] + + return data + +def _crop_image(image, left, right): + xmin, ymin = left + xmax, ymax = right + image[:ymin, :] = 0 + image[ymax:, :] = 0 + image[:, :xmin] = 0 + image[:, xmax:] = 0 + return image + +def _resize_image(image, height, width, bg=None): + new_image = cv2.resize(image, (width, height), interpolation=cv2.INTER_NEAREST) + px = (new_image.shape[1] - image.shape[1])//2 + py = (new_image.shape[0] - image.shape[0])//2 + + if px >= 0: + bg = new_image[py:py+image.shape[0], px:px+image.shape[1]] + else: + assert bg is not None + bg[-py:-py+new_image.shape[0], -px:-px+new_image.shape[1]] = new_image + + return bg + +def _crop_bbox(bbox: torch.Tensor, left: torch.Tensor, right: torch.Tensor): + bbox = bbox.clone() + bbox[:,2:4] += bbox[:,:2] + bbox[:,:2] = torch.clamp(bbox[:,:2], min=left, max=right) + bbox[:,2:4] = torch.clamp(bbox[:,2:4], min=left, max=right) + bbox[:,2:4] -= bbox[:,:2] + return bbox + +def _scale_and_clip(x, scale): + return int(torch.clamp(x * scale, min=0, max=scale-1)) + + +class RandomHFlip(BaseTransform): + def __init__(self, p: float): + self.p = p + + def __call__(self, data: Data): + if torch.rand(1) > self.p: + return data + + data.pos[:,0] = data.width - 1 - data.pos[:,0] + + if hasattr(data, "image"): + data.image = np.ascontiguousarray(data.image[:,::-1]) + + if hasattr(data, "bbox"): + data.bbox[:, 0] = data.width - 1 - (data.bbox[:, 0] + data.bbox[:, 2]) + + if hasattr(data, "bbox0"): + data.bbox0[:, 0] = data.width - 1 - (data.bbox0[:, 0] + data.bbox0[:, 2]) + + return data + + +class Crop(BaseTransform): + r"""Crop with max and min values, has to be called before a graph is generated. + + Args: + min (List[float]): min value per dimension + max (List[float]): max value per dimension + """ + def __init__(self, min: List[float], max: List[float]): + self.min = torch.as_tensor(min) + self.max = torch.as_tensor(max) + + def init(self, height, width): + size = [width, height] + self.max = torch.IntTensor([_scale_and_clip(m, s) for m, s in zip(self.max, size)]) + self.min = torch.IntTensor([_scale_and_clip(m, s) for m, s in zip(self.min, size)]) + + def __call__(self, data: Data): + data = _crop_events(data, self.min, self.max) + + if hasattr(data, "image"): + data.image = _crop_image(data.image, self.min, self.max) + + # crop bbox to dimension + if hasattr(data, "bbox"): + data.bbox = _crop_bbox(data.bbox, self.min, self.max) + + if hasattr(data, "bbox0"): + data.bbox0 = _crop_bbox(data.bbox0, self.min, self.max) + + return data + + +class RandomZoom(BaseTransform): + def __init__(self, zoom, subsample=False): + self.zoom = zoom + self.subsample = subsample + self.image = None + + if subsample: + self._count = None + + def _subsample(self, data, zoom, count): + pos_zoom = data.pos.numpy() + + mask = np.zeros(len(data.pos), dtype="bool") + _subsample(pos_zoom, data.x.numpy(), mask, count, threshold=1/(float(zoom)**2)) + + data.pos = torch.from_numpy(pos_zoom[mask].astype("int16")) # implicit cast to int + data.x = data.x[mask] + if hasattr(data, "t"): + data.t = data.t[mask] + + return data + + def init(self, height, width): + self.image = np.zeros((height, width, 3), dtype="uint8") + self._count = np.zeros((height + 1, width + 1), dtype="float32") + + def __call__(self, data): + zoom = torch.rand(1) * (self.zoom[1] - self.zoom[0]) + self.zoom[0] + width, height = int(np.ceil(data.width * zoom)), int(np.ceil(data.height * zoom)) + H, W = self.image.shape[:2] + + data.pos[:, 0] = ((data.pos[:, 0] - W // 2) * zoom + W // 2).to(torch.int16) + data.pos[:, 1] = ((data.pos[:, 1] - H // 2) * zoom + H // 2).to(torch.int16) + + if self.subsample and zoom < 1: + data = self._subsample(data, float(zoom), count=self._count.copy()) + + if hasattr(data, "image"): + data.image = _resize_image(data.image, width=width, height=height, bg=self.image.copy() if zoom < 1 else None) + + if hasattr(data, "bbox"): + data.bbox[:,2:4] *= zoom + data.bbox[:,0] = ((data.bbox[:,0] - W//2) * zoom + W//2) + data.bbox[:,1] = ((data.bbox[:,1] - H//2) * zoom + H//2) + + if hasattr(data, "bbox0"): + data.bbox0[:,2:4] *= zoom + data.bbox0[:,0] = ((data.bbox0[:,0] - W//2) * zoom + W//2) + data.bbox0[:,1] = ((data.bbox0[:,1] - H//2) * zoom + H//2) + + return data + + +class RandomCrop(BaseTransform): + r"""Random crop, assumes all pos values are in [0,1] + + Args: + size (List[float]): crop size per dimension + dim (List[int]): dimension of the crop, default = [0,1] + p float: only to random crop with a probability of p + """ + def __init__(self, size: List[float] = [0.75, 0.75], dim: List[int]=[0,1], p=0.5): + self.size = torch.as_tensor(size) + self.dim = dim + self.p = p + + def init(self, height, width): + size = torch.IntTensor([width, height]) + self.size = torch.IntTensor([_scale_and_clip(s, ss) for s, ss in zip(self.size, size)]) + self.left_max = size - self.size + + def __call__(self, data: Data): + if torch.rand(1) > self.p: + return data + + left = (torch.rand(len(self.dim)) * self.left_max).to(torch.int16) + right = left + self.size + + data = _crop_events(data, left, right) + + if hasattr(data, "image"): + data.image = _crop_image(data.image, left, right) + + # crop bbox to new crop dimension + if hasattr(data, "bbox"): + data.bbox = _crop_bbox(data.bbox, left, right) + + if hasattr(data, "bbox0"): + data.bbox0 = _crop_bbox(data.bbox0, left, right) + + return data + + +class RandomTranslate(BaseTransform): + r"""Random crop, assumes all pos values are in [0,1] + + Args: + size (float): crop size per dimension + dim (int): dimension of the crop, default = [0,1] + """ + def __init__(self, size: List[float]): + self.size = torch.as_tensor(size).float() + self.image = None + + def init(self, height, width): + size = [width, height] + self.size = torch.IntTensor([_scale_and_clip(s, ss) for s, ss in zip(self.size, size)]) + self.image = np.zeros((height + 2 * self.size[1], width + 2 * self.size[0], 3), dtype="uint8") + + def pad(self, image, bg): + px = (bg.shape[1] - image.shape[1])//2 + py = (bg.shape[0] - image.shape[0])//2 + bg[py:py + image.shape[0], px:px + image.shape[1]] = image + return bg + + def __call__(self, data: Data): + move_px = (self.size * (torch.rand(len(self.size)) * 2 - 1)).to(torch.int16) + data.pos = data.pos + move_px + + if hasattr(data, "image"): + image = self.pad(data.image, self.image.copy()) + data.image = image[self.size[1]-move_px[1]:self.size[1]-move_px[1]+data.height, \ + self.size[0]-move_px[0]:self.size[0]-move_px[0]+data.width] + + if hasattr(data, "bbox"): + data.bbox[:,:2] += move_px + + if hasattr(data, "bbox0"): + data.bbox0[:,:2] += move_px + + return data + + +class Augmentations: + transform_testing = T.Compose([ + Crop([0, 0], [1, 1]), + ]) + + def __init__(self, args): + self.transform_training = T.Compose([ + RandomHFlip(p=args.aug_p_flip), + RandomCrop([0.75, 0.75], p=0.2), + RandomZoom(zoom=[1, args.aug_zoom], subsample=True), + RandomTranslate([args.aug_trans, args.aug_trans, 0]), + Crop([0, 0], [1, 1]), + ]) + +def init_transforms(transforms, height, width): + for t in transforms: + if hasattr(t, "init"): + t.init(height=height, width=width) \ No newline at end of file diff --git a/src/dagr/data/dsec_data.py b/src/dagr/data/dsec_data.py new file mode 100644 index 0000000..3870673 --- /dev/null +++ b/src/dagr/data/dsec_data.py @@ -0,0 +1,182 @@ +from pathlib import Path +from typing import Optional, Callable + +from torch_geometric.data import Dataset + +import numpy as np +import cv2 + +import torch +from functools import lru_cache + +from dsec_det.dataset import DSECDet + +from dsec_det.io import yaml_file_to_dict +from dagr.data.dsec_utils import filter_tracks, crop_tracks, rescale_tracks, compute_class_mapping, map_classes, filter_small_bboxes +from dsec_det.directory import BaseDirectory +from dagr.data.augment import init_transforms +from dagr.data.utils import to_data + + +def tracks_to_array(tracks): + return np.stack([tracks['x'], tracks['y'], tracks['w'], tracks['h'], tracks['class_id']], axis=1) + + + +def interpolate_tracks(detections_0, detections_1, t): + assert len(detections_1) == len(detections_0) + if len(detections_0) == 0: + return detections_1 + + t0 = detections_0['t'][0] + t1 = detections_1['t'][0] + + assert t0 < t1 + + # need to sort detections + detections_0 = detections_0[detections_0['track_id'].argsort()] + detections_1 = detections_1[detections_1['track_id'].argsort()] + + r = ( t - t0 ) / ( t1 - t0 ) + detections_out = detections_0.copy() + for k in 'xywh': + detections_out[k] = detections_0[k] * (1 - r) + detections_1[k] * r + + return detections_out + +class EventDirectory(BaseDirectory): + @property + @lru_cache + def event_file(self): + return self.root / "left/events_2x.h5" + + +class DSEC(Dataset): + MAPPING = dict(pedestrian="pedestrian", rider=None, car="car", bus="car", truck="car", bicycle=None, + motorcycle=None, train=None) + def __init__(self, + root: Path, + split: str, + transform: Optional[Callable]=None, + debug=False, + min_bbox_diag=0, + min_bbox_height=0, + scale=2, + cropped_height=430, + only_perfect_tracks=False, + demo=False): + + Dataset.__init__(self) + + split_config = None + if not demo: + split_config = yaml_file_to_dict(Path(__file__).parent / "dsec_split.yaml") + assert split in split_config.keys(), f"'{split}' not in {list(split_config.keys())}" + + self.dataset = DSECDet(root=root, split=split, sync="back", debug=debug, split_config=split_config) + + for directory in self.dataset.directories.values(): + directory.events = EventDirectory(directory.events.root) + + self.scale = scale + self.width = self.dataset.width // scale + self.height = cropped_height // scale + self.classes = ("car", "pedestrian") + self.time_window = 1000000 + self.min_bbox_height = min_bbox_height + self.min_bbox_diag = min_bbox_diag + self.debug = debug + self.num_us = -1 + + self.class_remapping = compute_class_mapping(self.classes, self.dataset.classes, self.MAPPING) + + if transform is not None and hasattr(transform, "transforms"): + init_transforms(transform.transforms, self.height, self.dataset.width) + + self.transform = transform + + self.image_index_pairs, self.track_masks = filter_tracks(dataset=self.dataset, image_width=self.width, + image_height=self.height, + class_remapping=self.class_remapping, + min_bbox_height=min_bbox_height, + min_bbox_diag=min_bbox_diag, + only_perfect_tracks=only_perfect_tracks, + scale=scale) + + def set_num_us(self, num_us): + self.num_us = num_us + + def __len__(self): + return sum(len(d) for d in self.image_index_pairs.values()) + + def preprocess_detections(self, detections): + detections = rescale_tracks(detections, self.scale) + detections = crop_tracks(detections, self.width, self.height) + detections['class_id'], _ = map_classes(detections['class_id'], self.class_remapping) + return detections + + def preprocess_events(self, events): + mask = events['y'] < self.height + events = {k: v[mask] for k, v in events.items()} + if len(events['t']) > 0: + events['t'] = self.time_window + events['t'] - events['t'][-1] + events['p'] = 2 * events['p'].reshape((-1,1)).astype("int8") - 1 + return events + + def preprocess_image(self, image): + image = image[:self.scale * self.height] + image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_CUBIC) + image = torch.from_numpy(image).permute(2, 0, 1) + image = image.unsqueeze(0) + return image + + def __getitem__(self, idx): + dataset, image_index_pairs, track_masks, idx = self.rel_index(idx) + image_index_0, image_index_1 = image_index_pairs[idx] + image_ts_0, image_ts_1 = dataset.images.timestamps[[image_index_0, image_index_1]] + + detections_0 = self.dataset.get_tracks(image_index_0, mask=track_masks, directory_name=dataset.root.name) + detections_1 = self.dataset.get_tracks(image_index_1, mask=track_masks, directory_name=dataset.root.name) + + detections_0 = self.preprocess_detections(detections_0) + detections_1 = self.preprocess_detections(detections_1) + + image_0 = self.dataset.get_image(image_index_0, directory_name=dataset.root.name) + image_0 = self.preprocess_image(image_0) + + events = self.dataset.get_events(image_index_0, directory_name=dataset.root.name) + + if self.num_us >= 0: + image_ts_1 = image_ts_0 + self.num_us + events = {k: v[events['t'] < image_ts_1] for k, v in events.items()} + detections_1 = interpolate_tracks(detections_0, detections_1, image_ts_1) + + # here, the timestamp of the events is no longer absolute + events = self.preprocess_events(events) + + # convert to torch geometric data + data = to_data(**events, bbox=tracks_to_array(detections_1), bbox0=tracks_to_array(detections_0), t0=image_ts_0, t1=image_ts_1, + width=self.width, height=self.height, time_window=self.time_window, + image=image_0, sequence=str(dataset.root.name)) + + if self.transform is not None: + data = self.transform(data) + + # remove bboxes if they have 0 width or height + mask = filter_small_bboxes(data.bbox[:, 2], data.bbox[:, 3], self.min_bbox_height, self.min_bbox_diag) + data.bbox = data.bbox[mask] + mask = filter_small_bboxes(data.bbox0[:, 2], data.bbox0[:, 3], self.min_bbox_height, self.min_bbox_diag) + data.bbox0 = data.bbox0[mask] + + return data + + def rel_index(self, idx): + for folder in self.dataset.subsequence_directories: + name = folder.name + image_index_pairs = self.image_index_pairs[name] + directory = self.dataset.directories[name] + track_mask = self.track_masks[name] + if idx < len(image_index_pairs): + return directory, image_index_pairs, track_mask, idx + idx -= len(image_index_pairs) + raise IndexError \ No newline at end of file diff --git a/src/dagr/data/dsec_split.yaml b/src/dagr/data/dsec_split.yaml new file mode 100644 index 0000000..5e6d63d --- /dev/null +++ b/src/dagr/data/dsec_split.yaml @@ -0,0 +1,63 @@ +train: + - thun_00_a + - interlaken_00_c + - interlaken_00_d + - interlaken_00_e + - interlaken_00_f + - interlaken_00_g + - zurich_city_00_a + - zurich_city_00_b + - zurich_city_01_a + - zurich_city_01_b + - zurich_city_01_c + - zurich_city_01_d + - zurich_city_01_e + - zurich_city_01_f + - zurich_city_02_a + - zurich_city_02_b + - zurich_city_02_c + - zurich_city_02_d + - zurich_city_02_e + - zurich_city_03_a + - zurich_city_04_a + - zurich_city_04_b + - zurich_city_04_c + - zurich_city_04_d + - zurich_city_04_e + - zurich_city_04_f + - zurich_city_05_a + - zurich_city_05_b + - zurich_city_06_a + - zurich_city_07_a + - zurich_city_08_a + - zurich_city_09_a + - zurich_city_09_b + - zurich_city_09_c + - zurich_city_09_d + - zurich_city_09_e + - zurich_city_10_a + - zurich_city_10_b + - zurich_city_11_a + - zurich_city_11_b + - zurich_city_11_c +val: + - zurich_city_16_a + - zurich_city_17_a + - zurich_city_18_a + - zurich_city_19_a + - zurich_city_20_a + - zurich_city_21_a +test: + - thun_01_a + - thun_01_b + - thun_02_a + - interlaken_00_a + - interlaken_00_b + - interlaken_01_a + - zurich_city_12_a + - zurich_city_13_a + - zurich_city_13_b + - zurich_city_14_a + - zurich_city_14_b + - zurich_city_14_c + - zurich_city_15_a \ No newline at end of file diff --git a/src/dagr/data/dsec_utils.py b/src/dagr/data/dsec_utils.py new file mode 100644 index 0000000..78c2fd3 --- /dev/null +++ b/src/dagr/data/dsec_utils.py @@ -0,0 +1,191 @@ +import numpy as np +import h5py + + +def construct_pairs(indices, n=2): + indices = np.sort(indices) + indices = np.stack([indices[i:i+1-n] for i in range(n-1)] + [indices[n-1:]]) + mask = np.ones_like(indices[0]) > 0 + for i, row in enumerate(indices): + mask = mask & (indices[0] + i == row) + indices = indices[...,mask].T + return indices + +def rescale_tracks(tracks, scale): + tracks = tracks.copy() + for k in "xywh": + tracks[k] /= scale + return tracks + +def crop_tracks(tracks, width, height): + tracks = tracks.copy() + x1, y1 = tracks['x'], tracks['y'] + x2, y2 = x1 + tracks['w'], y1 + tracks['h'] + + x1 = np.clip(x1, 0, width-1) + x2 = np.clip(x2, 0, width-1) + + y1 = np.clip(y1, 0, height-1) + y2 = np.clip(y2, 0, height-1) + + tracks['x'] = x1 + tracks['y'] = y1 + tracks['w'] = x2-x1 + tracks['h'] = y2-y1 + + return tracks + +def map_classes(class_ids, old_to_new_mapping): + new_class_ids = old_to_new_mapping[class_ids] + mask = new_class_ids > -1 + return new_class_ids, mask + +def filter_small_bboxes(w, h, bbox_height=20, bbox_diag=30): + """ + Filter out tracks that are too small. + """ + diag = np.sqrt(h ** 2 + w ** 2) + return (diag > bbox_diag) & (w > bbox_height) & (h > bbox_height) + +def filter_tracks(dataset, image_width, image_height, class_remapping, min_bbox_height=0, min_bbox_diag=0, scale=1, only_perfect_tracks=False): + image_index_pairs = {} + track_masks = {} + + for directory_path in dataset.subsequence_directories: + tracks = dataset.directories[directory_path.name].tracks.tracks + image_timestamps = dataset.directories[directory_path.name].images.timestamps + + tracks_rescaled = rescale_tracks(tracks, scale) + tracks_rescaled = crop_tracks(tracks_rescaled, image_width, image_height) + + _, class_mask = map_classes(tracks_rescaled['class_id'], class_remapping) + size_mask = filter_small_bboxes(tracks_rescaled['w'], tracks_rescaled['h'], min_bbox_height, min_bbox_diag) + final_mask = size_mask & class_mask + + # 1. stores indices of images which are valid, i.e. survived all filters above + valid_image_indices = np.unique(np.nonzero(np.isin(image_timestamps, tracks_rescaled[final_mask]['t']))[0]) + valid_image_index_pairs = construct_pairs(valid_image_indices, 2) + + if only_perfect_tracks: + valid_image_timestamp_brackets = image_timestamps[valid_image_index_pairs] + img_idx_to_track_idx = compute_img_idx_to_track_idx(tracks['t'], valid_image_timestamp_brackets) + mask = filter_by_only_perfect_tracks(tracks_rescaled, img_idx_to_track_idx, tracks_mask=final_mask) + valid_image_index_pairs = valid_image_index_pairs[mask] + + image_index_pairs[directory_path.name] = valid_image_index_pairs + track_masks[directory_path.name] = final_mask + + return image_index_pairs, track_masks + +def _load_events(file, t0, num_events=None, num_us=None, height=None, time_window=None): + with h5py.File(file, 'r') as f: + ms = int((t0 - f['t_offset'][()]) / 1e3) + idx0 = int(f['ms_to_idx'][ms]) + + if num_events is not None: + idx1 = idx0 + num_events + if num_us is not None: + idx1 = int(f['ms_to_idx'][ms + int(num_us / 1e3)]) + + idx0, idx1 = sorted([idx0, idx1]) + idx0 = idx0 if idx0 >= 0 else 0 + idx1 = idx1 if idx1 >= 0 else 0 + + # load all events + events = {k: f[f'events/{k}'][idx0:idx1] for k in "xytp"} + + tq = events['t'][-1] if idx1 > idx0 else f[f'events/t'][max([idx1 - 1, idx0])] + + # cast to desired types + p = 2 * events["p"][..., None].astype("int8") - 1 + t_ev = events['t'][..., None] + xy = np.stack([events['x'], events['y']], axis=-1).astype("int16") + + if time_window is not None: + t = (time_window - tq + t_ev).astype('int32') + else: + t = tq.copy() + + # we have to add the offset here + tq += f['t_offset'][()] + tq = tq.astype("int64") + + # crop events to crop height + mask = (t[:, 0] > 0) + if height is not None: + mask &= (xy[:, 1] < height) + + events = (xy[mask], t[mask], p[mask]) + + return events, tq + + +def filter_by_only_perfect_tracks(tracks, img_idx_to_track_idx, tracks_mask=None): + i0, i1 = img_idx_to_track_idx + mask = np.ones_like(i0[0]) > 0 + for i in range(i0.shape[1]): + track = [tracks[i0[j][i]:i1[j][i]] for j in range(len(i0))] + if tracks_mask is not None: + track_mask = [tracks_mask[i0[j][i]:i1[j][i]] for j in range(len(i0))] + track = [t[m] for t, m in zip(track, track_mask)] + mask[i] = not is_invalid_track(track) + return mask + +def is_invalid_track(track): + track = [tr[tr['track_id'].argsort()] for tr in track] + + i_tr = track[0] + for c_tr in track[1:]: + if len(i_tr) != len(c_tr): + return True + if not (c_tr['track_id'] == i_tr['track_id']).all(): + return True + iou = compute_iou(i_tr, c_tr) + min_iou = np.min(iou) + if min_iou < 0.10: + return True + else: + return False + +def compute_iou(track0, track1): + x1, x2 = track0['x'], track0['x'] + track0['w'] + y1, y2 = track0['y'], track0['y'] + track0['h'] + + x1g, x2g = track1['x'], track1['x'] + track1['w'] + y1g, y2g = track1['y'], track1['y'] + track1['h'] + + # Intersection keypoints + xkis1 = np.max(np.stack([x1, x1g]), axis=0) + ykis1 = np.max(np.stack([y1, y1g]), axis=0) + xkis2 = np.min(np.stack([x2, x2g]), axis=0) + ykis2 = np.min(np.stack([y2, y2g]), axis=0) + + intsct = np.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + 1e-9 + iou = intsct / union + + return iou + + +def compute_indices_for_contiguous_parts(x): + x, counts = np.unique(x, return_counts=True) + idx = np.concatenate([np.array([0]), counts]).cumsum() + return np.stack([idx[:-1], idx[1:]], axis=-1) + +def _compute_img_idx_to_track_idx(t, t_query): + new_img_idx = compute_indices_for_contiguous_parts(t) + mask = np.isin(np.unique(t), t_query) + new_img_idx = new_img_idx[mask].T + return new_img_idx + +def compute_img_idx_to_track_idx(t, t_query): + return np.stack([_compute_img_idx_to_track_idx(t, t_q) for t_q in t_query.T]) + +def compute_class_mapping(classes, all_classes, mapping): + output_mapping = [] + for i, c in enumerate(all_classes): + mapped_class = mapping[c] + output_mapping.append(classes.index(mapped_class) if mapped_class in classes else -1) + return np.array(output_mapping) diff --git a/src/dagr/data/utils.py b/src/dagr/data/utils.py new file mode 100644 index 0000000..c34e953 --- /dev/null +++ b/src/dagr/data/utils.py @@ -0,0 +1,20 @@ +import numpy as np +import torch +from torch_geometric.data import Data + + +def to_data(**kwargs): + # convert all tracks to correct format + for k, v in kwargs.items(): + if k.startswith("bbox"): + kwargs[k] = torch.from_numpy(v) + + xy = np.stack([kwargs['x'], kwargs['y']], axis=-1).astype("int16") + t = kwargs['t'].astype("int32") + p = kwargs['p'].reshape((-1,1)) + + kwargs['x'] = torch.from_numpy(p) + kwargs['pos'] = torch.from_numpy(xy) + kwargs['t'] = torch.from_numpy(t) + + return Data(**kwargs) \ No newline at end of file diff --git a/src/dagr/graph/ev_graph.cu b/src/dagr/graph/ev_graph.cu new file mode 100755 index 0000000..a2eaca6 --- /dev/null +++ b/src/dagr/graph/ev_graph.cu @@ -0,0 +1,283 @@ +#include + +#include +#include +#include "spiral.h" +#include + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_DEVICE(x, y) AT_ASSERTM(x.device().index() == y.device().index(), #x " and " #y " must be in same CUDA device") + + +__global__ void fill_edges_cuda_kernel( + const int32_t* __restrict__ batch, + const int32_t* __restrict__ pos, + const int32_t* __restrict__ all_timestamps, + const int32_t* __restrict__ indices, + const int32_t* __restrict__ event_queue, + int64_t* __restrict__ edges, + // int64_t* __restrict__ num_neighbors_array, + int B, int Q, int H, int W, int N, int K, float radius, float delta_t_us, int max_num_neighbors, int min_index +) +{ + // linear index + const int event_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (event_idx >= N) + return; + + int radius_int = radius; + int num_neighbors = 0; + + int offset = event_idx * max_num_neighbors; + + int b = batch[event_idx]; + int x = pos[3 * event_idx + 0]; + int y = pos[3 * event_idx + 1]; + int ts_event = pos[3 * event_idx + 2]; + + // first add self edge + edges[offset + num_neighbors + K * 0] = indices[event_idx]-min_index; + edges[offset + num_neighbors + K * 1] = indices[event_idx]-min_index; + num_neighbors++; + + SpiralOut spiral; + for (int i=0; i= max_num_neighbors) break; + for (int q=0; q= 0) && (y_neighbor >= 0) && (x_neighbor < W) && (y_neighbor < H))) break; + + int64_t queue_idx = x_neighbor + W * y_neighbor + H * W * q + H * W * Q * b; + int idx = event_queue[queue_idx]; + + // break if exceeded max num neighbors or no more events in queue + if (idx < min_index) break; + + if (indices[event_idx] > idx) { + int32_t ts_neighbor = all_timestamps[idx-min_index]; + int32_t dt_us = ts_event - ts_neighbor; + + // if delta t is too large, no edge is added + if (dt_us > delta_t_us) continue; + + edges[offset + num_neighbors + K * 0] = idx-min_index; + edges[offset + num_neighbors + K * 1] = indices[event_idx]-min_index; + num_neighbors++; + if (num_neighbors >= max_num_neighbors) break; + } + } + spiral.goNext(); + } + //num_neighbors_array[event_idx] = num_neighbors; +} + +void fill_edges_cuda( + const torch::Tensor& batch, // N + const torch::Tensor& pos, // N x 3 + const torch::Tensor& all_timestamps, // N + const torch::Tensor& event_queue, // B x Q x H x W + const torch::Tensor& indices, // N + const int max_num_neighbors, + const float radius, + const float delta_t_us, + torch::Tensor& edges, // 2 x E + const int min_index + ) +{ + CHECK_INPUT(batch); + CHECK_INPUT(pos); + CHECK_INPUT(event_queue); + CHECK_INPUT(all_timestamps); + CHECK_INPUT(edges); + CHECK_INPUT(indices); + + CHECK_DEVICE(batch, event_queue); + CHECK_DEVICE(batch, pos); + CHECK_DEVICE(batch, edges); + CHECK_DEVICE(batch, indices); + CHECK_DEVICE(batch, all_timestamps); + + unsigned N = batch.size(0); + unsigned B = event_queue.size(0); + unsigned Q = event_queue.size(1); + unsigned H = event_queue.size(2); + unsigned W = event_queue.size(3); + unsigned K = edges.size(1); + + unsigned threads = 256; + dim3 blocks((N + threads - 1) / threads, 1); + + fill_edges_cuda_kernel<<>>( + batch.data(), + pos.data(), + all_timestamps.data(), + indices.data(), + event_queue.data(), + edges.data(), + //num_neighbors.data(), + B, Q, H, W, N, K, radius, delta_t_us, max_num_neighbors, min_index + ); +} + +template +__global__ void insert_in_queue_single_cuda_kernel( + const scalar_t* __restrict__ indices, + const scalar_t* __restrict__ events, + scalar_t* __restrict__ queue, + int B, int Q, int H, int W, int K +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K) + return; + + // find out how many events to write, and what is the offset + int counts = 1; + int offset = 0; + + // find out the x, y coords where to write the indices + int x = events[0]; + int y = events[1]; + int b = 0; + + // write indices. break if queue size or counter is exceeded + for (int q=Q-1; q>=0; q--) { + int index = b * H * W * Q + q * H * W + y * W + x; + // for the current position, get the one at q - shift. + // if q - shift goes in the negative, take from indices instead + if (q >= counts) { + int shifted_index = b * H * W * Q + (q-counts) * H * W + y * W + x; + queue[index] = queue[shifted_index]; + } else { + queue[index] = indices[offset + counts - 1 - q]; + } + } +} + + +template +__global__ void insert_in_queue_cuda_kernel( + const scalar_t* __restrict__ indices, + const scalar_t* __restrict__ unique_coords, + const scalar_t* __restrict__ cumsum_counts, + scalar_t* __restrict__ queue, + int B, int Q, int H, int W, int K +) +{ + // linear index + const int lin_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // check that thread is not out of valid range + if (lin_idx >= K) + return; + + // find out how many events to write, and what is the offset + int counts, offset; + if (lin_idx > 0) { + offset = cumsum_counts[lin_idx-1]; + counts = cumsum_counts[lin_idx] - offset; + } else { + offset = 0; + counts = cumsum_counts[lin_idx]; + } + + // find out the x, y coords where to write the indices + int x = unique_coords[lin_idx] % W; + int y = ((unique_coords[lin_idx] - x)/ W) % H; + int b = unique_coords[lin_idx] / (W*H); + + // write indices. break if queue size or counter is exceeded + for (int q=Q-1; q>=0; q--) { + int index = b * H * W * Q + q * H * W + y * W + x; + // for the current position, get the one at q - shift. + // if q - shift goes in the negative, take from indices instead + if (q >= counts) { + int shifted_index = b * H * W * Q + (q-counts) * H * W + y * W + x; + queue[index] = queue[shifted_index]; + } else { + queue[index] = indices[offset + counts - 1 - q]; + } + } +} + + +torch::Tensor insert_in_queue_single_cuda( + const torch::Tensor& indices, // 1 + const torch::Tensor& events, // 4 x 1 + const torch::Tensor& queue // B x Q x H x W + ) +{ + unsigned W = queue.size(3); + unsigned H = queue.size(2); + unsigned Q = queue.size(1); + unsigned B = queue.size(0); + unsigned K = 1; + + unsigned threads = 256; + dim3 blocks((K + threads - 1) / threads, 1); + + insert_in_queue_single_cuda_kernel<<>>( + indices.data(), + events.data(), + queue.data(), + B, Q, H, W, K + ); + + return queue; +} + + +torch::Tensor insert_in_queue_cuda( + const torch::Tensor& indices, // N -> num events + const torch::Tensor& unique_coords, // K -> num active pixels + const torch::Tensor& cumsum_counts, // K -> num active pixels + const torch::Tensor& queue // B x Q x H x W + ) +{ + CHECK_INPUT(indices); + CHECK_INPUT(unique_coords); + CHECK_INPUT(cumsum_counts); + CHECK_INPUT(queue); + + CHECK_DEVICE(indices, queue); + CHECK_DEVICE(indices, unique_coords); + CHECK_DEVICE(indices, cumsum_counts); + CHECK_DEVICE(indices, queue); + + unsigned W = queue.size(3); + unsigned H = queue.size(2); + unsigned Q = queue.size(1); + unsigned B = queue.size(0); + unsigned K = unique_coords.size(0); + + unsigned threads = 256; + dim3 blocks((K + threads - 1) / threads, 1); + + insert_in_queue_cuda_kernel<<>>( + indices.data(), + unique_coords.data(), + cumsum_counts.data(), + queue.data(), + B, Q, H, W, K + ); + + return queue; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fill_edges_cuda", &fill_edges_cuda, "Find edges from a queue of events."); + m.def("insert_in_queue_cuda", &insert_in_queue_cuda, "Insert events into queue."); + m.def("insert_in_queue_single_cuda", &insert_in_queue_single_cuda, "Insert single events into queue."); +} diff --git a/src/dagr/graph/ev_graph.py b/src/dagr/graph/ev_graph.py new file mode 100755 index 0000000..28ec1ee --- /dev/null +++ b/src/dagr/graph/ev_graph.py @@ -0,0 +1,166 @@ +import torch +from .utils import _insert_events_into_queue, _search_for_edges + + +def move_to_cuda(func): + def wrapper(self, x, *args, **kwargs): + device = x.device + on_cpu = device == "cpu" + if on_cpu: + x = x.to("cuda") + ret = func(self, x, *args, **kwargs) + if on_cpu: + ret = ret.cpu() + return ret + return wrapper + + +class AsyncGraph: + def __init__(self, width=640, + height=480, + batch_size=1, + max_num_neighbors=16, + max_queue_size=512, + radius=7, + delta_t_us=600000): + self.radius = radius + self.delta_t_us = delta_t_us + self.event_queue = None + + self.max_index = 0 + self.min_index = 0 + self.max_queue_size = max_queue_size + self.max_num_neighbors = max_num_neighbors + self.width = width + self.height = height + self.batch_size = batch_size + self.device = None + + self.edges = torch.zeros((2,0), dtype=torch.long) + self.all_timestamps = torch.zeros((0,), dtype=torch.int32) + self.new_indices = None + self.edge_buffer = None + self.event_queue = None + + def initialize(self, n_ev, device): + self.edges = torch.zeros((2,0), dtype=torch.long, device=device) + self.all_timestamps = torch.zeros((0,), dtype=torch.int32, device=device) + self.new_indices = torch.arange(n_ev, dtype=torch.int32, device=device) + self.edge_buffer = torch.full((2, self.max_num_neighbors * n_ev), dtype=torch.int64, fill_value=-1, device=device) + self.event_queue = torch.full((self.batch_size, self.max_queue_size, self.height, self.width), fill_value=-1, device=device, dtype=torch.int32) + + def reset(self): + self.edges = torch.zeros((2,0), dtype=torch.long, device=self.device) + self.all_timestamps = torch.zeros((0,), dtype=torch.int32, device=self.device) + self.max_index = 0 + self.min_index = 0 + if self.edge_buffer is not None: + self.edge_buffer.fill_(-1) + if self.event_queue is not None: + self.event_queue.fill_(-1) + + @move_to_cuda + def forward(self, batch, pos, collect_edges=True): + n_ev = len(batch) + + if self.device is None: + self.device = batch.device + self.initialize(n_ev, self.device) + + if len(batch) == 0: + return torch.zeros((2,0), device=self.device, dtype=torch.int32) + + assert type(batch) is torch.Tensor and batch.dtype == torch.int32, [type(batch), batch.dtype] + + self.all_timestamps = torch.cat([self.all_timestamps, pos[:,2]]) + + # insert events into queue, they have an ever growing index + if n_ev > len(self.new_indices): + self.new_indices = torch.arange(0, n_ev, dtype=torch.int32, device=self.device) + self.edge_buffer = torch.full((2, self.max_num_neighbors * n_ev), dtype=torch.int64, fill_value=-1, device=self.device) + + indices = self.max_index + self.new_indices[:n_ev] + self.max_index += n_ev + + self.event_queue = _insert_events_into_queue(batch, pos, indices=indices, queue=self.event_queue) + + # read out edges from event queue, they need to correspond to indices + # from the current nodes + self.edge_buffer.fill_(-1) + edge_indices = _search_for_edges(batch, pos, + all_timestamps=self.all_timestamps.contiguous(), + indices=indices, + queue=self.event_queue, + max_num_neighbors=self.max_num_neighbors, + radius=self.radius, + delta_t_us=self.delta_t_us, + edges=self.edge_buffer, + min_index=self.min_index) + + if collect_edges: + self.edges = torch.cat([self.edges, edge_indices], dim=-1) + + return edge_indices + + +class SlidingWindowGraph(AsyncGraph): + def __init__(self, width=640, + height=480, + batch_size=1, + max_num_neighbors=16, + max_queue_size=1024, + radius=7, + delta_t_us=600000): + AsyncGraph.__init__(self, width, height, batch_size, max_num_neighbors, + max_queue_size, radius, delta_t_us) + + @property + def init(self): + return len(self.all_timestamps) > 0 + + def delete_nodes(self, n_delete, delete_edges=True, return_edges=True): + # delete nodes + self.all_timestamps = self.all_timestamps[n_delete:] + self.min_index += n_delete + + # the current edges do not correspond to + # the nodes anymore, so they need to be decremented + if delete_edges: + mask = (self.edges[0] < n_delete) | (self.edges[1] < n_delete) + deleted_edges = self.edges[:,mask].clone() + self.edges = self.edges[:,~mask] + + self.edges.add_(-n_delete) + + if delete_edges and return_edges: + return deleted_edges + + @move_to_cuda + def forward(self, batch, pos, return_node_counts=False, return_total_edges=False, delete_nodes=True, collect_edges=True): + n_delete = len(batch) if self.init else 0 + + # first find the interactions + edges = AsyncGraph.forward(self, batch, pos, collect_edges=collect_edges) + + if return_total_edges: + total_edges = self.edges.clone() + + if return_node_counts: + tot_nodes = len(self.all_timestamps) + + ret = [edges] + + if delete_nodes: + deleted_edges = self.delete_nodes(n_delete) + ret.append(deleted_edges) + + if return_total_edges: + ret.append(total_edges) + + if return_node_counts: + ret.append([n_delete, len(batch), tot_nodes]) + + if len(ret) == 1: + ret = ret[0] + + return ret diff --git a/src/dagr/graph/spiral.h b/src/dagr/graph/spiral.h new file mode 100644 index 0000000..46e2cc7 --- /dev/null +++ b/src/dagr/graph/spiral.h @@ -0,0 +1,16 @@ +class SpiralOut{ +protected: + unsigned layer; + unsigned leg; +public: + int x, y; //read these as output from next, do not modify. + __device__ SpiralOut():layer(1),leg(0),x(0),y(0){} + __device__ void goNext(){ + switch(leg){ + case 0: ++x; if(x == layer) ++leg; break; + case 1: ++y; if(y == layer) ++leg; break; + case 2: --x; if(-x == layer) ++leg; break; + case 3: --y; if(-y == layer){ leg = 0; ++layer; } break; + } + } +}; \ No newline at end of file diff --git a/src/dagr/graph/utils.py b/src/dagr/graph/utils.py new file mode 100644 index 0000000..609700c --- /dev/null +++ b/src/dagr/graph/utils.py @@ -0,0 +1,23 @@ +import torch +import ev_graph_cuda +from typing import Union + + +def _insert_events_into_queue(batch, pos, indices, queue: torch.LongTensor): + if len(batch) > 1: + height, width = queue.shape[-2:] + lin_coords = pos[:,0] + width * pos[:,1] + width*height*batch + sorted_lin_coords, sort_index = torch.sort(lin_coords, stable=True, descending=False) + sorted_indices = indices[sort_index].int() + unique_coords, unique_counter = torch.unique_consecutive(sorted_lin_coords, return_counts=True) + cumsum_counter = torch.cumsum(unique_counter, dim=0).int() + queue = ev_graph_cuda.insert_in_queue_cuda(sorted_indices, unique_coords, cumsum_counter, queue) + else: + queue = ev_graph_cuda.insert_in_queue_single_cuda(indices, pos, queue) + + return queue + +def _search_for_edges(batch, pos, all_timestamps, queue, indices, max_num_neighbors, radius, delta_t_us, edges, min_index): + ev_graph_cuda.fill_edges_cuda(batch, pos, all_timestamps, queue, indices, max_num_neighbors, radius, delta_t_us, edges, min_index) + edges = edges[:,(edges[1]>=0)] + return edges diff --git a/src/dagr/model/layers/components.py b/src/dagr/model/layers/components.py new file mode 100644 index 0000000..444cd0e --- /dev/null +++ b/src/dagr/model/layers/components.py @@ -0,0 +1,35 @@ +import torch + +from torch_geometric.nn import BatchNorm +from torch_geometric.data import Data + +import torch_geometric.transforms as T + + +class BatchNormData(BatchNorm): + def forward(self, data: Data): + data.x = BatchNorm.forward(self, data.x) + return data + + +class Linear(torch.nn.Module): + def __init__(self, ic, oc, bias=True): + torch.nn.Module.__init__(self) + self.mlp = torch.nn.Linear(ic, oc, bias=bias) + + def forward(self, data: Data): + data.x = self.mlp(data.x) + return data + + +class Cartesian(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + T.Cartesian.__init__(self, *args, **kwargs) + + def forward(self, data): + if data.edge_index.shape[1] > 0: + return T.Cartesian.__call__(self, data) + else: + data.edge_attr = torch.zeros((0, 3), dtype=data.x.dtype, device=data.x.device) + return data \ No newline at end of file diff --git a/src/dagr/model/layers/conv.py b/src/dagr/model/layers/conv.py new file mode 100644 index 0000000..b0c5d90 --- /dev/null +++ b/src/dagr/model/layers/conv.py @@ -0,0 +1,72 @@ +import torch + +from torch_geometric.data import Data + +from dagr.model.layers.components import BatchNormData, Linear +from dagr.model.layers.spline_conv import MySplineConv +from dagr.model.utils import shallow_copy + + +class ConvBlock(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, args, degree=1) -> None: + super(ConvBlock, self).__init__() + self.dim = args.edge_attr_dim + self.activation = getattr(torch.nn.functional, args.activation, torch.nn.functional.elu) + self.conv = MySplineConv(in_channels=in_channels, + out_channels=out_channels, + args=args, + bias=False, + degree=degree) + + self.norm = BatchNormData(in_channels=out_channels) + + def forward(self, data: Data) -> torch.Tensor: + data = self.conv(data) + data = self.norm(data) + data.x = self.activation(data.x) + + return data + + +class ConvBlockWithSkip(torch.nn.Module): + def __init__(self, in_channel: int, out_channel: int, skip_in_channel: int, args) -> None: + super(ConvBlockWithSkip, self).__init__() + self.dim = args.edge_attr_dim + + self.conv = MySplineConv(in_channels=in_channel, + out_channels=out_channel, + args=args, + bias=False) + + self.activation = getattr(torch.nn.functional, args.activation, torch.nn.functional.elu) + self.norm = BatchNormData(in_channels=out_channel) + + self.lin = Linear(skip_in_channel, out_channel, bias=False) + self.norm_skip = BatchNormData(in_channels=out_channel) + + def forward(self, data: Data, data_skip: Data): + data = self.conv(data) + + data_skip = self.lin(data_skip) + data_skip = self.norm_skip(data_skip) + + data = self.norm(data) + data.x = self.activation(data.x + data_skip.x) + + return data + + +class Layer(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, args) -> None: + super(Layer, self).__init__() + self.in_channel = in_channels + self.out_channel = out_channels + + self.conv_block1 = ConvBlock(in_channels, out_channels, args) + self.conv_block2 = ConvBlockWithSkip(out_channels, out_channels, in_channels, args=args) + + def forward(self, data: Data) -> torch.Tensor: + data_skip = shallow_copy(data) + data = self.conv_block1(data) + output = self.conv_block2(data, data_skip) + return output diff --git a/src/dagr/model/layers/ev_tgn.py b/src/dagr/model/layers/ev_tgn.py new file mode 100644 index 0000000..1cce4de --- /dev/null +++ b/src/dagr/model/layers/ev_tgn.py @@ -0,0 +1,59 @@ +import torch + +from torch_geometric.data import Batch, Data +from dagr.graph.ev_graph import SlidingWindowGraph + + +def _get_value_as_int(obj, key): + val = getattr(obj, key) + return val if type(val) is int else val[0] + +def denormalize_pos(events): + if hasattr(events, "pos_denorm"): + return events.pos_denorm + + denorm = torch.tensor([int(events.width[0]), int(events.height[0]), int(events.time_window[0])], device=events.pos.device) + return (denorm.view(1,-1) * events.pos + 1e-3).int() + + +class EV_TGN(torch.nn.Module): + def __init__(self, args): + torch.nn.Module.__init__(self) + self.radius = args.radius + self.max_neighbors = args.max_neighbors + self.max_queue_size = 128 + self.graph_creators = None + + def init_graph_creator(self, data): + delta_t_us = int(self.radius * _get_value_as_int(data, "time_window")) + radius = int(self.radius * _get_value_as_int(data, "width")+1) + batch_size = data.num_graphs + width = int(_get_value_as_int(data, "width")) + height = int(_get_value_as_int(data, "height")) + self.graph_creators = SlidingWindowGraph(width=width, height=height, + max_num_neighbors=self.max_neighbors, + max_queue_size=self.max_queue_size, + batch_size=batch_size, + radius=radius, delta_t_us=delta_t_us) + + def forward(self, events: Data, reset=True): + if events.batch is None: + events = Batch.from_data_list([events]) + + # before we start, are the new events used to generate the graph, or are the new nodes attached to the network? + # if the first, then don't delete old events, if the second, delete as many events as are coming in. + if self.graph_creators is None: + self.init_graph_creator(events) + else: + if reset: + self.graph_creators.reset() + + pos = denormalize_pos(events) + #pos = torch.cat([events.batch.view(-1,1), pos, events.x.int()], dim=1).int() + # properties of the edges + # src_i <= dst_i + # dst_i <= dst_j if i 0: + self.bn = BatchNormData(in_channels) + + @property + def num_grid_cells(self): + return (1/self.voxel_size+1e-3).int().prod() + + def round_to_pixel(self, pos, wh_inv): + torch.div(pos+1e-5, wh_inv, out=pos, rounding_mode='floor') + return pos * wh_inv + + def forward(self, data: Data): + if data.x.shape[0] == 0: + return data + + pos = torch.cat([data.pos, data.batch.float().view(-1,1)], dim=-1) + cluster = grid_cluster(pos, size=self.voxel_size, start=self.start, end=self.end) + unique_clusters, cluster, perm, _ = consecutive_cluster(cluster) + edge_index = cluster[data.edge_index] + if self.self_loop: + edge_index = edge_index.unique(dim=-1) + else: + edge_index = edge_index[:, edge_index[0]!=edge_index[1]] + if edge_index.shape[1] > 0: + edge_index = edge_index.unique(dim=-1) + + batch = None if data.batch is None else data.batch[perm] + pos = None if data.pos is None else pool_pos(cluster, data.pos) + + if self.keep_temporal_ordering: + t_max, _ = torch_scatter.scatter_max(data.pos[:,-1], cluster, dim=0) + t_src, t_dst = t_max[edge_index] + edge_index = edge_index[:, t_dst > t_src] + + if self.aggr == 'max': + x, argmax = torch_scatter.scatter_max(data.x, cluster, dim=0) + else: + x = _avg_pool_x(cluster, data.x) + + new_data = Batch(batch=batch, x=x, edge_index=edge_index, pos=pos) + + if hasattr(data, "height"): + new_data.height = data.height + new_data.width = data.width + + # round x and y coordinates to the center of the voxel grid + new_data.pos[:,:2] = self.round_to_pixel(new_data.pos[:,:2], wh_inv=self.wh_inv) + + if self.transform is not None: + if new_data.edge_index.numel() > 0: + new_data = self.transform(new_data) + else: + new_data.edge_attr = torch.zeros(size=(0,pos.shape[1]), dtype=pos.dtype, device=pos.device) + + if self.bn is not None: + new_data = self.bn(new_data) + + return new_data diff --git a/src/dagr/model/layers/spline_conv.py b/src/dagr/model/layers/spline_conv.py new file mode 100644 index 0000000..6a4c853 --- /dev/null +++ b/src/dagr/model/layers/spline_conv.py @@ -0,0 +1,118 @@ +import torch + +from torch_geometric.nn.conv import SplineConv +from torch_geometric.data import Data +from torch_geometric.transforms.to_sparse_tensor import ToSparseTensor +from torch_spline_conv import spline_basis + + +class MySplineConv(SplineConv): + def __init__(self, in_channels, out_channels, args, bias=False, degree=1, **kwargs): + self.reproducible = True + self.to_sparse_tensor = ToSparseTensor(attr="edge_attr", remove_edge_index=False) + super().__init__(in_channels=in_channels, out_channels=out_channels, bias=bias, degree=degree, + dim=args.edge_attr_dim, aggr=args.aggr, kernel_size=args.kernel_size) + + def init_lut(self, height, width, rx=None, Mx=None, ry=None, My=None): + # attr is assumed to be computed as attr = (x_i - x_j)/(2M) + 0.5 + # where -r <= x_i - x_j <= r. So remapping to integers gives + # lut_index = 2M*attr - M + r. and 0 <= lut_index <= 2r + + ry = ry or rx + My = My or Mx + self.attr_remapping_matrix = torch.Tensor([[2 * Mx * width, 0, - Mx * width + rx], + [ 0, 2 * My * height, - My * height + ry]]) + + # generate all possible dx, dy + dxy = torch.stack(torch.meshgrid(torch.arange(-rx, rx+1), torch.arange(-ry, ry+1))).float() + dxy[0] = dxy[0] / (2 * Mx * width) + 0.5 + dxy[1] = dxy[1] / (2 * My * height) + 0.5 + edge_attr = dxy.view((2,-1)).t() + + bil_w, indices = spline_basis(edge_attr.to(self.weight.data.device), self.kernel_size, self.is_open_spline, self.degree) + lut_weights = (bil_w[...,None,None] * self.weight[indices]).sum(1) + _, cin, cout = lut_weights.shape + self.lut_weights = lut_weights.view((2 * rx + 1, 2 * ry + 1, cin, cout)) + + self.message = self.message_lut + + def message_lut(self, x_j, edge_attr): + # index = (attr - 0.5) * 2 * M + r + dx_index = (edge_attr[:,0] * self.attr_remapping_matrix[0,0] + self.attr_remapping_matrix[0,-1]+1e-3).long() + dy_index = (edge_attr[:,1] * self.attr_remapping_matrix[1,1] + self.attr_remapping_matrix[1,-1]+1e-3).long() + + weights = self.lut_weights[dx_index, dy_index] # N x C_out x C_in + x_out = torch.einsum("nio,ni->no", weights, x_j) + + return x_out + + def forward(self, data: Data)->Data: + if self.reproducible: + # first check we already computed the adjacency matrix + if not hasattr(data, "adj_t"): + data.edge_attr = data.edge_attr[:,:self.dim] + data = self.to_sparse_tensor(data) + data.x = self._forward(data.x, + edge_index=data.adj_t) + else: + data.x = self._forward(data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr[:, :self.dim], + size=(data.x.shape[0], data.x.shape[0])) + return data + + def _forward(self, x, edge_index, edge_attr=None, size=None): + """""" + # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) + if edge_index.numel() > 0: + out = self.propagate(edge_index, x=(x, x), edge_attr=edge_attr, size=size) + else: + out = torch.zeros((x.size(0), self.out_channels), dtype=x.dtype, device=x.device) + + if x is not None and self.root_weight: + out += self.lin(x) + + if self.bias is not None: + out += self.bias + + return out + +def to_dense(self, x, pos, pooling, batch=None, batch_size=None): + if hasattr(self, "batch_size"): + B = self.batch_size + elif batch_size is not None: + self.batch_size = batch_size + B = batch_size + elif batch is None: + batch = torch.zeros(size=(len(x),), dtype=torch.long, device=x.device) + B = 1 + self.batch_size = B + else: + B = batch.max().item() + 1 + self.batch_size = B + + if not hasattr(self, "dense"): + W, H = (1 / pooling[:2] + 1e-3).long() + C = x.shape[-1] + self.dense = torch.zeros(size=(B, C, H, W), dtype=x.dtype, device=x.device) + + est_x, est_y = (pos[:, :2] / pooling[:2]).t().long() + + self.dense = self.dense.detach() + self.dense.zero_() + + dense = self.dense[:B] if B < self.dense.shape[0] else self.dense + dense[batch.long(), :, est_y, est_x] = x + + return dense + + +class SplineConvToDense(MySplineConv): + def forward(self, data: Data, batch_size: int=None)->torch.Tensor: + data = super().forward(data) + if data.batch is None: + data.batch = torch.zeros(len(data.x), dtype=torch.long, device=data.x.device) + return self.to_dense(data.x, data.pos, data.pooling, data.batch, batch_size=batch_size) + + def to_dense(self, x, pos, pooling, batch=None, batch_size=None): + return to_dense(self, x, pos, pooling, batch, batch_size=batch_size) \ No newline at end of file diff --git a/src/dagr/model/networks/dagr.py b/src/dagr/model/networks/dagr.py new file mode 100644 index 0000000..bd36057 --- /dev/null +++ b/src/dagr/model/networks/dagr.py @@ -0,0 +1,290 @@ +import torch + +import torch.nn.functional as F + +from torch_geometric.data import Data +from yolox.models import YOLOX, YOLOXHead, IOUloss + +from dagr.model.networks.net import Net +from dagr.model.layers.spline_conv import SplineConvToDense +from dagr.model.layers.conv import ConvBlock +from dagr.model.utils import shallow_copy, init_subnetwork, voxel_size_to_params, postprocess_network_output, convert_to_evaluation_format, init_grid_and_stride + + +class DAGR(YOLOX): + def __init__(self, args, height, width): + self.conf_threshold = 0.001 + self.nms_threshold = 0.65 + + self.height = height + self.width = width + + backbone = Net(args, height=height, width=width) + head = GNNHead(num_classes=backbone.num_classes, + in_channels=backbone.out_channels, + in_channels_cnn=backbone.out_channels_cnn, + strides=backbone.strides, + args=args) + + super().__init__(backbone=backbone, head=head) + + if "img_net_checkpoint" in args: + state_dict = torch.load(args.img_net_checkpoint) + init_subnetwork(self, state_dict['ema'], "backbone.net.", freeze=True) + init_subnetwork(self, state_dict['ema'], "head.cnn_head.") + + def cache_luts(self, width, height, radius): + M = 2 * float(int(radius * width + 2) / width) + r = int(radius * width+1) + self.backbone.conv_block1.conv_block1.conv.init_lut(height=height, width=width, Mx=M, rx=r) + self.backbone.conv_block1.conv_block2.conv.init_lut(height=height, width=width, Mx=M, rx=r) + + rx, ry, M = voxel_size_to_params(self.backbone.pool1, height, width) + self.backbone.layer2.conv_block1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.backbone.layer2.conv_block2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + rx, ry, M = voxel_size_to_params(self.backbone.pool2, height, width) + self.backbone.layer3.conv_block1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.backbone.layer3.conv_block2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + rx, ry, M = voxel_size_to_params(self.backbone.pool3, height, width) + self.backbone.layer4.conv_block1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.backbone.layer4.conv_block2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + self.head.stem1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.cls_conv1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.reg_conv1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.cls_pred1.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.reg_pred1.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.obj_pred1.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + rx, ry, M = voxel_size_to_params(self.backbone.pool4, height, width) + self.backbone.layer5.conv_block1.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.backbone.layer5.conv_block2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + if self.head.num_scales > 1: + self.head.stem2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.cls_conv2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.reg_conv2.conv.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.cls_pred2.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.reg_pred2.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + self.head.obj_pred2.init_lut(height=height, width=width, Mx=M, rx=rx, ry=ry) + + def forward(self, x: Data, reset=True, return_targets=True, filtering=True): + if not hasattr(self.head, "output_sizes"): + self.head.output_sizes = self.backbone.get_output_sizes() + + if self.training: + targets = self.convert_to_training_format(x.bbox, x.bbox_batch, x.num_graphs) + + if self.backbone.use_image: + targets0 = self.convert_to_training_format(x.bbox0, x.bbox0_batch, x.num_graphs) + targets = (targets, targets0) + + # gt_target inputs need to be [l cx cy w h] in pixels + outputs = YOLOX.forward(self, x, targets) + + return outputs + + x.reset = reset + + outputs = YOLOX.forward(self, x) + + detections = postprocess_network_output(outputs, self.backbone.num_classes, self.conf_threshold, self.nms_threshold, filtering=filtering, + height=self.height, width=self.width) + + ret = [detections] + + if return_targets and hasattr(x, 'bbox'): + targets = convert_to_evaluation_format(x) + ret.append(targets) + + return ret + + +class CNNHead(YOLOXHead): + def forward(self, xin): + outputs = dict(cls_output=[], reg_output=[], obj_output=[]) + + for k, (cls_conv, reg_conv, x) in enumerate(zip(self.cls_convs, self.reg_convs, xin)): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + reg_feat = reg_conv(reg_x) + + outputs["cls_output"].append(self.cls_preds[k](cls_feat)) + outputs["reg_output"].append(self.reg_preds[k](reg_feat)) + outputs["obj_output"].append(self.obj_preds[k](reg_feat)) + + return outputs + + +class GNNHead(YOLOXHead): + def __init__( + self, + num_classes, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + in_channels_cnn=[256, 512, 1024], + act="silu", + depthwise=False, + args=None + ): + YOLOXHead.__init__(self, num_classes, args.yolo_stem_width, strides, in_channels, act, depthwise) + + self.num_scales = args.num_scales + self.use_image = args.use_image + self.batch_size = args.batch_size + self.no_events = args.no_events + + self.in_channels = in_channels + self.n_anchors = 1 + self.num_classes = num_classes + + n_reg = max(in_channels) + self.stem1 = ConvBlock(in_channels=in_channels[0], out_channels=n_reg, args=args) + self.cls_conv1 = ConvBlock(in_channels=n_reg, out_channels=n_reg, args=args) + self.cls_pred1 = SplineConvToDense(in_channels=n_reg, out_channels=self.n_anchors * self.num_classes, bias=True, args=args) + self.reg_conv1 = ConvBlock(in_channels=n_reg, out_channels=n_reg, args=args) + self.reg_pred1 = SplineConvToDense(in_channels=n_reg, out_channels=4, bias=True, args=args) + self.obj_pred1 = SplineConvToDense(in_channels=n_reg, out_channels=self.n_anchors, bias=True, args=args) + + if self.num_scales > 1: + self.stem2 = ConvBlock(in_channels=in_channels[1], out_channels=n_reg, args=args) + self.cls_conv2 = ConvBlock(in_channels=n_reg, out_channels=n_reg, args=args) + self.cls_pred2 = SplineConvToDense(in_channels=n_reg, out_channels=self.n_anchors * self.num_classes, bias=True, args=args) + self.reg_conv2 = ConvBlock(in_channels=n_reg, out_channels=n_reg, args=args) + self.reg_pred2 = SplineConvToDense(in_channels=n_reg, out_channels=4, bias=True, args=args) + self.obj_pred2 = SplineConvToDense(in_channels=n_reg, out_channels=self.n_anchors, bias=True, args=args) + + if self.use_image: + self.cnn_head = CNNHead(num_classes=num_classes, strides=strides, in_channels=in_channels_cnn) + + self.use_l1 = False + self.l1_loss = torch.nn.L1Loss(reduction="none") + self.bcewithlog_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + self.iou_loss = IOUloss(reduction="none") + self.strides = strides + self.grids = [torch.zeros(1)] * len(in_channels) + + self.grid_cache = None + self.stride_cache = None + self.cache = [] + + def process_feature(self, x, stem, cls_conv, reg_conv, cls_pred, reg_pred, obj_pred, batch_size, cache): + x = stem(x) + + cls_feat = cls_conv(shallow_copy(x)) + reg_feat = reg_conv(x) + + # we need to provide the batchsize, since sometimes it cannot be foudn from the data, especially when nodes=0 + cls_output = cls_pred(cls_feat, batch_size=batch_size) + reg_output = reg_pred(shallow_copy(reg_feat), batch_size=batch_size) + obj_output = obj_pred(reg_feat, batch_size=batch_size) + + return cls_output, reg_output, obj_output + + def forward(self, xin: Data, labels=None, imgs=None): + # for events + image outputs + hybrid_out = dict(outputs=[], origin_preds=[], x_shifts=[], y_shifts=[], expanded_strides=[]) + image_out = dict(outputs=[], origin_preds=[], x_shifts=[], y_shifts=[], expanded_strides=[]) + + if self.use_image: + xin, image_feat = xin + + if labels is not None: + if self.use_image: + labels, image_labels = labels + + # resize image, and process with CNN + image_feat = [torch.nn.functional.interpolate(f, o) for f, o in zip(image_feat, self.output_sizes)] + out_cnn = self.cnn_head(image_feat) + + # collect outputs from image alone, so the image network also learns to detect on its own. + for k in [0, 1]: + self.collect_outputs(out_cnn["cls_output"][k], + out_cnn["reg_output"][k], + out_cnn["obj_output"][k], + k, self.strides[k], ret=image_out) + + batch_size = len(out_cnn["cls_output"][0]) if self.use_image else self.batch_size + cls_output, reg_output, obj_output = self.process_feature(xin[0], self.stem1, self.cls_conv1, self.reg_conv1, + self.cls_pred1, self.reg_pred1, self.obj_pred1, batch_size=batch_size, cache=self.cache) + + if self.use_image: + cls_output[:batch_size] += out_cnn["cls_output"][0].detach() + reg_output[:batch_size] += out_cnn["reg_output"][0].detach() + obj_output[:batch_size] += out_cnn["obj_output"][0].detach() + + self.collect_outputs(cls_output, reg_output, obj_output, 0, self.strides[0], ret=hybrid_out) + + if self.num_scales > 1: + cls_output, reg_output, obj_output = self.process_feature(xin[1], self.stem2, self.cls_conv2, + self.reg_conv2, self.cls_pred2, self.reg_pred2, + self.obj_pred2, batch_size=batch_size, cache=self.cache) + if self.use_image: + batch_size = out_cnn["cls_output"][0].shape[0] + cls_output[:batch_size] += out_cnn["cls_output"][1].detach() + reg_output[:batch_size] += out_cnn["reg_output"][1].detach() + obj_output[:batch_size] += out_cnn["obj_output"][1].detach() + + self.collect_outputs(cls_output, reg_output, obj_output, 1, self.strides[1], ret=hybrid_out) + + if self.training: + # if we are only training the image detectors (pretraining), + # we only need to minimize the loss at detections from the image branch. + if self.use_image: + return self.get_losses( + imgs, + image_out['x_shifts'], + image_out['y_shifts'], + image_out['expanded_strides'], + image_labels, + torch.cat(image_out['outputs'], 1), + image_out['origin_preds'], + dtype=image_out['x_shifts'][0].dtype, + ) + else: + return self.get_losses( + imgs, + hybrid_out['x_shifts'], + hybrid_out['y_shifts'], + hybrid_out['expanded_strides'], + labels, + torch.cat(hybrid_out['outputs'], 1), + hybrid_out['origin_preds'], + dtype=xin[0].x.dtype, + ) + else: + out = image_out['outputs'] if self.no_events else hybrid_out['outputs'] + + self.hw = [x.shape[-2:] for x in out] + # [batch, n_anchors_all, 85] + outputs = torch.cat([x.flatten(start_dim=2) for x in out], dim=2).permute(0, 2, 1) + + return self.decode_outputs(outputs, dtype=out[0].type()) + + def collect_outputs(self, cls_output, reg_output, obj_output, k, stride_this_level, ret=None): + if self.training: + output = torch.cat([reg_output, obj_output, cls_output], 1) + output, grid = self.get_output_and_grid(output, k, stride_this_level, output.type()) + ret['x_shifts'].append(grid[:, :, 0]) + ret['y_shifts'].append(grid[:, :, 1]) + ret['expanded_strides'].append(torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(output)) + else: + output = torch.cat( + [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 + ) + + ret['outputs'].append(output) + + def decode_outputs(self, outputs, dtype): + if self.grid_cache is None: + self.grid_cache, self.stride_cache = init_grid_and_stride(self.hw, self.strides, dtype) + + outputs[..., :2] = (outputs[..., :2] + self.grid_cache) * self.stride_cache + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * self.stride_cache + return outputs + diff --git a/src/dagr/model/networks/ema.py b/src/dagr/model/networks/ema.py new file mode 100644 index 0000000..4e498db --- /dev/null +++ b/src/dagr/model/networks/ema.py @@ -0,0 +1,51 @@ +import torch +import math +from copy import deepcopy + + +class ModelEMA: + """ + Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models + Keep a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + A smoothed version of the weights is necessary for some training schemes to perform well. + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + + def __init__(self, model, decay=0.9999, updates=0): + """ + Args: + model (nn.Module): model to apply EMA. + decay (float): ema decay reate. + updates (int): counter of EMA updates. + """ + # Create EMA(FP32) + self.ema = deepcopy(model).eval() + + try: + # if we do not do this, all the hooks will be activated for the other model, which will create + # a lot of memory usage + self.ema.backbone.net.remove_hooks() + self.ema.backbone.net.register_hooks() + except: + pass + + self.updates = updates + # decay exponential ramp (to help early epochs) + self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay(self.updates) + + msd = model.state_dict() + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1.0 - d) * msd[k].detach() diff --git a/src/dagr/model/networks/net.py b/src/dagr/model/networks/net.py new file mode 100644 index 0000000..4d0b3d1 --- /dev/null +++ b/src/dagr/model/networks/net.py @@ -0,0 +1,223 @@ +import torch + +import torch_geometric.transforms as T + +from torch_geometric.data import Data +from dagr.model.layers.ev_tgn import EV_TGN +from dagr.model.layers.pooling import Pooling +from dagr.model.layers.conv import Layer +from dagr.model.layers.components import Cartesian +from dagr.model.networks.net_img import HookModule +from dagr.model.utils import shallow_copy +from torchvision.models import resnet18, resnet34, resnet50 + + +def sampling_skip(data, image_feat): + image_feat_at_nodes = sample_features(data, image_feat) + return torch.cat((data.x, image_feat_at_nodes), dim=1) + +def compute_pooling_at_each_layer(pooling_dim_at_output, num_layers): + py, px = map(int, pooling_dim_at_output.split("x")) + pooling_base = torch.tensor([1.0 / px, 1.0 / py, 1.0 / 1]) + poolings = [] + for i in range(num_layers): + pooling = pooling_base / 2 ** (3 - i) + pooling[-1] = 1 + poolings.append(pooling) + poolings = torch.stack(poolings) + return poolings + +class Net(torch.nn.Module): + def __init__(self, args, height, width): + super().__init__() + + channels = [1, int(args.base_width*32), int(args.after_pool_width*64), + int(args.net_stem_width*128), + int(args.net_stem_width*128), + int(args.net_stem_width*128)] + + self.out_channels_cnn = [] + if args.use_image: + img_net = eval(args.img_net) + self.out_channels_cnn = [256, 256] + self.net = HookModule(img_net(pretrained=True), + input_channels=3, + height=height, width=width, + feature_layers=["conv1", "layer1", "layer2", "layer3", "layer4"], + output_layers=["layer3", "layer4"], + feature_channels=channels[1:], + output_channels=self.out_channels_cnn) + + self.use_image = args.use_image + self.num_scales = args.num_scales + + self.num_classes = dict(dsec=2).get(args.dataset, 2) + + self.events_to_graph = EV_TGN(args) + + output_channels = channels[1:] + self.out_channels = output_channels[-2:] + + input_channels = channels[:-1] + if self.use_image: + input_channels = [input_channels[i] + self.net.feature_channels[i] for i in range(len(input_channels))] + + # parse x and y pooling dimensions at output + poolings = compute_pooling_at_each_layer(args.pooling_dim_at_output, num_layers=4) + max_vals_for_cartesian = 2*poolings[:,:2].max(-1).values + self.strides = torch.ceil(poolings[-2:,1] * height).numpy().astype("int32").tolist() + + effective_radius = 2*float(int(args.radius * width + 2) / width) + self.edge_attrs = Cartesian(norm=True, cat=False, max_value=effective_radius) + + self.conv_block1 = Layer(2+input_channels[0], output_channels[0], args=args) + + cart1 = T.Cartesian(norm=True, cat=False, max_value=2*effective_radius) + self.pool1 = Pooling(poolings[0], width=width, height=height, batch_size=args.batch_size, + transform=cart1, aggr=args.pooling_aggr, keep_temporal_ordering=args.keep_temporal_ordering) + + self.layer2 = Layer(input_channels[1]+2, output_channels[1], args=args) + + cart2 = T.Cartesian(norm=True, cat=False, max_value=max_vals_for_cartesian[1]) + self.pool2 = Pooling(poolings[1], width=width, height=height, batch_size=args.batch_size, + transform=cart2, aggr=args.pooling_aggr, keep_temporal_ordering=args.keep_temporal_ordering) + + self.layer3 = Layer(input_channels[2]+2, output_channels[2], args=args) + + cart3 = T.Cartesian(norm=True, cat=False, max_value=max_vals_for_cartesian[2]) + self.pool3 = Pooling(poolings[2], width=width, height=height, batch_size=args.batch_size, + transform=cart3, aggr=args.pooling_aggr, keep_temporal_ordering=args.keep_temporal_ordering) + + self.layer4 = Layer(input_channels[3]+2, output_channels[3], args=args) + + cart4 = T.Cartesian(norm=True, cat=False, max_value=max_vals_for_cartesian[3]) + self.pool4 = Pooling(poolings[3], width=width, height=height, batch_size=args.batch_size, + transform=cart4, aggr='mean', keep_temporal_ordering=args.keep_temporal_ordering) + + self.layer5 = Layer(input_channels[4]+2, output_channels[4], args=args) + + self.cache = [] + + def get_output_sizes(self): + poolings = [self.pool3.voxel_size[:2], self.pool4.voxel_size[:2]] + output_sizes = [(1 / p + 1e-3).cpu().int().numpy().tolist()[::-1] for p in poolings] + return output_sizes + + def forward(self, data: Data, reset=True): + if self.use_image: + image_feat, image_outputs = self.net(data.image) + + if hasattr(data, 'reset'): + reset = data.reset + + data = self.events_to_graph(data, reset=reset) + + if self.use_image: + data.x = sampling_skip(data, image_feat[0].detach()) + data.skipped = True + data.num_image_channels = image_feat[0].shape[1] + + data = self.edge_attrs(data) + data.edge_attr = torch.clamp(data.edge_attr, min=0, max=1) + rel_delta = data.pos[:, :2] + data.x = torch.cat((data.x, rel_delta), dim=1) + data = self.conv_block1(data) + + if self.use_image: + data.x = sampling_skip(data, image_feat[1].detach()) + + data = self.pool1(data) + + if self.use_image: + data.skipped = True + data.num_image_channels = image_feat[1].shape[1] + + rel_delta = data.pos[:,:2] + data.x = torch.cat((data.x, rel_delta), dim=1) + data = self.layer2(data) + + if self.use_image: + data.x = sampling_skip(data, image_feat[2].detach()) + + data = self.pool2(data) + + if self.use_image: + data.skipped = True + data.num_image_channels = image_feat[2].shape[1] + + rel_delta = data.pos[:,:2] + data.x = torch.cat((data.x, rel_delta), dim=1) + data = self.layer3(data) + + if self.use_image: + data.x = sampling_skip(data, image_feat[3].detach()) + + data = self.pool3(data) + + if self.use_image: + data.skipped = True + data.num_image_channels = image_feat[3].shape[1] + + rel_delta = data.pos[:,:2] + data.x = torch.cat((data.x, rel_delta), dim=1) + data = self.layer4(data) + + out3 = shallow_copy(data) + out3.pooling = self.pool3.voxel_size[:3] + + if self.use_image: + data.x = sampling_skip(data, image_feat[4].detach()) + + data = self.pool4(data) + + if self.use_image: + data.skipped = True + data.num_image_channels = image_feat[4].shape[1] + + rel_delta = data.pos[:,:2] + data.x = torch.cat((data.x, rel_delta), dim=1) + data = self.layer5(data) + + out4 = data + out4.pooling = self.pool4.voxel_size[:3] + + output = [out3, out4] + + if self.use_image: + return output[-self.num_scales:], image_outputs[-self.num_scales:] + return output[-self.num_scales:] + + +def sample_features(data, image_feat, image_sample_mode="bilinear"): + if data.batch is None or len(data.batch) != len(data.pos): + data.batch = torch.zeros(len(data.pos), dtype=torch.long, device=data.x.device) + return _sample_features(data.pos[:,0] * data.width[0], + data.pos[:,1] * data.height[0], + data.batch.float(), image_feat, + data.width[0], + data.height[0], + image_feat.shape[0], + image_sample_mode) + +def _sample_features(x, y, b, image_feat, width, height, batch_size, image_sample_mode): + x = 2 * x / (width - 1) - 1 + y = 2 * y / (height - 1) - 1 + + batch_size = batch_size if batch_size > 1 else 2 + b = 2 * b / (batch_size - 1) - 1 + + grid = torch.stack((x, y, b), dim=-1).view(1, 1, 1,-1, 3) # N x D_out x H_out x W_out x 3 (N=1, D_out=1, H_out=1) + image_feat = image_feat.permute(1,0,2,3).unsqueeze(0) # N x C x D x H x W (N=1) + + image_feat_sampled = torch.nn.functional.grid_sample(image_feat, + grid=grid, + mode=image_sample_mode, + align_corners=True) # N x C x H_out x W_out (H_out=1, N=1) + + image_feat_sampled = image_feat_sampled.view(image_feat.shape[1], -1).t() + + return image_feat_sampled + + + + diff --git a/src/dagr/model/networks/net_img.py b/src/dagr/model/networks/net_img.py new file mode 100644 index 0000000..6a38c4c --- /dev/null +++ b/src/dagr/model/networks/net_img.py @@ -0,0 +1,135 @@ +import torch + + +class Layer(torch.nn.Module): + def __init__(self, input_channels, output_channels): + super(Layer, self).__init__() + self.conv1 = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) + self.bn1 = torch.nn.BatchNorm2d(output_channels) + + self.conv2 = torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = torch.nn.BatchNorm2d(output_channels) + + self.dwc = torch.nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0) + self.bn_skip = torch.nn.BatchNorm2d(output_channels) + self.act = torch.nn.ReLU() + + def forward(self, x): + x_skip = x.clone() + x = self.act(self.bn1(self.conv1(x))) + x = self.bn2(self.conv2(x)) + x = x + self.bn_skip(self.dwc(x_skip)) + return self.act(x) + + +class ConvBlockDense(torch.nn.Module): + def __init__(self, in_channels, out_channels, bias=False, act=torch.nn.ReLU(), bn=True): + super(ConvBlockDense, self).__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, kernel_size=3, stride=1, padding=1) + self.bn = torch.nn.BatchNorm2d(out_channels) + self.act = act + self.use_bn = bn + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class HookModule(torch.nn.Module): + """ + Define the module, then you can determine which features are extracted, and which outputs are extracted. + For each you can decide if they are mapped to a lower dimension or not. + + """ + def __init__(self, module, height, width, input_channels=3, feature_layers=(), output_layers=(), feature_channels=None, output_channels=None): + torch.nn.Module.__init__(self) + self.module = module.cpu() + + if input_channels != 3: + self.module.conv1 = torch.nn.Conv2d(in_channels=input_channels, out_channels=self.module.conv1.out_channels, + kernel_size=self.module.conv1.kernel_size, + padding=self.module.conv1.padding, + bias=False) + + self.feature_layers = feature_layers + self.output_layers = output_layers + + self.hooks = [] + self.features = [] + self.outputs = [] + self.register_hooks() + + self.feature_channels = [] + self.output_channels = [] + self.compute_channels_with_dummy(shape=(1, input_channels, height, width)) + + self.feature_dconv = torch.nn.ModuleList() + if feature_channels is not None: + assert len(feature_channels) == len(self.feature_channels) + self.feature_dconv = torch.nn.ModuleList( + [ + torch.nn.Conv2d(in_channels=cin, out_channels=cout, kernel_size=1, stride=1, padding=0) + for cin, cout in zip(self.feature_channels, feature_channels) + ] + ) + self.feature_channels = feature_channels + + self.output_dconv = torch.nn.ModuleList() + if output_channels is not None: + assert len(output_channels) == len(self.output_channels) + self.output_dconv = torch.nn.ModuleList( + [ + torch.nn.Conv2d(in_channels=cin, out_channels=cout, kernel_size=1, stride=1, padding=0) + for cin, cout in zip(self.output_channels, output_channels) + ] + ) + self.output_channels = output_channels + + def extract_layer(self, module, layer): + if len(layer) == 0: + return module + else: + return self.extract_layer(module._modules[layer[0]], layer[1:]) + + def compute_channels_with_dummy(self, shape): + dummy_input = torch.zeros(shape) + self.module.forward(dummy_input) + self.feature_channels = [f.shape[1] for f in self.features] + self.output_channels = [o.shape[1] for o in self.outputs] + self.features = [] + self.outputs = [] + + def remove_hooks(self): + for h in self.hooks: + h.remove() + + def register_hooks(self): + self.features = [] + self.outputs = [] + features_hook = lambda m, i, o: self.features.append(o) + outputs_hook = lambda m, i, o: self.outputs.append(o) + for l in self.feature_layers: + hook_id = self.extract_layer(self.module, l.split(".")).register_forward_hook(features_hook) + self.hooks.append(hook_id) + for l in self.output_layers: + hook_id = self.extract_layer(self.module, l.split(".")).register_forward_hook(outputs_hook) + self.hooks.append(hook_id) + + def forward(self, x): + self.features = [] + self.outputs = [] + self.module(x) + + features = self.features + if len(self.feature_dconv) > 0: + features = [dconv(f) for f, dconv in zip(self.features, self.feature_dconv)] + + outputs = self.outputs + if len(self.output_dconv) > 0: + outputs = [dconv(o) for o, dconv in zip(self.outputs, self.output_dconv)] + + return features, outputs \ No newline at end of file diff --git a/src/dagr/model/utils.py b/src/dagr/model/utils.py new file mode 100644 index 0000000..6320d3c --- /dev/null +++ b/src/dagr/model/utils.py @@ -0,0 +1,165 @@ +import torchvision +import torch + +import numpy as np + +from torch_geometric.data import Data + + +def init_subnetwork(net, state_dict, name="backbone.net.", freeze=False): + assert name.endswith(".") + + # get submodule + attrs = name.split(".")[:-1] + for attr in attrs: + net = getattr(net, attr) + + # load weights and freeze + sub_state_dict = {k.replace(name, ""): v for k, v in state_dict.items() if name in k} + net.load_state_dict(sub_state_dict) + + if freeze: + for param in net.parameters(): + param.requires_grad = False + +def batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold, width, height): + # adopted from torchvision nms, but faster + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + max_dim = max([width, height]) + offsets = idxs * float(max_dim + 1) + boxes_for_nms = boxes + offsets[:, None] + keep = torchvision.ops.nms(boxes_for_nms, scores, iou_threshold) + return keep + +def convert_to_evaluation_format(data): + targets = [] + for d in data.to_data_list(): + bbox = d.bbox.clone() + bbox[:,2:4] += bbox[:,:2] + targets.append({ + "boxes": bbox[:,:4], + "labels": bbox[:, 4].long() # class 0 is background class + }) + return targets + +def convert_to_training_format(bbox, batch, batch_size): + max_detections = 100 + targets = torch.zeros(size=(batch_size, max_detections, 5), dtype=torch.float32, device=bbox.device) + unique, counts = torch.unique(batch, return_counts=True) + counter = _sequential_counter(counts) + + bbox = bbox.clone() + # xywhlc pix -> lcxcywh pix + bbox[:, :2] += bbox[:, 2:4] * .5 + bbox = torch.roll(bbox[:, :5], dims=1, shifts=1) + + targets[batch, counter] = bbox + + return targets + +def postprocess_network_output(prediction, num_classes, conf_thre=0.01, nms_thre=0.65, height=640, width=640, filtering=True): + prediction[..., :2] -= prediction[...,2:4] / 2 # cxcywh->xywh + prediction[..., 2:4] += prediction[...,:2] + + output = [] + for i, image_pred in enumerate(prediction): + + # If none are remaining => process next image + if len(image_pred) == 0: + output.append({ + "boxes": torch.zeros(0, 4, dtype=torch.float32), + "scores": torch.zeros(0, dtype=torch.float), + "labels": torch.zeros(0, dtype=torch.long) + }) + continue + + # Get score and class with highest confidence + class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) + image_pred[:, 4:5] *= class_conf + + conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() + # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) + detections = torch.cat((image_pred[:, :5], class_pred), 1) + + if filtering: + detections = detections[conf_mask] + + if len(detections) == 0: + output.append({ + "boxes": torch.zeros(0, 4, dtype=torch.float32), + "scores": torch.zeros(0, dtype=torch.float), + "labels": torch.zeros(0, dtype=torch.long) + }) + continue + + nms_out_index = batched_nms_coordinate_trick(detections[:, :4], detections[:, 4], detections[:, 5], + nms_thre, width=width, height=height) + + if filtering: + detections = detections[nms_out_index] + + output.append({ + "boxes": detections[:, :4], + "scores": detections[:, 4], + "labels": detections[:, -1].long() + }) + + return output + +def voxel_size_to_params(pooling_layer, height, width): + rx = int(np.ceil(2*pooling_layer.voxel_size[0].cpu().numpy() * width)) + ry = int(np.ceil(2*pooling_layer.voxel_size[1].cpu().numpy() * height)) + M = pooling_layer.transform.max + return rx, ry, M + + +def init_grid_and_stride(hw, strides, dtype): + grids = [] + all_strides = [] + for (hsize, wsize), stride in zip(hw, strides): + yv, xv = torch.meshgrid(torch.arange(hsize), torch.arange(wsize), indexing="ij") + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + all_strides.append(torch.full((*shape, 1), stride)) + + grid_cache = torch.cat(grids, dim=1).type(dtype) + stride_cache = torch.cat(all_strides, dim=1).type(dtype) + + return grid_cache, stride_cache + +def _sequential_counter(counts: torch.LongTensor): + """ + Returns a torch tensor which counts up for each count + Example: counts = [2,4,6,2,4] then the output will be + output = [0,1,0,1,2,3,0,1,2,3,4,5,0,1,0,1,2,3] + """ + assert counts.dtype == torch.long + assert len(counts.shape) > 0 + assert (counts >= 0).all() + + len_counter = counts.sum() + tensors_kwargs = dict(device=counts.device, dtype=torch.long) + + # first construct delta function, which has value c_N at position sum_k=0^N c_k + delta = torch.zeros(size=(len_counter,), **tensors_kwargs) + x_coord = counts.cumsum(dim=0) + delta[x_coord[:-1]] = counts[:-1] + + # next construct step function, and the result it a linear function minus this step function + step = delta.cumsum(dim=0) + counter = torch.arange(len_counter, **tensors_kwargs) - step + + return counter + +def shallow_copy(data): + out = Data(x=data.x.clone(), edge_index=data.edge_index, edge_attr=data.edge_attr, pos=data.pos, batch=data.batch) + for key in ["active_clusters", "_changed_attr", "_changed_attr_indices","diff_idx", "diff_pos_idx", "pooling", "num_image_channels", "skipped", "pooled"]: + if hasattr(data, key): + setattr(out, key, getattr(data, key)) + for key in ["diff_idx", "diff_pos_idx"]: + if hasattr(data, key): + setattr(out, key, getattr(data, key).clone()) + return out + diff --git a/src/dagr/utils/args.py b/src/dagr/utils/args.py new file mode 100644 index 0000000..3f3be32 --- /dev/null +++ b/src/dagr/utils/args.py @@ -0,0 +1,107 @@ +import argparse +import yaml + +from pathlib import Path + + +def BASE_FLAGS(): + parser = argparse.ArgumentParser("") + parser.add_argument('--dataset_directory', type=Path, default=argparse.SUPPRESS, help="Path to the directory containing the dataset.") + parser.add_argument('--output_directory', type=Path, default=argparse.SUPPRESS, help="Path to the logging directory.") + parser.add_argument("--checkpoint", type=Path, default=argparse.SUPPRESS, help="Path to the directory containing the checkpoint.") + parser.add_argument("--img_net", default=argparse.SUPPRESS, type=str) + parser.add_argument("--img_net_checkpoint", type=Path, default=argparse.SUPPRESS) + + parser.add_argument("--config", type=Path, default="../config/detection.yaml") + parser.add_argument("--use_image", action="store_true") + parser.add_argument("--no_events", action="store_true") + parser.add_argument("--keep_temporal_ordering", action="store_true") + + # task params + parser.add_argument("--task", default=argparse.SUPPRESS, type=str) + parser.add_argument("--dataset", default=argparse.SUPPRESS, type=str) + + # graph params + parser.add_argument('--radius', default=argparse.SUPPRESS, type=float) + parser.add_argument('--time_window_us', default=argparse.SUPPRESS, type=int) + parser.add_argument('--max_neighbors', default=argparse.SUPPRESS, type=int) + parser.add_argument('--n_nodes', default=argparse.SUPPRESS, type=int) + + # learning params + parser.add_argument('--batch_size', default=argparse.SUPPRESS, type=int) + + # network params + parser.add_argument("--activation", default=argparse.SUPPRESS, type=str, help="Can be one of ['Hardshrink', 'Hardsigmoid', 'Hardswish', 'ReLU', 'ReLU6', 'SoftShrink', 'HardTanh']") + parser.add_argument("--edge_attr_dim", default=argparse.SUPPRESS, type=int) + parser.add_argument("--aggr", default=argparse.SUPPRESS, type=str) + parser.add_argument("--kernel_size", default=argparse.SUPPRESS, type=int) + parser.add_argument("--pooling_aggr", default=argparse.SUPPRESS, type=str) + + parser.add_argument("--base_width", default=argparse.SUPPRESS, type=float) + parser.add_argument("--after_pool_width", default=argparse.SUPPRESS, type=float) + parser.add_argument('--net_stem_width', default=argparse.SUPPRESS, type=float) + parser.add_argument("--yolo_stem_width", default=argparse.SUPPRESS, type=float) + parser.add_argument("--num_scales", default=argparse.SUPPRESS, type=int) + parser.add_argument('--pooling_dim_at_output', default=argparse.SUPPRESS) + parser.add_argument('--weight_decay', default=argparse.SUPPRESS, type=float) + parser.add_argument('--clip', default=argparse.SUPPRESS, type=float) + + parser.add_argument('--aug_p_flip', default=argparse.SUPPRESS, type=float) + + return parser + +def FLAGS(): + parser = BASE_FLAGS() + + # learning params + parser.add_argument('--aug_trans', default=argparse.SUPPRESS, type=float) + parser.add_argument('--aug_zoom', default=argparse.SUPPRESS, type=float) + parser.add_argument('--l_r', default=argparse.SUPPRESS, type=float) + parser.add_argument('--tot_num_epochs', default=argparse.SUPPRESS, type=int) + + parser.add_argument('--run_test', action="store_true") + + parser.add_argument('--num_interframe_steps', type=int, default=10) + + args = parser.parse_args() + + if args.config != "": + args = parse_config(args, args.config) + + args.dataset_directory = Path(args.dataset_directory) + args.output_directory = Path(args.output_directory) + + if "checkpoint" in args: + args.checkpoint = Path(args.checkpoint) + + return args + +def FLOPS_FLAGS(): + parser = BASE_FLAGS() + + # for flop eval + parser.add_argument("--check_consistency", action="store_true") + parser.add_argument("--dense", action="store_true") + + # for runtime eval + args = parser.parse_args() + + if args.config != "": + args = parse_config(args, args.config) + + args.dataset_directory = Path(args.dataset_directory) + args.output_directory = Path(args.output_directory) + + if "checkpoint" in args: + args.checkpoint = Path(args.checkpoint) + + return args + + +def parse_config(args: argparse.ArgumentParser, config: Path): + with config.open() as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + for k, v in config.items(): + if k not in args: + setattr(args, k, v) + return args diff --git a/src/dagr/utils/buffers.py b/src/dagr/utils/buffers.py new file mode 100644 index 0000000..1bed03d --- /dev/null +++ b/src/dagr/utils/buffers.py @@ -0,0 +1,146 @@ +import numpy as np +import torch + +from typing import List, Dict +from pathlib import Path + +from .coco_eval import evaluate_detection + + +def diag_filter(bbox, height: int, width: int, min_box_diagonal: int = 30, min_box_side: int = 20): + bbox[..., 0::2] = torch.clamp(bbox[..., 0::2], 0, width - 1) + bbox[..., 1::2] = torch.clamp(bbox[..., 1::2], 0, height - 1) + w, h = (bbox[..., 2:] - bbox[..., :2]).t() + diag = torch.sqrt(w ** 2 + h ** 2) + mask = (diag > min_box_diagonal) & (w > min_box_side) & (h > min_box_side) + return mask + + +def filter_bboxes(detections: List[Dict[str, torch.Tensor]], height: int, width: int, min_box_diagonal: int = 30, + min_box_side: int = 20): + filtered_bboxes = [] + for d in detections: + bbox = d["boxes"] + + # first clamp boxes to image + mask = diag_filter(bbox, height, width, min_box_diagonal, min_box_side) + bbox = {k: v[mask] for k, v in d.items()} + + filtered_bboxes.append(bbox) + + return filtered_bboxes + +def format_data(data, normalizer=None): + if normalizer is None: + normalizer = torch.stack([data.width[0], data.height[0], data.time_window[0]], dim=-1) + + if hasattr(data, "image"): + data.image = data.image.float() / 255.0 + + data.pos = torch.cat([data.pos, data.t.view((-1,1))], dim=-1) + data.t = None + data.x = data.x.float() + data.pos = data.pos / normalizer + return data + +def bbox_t_to_ndarray(bbox, t): + dtype = [('t', ' 0: + output = {k: np.concatenate(v) for k, v in output.items() if len(v) > 0} + + return output + +def to_cpu(data_list: List[Dict[str, torch.Tensor]]): + return [{k: v.cpu() for k, v in d.items()} for d in data_list] + +class Buffer: + def __init__(self): + self.buffer = [] + + def extend(self, elements: List[Dict[str, torch.Tensor]]): + self.buffer.extend(to_cpu(elements)) + + def clear(self): + self.buffer.clear() + + def __iter__(self): + return iter(self.buffer) + + def __next__(self): + return next(self.buffer) + + + +class DetectionBuffer: + def __init__(self, height: int, width: int, classes: List[str]): + self.height = height + self.width = width + self.classes = classes + self.detections = Buffer() + self.ground_truth = Buffer() + + def compile(self, sequences, timestamps): + detections = compile(self.detections, sequences, timestamps) + groundtruth = compile(self.ground_truth, sequences, timestamps) + return detections, groundtruth + + def update(self, detections: List[Dict[str, torch.Tensor]], groundtruth: List[Dict[str, torch.Tensor]], dataset: str, height=None, width=None): + self.detections.extend(detections) + self.ground_truth.extend(groundtruth) + + def compute(self)->Dict[str, float]: + output = evaluate_detection(self.ground_truth.buffer, self.detections.buffer, height=self.height, width=self.width, classes=self.classes) + output = {k.replace("AP", "mAP"): v for k, v in output.items()} + self.detections.clear() + self.ground_truth.clear() + return output + + +class DictBuffer: + def __init__(self): + self.running_mean = None + self.n = 0 + + def __recursive_mean(self, mn: float, s: float): + return self.n / (self.n + 1) * mn + s / (self.n + 1) + + def update(self, dictionary: Dict[str, float]): + if self.running_mean is None: + self.running_mean = {k: 0 for k in dictionary} + + self.running_mean = {k: self.__recursive_mean(self.running_mean[k], dictionary[k]) for k in dictionary} + self.n += 1 + + def save(self, path): + torch.save(self.running_mean, path) + + def compute(self)->Dict[str, float]: + return self.running_mean + diff --git a/src/dagr/utils/coco_eval.py b/src/dagr/utils/coco_eval.py new file mode 100644 index 0000000..ee3cf75 --- /dev/null +++ b/src/dagr/utils/coco_eval.py @@ -0,0 +1,233 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import contextlib +from pycocotools.coco import COCO +from detectron2.evaluation.fast_eval_api import COCOeval_opt as COCOeval +#from detectron2.evaluation.fast_eval_api import COCOeval + +import numpy as np +from typing import List, Dict, Tuple +from torch import Tensor + +BBOX_DTYPE = np.dtype({'names':['t','x','y','w','h','class_id','track_id','class_confidence'], 'formats':[' Tuple[Dict, Dict]: + """ + Compute detection KPIs on list of boxes in the numpy format, using the COCO python API + https://github.com/cocodataset/cocoapi + KPIs are only computed on timestamps where there is actual at least one box + (fully empty frames are not considered) + :param gt_boxes_list: list of numpy array for GT boxes (one per file) + :param dt_boxes_list: list of numpy array for detected boxes + :param classes: iterable of classes names + :param height: int for box size statistics + :param width: int for box size statistics + :param time_tol: int size of the temporal window in micro seconds to look for a detection around a gt box + """ + flattened_gt = [] + flattened_dt = [] + for gt_boxes, dt_boxes in zip(gt_boxes_list, dt_boxes_list): + gt_boxes = _to_prophesee(gt_boxes) + dt_boxes = _to_prophesee(dt_boxes) + + assert np.all(gt_boxes['t'][1:] >= gt_boxes['t'][:-1]) + assert np.all(dt_boxes['t'][1:] >= dt_boxes['t'][:-1]) + + all_ts = np.unique(gt_boxes['t']) + + gt_win, dt_win = _match_times(all_ts, gt_boxes, dt_boxes, time_tol) + flattened_gt = flattened_gt + gt_win + flattened_dt = flattened_dt + dt_win + + + num_detections = sum([d.size for d in flattened_dt]) + if num_detections == 0: + # Corner case at the very beginning of the training. + print('no detections for evaluation found.') + return None + + categories = [{"id": id + 1, "name": class_name, "supercategory": "none"} + for id, class_name in enumerate(classes)] + + return _to_coco_format(flattened_gt, flattened_dt, categories, height=height, width=width), len(flattened_gt) + + + +def evaluate_detection(gt_boxes_list: List[Dict[str, Tensor]], + dt_boxes_list: List[Dict[str, Tensor]], + classes: str=("car", "pedestrian"), + height: int=240, + width: int=304, + time_tol: int=50000) -> Dict[str, float]: + """ + Compute detection KPIs on list of boxes in the numpy format, using the COCO python API + https://github.com/cocodataset/cocoapi + KPIs are only computed on timestamps where there is actual at least one box + (fully empty frames are not considered) + :param gt_boxes_list: list of numpy array for GT boxes (one per file) + :param dt_boxes_list: list of numpy array for detected boxes + :param classes: iterable of classes names + :param height: int for box size statistics + :param width: int for box size statistics + :param time_tol: int size of the temporal window in micro seconds to look for a detection around a gt box + """ + output = _convert_to_coco_format(gt_boxes_list, + dt_boxes_list, + classes, + height, + width, + time_tol) + + if output is None: + out_keys = ('AP', 'AP_50', 'AP_75', 'AP_S', 'AP_M', 'AP_L') + return {k: 0 for k in out_keys} + else: + (dataset, results), num_gts = output + return _coco_eval(dataset, results, num_gts) + +def _to_prophesee(det: Dict[str, Tensor]): + num_bboxes = len(det['boxes']) + out = np.zeros(shape=(num_bboxes,), dtype=BBOX_DTYPE) + det = {k: v.cpu().numpy() for k, v in det.items()} + x1, y1, x2, y2 = det['boxes'].T + out["x"] = x1 + out["y"] = y1 + out["w"] = x2-x1 + out["h"] = y2-y1 + out["class_id"] = det["labels"] + out["class_confidence"] = det.get("scores", np.ones(shape=(num_bboxes,), dtype="float32")) + return out + +def _match_times(all_ts, gt_boxes, dt_boxes, time_tol): + """ + match ground truth boxes and ground truth detections at all timestamps using a specified tolerance + return a list of boxes vectors + """ + gt_size = len(gt_boxes) + dt_size = len(dt_boxes) + + windowed_gt = [] + windowed_dt = [] + + low_gt, high_gt = 0, 0 + low_dt, high_dt = 0, 0 + for ts in all_ts: + + while low_gt < gt_size and gt_boxes[low_gt]['t'] < ts: + low_gt += 1 + # the high index is at least as big as the low one + high_gt = max(low_gt, high_gt) + while high_gt < gt_size and gt_boxes[high_gt]['t'] <= ts: + high_gt += 1 + + # detection are allowed to be inside a window around the right detection timestamp + low = ts - time_tol + high = ts + time_tol + while low_dt < dt_size and dt_boxes[low_dt]['t'] < low: + low_dt += 1 + # the high index is at least as big as the low one + high_dt = max(low_dt, high_dt) + while high_dt < dt_size and dt_boxes[high_dt]['t'] <= high: + high_dt += 1 + + windowed_gt.append(gt_boxes[low_gt:high_gt]) + windowed_dt.append(dt_boxes[low_dt:high_dt]) + + return windowed_gt, windowed_dt + + +def _coco_eval(dataset, results, num_gts): + """simple helper function wrapping around COCO's Python API + :params: gts iterable of numpy boxes for the ground truth + :params: detections iterable of numpy boxes for the detections + :params: height int + :params: width int + :params: labelmap iterable of class labels + """ + + + # Meaning: https://cocodataset.org/#detection-eval + out_keys = ('AP', 'AP_50', 'AP_75', 'AP_S', 'AP_M', 'AP_L') + out_dict = {k: 0.0 for k in out_keys} + + + coco_gt = COCO() + coco_gt.dataset = dataset + coco_gt.createIndex() + coco_pred = coco_gt.loadRes(results) + + coco_eval = COCOeval(coco_gt, coco_pred, 'bbox') + coco_eval.params.imgIds = np.arange(1, num_gts + 1, dtype=int) + coco_eval.evaluate() + coco_eval.accumulate() + + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): + # info: https://stackoverflow.com/questions/8391411/how-to-block-calls-to-print + coco_eval.summarize() + for idx, key in enumerate(out_keys): + out_dict[key] = coco_eval.stats[idx] + return out_dict + + + +def _to_coco_format(gts, detections, categories, height=240, width=304): + """ + utilitary function producing our data in a COCO usable format + """ + annotations = [] + results = [] + images = [] + + # to dictionary + for image_id, (gt, pred) in enumerate(zip(gts, detections)): + im_id = image_id + 1 + + images.append( + {"date_captured": "2019", + "file_name": "n.a", + "id": im_id, + "license": 1, + "url": "", + "height": height, + "width": width}) + + for bbox in gt: + x1, y1 = bbox['x'], bbox['y'] + w, h = bbox['w'], bbox['h'] + area = w * h + + annotation = { + "area": float(area), + "iscrowd": False, + "image_id": im_id, + "bbox": [x1, y1, w, h], + "category_id": int(bbox['class_id']) + 1, + "id": len(annotations) + 1 + } + annotations.append(annotation) + + for bbox in pred: + + image_result = { + 'image_id': im_id, + 'category_id': int(bbox['class_id']) + 1, + 'score': float(bbox['class_confidence']), + 'bbox': [bbox['x'], bbox['y'], bbox['w'], bbox['h']], + } + results.append(image_result) + + dataset = {"info": {}, + "licenses": [], + "type": 'instances', + "images": images, + "annotations": annotations, + "categories": categories} + return dataset, results diff --git a/src/dagr/utils/learning_rate_scheduler.py b/src/dagr/utils/learning_rate_scheduler.py new file mode 100644 index 0000000..f59ca82 --- /dev/null +++ b/src/dagr/utils/learning_rate_scheduler.py @@ -0,0 +1,48 @@ +from functools import partial +import math +from typing import List + +import numpy as np + + +class LRSchedule: + def __init__(self, + warmup_epochs: float, + num_iters_per_epoch: int, + tot_num_epochs: int, + min_lr_ratio: float=0.05, + warmup_lr_start: float=0, + steps_at_iteration=[50000], + reduction_at_step=0.5): + + warmup_total_iters = num_iters_per_epoch * warmup_epochs + total_iters = tot_num_epochs * num_iters_per_epoch + no_aug_iters = 0 + self.lr_func = partial(_yolox_warm_cos_lr, min_lr_ratio, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iters, steps_at_iteration, reduction_at_step) + + def __call__(self, *args, **kwargs)->float: + return self.lr_func(*args, **kwargs) + + +def _yolox_warm_cos_lr( + min_lr_ratio: float, + total_iters: int, + warmup_total_iters: int, + warmup_lr_start: float, + no_aug_iter: int, + steps_at_iteration: List[int], + reduction_at_step: float, + iters: int)->float: + """Cosine learning rate with warm up.""" + min_lr = min_lr_ratio + if iters < warmup_total_iters: + # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start + lr = (1 - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start + else: + lr = min_lr + 0.5 * (1 - min_lr) * (1.0 + math.cos(math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))) + + for step in steps_at_iteration: + if iters >= step: + lr *= reduction_at_step + + return lr \ No newline at end of file diff --git a/src/dagr/utils/logging.py b/src/dagr/utils/logging.py new file mode 100644 index 0000000..85020de --- /dev/null +++ b/src/dagr/utils/logging.py @@ -0,0 +1,125 @@ +import torch +import wandb +import os + +from typing import List, Dict, Optional +from torch_geometric.data import Batch +from pathlib import PosixPath +from pprint import pprint +from pathlib import Path + +from torch_geometric.data import Data + + +def set_up_logging_directory(dataset, task, output_directory): + project = f"low_latency-{dataset}-{task}" + + output_directory = output_directory / dataset / task + output_directory.mkdir(parents=True, exist_ok=True) + wandb.init(project=project, entity="rpg", save_code=True, dir=str(output_directory)) + + name = wandb.run.name + output_directory = output_directory / name + output_directory.mkdir(parents=True, exist_ok=True) + os.system(f"cp -r {os.path.join(os.path.dirname(__file__), '../../low_latency_object_detection')} {str(output_directory)}") + + return output_directory + +def log_hparams(args): + hparams = {k: str(v) if type(v) is PosixPath else v for k, v in vars(args).items()} + pprint(hparams) + wandb.log(hparams) + +def log_bboxes(data: Batch, + targets: List[Dict[str, torch.Tensor]], + detections: List[Dict[str, torch.Tensor]], + class_names: List[str], + bidx: int, + key: str): + + gt_bbox = [] + det_bbox = [] + images = [] + for b, datum in enumerate(data.to_data_list()): + image = visualize_events(datum) + image = torch.cat([image, image], dim=1) + images.append(image) + + if len(detections) > 0: + det = detections[b] + det = torch.cat([det['boxes'], det['labels'].view(-1,1), det['scores'].view(-1,1)], dim=-1) + det[:, [0, 2]] += b * datum.width + det_bbox.append(det) + + if len(targets) > 0: + tar = targets[b] + tar = torch.cat([tar['boxes'], tar['labels'].view(-1, 1), torch.ones_like(tar['labels'].view(-1, 1))], dim=-1) + tar[:, [0, 2]] += b * datum.width + tar[:, [1, 3]] += datum.height + gt_bbox.append(tar) + + if b == bidx-1: + break + + pred_bbox = torch.cat(det_bbox) + gt_bbox = torch.cat(gt_bbox) + images = torch.cat(images, dim=-1) + + gt_bbox[:,[0,2]] /= (bidx * datum.width) + gt_bbox[:,[1,3]] /= (2 * datum.height) + + pred_bbox[:,[0,2]] /= (bidx * datum.width) + pred_bbox[:,[1,3]] /= (2 * datum.height) + + image = __convert_to_wandb_data(images.detach().float().cpu(), + gt_bbox.detach().cpu(), + pred_bbox.detach().cpu(), + class_names) + + wandb.log({key: image}) + +def visualize_events(data: Data)->torch.Tensor: + x, y = data.pos[:,:2].long().t() + p = data.x[:,0].long() + + if hasattr(data, "image"): + image = data.image[0].clone() + else: + image = torch.full(size=(3, data.height, data.width), fill_value=255, device=p.device, dtype=torch.uint8) + + is_pos = p == 1 + image[:, y[is_pos], x[is_pos]] = torch.tensor([[0],[0],[255]], dtype=torch.uint8, device=p.device) + image[:, y[~is_pos], x[~is_pos]] = torch.tensor([[255],[0],[0]], dtype=torch.uint8, device=p.device) + + return image + +def __convert_to_wandb_data(image: torch.Tensor, gt: torch.Tensor, p: torch.Tensor, class_names: List[str])->wandb.Image: + return wandb.Image(image, boxes={ + "predictions": __parse_bboxes(p, class_names, suffix="P"), + "ground_truth": __parse_bboxes(gt, class_names) + }) + +def __parse_bboxes(bboxes: torch.Tensor, class_names: List[str], suffix: str="GT"): + # bbox N x 6 -> xyxycs + return { + "box_data": [__parse_bbox(bbox, class_names, suffix) for bbox in bboxes], + "class_labels": dict(enumerate(class_names)) + } + +def __parse_bbox(bbox: torch.Tensor, class_names: List[str], suffix: str="GT"): + # bbox xyxycs + return { + "position": { + "minX": float(bbox[0]), + "minY": float(bbox[1]), + "maxX": float(bbox[2]), + "maxY": float(bbox[3]) + }, + "class_id": int(bbox[-2]), + "scores": { + "object score": float(bbox[-1]) + }, + "bbox_caption": f"{suffix} - {class_names[int(bbox[-2])]}" + } + + diff --git a/src/dagr/utils/testing.py b/src/dagr/utils/testing.py new file mode 100644 index 0000000..b7645b7 --- /dev/null +++ b/src/dagr/utils/testing.py @@ -0,0 +1,50 @@ +import torch +from dagr.utils.logging import log_bboxes +from dagr.utils.buffers import DetectionBuffer, format_data +import tqdm + +def to_npy(detections): + return [{k: v.cpu().numpy() for k, v in d.items()} for d in detections] + +def format_detections(sequences, t, detections): + detections = to_npy(detections) + for i, det in enumerate(detections): + det['sequence'] = sequences[i] + det['t'] = t[i] + return detections + +def run_test_with_visualization(loader, model, dataset: str, log_every_n_batch=-1, name="", compile_detections=False): + model.eval() + mapcalc = DetectionBuffer(height=loader.dataset.height, width=loader.dataset.width, + classes=loader.dataset.classes) + + counter = 0 + if compile_detections: + compiled_detections = [] + + for i, data in enumerate(tqdm.tqdm(loader, desc=f"Testing {name}")): + data = data.cuda(non_blocking=True) + data_for_visualization = data.clone() + + data = format_data(data) + detections, targets = model(data.clone()) + + if compile_detections: + compiled_detections.extend(format_detections(data.sequence, data.t1, detections)) + + if log_every_n_batch > 0 and counter % log_every_n_batch == 0: + log_bboxes(data_for_visualization, targets=targets, detections=detections, bidx=4, + class_names=loader.dataset.classes, key="testing/evaluated_bboxes") + + mapcalc.update(detections, targets, dataset, data.height[0], data.width[0]) + + if i % 5 == 0: + torch.cuda.empty_cache() + + counter += 1 + + torch.cuda.empty_cache() + + data = mapcalc.compute() + + return (data, compiled_detections) if compile_detections else data \ No newline at end of file diff --git a/src/dagr/visualization/bbox_viz.py b/src/dagr/visualization/bbox_viz.py new file mode 100644 index 0000000..953a4dc --- /dev/null +++ b/src/dagr/visualization/bbox_viz.py @@ -0,0 +1,72 @@ +import numpy as np +import cv2 +import torchvision +import torch + + +_COLORS = np.array([[0.000, 0.8, 0.1], [1, 0.67, 0.00]]) +class_names = ["car", "pedestrian"] + + +def draw_bbox_on_img(img, x, y, w, h, labels, scores=None, conf=0.5, nms=0.45, label="", linewidth=2): + if scores is not None: + mask = filter_boxes(x, y, w, h, labels, scores, conf, nms) + x = x[mask] + y = y[mask] + w = w[mask] + h = h[mask] + labels = labels[mask] + scores = scores[mask] + + for i in range(len(x)): + if scores is not None and scores[i] < conf: + continue + + x0 = int(x[i]) + y0 = int(y[i]) + x1 = int(x[i] + w[i]) + y1 = int(y[i] + h[i]) + cls_id = int(labels[i]) + + color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist() + + text = f"{label}-{class_names[cls_id]}" + + if scores is not None: + text += f":{scores[i] * 100: .1f}" + + txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255) + font = cv2.FONT_HERSHEY_SIMPLEX + + txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] + cv2.rectangle(img, (x0, y0), (x1, y1), color, linewidth) + + txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist() + txt_height = int(1.5*txt_size[1]) + cv2.rectangle( + img, + (x0, y0 - txt_height), + (x0 + txt_size[0] + 1, y0 + 1), + txt_bk_color, + -1 + ) + cv2.putText(img, text, (x0, y0 + txt_size[1]-txt_height), font, 0.4, txt_color, thickness=1) + return img + +def filter_boxes(x, y, w, h, labels, scores, conf, nms): + mask = scores > conf + + x1, y1 = x + w, y + h + box_coords = np.stack([x, y, x1, y1], axis=-1) + + nms_out_index = torchvision.ops.batched_nms( + torch.from_numpy(box_coords), + torch.from_numpy(np.ascontiguousarray(scores)), + torch.from_numpy(labels), + nms + ) + + nms_mask = np.ones_like(mask) == 0 + nms_mask[nms_out_index] = True + + return mask & nms_mask diff --git a/src/dagr/visualization/event_viz.py b/src/dagr/visualization/event_viz.py new file mode 100644 index 0000000..d5d4340 --- /dev/null +++ b/src/dagr/visualization/event_viz.py @@ -0,0 +1,10 @@ +import numba + +@numba.jit(nopython=True) +def draw_events_on_image(img, x, y, p, alpha=0.5): + img_copy = img.copy() + for i in range(len(p)): + if y[i] < len(img): + img[y[i], x[i], :] = alpha * img_copy[y[i], x[i], :] + img[y[i], x[i], int(p[i])-1] += 255 * (1-alpha) + return img \ No newline at end of file