Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Mar 5, 2024
1 parent 87f4b5a commit f091a7b
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 30 deletions.
3 changes: 1 addition & 2 deletions sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def _read_model_wandb(path: str) -> str:
from unipercept import file_io

run = _wandb_read_run(path)
from wandb.sdk.wandb_run import Run

import wandb
from wandb.sdk.wandb_run import Run

assert path.startswith(WANDB_RUN_PREFIX)

Expand Down
5 changes: 4 additions & 1 deletion sources/unipercept/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def _main(args):

try:
results = engine.run_evaluation(
model_factory, suites=args.suite if len(args.suite) > 0 else None
model_factory,
suites=args.suite
if args.suite is not None and len(args.suite) > 0
else None,
)
_logger.info(
"Evaluation results: \n%s", up.log.create_table(results, format="long")
Expand Down
8 changes: 4 additions & 4 deletions sources/unipercept/data/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def with_training_defaults(cls, dataset: PerceptionDataset, **kwargs) -> T.Self:
"""Create a loader factory with default settings for inference mode."""
return cls(
dataset=dataset,
sampler=SamplerFactory(sampler="inference"),
config=DataLoaderConfig(drop_last=False),
sampler=SamplerFactory(sampler="training"),
config=DataLoaderConfig(drop_last=True),
**kwargs,
)

Expand All @@ -132,8 +132,8 @@ def with_inference_defaults(cls, dataset: PerceptionDataset, **kwargs) -> T.Self
"""Create a loader factory with default settings for training mode."""
return cls(
dataset=dataset,
sampler=SamplerFactory(sampler="training"),
config=DataLoaderConfig(drop_last=True),
sampler=SamplerFactory(sampler="inference"),
config=DataLoaderConfig(drop_last=False),
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.optim
import torch.types
import torch.utils.data
import wandb
from omegaconf import DictConfig, OmegaConf
from PIL import Image as pil_image
from tabulate import tabulate
Expand All @@ -35,7 +36,6 @@
from torch.utils.data import Dataset
from typing_extensions import override

import wandb
from unipercept import file_io
from unipercept.data import DataLoaderFactory
from unipercept.engine._params import EngineParams, EvaluationSuite, TrainingStage
Expand Down
58 changes: 53 additions & 5 deletions sources/unipercept/engine/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
barrier,
check_main_process,
cpus_available,
gather,
gather_tensordict,
get_process_count,
main_process_first,
Expand All @@ -43,7 +44,12 @@
from unipercept.utils.tensorclass import Tensorclass
from unipercept.utils.typings import Pathable

__all__ = ["ResultsWriter", "PersistentTensordictWriter", "MemmapTensordictWriter"]
__all__ = [
"ResultsWriter",
"PersistentTensordictWriter",
"MemmapTensordictWriter",
"MemoryTensordictWriter",
]

_logger = get_logger(__name__)

Expand Down Expand Up @@ -120,10 +126,17 @@ def __getitem__(self, index: int | slice | tuple[int, int]) -> Tensor:
loc[1] += self._index[0]

tensors: list[Tensor | None] = [None] * (loc[1] - loc[0])
with concurrent.futures.ThreadPoolExecutor() as pool:
for i, t in pool.map(lambda i: (i, self._load_at(i)), range(*loc)):
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
futs = []
for i in range(*loc):
futs.append(pool.submit(self._load_at, i))
for fut in futs:
i, t = fut.result()
tensors[i] = t

# for i, t in pool.map(lambda i: (i, self._load_at(i)), range(*loc)):
# tensors[i] = t

if any(t is None for t in tensors):
raise RuntimeError("Some tensors were not loaded")
return torch.stack(T.cast(list[torch.Tensor], tensors))
Expand Down Expand Up @@ -265,11 +278,11 @@ def items(self) -> T.Iterator[T.Tuple[str, T.Any]]:

@classmethod
@TX.override
def from_module(cls, *args, **kwargs):
def _from_module(cls, *args, **kwargs):
raise NotImplementedError(f"{cls.__name__} does not support from_module")

@TX.override
def to_module(self, *args, **kwargs):
def _to_module(self, *args, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} does not support to_module"
)
Expand Down Expand Up @@ -754,6 +767,41 @@ def tensordict(self) -> TensorDictBase:
return self._td


class MemoryTensordictWriter(ResultsWriter):
"""
Writer that stores results in a large list of TensorDicts, which is not memory
efficient but very fast to write to and read from.
When flushed, the results are synchronized to the main process.
The resulting tensordict is a LazyStackedTensorDict of the list of TensorDicts.
"""

def __init__(self):
self._results: list[TensorDictBase] = []

@TX.override
def add(self, data: TensorDictBase) -> None:
self._results.append(data.cpu())

@TX.override
def flush(self) -> None:
if self._is_closed:
msg = f"{self.__class__.__name__} is closed"
raise RuntimeError(msg)
# TODO
raise NotImplementedError()

@TX.override
def close(self) -> None:
self._is_closed = True

@property
@TX.override
def tensordict(self) -> TensorDictBase:
if not self._is_closed:
raise RuntimeError("ResultsWriter is not closed")
return LazyStackedTensorDict(self._results)


def _find_memmap_indices(path: Pathable) -> T.Tuple[int, int]:
"""
Find the indices of the first and last memory-mapped file in a directory.
Expand Down
43 changes: 29 additions & 14 deletions sources/unipercept/evaluators/_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,41 @@ def compute_pq(
allow_stuff_instances=allow_stuff_instances,
num_categories=num_categories,
)

_logger.debug("Creating MP context")
mp_context = M.get_context("spawn" if device.type != "cpu" else None)
progress_bar = tqdm(
desc="Computing panoptic metrics",
dynamic_ncols=True,
total=sample_amt,
disable=not check_main_process(local=True) or not self.show_progress,
)
# mp_context = M.get_context("spawn" if device.type != "cpu" else None)
# with concurrent.futures.ProcessPoolExecutor(
# min(cpus_available(), M.cpu_count() // 2, 32), mp_context=mp_context
# ) as pool:
with concurrent.futures.ThreadPoolExecutor() as pool:
for result in pool.map(compute_at, range(sample_amt)):
progress_bar.update(1)
if result is None:
continue
iou += result[0]
tp += result[1]
fp += result[2]
fn += result[3]
progress_bar.close()
try:
# progress_bar.set_postfix_str("init")
# with concurrent.futures.ProcessPoolExecutor(
# min(cpus_available(), 16), mp_context=mp_context
# ) as pool:
with concurrent.futures.ThreadPoolExecutor() as pool:
# if True:
# indices = list(range(sample_amt))
# for result in pool.map(compute_at, indices):
progress_bar.set_postfix_str("dispatch")
futs = []
for n in range(sample_amt):
futs.append(pool.submit(compute_at, n))
progress_bar.set_postfix_str("update")
for fut in futs:
result = fut.result()
# for result in map(compute_at, indices):
progress_bar.update(1)
if result is None:
continue
iou += result[0]
tp += result[1]
fp += result[2]
fn += result[3]
finally:
progress_bar.close()
# Compute PQ = SQ * RQ
sq = H.stable_divide(iou, tp)
rq = H.stable_divide(tp, tp + 0.5 * fp + 0.5 * fn)
Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/integrations/wandb_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import torch.nn as nn
import typing_extensions as TX
import wandb
import wandb.errors

import wandb
from unipercept import file_io
from unipercept.config import get_env
from unipercept.engine import EngineParams
Expand Down
1 change: 1 addition & 0 deletions sources/unipercept/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import accelerate.utils
import torch
import torch.distributed
import torch.types
import torch.utils.data
from tensordict import TensorDict, TensorDictBase
Expand Down
2 changes: 1 addition & 1 deletion tests/unipercept/engine/test_engine_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_memmap_writer(benchmark, tmp_path: Path):
@benchmark
def write_then_read():
path = tmp_path / "memmap_writer"
writer = MemmapTensordictWriter(path, NUM_SAMPLES)
writer = MemmapTensordictWriter(path, NUM_SAMPLES, write_offset=0)
_run_write_read_benchmark(writer)
shutil.rmtree(path)

Expand Down
2 changes: 1 addition & 1 deletion tests/unipercept/utils/test_utils_iopath_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pathlib import Path

import pytest

import wandb

from unipercept import file_io
from unipercept.utils.iopath_handlers import WebDAVPathHandler

Expand Down

0 comments on commit f091a7b

Please sign in to comment.