Skip to content

Commit

Permalink
🔨 Finish reworking CUB
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jan 15, 2025
1 parent f6802aa commit c3110a7
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions torch_uncertainty/datasets/classification/cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
train: bool = True,
transform: Callable | None = None,
target_transform: Callable | None = None,
load_attributes: bool = False,
return_attributes: bool = False,
download: bool = False,
):
"""The Caltech-UCSD Birds-200-2011 dataset.
Expand All @@ -34,8 +34,8 @@ def __init__(
returns a transformed version. E.g, transforms.RandomCrop. Defaults to None.
target_transform (callable, optional): A function/transform that takes in the target
and transforms it. Defaults to None.
load_attributes (bool, optional): If True, loads the attributes of the dataset and
returns them instead of the images. Defaults to False.
return_attributes (bool, optional): If True, returns the attributes instead of the images.
Defaults to False.
download (bool, optional): If True, downloads the dataset from the internet and puts it
in root directory. If dataset is already downloaded, it is not downloaded again.
Defaults to False.
Expand All @@ -57,16 +57,26 @@ def __init__(
super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform)

training_idx = self._load_train_idx()
self.attributes, self.certainties = self._load_attributes()
if load_attributes:
self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False)
self.attribute_names = self._load_attribute_names()
self.loader = torch.nn.Identity()
self.attributes, self.uncertainties = self._load_attributes()
self.attribute_names = self._load_attribute_names()
self.classnames = self._load_classnames()

self.samples = [sample for i, sample in enumerate(self.samples) if training_idx[i] == train]
self._labels = [label for i, label in enumerate(self.targets) if training_idx[i] == train]
self.attributes = rearrange(
torch.masked_select(self.attributes, training_idx.unsqueeze(-1) == train),
"(n c) -> n c",
c=312,
)
self.uncertainties = rearrange(
torch.masked_select(self.uncertainties, training_idx.unsqueeze(-1) == train),
"(n c) -> n c",
c=312,
)

self.classnames = self._load_classnames()
if return_attributes:
self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False)
self.loader = torch.nn.Identity()

def _load_classnames(self) -> list[str]:
"""Load the classnames of the dataset.
Expand All @@ -92,16 +102,18 @@ def _load_attributes(self) -> tuple[Tensor, Tensor]:
"""Load the attributes associated to each image.
Returns:
tuple[Tensor, Tensor]: The presence of the 312 attributes along with their certainty.
tuple[Tensor, Tensor]: The presence of the 312 attributes along with their uncertainty.
The uncertainty is 0 for certain samples and 1 for non-visible attributes.
"""
attributes, certainty = [], []
attributes, uncertainty = [], []
with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open(
"r"
) as f:
attributes = [int(line.split(" ")[2]) for line in f]
certainty = [(int(line.split(" ")[3]) - 1) / 3 for line in f]
for line in f:
attributes.append(int(line.split(" ")[2]))
uncertainty.append(1 - (int(line.split(" ")[3]) - 1) / 3)
return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange(
torch.as_tensor(certainty), "(n c) -> n c", c=312
torch.as_tensor(uncertainty), "(n c) -> n c", c=312
)

def _load_attribute_names(self) -> list[str]:
Expand Down

0 comments on commit c3110a7

Please sign in to comment.