Skip to content

Commit

Permalink
Pre-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Mar 5, 2024
1 parent 8f9ca4d commit 87f4b5a
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 140 deletions.
3 changes: 2 additions & 1 deletion sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def _read_model_wandb(path: str) -> str:
from unipercept import file_io

run = _wandb_read_run(path)
import wandb
from wandb.sdk.wandb_run import Run

import wandb

assert path.startswith(WANDB_RUN_PREFIX)

_logger.info("Reading W&B model checkpoint from %s", path)
Expand Down
1 change: 0 additions & 1 deletion sources/unipercept/_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def get_dataset(
...

@T.overload
@TX.deprecated("Use 'name' instead of 'query'.")
def get_dataset(
query: None, # noqa: U100
**kwargs: T.Any, # noqa: U100
Expand Down
20 changes: 14 additions & 6 deletions sources/unipercept/cli/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Training and evaluation entry point."""
"""Evaluation entry point."""
from __future__ import annotations

import argparse
Expand Down Expand Up @@ -35,12 +35,17 @@ def evaluate(p: argparse.ArgumentParser):
type=str,
help="path to load model weights from, if not are inferred from the configuration path",
)

p.add_argument(
"suites",
nargs="*",
"--path",
"-o",
type=up.file_io.Path,
help="path to store outputs from evaluation",
)
p.add_argument(
"--suite",
nargs="+",
type=str,
help="evaluation suites to run (default: all available)",
help="evaluation suite to run",
)

return _main
Expand Down Expand Up @@ -68,8 +73,11 @@ def _main(args):

engine = up.create_engine(lazy_config)
model_factory = up.create_model_factory(lazy_config, weights=args.weights or None)

try:
results = engine.run_evaluation(model_factory)
results = engine.run_evaluation(
model_factory, suites=args.suite if len(args.suite) > 0 else None
)
_logger.info(
"Evaluation results: \n%s", up.log.create_table(results, format="long")
)
Expand Down
22 changes: 21 additions & 1 deletion sources/unipercept/data/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ class DataLoaderFactory:
"""

dataset: PerceptionDataset
actions: T.Sequence[Op]
sampler: SamplerFactory
actions: T.Sequence[Op] = D.field(default_factory=list)
config: DataLoaderConfig = D.field(default_factory=DataLoaderConfig)
iterable: bool = D.field(
default=False,
Expand All @@ -117,6 +117,26 @@ class DataLoaderFactory:
},
)

@classmethod
def with_training_defaults(cls, dataset: PerceptionDataset, **kwargs) -> T.Self:
"""Create a loader factory with default settings for inference mode."""
return cls(
dataset=dataset,
sampler=SamplerFactory(sampler="inference"),
config=DataLoaderConfig(drop_last=False),
**kwargs,
)

@classmethod
def with_inference_defaults(cls, dataset: PerceptionDataset, **kwargs) -> T.Self:
"""Create a loader factory with default settings for training mode."""
return cls(
dataset=dataset,
sampler=SamplerFactory(sampler="training"),
config=DataLoaderConfig(drop_last=True),
**kwargs,
)

def __call__(
self, batch_size: int | None = None, /, use_distributed: bool = True
) -> DataLoader:
Expand Down
22 changes: 13 additions & 9 deletions sources/unipercept/data/sets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from __future__ import annotations

import abc
import dataclasses
import dataclasses as D
import enum as E
import functools
import typing as T

import torch
import torch.utils.data
import typing_extensions as TX

from unipercept.data.tensors import PanopticMap
from unipercept.data.types import Manifest, QueueItem
from unipercept.utils.camera import build_calibration_matrix
from unipercept.utils.catalog import DataManager
from unipercept.utils.dataset import Dataset as _BaseDataset
Expand All @@ -20,11 +23,11 @@
from unipercept.utils.tensorclass import Tensorclass

if T.TYPE_CHECKING:
from unipercept.data.collect import ExtractIndividualFrames
from unipercept.model import CaptureData
from unipercept.data.collect import QueueGeneratorType
import unipercept
from unipercept.data.collect import ExtractIndividualFrames, QueueGeneratorType
from unipercept.data.types.coco import COCOCategory
from unipercept.model import CaptureData, ModelOutput

from unipercept.data.types import COCOCategory, Manifest, QueueItem

__all__ = [
"PerceptionDataset",
Expand Down Expand Up @@ -527,7 +530,8 @@ class PerceptionDataset(
"""Baseclass for datasets that are composed of captures and motions."""

queue_fn: T.Callable[[Manifest], QueueGeneratorType] = dataclasses.field(
default_factory=_individual_frames_queue
default_factory=_individual_frames_queue,
metadata={"help": "Queue generator", "locate": True},
)

@TX.override
Expand Down Expand Up @@ -589,17 +593,17 @@ def _load_capture_data(

@classmethod
def _load_motion_data(
cls, sources: T.Sequence[up.data.types.MotionSources], info: Metadata
) -> up.model.MotionData:
cls, sources: T.Sequence[unipercept.data.types.MotionSources], info: Metadata
) -> unipercept.model.MotionData:
raise NotImplementedError(f"{cls.__name__} does not implement motion sources!")

_data_cache: T.ClassVar[dict[str, up.model.InputData]] = {}
_data_cache: T.ClassVar[dict[str, unipercept.model.InputData]] = {}

@classmethod
@TX.override
def _load_data(
cls, key: str, item: QueueItem, info: Metadata
) -> up.model.InputData:
) -> unipercept.model.InputData:
from unipercept.model import CameraModel, InputData

# Check for cache hit, should be a memmaped tensor
Expand Down
32 changes: 27 additions & 5 deletions sources/unipercept/data/sets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from __future__ import annotations

import dataclasses as D
import functools
import operator
import re
import typing as T
from dataclasses import dataclass
from datetime import datetime
from importlib import metadata
from typing import Iterable, Literal, Mapping

import torch.utils.data
import typing_extensions as TX
from tensordict import TensorDictBase
from typing_extensions import override

from unipercept import file_io
Expand All @@ -23,10 +28,14 @@
SType,
create_metadata,
)
from unipercept.evaluators import Evaluator
from unipercept.model import ModelOutput
from unipercept.utils.formatter import formatter

if T.TYPE_CHECKING:
import unipercept as up
from unipercept.data import DataLoaderFactory
from unipercept.data.types.coco import COCOResultPanoptic

__all__ = ["CityscapesDataset", "CityscapesVPSDataset"]

Expand Down Expand Up @@ -64,6 +73,16 @@ def attach_id(cls, path: str) -> tuple[T.Self, str]:
path,
)

@property
def primary_key(self) -> str:
"""Return a canonical primary key for the file.
Returns
-------
Primary key, e.g. "berlin_000123_000019"
"""
return f"{self.city}_{self.drive}_{self.frame}"

def __lt__(self, other: FileID) -> bool:
return D.astuple(self) < D.astuple(other)

Expand All @@ -77,8 +96,8 @@ def __ge__(self, other: FileID) -> bool:
return D.astuple(self) >= D.astuple(other)


def get_primary_key(seq_key: str, idx: int) -> str:
return f"{seq_key}_{idx:06d}"
# def get_primary_key(seq_key: str, idx: int) -> str:
# return f"{seq_key}_{idx:06d}"


def get_sequence_key(seq_idx: int) -> str:
Expand Down Expand Up @@ -289,6 +308,7 @@ def to_canonical(self) -> up.data.types.PinholeModelParameters:
]


@functools.lru_cache()
def get_info():
return create_metadata(
CLASSES,
Expand All @@ -309,8 +329,10 @@ class CityscapesDataset(PerceptionDataset, info=get_info, id="cityscapes"):
Link: https://www.cityscapes-dataset.com/
"""

split: Literal["train", "val", "test"]
root: str = "//datasets/cityscapes"
split: Literal["train", "val", "test"] = D.field(metadata={"help": "Dataset split"})
root: str = D.field(
default="//datasets/cityscapes", metadata={"help": "Root directory"}
)

path_image = formatter("{self.root}/leftImg8bit/{self.split}")
path_panoptic = formatter("{self.root}/gtFine/cityscapes_panoptic_{self.split}")
Expand Down Expand Up @@ -437,7 +459,7 @@ def _build_manifest(self) -> up.data.types.Manifest:
camera = CAMERA.to_canonical() # TODO: read from json
captures: list[up.data.types.CaptureRecord] = [
{
"primary_key": get_primary_key(seq_key, i),
"primary_key": id.primary_key, # get_primary_key(seq_key, i),
"sources": sources_map[id],
}
for i, id in enumerate(ids)
Expand Down
25 changes: 24 additions & 1 deletion sources/unipercept/data/tensors/_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import typing as T
from enum import StrEnum, auto

import PIL.Image as pil_image
import safetensors.torch as safetensors
import torch
from torchvision.tv_tensors import Mask as _Mask
from typing_extensions import override

from unipercept import file_io
from unipercept.data.tensors.helpers import write_png
from unipercept.data.types.coco import COCOResultPanoptic, COCOResultPanopticSegment
from unipercept.utils.typings import Pathable

from .registry import pixel_maps
Expand Down Expand Up @@ -373,7 +375,9 @@ def remove_instances_(self, semantic_list: T.Iterable[int]) -> None:
for class_ in semantic_list:
self[can_map == class_] = class_ * self.DIVISOR

def translate_semantic_(self, translation: dict[int, int]) -> None:
def translate_semantic_(
self, translation: dict[int, int], inverse: bool = False
) -> None:
"""
Apply a translation to the class labels. The translation is a dictionary mapping old class IDs to
new class IDs. All old class IDs that are not in the dictionary are mapped to ``ignore_label``.
Expand All @@ -388,13 +392,32 @@ def translate_semantic_(self, translation: dict[int, int]) -> None:
old_id,
new_id,
) in translation.items():
if inverse:
old_id, new_id = new_id, old_id

mask = sem_map == old_id
self[mask] = new_id * self.DIVISOR + ins_map[mask]

def get_nonempty(self) -> _Mask:
"""Return a new instance with only the non-empty pixels."""
return self[self != self.IGNORE * self.DIVISOR].as_subclass(_Mask)

def to_coco(self) -> tuple[pil_image.Image, list[COCOResultPanopticSegment]]:
segm = torch.zeros_like(self, dtype=torch.int32)

segments_info = []

for i, (sem_id, ins_id, mask) in enumerate(self.get_masks()):
coco_id = i + 1
segm[mask] = coco_id
segments_info.append(
COCOResultPanopticSegment(id=coco_id, category_id=sem_id)
)

img = pil_image.fromarray(segm.numpy().astype("uint8"), mode="L")

return img, segments_info


# def transform_label_map(label_map: torch.Tensor, transform: Transform) -> PanopticMap:
# map_uint8 = np.zeros((*label_map.shape, 3), dtype=np.uint8)
Expand Down
3 changes: 1 addition & 2 deletions sources/unipercept/data/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import typing_extensions as TX

from . import sanity
from ._coco import *
from . import coco, sanity
from ._info import *
from ._manifest import *
48 changes: 0 additions & 48 deletions sources/unipercept/data/types/_coco.py

This file was deleted.

Loading

0 comments on commit 87f4b5a

Please sign in to comment.