Skip to content

Commit

Permalink
✨ Add small version of MUAD dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
alafage committed Dec 10, 2024
1 parent 2586e37 commit d13e0c6
Showing 1 changed file with 42 additions and 14 deletions.
56 changes: 42 additions & 14 deletions torch_uncertainty/datasets/muad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import Literal

from huggingface_hub import hf_hub_download

if util.find_spec("cv2"):
import cv2

Expand Down Expand Up @@ -38,17 +40,30 @@ class MUAD(VisionDataset):
"val": "957af9c1c36f0a85c33279e06b6cf8d8",
"val_depth": "0282030d281aeffee3335f713ba12373",
}

small_muad_url = "fira7s/Muad2"

_num_samples = {
"train": 3420,
"val": 492,
"test": ...,
"full": {
"train": 3420,
"val": 492,
"test": ...,
},
"small": {
"train": 400,
"val": 54,
"test": 112,
"ood": 20,
},
}

targets: list[Path] = []

def __init__(
self,
root: str | Path,
split: Literal["train", "val"],
split: Literal["train", "val", "test", "ood"],
version: Literal["small", "full"] = "full",
min_depth: float | None = None,
max_depth: float | None = None,
target_type: Literal["semantic", "depth"] = "semantic",
Expand All @@ -61,6 +76,8 @@ def __init__(
root (str): Root directory of dataset where directory 'leftImg8bit'
and 'leftLabel' or 'leftDepth' are located.
split (str, optional): The image split to use, 'train' or 'val'.
version (str, optional): The version of the dataset to use, 'small'
or 'full'. Defaults to 'full'.
min_depth (float, optional): The maximum depth value to use if
target_type is 'depth'. Defaults to None.
max_depth (float, optional): The maximum depth value to use if
Expand All @@ -86,20 +103,25 @@ def __init__(
"torch_uncertainty with the image option:"
"""pip install -U "torch_uncertainty[image]"."""
)

if version == "small" and target_type == "depth":
raise ValueError("Depth target is not available for the small version of MUAD.")

logging.info(
"MUAD is restricted to non-commercial use. By using MUAD, you "
"agree to the terms and conditions."
)
super().__init__(
root=Path(root) / "MUAD",
transforms=transforms,
)

dataset_root = Path(root) / "MUAD" if version == "full" else Path(root) / "MUAD_small"

super().__init__(dataset_root, transforms=transforms)
self.min_depth = min_depth
self.max_depth = max_depth

if split not in ["train", "val"]:
if split not in ["train", "val", "test", "ood"]:
raise ValueError(f"split must be one of ['train', 'val']. Got {split}.")
self.split = split
self.version = version
self.target_type = target_type

if not self.check_split_integrity("leftImg8bit"):
Expand Down Expand Up @@ -211,13 +233,12 @@ def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]:
def check_split_integrity(self, folder: str) -> bool:
split_path = self.root / self.split
return (
split_path.is_dir()
and len(list((split_path / folder).glob("**/*"))) == self._num_samples[self.split]
split_path.is_dir() and len(list((split_path / folder).glob("**/*"))) == self.__len__()
)

def __len__(self) -> int:
"""The number of samples in the dataset."""
return self._num_samples[self.split]
return self._num_samples[self.version][self.split]

def _make_dataset(self, path: Path) -> None:
"""Create a list of samples and targets.
Expand All @@ -241,8 +262,15 @@ def _make_dataset(self, path: Path) -> None:

def _download(self, split: str) -> None:
"""Download and extract the chosen split of the dataset."""
split_url = self.base_url + split + ".zip"
download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split])
if self.version == "small":
filename = f"{split}.zip"
downloaded_file = hf_hub_download(
repo_id=self.small_muad_url, filename=filename, repo_type="dataset"
)
shutil.unpack_archive(downloaded_file, extract_dir=self.root)
else:
split_url = self.base_url + split + ".zip"
download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split])

@property
def color_palette(self) -> np.ndarray:
Expand Down

0 comments on commit d13e0c6

Please sign in to comment.