Skip to content

Commit

Permalink
Autoformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Dec 13, 2023
1 parent 816fe99 commit 832a93b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
5 changes: 3 additions & 2 deletions sources/unipercept/engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,9 @@ def _inference_loop(
# Output memory
if results_path is None:
results_remove_on_exit = True
results_path = file_io.Path(f"//scratch/{self._params.project_name}/{self._params.session_name}/{prefix}-results.h5")
results_path = file_io.Path(
f"//scratch/{self._params.project_name}/{self._params.session_name}/{prefix}-results.h5"
)
else:
results_remove_on_exit = False
results_mem = PersistentTensordictWriter(str(results_path), samples_total)
Expand Down Expand Up @@ -971,7 +973,6 @@ def _inference_loop(
if not key.startswith(f"{prefix}/"):
metrics[f"{prefix}/{key}"] = metrics.pop(key)


return metrics, samples_processed

def _store_visualizations(self, visuals: dict[str, pil_image.Image], prefix: str = "") -> None:
Expand Down
13 changes: 8 additions & 5 deletions sources/unipercept/engine/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def tensordict(self) -> TensorDictBase:
raise NotImplementedError("Abstract property `tensordict` not implemented.")



class PersistentTensordictWriter(ResultsWriter):
"""
Writes results to a H5 file using PersistentTensorDict from multiple processes, uses a buffer to reduce the number of writes.
Expand Down Expand Up @@ -90,17 +89,22 @@ def __del__(self):
self.close()

@staticmethod
def error_when_closed(fn: T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]) -> T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]:
def error_when_closed(
fn: T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]
) -> T.Callable[T.Concatenate[PersistentTensordictWriter, _P], _R]:
@functools.wraps(fn)
def wrapper(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if self._is_closed:
raise RuntimeError(f"{self.__class__.__name__} is closed")
return fn(self, *args, **kwargs)

return wrapper

@functools.cached_property
@functools.cached_property
def path(self) -> file_io.Path:
return file_io.Path(self._path)
p = file_io.Path(self._path)
p.parent.mkdir(parents=True, exist_ok=True)
return p

def __len__(self) -> int:
return self._size
Expand Down Expand Up @@ -175,4 +179,3 @@ def tensordict(self) -> TensorDictBase:
if self._td is None:
self._td = self._td_factory(filename=self.path)
return self._td

9 changes: 4 additions & 5 deletions sources/unipercept/nn/layers/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,22 @@ class SemanticMerge(nn.Module):

def __init__(
self,
in_features: T.Iterable[str],
input_shape: T.Mapping[str, BackboneFeatureInfo],
in_features: T.Mapping[str, BackboneFeatureInfo],
common_stride: int,
out_channels: int,
weight_method: WeightMethod | str = WeightMethod.FAST_ATTENTION,
):
super().__init__()

self.in_features = list(in_features)
self.in_features = list(in_features.keys())
self.common_stride = int(common_stride)

if isinstance(weight_method, str):
weight_method = WeightMethod(weight_method)
self.weight_method = weight_method

feature_strides = {k: T.cast(int, v.stride) for k, v in input_shape.items()}
feature_channels = {k: T.cast(int, v.channels) for k, v in input_shape.items()}
feature_strides = {k: T.cast(int, v.stride) for k, v in in_features.items()}
feature_channels = {k: T.cast(int, v.channels) for k, v in in_features.items()}

self.scale_heads = nn.ModuleList()
for in_feature in self.in_features:
Expand Down
3 changes: 1 addition & 2 deletions sources/unipercept/render/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from unipercept.data.tensors import PanopticMap
from unipercept.model import InputData, ModelOutput

__all__ = ["plot_input_data", "plot_predictions",
"draw_image", "draw_image_segmentation", "draw_image_depth"]
__all__ = ["plot_input_data", "plot_predictions", "draw_image", "draw_image_segmentation", "draw_image_depth"]

from unicore.utils.missing import MissingValue

Expand Down

0 comments on commit 832a93b

Please sign in to comment.