Skip to content

Commit

Permalink
Experimental compile to improve training speed
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Nov 23, 2023
1 parent b493457 commit d64c1ba
Show file tree
Hide file tree
Showing 8 changed files with 714 additions and 857 deletions.
37 changes: 11 additions & 26 deletions notebooks/backbones.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,19 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.26 ms ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"TensorDict(\n",
" fields={\n",
" ext.1: Tensor(shape=torch.Size([2, 256, 128, 64]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.2: Tensor(shape=torch.Size([2, 512, 64, 32]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.3: Tensor(shape=torch.Size([2, 1024, 32, 16]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.4: Tensor(shape=torch.Size([2, 2048, 16, 8]), device=cuda:0, dtype=torch.float32, is_shared=True)},\n",
" batch_size=torch.Size([2]),\n",
" device=cuda:0,\n",
" is_shared=True)\n",
"5.21 ms ± 4.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"TensorDict(\n",
" fields={\n",
" ext.1: Tensor(shape=torch.Size([2, 64, 256, 128]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.2: Tensor(shape=torch.Size([2, 256, 128, 64]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.3: Tensor(shape=torch.Size([2, 512, 64, 32]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.4: Tensor(shape=torch.Size([2, 1024, 32, 16]), device=cuda:0, dtype=torch.float32, is_shared=True),\n",
" ext.5: Tensor(shape=torch.Size([2, 2048, 16, 8]), device=cuda:0, dtype=torch.float32, is_shared=True)},\n",
" batch_size=torch.Size([2]),\n",
" device=cuda:0,\n",
" is_shared=True)\n"
"ename": "NameError",
"evalue": "name 'resolve_data_config' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/kstolle/projects/unipercept/notebooks/backbones.ipynb Cell 3\u001b[0m line \u001b[0;36m6\n\u001b[1;32m <a href='vscode-notebook-cell://tunnel%2Bsnellius-gcn10/home/kstolle/projects/unipercept/notebooks/backbones.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://tunnel%2Bsnellius-gcn10/home/kstolle/projects/unipercept/notebooks/backbones.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m2\u001b[39m, \u001b[39m3\u001b[39m, \u001b[39m512\u001b[39m, \u001b[39m256\u001b[39m)\u001b[39m.\u001b[39mcuda()\n\u001b[0;32m----> <a href='vscode-notebook-cell://tunnel%2Bsnellius-gcn10/home/kstolle/projects/unipercept/notebooks/backbones.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>\u001b[0m bb_timm \u001b[39m=\u001b[39m up\u001b[39m.\u001b[39;49mnn\u001b[39m.\u001b[39;49mbackbones\u001b[39m.\u001b[39;49mtimm\u001b[39m.\u001b[39;49mTimmBackbone(name\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mresnet50\u001b[39;49m\u001b[39m\"\u001b[39;49m)\u001b[39m.\u001b[39mcuda()\n\u001b[1;32m <a href='vscode-notebook-cell://tunnel%2Bsnellius-gcn10/home/kstolle/projects/unipercept/notebooks/backbones.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>\u001b[0m bb_tv \u001b[39m=\u001b[39m up\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39mbackbones\u001b[39m.\u001b[39mtorchvision\u001b[39m.\u001b[39mTorchvisionBackbone(name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mresnet50\u001b[39m\u001b[39m\"\u001b[39m)\u001b[39m.\u001b[39mcuda()\n\u001b[1;32m <a href='vscode-notebook-cell://tunnel%2Bsnellius-gcn10/home/kstolle/projects/unipercept/notebooks/backbones.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m get_ipython()\u001b[39m.\u001b[39mrun_line_magic(\u001b[39m'\u001b[39m\u001b[39mtimeit\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mout_tv= bb_tv(inp)\u001b[39m\u001b[39m'\u001b[39m)\n",
"File \u001b[0;32m/gpfs/home3/kstolle/projects/unipercept/sources/unipercept/nn/backbones/timm.py:69\u001b[0m, in \u001b[0;36mTimmBackbone.__init__\u001b[0;34m(self, name, pretrained, nodes, keys, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m extractor \u001b[39m=\u001b[39m build_extractor(name, pretrained\u001b[39m=\u001b[39mpretrained, out_indices\u001b[39m=\u001b[39mnodes)\n\u001b[1;32m 67\u001b[0m info \u001b[39m=\u001b[39m infer_feature_info(extractor, dims)\n\u001b[0;32m---> 69\u001b[0m config \u001b[39m=\u001b[39m resolve_data_config({}, model\u001b[39m=\u001b[39mmodel)\n\u001b[1;32m 71\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config:\n\u001b[1;32m 72\u001b[0m kwargs\u001b[39m.\u001b[39msetdefault(\u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m, config[\u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m])\n",
"\u001b[0;31mNameError\u001b[0m: name 'resolve_data_config' is not defined"
]
}
],
Expand Down Expand Up @@ -65,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down
1,378 changes: 621 additions & 757 deletions notebooks/multidvps.ipynb

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion sources/unimodels/multidvps/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import math
import typing as T
from calendar import c
import warnings

import einops
import torch
import torch.nn as nn
from tensordict import TensorDict
from typing_extensions import override
from unimodels.multidvps import logic, modules
Expand Down Expand Up @@ -38,6 +40,15 @@
__all__ = ["MultiDVPS"]


_M = T.TypeVar("_M", bound=nn.Module)

def _maybe_optimize_submodule(module: _M, **kwargs) -> _M:
try:
module = T.cast(_M, torch.compile(module, **kwargs))
except Exception as err:
warnings.warn(f"Could not compile submodule {module.__class__.__name__}: {err}")
return module

class MultiDVPS(up.model.ModelBase):
"""Depth-Aware Video Panoptic Segmentation model using dynamic convolutions."""

Expand Down Expand Up @@ -94,7 +105,7 @@ def __init__(
self.depth_fixed = {} if depth_fixed is None else depth_fixed

# Submodules
self.backbone = backbone
self.backbone = _maybe_optimize_submodule(backbone)
self.detector = detector
self.kernel_mapper = kernel_mapper
self.fusion_thing = fusion_thing
Expand Down
12 changes: 9 additions & 3 deletions sources/unipercept/data/io/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import functools
import torch
from torchvision.io import ImageReadMode
from torchvision.io import read_image as _read_image
from PIL import Image as pil_image
from torchvision.transforms.v2.functional import to_image, to_dtype
from typing_extensions import deprecated
from unicore import file_io

Expand All @@ -25,8 +26,13 @@ def read_image(path: str, *, mode=ImageReadMode.RGB) -> up.data.tensors.Image:

from unipercept.data.tensors import Image

img = _read_image(path, mode=mode)
return img.as_subclass(Image)
img = pil_image.open(path).convert("RGB")
img = to_image(img)
img = to_dtype(img, torch.float32, scale=True)

assert img.shape[0] == 3, f"Expected image to have 3 channels, got {img.shape[0]}!"

return img


@functools.lru_cache(maxsize=MAX_CACHE)
Expand Down
45 changes: 24 additions & 21 deletions sources/unipercept/data/sets/vistas.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@
unified_id= 255,
name="construction--structure--tunnel",
),
SClass(color= (220, 20, 60), SType.THING, dataset_id=19, unified_id= 11, name="human--person"),
SClass(color= (6, 0, 0), SType.THING, dataset_id=20, unified_id= 12, name="human--rider--bicyclist"),
SClass(color= (6, 0, 100), SType.THING, dataset_id=21, unified_id= 12, name="human--rider--motorcyclist"),
SClass(color= (6, 0, 200), SType.THING, dataset_id=22, unified_id= 12, name="human--rider--other-rider"),
SClass(color= (220, 20, 60), kind=SType.THING, dataset_id=19, unified_id= 11, name="human--person"),
SClass(color= (6, 0, 0), kind=SType.THING, dataset_id=20, unified_id= 12, name="human--rider--bicyclist"),
SClass(color= (6, 0, 100), kind=SType.THING, dataset_id=21, unified_id= 12, name="human--rider--motorcyclist"),
SClass(color= (6, 0, 200), kind=SType.THING, dataset_id=22, unified_id= 12, name="human--rider--other-rider"),
SClass(color= (200, 128, 128), kind=SType.STUFF, dataset_id=23, unified_id= 0, name="marking--crosswalk-zebra"),
SClass(color= (6, 6, 6), kind=SType.STUFF, dataset_id=24, unified_id= 0, name="marking--general"),
SClass(color= (64, 170, 64), kind=SType.STUFF, dataset_id=25, unified_id= 9, name="nature--mountain"),
Expand Down Expand Up @@ -121,19 +121,19 @@
SClass(color= (192, 192, 192), kind=SType.STUFF, dataset_id=49, unified_id= 255, name="object--traffic-sign--back"),
SClass(color= (220, 220, 0), kind=SType.STUFF, dataset_id=50, unified_id= 7, name="object--traffic-sign--front"),
SClass(color= (140, 140, 20), kind=SType.STUFF, dataset_id=51, unified_id= 255, name="object--trash-can"),
SClass(color= (119, 11, 32), SType.THING, dataset_id=52, unified_id= 18, name="object--vehicle--bicycle"),
SClass(color= (119, 11, 32), kind=SType.THING, dataset_id=52, unified_id= 18, name="object--vehicle--bicycle"),
SClass(color= (150, 0, 6), kind=SType.STUFF, dataset_id=53, unified_id= 255, name="object--vehicle--boat"),
SClass(color= (0, 60, 100), SType.THING, dataset_id=54, unified_id= 15, name="object--vehicle--bus"),
SClass(color= (0, 0, 142), SType.THING, dataset_id=55, unified_id= 13, name="object--vehicle--car"),
SClass(color= (0, 60, 100), kind=SType.THING, dataset_id=54, unified_id= 15, name="object--vehicle--bus"),
SClass(color= (0, 0, 142), kind=SType.THING, dataset_id=55, unified_id= 13, name="object--vehicle--car"),
SClass(
color= (0, 0, 90),
kind=SType.STUFF,
dataset_id=56,
unified_id= 255,
name="object--vehicle--caravan",
), # Tim: class met id: 29, staat niet in cityscapes config
SClass(color= (0, 0, 230), SType.THING, dataset_id=57, unified_id= 17, name="object--vehicle--motorcycle"),
SClass(color= (0, 80, 100), SType.THING, dataset_id=58, unified_id= 16, name="object--vehicle--on-rails"),
SClass(color= (0, 0, 230), kind=SType.THING, dataset_id=57, unified_id= 17, name="object--vehicle--motorcycle"),
SClass(color= (0, 80, 100), kind=SType.THING, dataset_id=58, unified_id= 16, name="object--vehicle--on-rails"),
SClass(color= (128, 64, 64), kind=SType.STUFF, dataset_id=59, unified_id= 255, name="object--vehicle--other-vehicle"),
SClass(
color= (0, 0, 110),
Expand All @@ -142,7 +142,7 @@
unified_id= 255,
name="object--vehicle--trailer",
), # Tim: class met id: 30, staat niet in cityscapes config
SClass(color= (0, 0, 70), SType.THING, dataset_id=61, unified_id= 14, name="object--vehicle--truck"),
SClass(color= (0, 0, 70), kind=SType.THING, dataset_id=61, unified_id= 14, name="object--vehicle--truck"),
SClass(color= (0, 0, 192), kind=SType.STUFF, dataset_id=62, unified_id= 255, name="object--vehicle--wheeled-slow"),
SClass(color= (32, 32, 32), kind=SType.STUFF, dataset_id=63, unified_id= 255, name="void--car-mount"),
SClass(
Expand Down Expand Up @@ -222,10 +222,10 @@
unified_id= 255,
name="construction--structure--tunnel",
),
SClass(color= (220, 20, 60), SType.THING, dataset_id=19, unified_id= 11, name="human--person"),
SClass(color= (6, 0, 0), SType.THING, dataset_id=20, unified_id= 12, name="human--rider--bicyclist"),
SClass(color= (6, 0, 100), SType.THING, dataset_id=21, unified_id= 12, name="human--rider--motorcyclist"),
SClass(color= (6, 0, 200), SType.THING, dataset_id=22, unified_id= 12, name="human--rider--other-rider"),
SClass(color= (220, 20, 60), kind=SType.THING, dataset_id=19, unified_id= 11, name="human--person"),
SClass(color= (6, 0, 0), kind=SType.THING, dataset_id=20, unified_id= 12, name="human--rider--bicyclist"),
SClass(color= (6, 0, 100), kind=SType.THING, dataset_id=21, unified_id= 12, name="human--rider--motorcyclist"),
SClass(color= (6, 0, 200), kind=SType.THING, dataset_id=22, unified_id= 12, name="human--rider--other-rider"),
SClass(color= (200, 128, 128), kind=SType.STUFF, dataset_id=23, unified_id= 0, name="marking--crosswalk-zebra"),
SClass(color= (6, 6, 6), kind=SType.STUFF, dataset_id=24, unified_id= 0, name="marking--general"),
SClass(color= (64, 170, 64), kind=SType.STUFF, dataset_id=25, unified_id= 9, name="nature--mountain"),
Expand Down Expand Up @@ -261,19 +261,19 @@
SClass(color= (192, 192, 192), kind=SType.STUFF, dataset_id=49, unified_id= 255, name="object--traffic-sign--back"),
SClass(color= (220, 220, 0), kind=SType.STUFF, dataset_id=50, unified_id= 7, name="object--traffic-sign--front"),
SClass(color= (140, 140, 20), kind=SType.STUFF, dataset_id=51, unified_id= 255, name="object--trash-can"),
SClass(color= (119, 11, 32), SType.THING, dataset_id=52, unified_id= 18, name="object--vehicle--bicycle"),
SClass(color= (119, 11, 32), kind=SType.THING, dataset_id=52, unified_id= 18, name="object--vehicle--bicycle"),
SClass(color= (150, 0, 6), kind=SType.STUFF, dataset_id=53, unified_id= 255, name="object--vehicle--boat"),
SClass(color= (0, 60, 100), SType.THING, dataset_id=54, unified_id= 15, name="object--vehicle--bus"),
SClass(color= (0, 0, 142), SType.THING, dataset_id=55, unified_id= 13, name="object--vehicle--car"),
SClass(color= (0, 60, 100), kind=SType.THING, dataset_id=54, unified_id= 15, name="object--vehicle--bus"),
SClass(color= (0, 0, 142), kind=SType.THING, dataset_id=55, unified_id= 13, name="object--vehicle--car"),
SClass(
color= (0, 0, 90),
kind=SType.STUFF,
dataset_id=56,
unified_id= 255,
name="object--vehicle--caravan",
), # Tim: class met id: 29, staat niet in cityscapes config
SClass(color= (0, 0, 230), SType.THING, dataset_id=57, unified_id= 17, name="object--vehicle--motorcycle"),
SClass(color= (0, 80, 100), SType.THING, dataset_id=58, unified_id= 16, name="object--vehicle--on-rails"),
SClass(color= (0, 0, 230), kind=SType.THING, dataset_id=57, unified_id= 17, name="object--vehicle--motorcycle"),
SClass(color= (0, 80, 100), kind=SType.THING, dataset_id=58, unified_id= 16, name="object--vehicle--on-rails"),
SClass(color= (128, 64, 64), kind=SType.STUFF, dataset_id=59, unified_id= 255, name="object--vehicle--other-vehicle"),
SClass(
color= (0, 0, 110),
Expand All @@ -282,7 +282,7 @@
unified_id= 255,
name="object--vehicle--trailer",
), # Tim: class met id: 30, staat niet in cityscapes config
SClass(color= (0, 0, 70), SType.THING, dataset_id=61, unified_id= 14, name="object--vehicle--truck"),
SClass(color= (0, 0, 70), kind=SType.THING, dataset_id=61, unified_id= 14, name="object--vehicle--truck"),
SClass(color= (0, 0, 192), kind=SType.STUFF, dataset_id=62, unified_id= 255, name="object--vehicle--wheeled-slow"),
SClass(color= (32, 32, 32), kind=SType.STUFF, dataset_id=63, unified_id= 255, name="void--car-mount"),
SClass(
Expand Down Expand Up @@ -323,9 +323,12 @@ def _build_manifest(self) -> Manifest:

sequences: T.Mapping[str, ManifestSequence] = {}

image_list = list((split_dir / "images").glob("*.png"))
image_list = list((split_dir / "images").glob("*.jpg"))
image_list.sort(key=lambda p: p.stem)

if len(image_list) == 0:
raise RuntimeError(f"No images found in {split_dir}")

for image_path in tqdm(image_list, desc="Building manifest", unit="image"):
cap_key = image_path.stem
seq_key = cap_key # NOTE: no sequences
Expand Down
40 changes: 6 additions & 34 deletions sources/unipercept/nn/backbones/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import torch.nn as nn
from typing_extensions import override

import unipercept.data.tensors

from torch.types import Device

__all__ = ["Normalizer"]


Expand All @@ -20,47 +16,23 @@ class Normalizer(nn.Module):
Normalizes input captures
"""

def __init__(self, mean: list[float], std: list[float], image_format: str | None = None):
def __init__(self, mean: list[float], std: list[float]):
super().__init__()

self.image_format = image_format
self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False)
self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1))
self.register_buffer("std", torch.tensor(std).view(-1, 1, 1))
assert self.mean.shape == self.std.shape, f"{self.mean} and {self.std} have different shapes!"

@property
def device(self) -> Device:
return self.mean.device

@override
def denormalize(self, image: torch.Tensor) -> torch.Tensor:
"""
Denormalize an image to a float with values [0, 1]
"""
image = image.to(device=self.device, dtype=torch.float32) * self.std + self.mean # type: ignore

if self.image_format == "BGR":
image = image[[2, 1, 0], :, :]
elif self.image_format != "RGB":
raise ValueError(f"Unknown image format: {self.image_format}")

image /= 255.0
image = image.clamp(0, 1)

return image

def normalize(self, image: torch.Tensor) -> torch.Tensor:
"""
Normalize an image.
"""
return (image.to(device=self.device) - self.mean) / self.std # type: ignore
return (image - self.mean) / self.std # type: ignore

@override
def forward(self, data: unipercept.model.InputData) -> unipercept.model.InputData:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Copy data to device and normalize each captures input image.
"""

data.captures.images = self.normalize(data.captures.images).as_subclass(unipercept.data.tensors.Image)

return data
return self.normalize(x)
Loading

0 comments on commit d64c1ba

Please sign in to comment.