Skip to content

Commit

Permalink
Changed dataloading structure
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Nov 24, 2023
1 parent d64c1ba commit a11636c
Show file tree
Hide file tree
Showing 18 changed files with 868 additions and 670 deletions.
6 changes: 5 additions & 1 deletion configs/cityscapes/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
actions=[
L(up.data.ops.TorchvisionOp)(
transforms=[
L(transforms.CenterCrop)(size=(1024 - 32, 2048 - 32)),
L(transforms.RandomResize)(min_size=512, max_size=1024 + 512, antialias=True),
L(transforms.RandomHorizontalFlip)(),
L(transforms.RandomCrop)(size=(512, 1024), pad_if_needed=False),
L(transforms.RandomHorizontalFlip)(),
L(transforms.RandomPhotometricDistort)(),
]
),
],
Expand All @@ -45,9 +47,11 @@
actions=[
L(up.data.ops.TorchvisionOp)(
transforms=[
L(transforms.CenterCrop)(size=(1024 - 32, 2048 - 32)),
L(transforms.RandomResize)(min_size=512, max_size=1024, antialias=True),
L(transforms.RandomCrop)(size=(512, 1024), pad_if_needed=False),
L(transforms.RandomHorizontalFlip)(),
L(transforms.RandomPhotometricDistort)(),
]
),
L(up.data.ops.PseudoMotion)(frames=2, size=(512, 1024)),
Expand Down
2 changes: 1 addition & 1 deletion configs/cityscapes/multidvps_resnet18_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
trainer.config.eval_steps = 1000
trainer.config.logging_steps = 10

model.backbone.base = L(up.nn.backbones.timm.TimmBackbone)(name="resnet18")
model.backbone.base = L(up.nn.backbones.timm.TimmBackbone)(name="resnet18d")
model.backbone.out_channels = 24
model.detector.localizer.encoder.out_channels = 64
model.feature_encoder.shared_encoder.out_channels = 64
Expand Down
2 changes: 1 addition & 1 deletion configs/cityscapes/multidvps_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
weighted_num=7,
common_stride=4,
backbone=L(up.nn.backbones.fpn.FeaturePyramidNetwork)(
bottom_up=L(up.nn.backbones.timm.TimmBackbone)(name="resnet50"),
bottom_up=L(up.nn.backbones.timm.TimmBackbone)(name="resnet50d"),
in_features=["ext.2", "ext.3", "ext.4", "ext.5"],
out_channels=128,
norm=up.nn.layers.norm.LayerNormCHW,
Expand Down
8 changes: 8 additions & 0 deletions notebooks/backbones.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
"File \u001b[0;32m/gpfs/home3/kstolle/projects/unipercept/sources/unipercept/nn/backbones/timm.py:69\u001b[0m, in \u001b[0;36mTimmBackbone.__init__\u001b[0;34m(self, name, pretrained, nodes, keys, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m extractor \u001b[39m=\u001b[39m build_extractor(name, pretrained\u001b[39m=\u001b[39mpretrained, out_indices\u001b[39m=\u001b[39mnodes)\n\u001b[1;32m 67\u001b[0m info \u001b[39m=\u001b[39m infer_feature_info(extractor, dims)\n\u001b[0;32m---> 69\u001b[0m config \u001b[39m=\u001b[39m resolve_data_config({}, model\u001b[39m=\u001b[39mmodel)\n\u001b[1;32m 71\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config:\n\u001b[1;32m 72\u001b[0m kwargs\u001b[39m.\u001b[39msetdefault(\u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m, config[\u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m])\n",
"\u001b[0;31mNameError\u001b[0m: name 'resolve_data_config' is not defined"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
Expand Down
1,343 changes: 749 additions & 594 deletions notebooks/multidvps.ipynb

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions scripts/hpc_run.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#!/bin/bash
# Script for training a model in a HPC environment with Slurm and multiple GPUs.
# Tested on the Snellius cluster provided by SURF.
#SBATCH --mail-type=BEGIN,END,FAIL
#SBATCH --partition=gpu
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gpus=2
#SBATCH --cpus-per-gpu=18
#SBATCH --time=24:00:00
#SBATCH --job-name=unipercept

#SBATCH --mail-type=ALL
#SBATCH --partition=gpu --nodes 1

set -e

Expand All @@ -12,9 +16,6 @@ echo "Running on $(hostname)"
echo "Loading HPC modules"
source "./scripts/hpc_env.sh"

echo "Loading Python virtual environment"
source "./venv/bin/activate"

echo "Starting distributed training"
accelerate launch $(which unicli) $@
srun `realpath ./venv/bin/accelerate` launch `realpath ./venv/bin/unicli` $@
exit $?
4 changes: 4 additions & 0 deletions sources/unimodels/multidvps/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@

_M = T.TypeVar("_M", bound=nn.Module)

OPTIMIZE_ENABLED = False

def _maybe_optimize_submodule(module: _M, **kwargs) -> _M:
if not OPTIMIZE_ENABLED:
return module
try:
module = T.cast(_M, torch.compile(module, **kwargs))
except Exception as err:
Expand Down
4 changes: 2 additions & 2 deletions sources/unimodels/multidvps/logic/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ def merge_predictions(
# Check skipping conditions based on model config
if stuff_with_things and cat_st_train == 0:
continue # 0 is a special 'thing' class
# if stuff_all_classes and (cat_st in id_map_thing.values()):
# continue # Skip semantic classes that are also things
if stuff_all_classes and (cat_st in id_map_thing.values()):
continue # Skip semantic classes that are also things

# Select only pixels that belong to the current class and are not
# already present in the output panpotic segmentation
Expand Down
7 changes: 4 additions & 3 deletions sources/unipercept/data/_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Defines functions for creating dataloaders for training and validation, using the common dataset format."""

from __future__ import annotations

import os
import dataclasses
import multiprocessing as M
import typing as T
Expand All @@ -26,13 +26,14 @@

_logger = get_logger(__name__)

DEFAULT_NUM_WORKERS = max(1, int(os.getenv("SLURM_CPUS_ON_NODE", M.cpu_count() // 2)))

@dataclasses.dataclass(slots=True, frozen=True)
class DataLoaderConfig:
drop_last: bool = False
pin_memory: bool = True
num_workers: int = max(1, M.cpu_count() // 2)
prefetch_factor: int | None = 4
num_workers: int = DEFAULT_NUM_WORKERS
prefetch_factor: int | None = 6
persistent_workers: bool | None = False


Expand Down
91 changes: 63 additions & 28 deletions sources/unipercept/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
import torch.nn
import torch.types
import torch.utils.data as torch_data
import torchvision.transforms.v2 as tvt2
import torchvision.ops
from torchvision import disable_beta_transforms_warning as __disable_warning
from typing_extensions import override
from unicore.utils.pickle import as_picklable

from .tensors import BoundingBoxes, PanopticMap, DepthMap, BoundingBoxFormat
from unipercept.utils.logutils import get_logger

if T.TYPE_CHECKING:
import unipercept as up
from unipercept.model import InputData

_logger = get_logger(name=__file__)

Expand All @@ -35,9 +38,9 @@
__all__ = ["apply_dataset", "Op", "CloneOp", "NoOp", "TorchvisionOp"]


# ---------------------- #
# Baseclass and protocol #
# ---------------------- #
########################################################################################################################
# BASE CLASS FOR OPS
########################################################################################################################


class Op(torch.nn.Module, metaclass=abc.ABCMeta):
Expand All @@ -49,40 +52,40 @@ def __init__(self) -> None:
super().__init__()

@override
def forward(self, inputs: up.model.InputData) -> up.model.InputData:
def forward(self, inputs: InputData) -> InputData:
assert len(inputs.batch_size) == 0, f"Expected a single batched data point, got {inputs.batch_size}!"
inputs = self._run(inputs)
return inputs

@abc.abstractmethod
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
raise NotImplementedError(f"{self.__class__.__name__} is missing required implemention!")

if T.TYPE_CHECKING:

@override
def __call__(self, inputs: up.model.InputData) -> up.model.InputData:
def __call__(self, inputs: InputData) -> InputData:
...


# ------------------------------ #
# Basic and primitive operations #
# ------------------------------ #
########################################################################################################################
# BASIC OPS
########################################################################################################################


class NoOp(Op):
"""Do nothing."""

@override
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
return inputs


class PinOp(Op):
"""Pin the input data to the device."""

@override
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
inputs = inputs.pin_memory()
return inputs

Expand All @@ -96,7 +99,7 @@ def __init__(self, *args, **kwargs) -> None:
self.register_forward_hook(self._log) # type: ignore

@staticmethod
def _log(mod, inputs: up.model.InputData, outputs: tuple[list[str], up.model.InputData]) -> None:
def _log(mod, inputs: InputData, outputs: tuple[list[str], InputData]) -> None:
ids_str = ", ".join(inputs.ids)
print(f"Applying ops on: '{ids_str}'...")

Expand All @@ -105,16 +108,15 @@ class CloneOp(Op):
"""Copy the input data."""

@override
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
inputs = inputs.clone(recurse=True)
return inputs


# ---------------------------------- #
# Torchvision transforms as Ops #
# ---------------------------------- #
########################################################################################################################
# TORCHVISION: Wrappers for torchvision transforms
########################################################################################################################

import torchvision.transforms.v2 as tvt2


class TorchvisionOp(Op):
Expand All @@ -136,7 +138,7 @@ def __init__(self, transforms: T.Sequence[tvt2.Transform] | tvt2.Transform, *, v
raise ValueError(f"Expected transforms to be a sequence or transform`, got {transforms}!")

@override
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
from .tensors.registry import pixel_maps

if inputs.motions is not None:
Expand Down Expand Up @@ -164,6 +166,9 @@ def _run(self, inputs: up.model.InputData) -> up.model.InputData:

return inputs

########################################################################################################################
# PSEUDO MOTION
########################################################################################################################

class PseudoMotion(Op):
def __init__(
Expand Down Expand Up @@ -192,17 +197,19 @@ def __init__(
tvt2.Resize(tuple(int(s * scale) for s in size_crop), antialias=True),
tvt2.RandomAdjustSharpness(1.5),
tvt2.RandomAffine(shear=(-shear, shear), degrees=(-rotation, rotation)),
tvt2.RandomPhotometricDistort(),
tvt2.GaussianBlur((5, 9))
]
)

@override
def _run(self, inputs: up.model.InputData) -> up.model.InputData:
def _run(self, inputs: InputData) -> InputData:
assert len(inputs.batch_size) == 0

bs = list(inputs.captures.batch_size)
assert bs[-1] == 1, f"Data already is a sequence: {inputs.captures.batch_size}"

inp_list: list[up.model.InputData] = []
inp_list: list[InputData] = []

for i in range(self._out_frames):
inp_prev = inputs if i == 0 else self._upscale(inp_list[i - 1].clone())
Expand All @@ -224,15 +231,43 @@ def _run(self, inputs: up.model.InputData) -> up.model.InputData:

return inputs

########################################################################################################################
# BOXES FROM MASKS
########################################################################################################################

class BoxesFromMasks(Op):
"""
Adds bounding boxes for each ground truth mask in the input segmentation.
"""
def __init__(self):
super().__init__()

@override
def _run(self, inputs: InputData) -> InputData:
assert len(inputs.batch_size) == 0

caps = inputs.captures.fix_subtypes_()
if caps.segmentations is not None:
boxes = []
for cap in caps:
segs = torch.stack([m for _, m in cap.segmentations.as_subclass(PanopticMap).get_instance_masks()])
boxes.append(torchvision.ops.masks_to_boxes(segs))

h, w = inputs.captures.images.shape[-2:]
inputs.captures.boxes = [BoundingBoxes(b, format=BoundingBoxFormat.XYXY, canvas_size=(h,w)) for b in boxes]

return inputs



# ------------------------------------------------- #
# Transformed versions of map and iterable datasets #
# ------------------------------------------------- #
########################################################################################################################
# TRANSFORMED DATASETS
########################################################################################################################

_D = T.TypeVar("_D", bound=torch_data.Dataset, contravariant=True)


class _TransformedIterable(torch_data.IterableDataset["up.model.InputData"], T.Generic[_D]):
class _TransformedIterable(torch_data.IterableDataset["InputData"], T.Generic[_D]):
"""Applies a sequence of transformations to an iterable dataset."""

__slots__ = ("_set", "_fns")
Expand Down Expand Up @@ -260,7 +295,7 @@ def __repr__(self):
return f"<{repr(self._set)} x {len(self._fns)} transforms>"

@override
def __iter__(self) -> T.Iterator[up.model.InputData]:
def __iter__(self) -> T.Iterator[InputData]:
it = iter(self._set)
while True:
try:
Expand All @@ -277,7 +312,7 @@ def __iter__(self) -> T.Iterator[up.model.InputData]:
yield inputs


class _TransformedMap(torch_data.Dataset["up.model.InputData"], T.Generic[_D]):
class _TransformedMap(torch_data.Dataset["InputData"], T.Generic[_D]):
"""Applies a sequence of transformations to an iterable dataset."""

__slots__ = ("_set", "_fns", "_retry", "_fallback_candidates")
Expand Down Expand Up @@ -308,7 +343,7 @@ def __repr__(self):
return f"<{repr(self._set)} x {len(self._fns)} transforms>"

@override
def __getitem__(self, idx: int | str) -> tuple[up.model.InputData]:
def __getitem__(self, idx: int | str) -> tuple[InputData]:
for _ in range(self._retry):
inputs = self._set[idx]
assert len(inputs.batch_size) == 0, f"Expected a single batched data point, got {inputs.batch_size}!"
Expand Down
1 change: 1 addition & 0 deletions sources/unipercept/data/sets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def _load_capture_data(
images=multi_read(read_image, "image", no_entries="error")(sources),
segmentations=multi_read(read_segmentation, "panoptic", no_entries="none")(sources, info),
depths=multi_read(read_depth_map, "depth", no_entries="none")(sources),
boxes=None,
batch_size=[num_caps],
)

Expand Down
3 changes: 1 addition & 2 deletions sources/unipercept/data/tensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@
from . import helpers, registry
from ._depth import *
from ._flow import *
from ._image import *
from ._mask import *
from ._panoptic import *
from ._torchvision import *
13 changes: 0 additions & 13 deletions sources/unipercept/data/tensors/_image.py

This file was deleted.

Loading

0 comments on commit a11636c

Please sign in to comment.