diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 0924b40e..079d20ec 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -21,7 +21,7 @@ def __init__( train: bool = True, transform: Callable | None = None, target_transform: Callable | None = None, - load_attributes: bool = False, + return_attributes: bool = False, download: bool = False, ): """The Caltech-UCSD Birds-200-2011 dataset. @@ -34,8 +34,8 @@ def __init__( returns a transformed version. E.g, transforms.RandomCrop. Defaults to None. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to None. - load_attributes (bool, optional): If True, loads the attributes of the dataset and - returns them instead of the images. Defaults to False. + return_attributes (bool, optional): If True, returns the attributes instead of the images. + Defaults to False. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. @@ -57,16 +57,26 @@ def __init__( super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) training_idx = self._load_train_idx() - self.attributes, self.certainties = self._load_attributes() - if load_attributes: - self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False) - self.attribute_names = self._load_attribute_names() - self.loader = torch.nn.Identity() + self.attributes, self.uncertainties = self._load_attributes() + self.attribute_names = self._load_attribute_names() + self.classnames = self._load_classnames() self.samples = [sample for i, sample in enumerate(self.samples) if training_idx[i] == train] self._labels = [label for i, label in enumerate(self.targets) if training_idx[i] == train] + self.attributes = rearrange( + torch.masked_select(self.attributes, training_idx.unsqueeze(-1) == train), + "(n c) -> n c", + c=312, + ) + self.uncertainties = rearrange( + torch.masked_select(self.uncertainties, training_idx.unsqueeze(-1) == train), + "(n c) -> n c", + c=312, + ) - self.classnames = self._load_classnames() + if return_attributes: + self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False) + self.loader = torch.nn.Identity() def _load_classnames(self) -> list[str]: """Load the classnames of the dataset. @@ -92,16 +102,18 @@ def _load_attributes(self) -> tuple[Tensor, Tensor]: """Load the attributes associated to each image. Returns: - tuple[Tensor, Tensor]: The presence of the 312 attributes along with their certainty. + tuple[Tensor, Tensor]: The presence of the 312 attributes along with their uncertainty. + The uncertainty is 0 for certain samples and 1 for non-visible attributes. """ - attributes, certainty = [], [] + attributes, uncertainty = [], [] with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open( "r" ) as f: - attributes = [int(line.split(" ")[2]) for line in f] - certainty = [(int(line.split(" ")[3]) - 1) / 3 for line in f] + for line in f: + attributes.append(int(line.split(" ")[2])) + uncertainty.append(1 - (int(line.split(" ")[3]) - 1) / 3) return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange( - torch.as_tensor(certainty), "(n c) -> n c", c=312 + torch.as_tensor(uncertainty), "(n c) -> n c", c=312 ) def _load_attribute_names(self) -> list[str]: