Skip to content

Commit

Permalink
Amend: fix video pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Feb 26, 2024
1 parent 7cfe5ad commit e208aca
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 226 deletions.
215 changes: 0 additions & 215 deletions notebooks/evaluation.ipynb

This file was deleted.

5 changes: 5 additions & 0 deletions sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
"prepare_images",
]


def __dir__() -> list[str]:
return __all__


_logger = get_logger(__name__)


Expand Down
24 changes: 14 additions & 10 deletions sources/unipercept/_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,67 +12,71 @@
__all__ = ["get_dataset", "get_info", "get_info_at", "list_datasets", "list_info"]


def __dir__() -> list[str]:
return __all__


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["cityscapes"],
query: T.Literal["cityscapes"],
) -> type[unisets.cityscapes.CityscapesDataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["cityscapes-vps"],
query: T.Literal["cityscapes-vps"],
) -> type[unisets.cityscapes.CityscapesVPSDataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["kitti-360"],
query: T.Literal["kitti-360"],
) -> type[unisets.kitti_360.KITTI360Dataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["kitti-step"],
query: T.Literal["kitti-step"],
) -> type[unisets.kitti_step.KITTISTEPDataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["kitti-sem"],
query: T.Literal["kitti-sem"],
) -> type[unisets.kitti_sem.SemKITTIDataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["vistas"],
query: T.Literal["vistas"],
) -> type[unisets.vistas.VistasDataset]:
...


@T.overload
def get_dataset( # noqa: D103
name: T.Literal["wilddash"],
query: T.Literal["wilddash"],
) -> type[unisets.wilddash.WildDashDataset]:
...


@T.overload
def get_dataset(name: str) -> type[unisets.PerceptionDataset]: # noqa: D103
def get_dataset(query: str) -> type[unisets.PerceptionDataset]: # noqa: D103
...


def get_dataset(name: str) -> type[unisets.PerceptionDataset]:
def get_dataset(query: str) -> type[unisets.PerceptionDataset]:
"""
Read a dataset from the catalog, returning the dataset **class** type.
"""
from unipercept.data.sets import catalog

return catalog.get_dataset(name)
return catalog.get_dataset(query)


def get_info(query: str) -> unisets.Metadata:
Expand Down
Empty file removed sources/unipercept/api/__init__.py
Empty file.
1 change: 0 additions & 1 deletion sources/unipercept/data/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import math
import multiprocessing as M
import operator
import os
import typing as T
import warnings

Expand Down
1 change: 1 addition & 0 deletions sources/unipercept/render/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from ._colormap import *
from ._plot import *
from ._visualizer import *
from ._video import *
63 changes: 63 additions & 0 deletions sources/unipercept/render/_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from contextlib import contextmanager
import dataclasses as D
import functools
import os
import sys

import PIL.Image as pil_image

import typing as T
import tempfile

from unipercept.utils.typings import Pathable

__all__ = ["video_writer"]


@contextmanager
def video_writer(out: Pathable, *, fps: int, overwrite: bool = False):
"""
Used for writing a sequence of PIL images to a (temporary) directory, and then
encoding them into a video file using ``ffmpeg`` commands.
"""

from unipercept.file_io import Path

def _parse_output_path(out: Pathable) -> str:
out = Path(out)
if out.is_file():
if not overwrite:
msg = f"File {out!r} already exists, and overwrite is set to False."
raise FileExistsError(msg)
out.unlink()
else:
out.parent.mkdir(parents=True, exist_ok=True)
return str(out)

def _get_ffmpeg_path() -> str:
return "ffmpeg.exe" if sys.platform == "win32" else "ffmpeg"

def _get_ffmpeg_cmd(fps: int, dir: str, out: str) -> tuple[str, ...]:
frame_glob = os.path.join(dir, "*.png")
return (
_get_ffmpeg_path(),
f"-framerate {fps}",
"-pattern_type glob",
f"-i {frame_glob!r}",
"-c:v libx264",
"-pix_fmt yuv420p",
f"{out!r}",
)

def _save_image(im: pil_image.Image, *, dir: str):
next_frame = len(os.listdir(dir))
im.save(os.path.join(dir, f"{next_frame:010d}.png"))

with tempfile.TemporaryDirectory() as dir:
try:
yield functools.partial(_save_image, dir=dir)
finally:
cmd = " ".join(_get_ffmpeg_cmd(fps, dir, out=_parse_output_path(out)))

print(cmd, file=sys.stderr)
os.system(cmd)

0 comments on commit e208aca

Please sign in to comment.