From 832a93b4bbd1ee71f90a5369f166f28ec488223e Mon Sep 17 00:00:00 2001 From: Kurt Stolle Date: Wed, 13 Dec 2023 10:37:08 +0100 Subject: [PATCH] Autoformatting --- sources/unipercept/engine/_engine.py | 5 +++-- sources/unipercept/engine/writer.py | 13 ++++++++----- sources/unipercept/nn/layers/merge.py | 9 ++++----- sources/unipercept/render/_plot.py | 3 +-- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sources/unipercept/engine/_engine.py b/sources/unipercept/engine/_engine.py index e971654..00d29ee 100644 --- a/sources/unipercept/engine/_engine.py +++ b/sources/unipercept/engine/_engine.py @@ -873,7 +873,9 @@ def _inference_loop( # Output memory if results_path is None: results_remove_on_exit = True - results_path = file_io.Path(f"//scratch/{self._params.project_name}/{self._params.session_name}/{prefix}-results.h5") + results_path = file_io.Path( + f"//scratch/{self._params.project_name}/{self._params.session_name}/{prefix}-results.h5" + ) else: results_remove_on_exit = False results_mem = PersistentTensordictWriter(str(results_path), samples_total) @@ -971,7 +973,6 @@ def _inference_loop( if not key.startswith(f"{prefix}/"): metrics[f"{prefix}/{key}"] = metrics.pop(key) - return metrics, samples_processed def _store_visualizations(self, visuals: dict[str, pil_image.Image], prefix: str = "") -> None: diff --git a/sources/unipercept/engine/writer.py b/sources/unipercept/engine/writer.py index 5b54d7c..6d66422 100644 --- a/sources/unipercept/engine/writer.py +++ b/sources/unipercept/engine/writer.py @@ -55,7 +55,6 @@ def tensordict(self) -> TensorDictBase: raise NotImplementedError("Abstract property `tensordict` not implemented.") - class PersistentTensordictWriter(ResultsWriter): """ Writes results to a H5 file using PersistentTensorDict from multiple processes, uses a buffer to reduce the number of writes. @@ -90,17 +89,22 @@ def __del__(self): self.close() @staticmethod - def error_when_closed(fn: T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]) -> T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]: + def error_when_closed( + fn: T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R] + ) -> T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]: @functools.wraps(fn) def wrapper(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: if self._is_closed: raise RuntimeError(f"{self.__class__.__name__} is closed") return fn(self, *args, **kwargs) + return wrapper - @functools.cached_property + @functools.cached_property def path(self) -> file_io.Path: - return file_io.Path(self._path) + p = file_io.Path(self._path) + p.parent.mkdir(parents=True, exist_ok=True) + return p def __len__(self) -> int: return self._size @@ -175,4 +179,3 @@ def tensordict(self) -> TensorDictBase: if self._td is None: self._td = self._td_factory(filename=self.path) return self._td - diff --git a/sources/unipercept/nn/layers/merge.py b/sources/unipercept/nn/layers/merge.py index 313c80e..f558b3f 100644 --- a/sources/unipercept/nn/layers/merge.py +++ b/sources/unipercept/nn/layers/merge.py @@ -34,23 +34,22 @@ class SemanticMerge(nn.Module): def __init__( self, - in_features: T.Iterable[str], - input_shape: T.Mapping[str, BackboneFeatureInfo], + in_features: T.Mapping[str, BackboneFeatureInfo], common_stride: int, out_channels: int, weight_method: WeightMethod | str = WeightMethod.FAST_ATTENTION, ): super().__init__() - self.in_features = list(in_features) + self.in_features = list(in_features.keys()) self.common_stride = int(common_stride) if isinstance(weight_method, str): weight_method = WeightMethod(weight_method) self.weight_method = weight_method - feature_strides = {k: T.cast(int, v.stride) for k, v in input_shape.items()} - feature_channels = {k: T.cast(int, v.channels) for k, v in input_shape.items()} + feature_strides = {k: T.cast(int, v.stride) for k, v in in_features.items()} + feature_channels = {k: T.cast(int, v.channels) for k, v in in_features.items()} self.scale_heads = nn.ModuleList() for in_feature in self.in_features: diff --git a/sources/unipercept/render/_plot.py b/sources/unipercept/render/_plot.py index 0ba05b8..ae1fd9b 100644 --- a/sources/unipercept/render/_plot.py +++ b/sources/unipercept/render/_plot.py @@ -24,8 +24,7 @@ from unipercept.data.tensors import PanopticMap from unipercept.model import InputData, ModelOutput -__all__ = ["plot_input_data", "plot_predictions", - "draw_image", "draw_image_segmentation", "draw_image_depth"] +__all__ = ["plot_input_data", "plot_predictions", "draw_image", "draw_image_segmentation", "draw_image_depth"] from unicore.utils.missing import MissingValue