diff --git a/src/pl_datamodules/hf_datamodule.py b/src/pl_datamodules/hf_datamodule.py deleted file mode 100644 index 9f536c7..0000000 --- a/src/pl_datamodules/hf_datamodule.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch - -from typing import Optional -from .base_datamodule import GroupDataModule -from torchvision.transforms import transforms -from datasets import load_dataset -from ..datasets.utils import ( - ReweightedDataset, - UndersampledByGroupDataset, - split_dataset, -) -from ..datasets.mnli_dataset import MNLIDataset -from transformers.tokenization_utils import PreTrainedTokenizer - - -class MNLIDataModule(GroupDataModule): - dataset_name = "multi_nli" - num_classes = 3 # entailment (0), neutral (1), contradiction (2) - dims = None # TODO - - def __init__( - self, - tokenizer: PreTrainedTokenizer, - train_frac: float, - new_group_sizes: Optional = None, - new_group_fracs: Optional = None, - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.train_frac = train_frac - self.new_group_sizes = new_group_sizes - self.new_group_fracs = new_group_fracs - - def prepare_data(self): - """Download data if needed. This method is called only from a single GPU. - Do not use it to assign state (self.x = y).""" - - _ = load_dataset(self.dataset_name, cache_dir=self.data_dir) - - def setup(self, stage=None): - """Load data. Set variables: self.train_dataset, self.data_val, self.test_dataset.""" - full_dataset = load_dataset(self.dataset_name, cache_dir=self.data_dir) - - train_transform = init_transform(self.tokenizer) - eval_transform = init_transform(self.tokenizer) - - train_dataset = MNLIDataset( - full_dataset["train"].filter(lambda example: example["label"] != -1), - input_transform=train_transform, - frac=self.train_frac, - ) - if self.new_group_sizes is not None or self.new_group_fracs is not None: - new_train_dataset = UndersampledByGroupDataset( - train_dataset, - train_dataset.group_array, - self.new_group_sizes, - self.new_group_fracs, - ) - new_train_dataset.y_array = train_dataset.y_array[new_train_dataset.indices] - new_train_dataset.group_array = train_dataset.group_array[ - new_train_dataset.indices - ] - train_dataset = new_train_dataset - - num_train_full = len(train_dataset) - num_train = int(0.8 * num_train_full) # 80/20 split - # WARNING: val dataset will have the same transforms as the train dataset! - train_dataset, val_dataset = split_dataset( - train_dataset, [num_train], shuffle=True, seed=0 - ) # always use seed 0 for split - - test_dataset = MNLIDataset( - full_dataset["validation_matched"].filter( - lambda example: example["label"] != -1 - ), - input_transform=eval_transform, - ) - - self.train_y_counter, self.train_g_counter, _ = self.compute_weights( - train_dataset - ) - print(f"Train class counts: {self.train_y_counter}") - print(f"Train group counts: {self.train_g_counter}") - - self.val_y_counter, self.val_g_counter, val_weights = self.compute_weights( - val_dataset - ) - print(f"Val class counts: {self.val_y_counter}") - print(f"Val group counts: {self.val_g_counter}") - val_dataset = ReweightedDataset(val_dataset, weights=val_weights) - - self.test_y_counter, self.test_g_counter, test_weights = self.compute_weights( - test_dataset - ) - print(f"Test class counts: {self.test_y_counter}") - print(f"Test group counts: {self.test_g_counter}") - test_dataset = ReweightedDataset(test_dataset, weights=test_weights) - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.test_dataset = test_dataset - - -def init_transform(tokenizer): - def transform_inference(premise__hypothesis): - premise, hypothesis = premise__hypothesis - encodings = tokenizer( - premise, - hypothesis, - truncation=True, - padding="max_length", - return_tensors="pt", - ) - if tokenizer.name_or_path == "bert-base-uncased": - x = torch.stack( - ( - encodings["input_ids"], - encodings["attention_mask"], - encodings["token_type_ids"], - ), - dim=2, - ) - elif tokenizer.name_or_path == "distilbert-base-uncased": - x = torch.stack( - (encodings["input_ids"], encodings["attention_mask"]), dim=2 - ) - else: - raise RuntimeError - x = torch.squeeze( - x, dim=0 - ) # First shape dim is always 1 since we're not in batch mode - return x - - transform = transforms.Lambda(lambda x: transform_inference(x)) - return transform diff --git a/src/pl_datamodules/imbalanced_datamodule.py b/src/pl_datamodules/imbalanced_datamodule.py deleted file mode 100644 index c4d1408..0000000 --- a/src/pl_datamodules/imbalanced_datamodule.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch -from typing import List, Sequence, Optional -from math import prod -from pathlib import Path - -from ..datasets.imbalanced_cifar10_dataset import ImbalancedCIFAR10 -from ..datasets.imbalanced_moons_dataset import ImbalancedMoonsDataset -from torch.utils.data import Dataset -from torchvision.transforms import transforms - - -class ImbalancedCIFAR10DataModule(BaseDataModule): - dims = (3, 32, 32) - - def __init__( - self, - desired_classes: List[int] = None, - num_undersample_per_train_class: List[int] = None, - num_oversample_per_train_class: List[int] = None, - num_undersample_per_test_class: List[int] = None, - num_oversample_per_test_class: List[int] = None, - **kwargs, - ): - """ - desired_classes: None indicates keep all classes - num_*sample_per_*_class: None indicates don't *sample - """ - super().__init__(**kwargs) - - desired_classes = ( - list(set(desired_classes)) - if desired_classes is not None - else list(range(10)) - ) - - if ( - num_undersample_per_train_class is None - and num_oversample_per_train_class is None - ): - num_undersample_per_train_class = [5000] * len(desired_classes) - num_oversample_per_train_class = [5000] * len(desired_classes) - if ( - num_undersample_per_train_class is not None - and num_oversample_per_train_class is None - ): - num_oversample_per_train_class = list(num_undersample_per_train_class) - - if ( - num_undersample_per_test_class is None - and num_oversample_per_test_class is None - ): - num_undersample_per_test_class = [1000] * len(desired_classes) - num_oversample_per_test_class = [1000] * len(desired_classes) - - if ( - num_undersample_per_test_class is not None - and num_oversample_per_test_class is None - ): - num_oversample_per_test_class = list(num_undersample_per_test_class) - - self.num_classes = len(desired_classes) - self.desired_classes = desired_classes - self.num_undersample_per_train_class = num_undersample_per_train_class - self.num_oversample_per_train_class = num_oversample_per_train_class - self.num_undersample_per_test_class = num_undersample_per_test_class - self.num_oversample_per_test_class = num_oversample_per_test_class - - transforms_list = [ - transforms.ToTensor(), - transforms.Normalize( - (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) - ), # see https://github.com/kuangliu/pytorch-cifar/issues/19 for example - ] - if self.flatten_input: - transforms_list.append(transforms.Lambda(lambda x: torch.flatten(x))) - self.dims = (prod(self.dims),) - - self.transforms = transforms.Compose(transforms_list) - - def prepare_data(self): - """Download data if needed. This method is called only from a single GPU. - Do not use it to assign state (self.x = y).""" - self.prepare_train_dataset() - self.prepare_val_dataset() - - def prepare_train_dataset(self): - dataset = ImbalancedCIFAR10( - self.data_dir, - desired_classes=self.desired_classes, - num_undersample_per_class=self.num_undersample_per_train_class, - num_oversample_per_class=self.num_oversample_per_train_class, - train=True, - download=True, - transform=self.transforms, - ) - return dataset - - def prepare_val_dataset(self): - # Note: Using test set as validation set - dataset = ImbalancedCIFAR10( - self.data_dir, - desired_classes=self.desired_classes, - num_undersample_per_class=self.num_undersample_per_test_class, - num_oversample_per_class=self.num_oversample_per_test_class, - train=False, - download=True, - transform=self.transforms, - ) - return dataset - - def setup(self, stage=None): - """Load data. Set variables: self.train_dataset, self.data_val, self.test_dataset.""" - # get only datapoints that belong to classes - self.train_dataset = self.prepare_train_dataset() - self.val_dataset = self.prepare_val_dataset() - # TODO: don't validate on test set? - - -class ImbalancedMoonsDataModule(BaseDataModule): - dims = (2,) - - def __init__( - self, - num_undersample_per_train_class: List[int] = None, - num_oversample_per_train_class: List[int] = None, - num_undersample_per_test_class: List[int] = None, - num_oversample_per_test_class: List[int] = None, - **kwargs, - ): - """ - desired_classes: None indicates keep all classes - num_*sample_per_*_class: None indicates don't *sample - """ - super().__init__(**kwargs) - - if ( - num_undersample_per_train_class is None - and num_oversample_per_train_class is None - ): - num_undersample_per_train_class = [512, 512] - num_oversample_per_train_class = [512, 512] - if ( - num_undersample_per_train_class is not None - and num_oversample_per_train_class is None - ): - num_oversample_per_train_class = list(num_undersample_per_train_class) - - if ( - num_undersample_per_test_class is None - and num_oversample_per_test_class is None - ): - num_undersample_per_test_class = [512, 512] - num_oversample_per_test_class = [512, 512] - - if ( - num_undersample_per_test_class is not None - and num_oversample_per_test_class is None - ): - num_oversample_per_test_class = list(num_undersample_per_test_class) - - self.num_classes = 2 - self.num_undersample_per_train_class = num_undersample_per_train_class - self.num_oversample_per_train_class = num_oversample_per_train_class - self.num_undersample_per_test_class = num_undersample_per_test_class - self.num_oversample_per_test_class = num_oversample_per_test_class - - def prepare_data(self): - """Download data if needed. This method is called only from a single GPU. - Do not use it to assign state (self.x = y).""" - self.prepare_train_dataset() - self.prepare_val_dataset() - - def prepare_train_dataset(self): - dataset = ImbalancedMoonsDataset( - num_undersample_per_class=self.num_undersample_per_train_class, - num_oversample_per_class=self.num_oversample_per_train_class, - num_samples=(512, 512), - ) - return dataset - - def prepare_val_dataset(self): - dataset = ImbalancedMoonsDataset( - num_undersample_per_class=self.num_undersample_per_train_class, - num_oversample_per_class=self.num_oversample_per_train_class, - num_samples=(512, 512), - ) - return dataset - - def setup(self, stage=None): - """Load data. Set variables: self.train_dataset, self.data_val, self.test_dataset.""" - self.train_dataset = self.prepare_train_dataset() - self.val_dataset = self.prepare_val_dataset() diff --git a/src/pl_datamodules/inaturalist_datamodule.py b/src/pl_datamodules/inaturalist_datamodule.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/pl_datamodules/utils.py b/src/pl_datamodules/utils.py index 637e698..6e906ed 100644 --- a/src/pl_datamodules/utils.py +++ b/src/pl_datamodules/utils.py @@ -5,8 +5,6 @@ ImbalancedCIFAR10DataModule, ImbalancedCIFAR100DataModule, ) -from .wilds_datamodule import WILDSDataModule -from .hf_datamodule import MNLIDataModule from .base_datamodule import GroupDataModule from torch.utils.data import DataLoader, Subset diff --git a/src/pl_datamodules/wilds_datamodule.py b/src/pl_datamodules/wilds_datamodule.py deleted file mode 100644 index 139a374..0000000 --- a/src/pl_datamodules/wilds_datamodule.py +++ /dev/null @@ -1,230 +0,0 @@ -import torch -import wilds -from omegaconf import DictConfig -from wilds.common.grouper import CombinatorialGrouper - -from typing import List, Optional -from math import prod -from .base_datamodule import ( - GroupDataModule, - IMAGENET_DEFAULT_MEAN, - IMAGENET_DEFAULT_STD, -) -from torchvision.transforms import transforms -from ..datasets.wilds_dataset import WILDSDataset -from ..datasets.utils import ReweightedDataset -from transformers.tokenization_utils import PreTrainedTokenizer - - -# Default dataset settings: https://github.com/p-lambda/wilds/blob/e95bba8408aff524b48b96a4e7648df72773ad60/examples/configs/datasets.py -class WILDSDataModule(GroupDataModule): - def __init__( - self, - dataset_name, - train_transform: DictConfig, - eval_transform: DictConfig, - num_classes: int, - groupby_fields: List[str], - tokenizer: Optional[PreTrainedTokenizer] = None, - resolution: Optional[List[int]] = None, - split_scheme="official", - download=True, - **kwargs, - ): - super().__init__(**kwargs) - self.dataset_name = dataset_name - self.download = download - self.split_scheme = split_scheme - self.train_transform = train_transform - self.eval_transform = eval_transform - if resolution is not None: - self.dims = (3,) + tuple(resolution) - else: - # TODO - self.dims = None - self.num_classes = num_classes - self.groupby_fields = groupby_fields - self.tokenizer = tokenizer - - if self.flatten_input: - self.dims = (prod(self.dims),) - - def prepare_data(self): - """Download data if needed. This method is called only from a single GPU. - Do not use it to assign state (self.x = y).""" - - # Initializing wilds_dataset.WILDSDataset will create data dir and download - _: wilds.datasets.wilds_dataset.WILDSDataset = wilds.get_dataset( - dataset=self.dataset_name, - root_dir=self.data_dir, - download=self.download, - split_scheme=self.split_scheme, - ) - - def setup(self, stage=None): - """Load data. Set variables: self.train_dataset, self.data_val, self.test_dataset.""" - full_dataset = wilds.get_dataset( - dataset=self.dataset_name, - root_dir=self.data_dir, - download=self.download, - split_scheme=self.split_scheme, - ) - train_transforms_list = initialize_transform( - dataset=full_dataset, tokenizer=self.tokenizer, **self.train_transform - ) - eval_transforms_list = initialize_transform( - dataset=full_dataset, tokenizer=self.tokenizer, **self.eval_transform - ) - - if self.flatten_input: - raise NotImplementedError - # train_transforms_list.append(transforms.Lambda(lambda x: torch.flatten(x))) - # eval_transforms_list.append(transforms.Lambda(lambda x: torch.flatten(x))) - self.train_transform = transforms.Compose(train_transforms_list) - self.eval_transform = transforms.Compose(eval_transforms_list) - - grouper = CombinatorialGrouper(full_dataset, groupby_fields=self.groupby_fields) - train_dataset = WILDSDataset( - full_dataset.get_subset("train", transform=self.train_transform), grouper - ) - # Note that some datasets from the WILDS dataset actually use other groupers for - # eval, such as https://github.com/p-lambda/wilds/blob/e95bba8408aff524b48b96a4e7648df72773ad60/wilds/datasets/fmow_dataset.py#L203 - val_dataset = WILDSDataset( - full_dataset.get_subset("val", transform=self.eval_transform), grouper - ) - test_dataset = WILDSDataset( - full_dataset.get_subset("test", transform=self.eval_transform), grouper - ) - - self.train_y_counter, self.train_g_counter, _ = self.compute_weights( - train_dataset - ) - print(f"Train class counts: {self.train_y_counter}") - print(f"Train group counts: {self.train_g_counter}") - self.val_y_counter, self.val_g_counter, val_weights = self.compute_weights( - val_dataset - ) - print(f"Val class counts: {self.val_y_counter}") - print(f"Val group counts: {self.val_g_counter}") - self.test_y_counter, self.test_g_counter, test_weights = self.compute_weights( - test_dataset - ) - print(f"Test class counts: {self.test_y_counter}") - print(f"Test group counts: {self.test_g_counter}") - - val_dataset = ReweightedDataset(val_dataset, weights=val_weights) - test_dataset = ReweightedDataset(test_dataset, weights=test_weights) - - self.train_dataset = train_dataset - self.val_dataset = val_dataset - self.test_dataset = test_dataset - - -def initialize_transform( - transform_name: Optional[str], - config: DictConfig, - dataset: wilds.datasets.wilds_dataset.WILDSDataset, - tokenizer, -): - if transform_name is None: - return [] - elif transform_name == "bert": - return initialize_text_transform(config, tokenizer) - elif transform_name == "image_base": - return initialize_image_base_transform(config, dataset) - elif transform_name == "image_resize_and_center_crop": - return initialize_image_resize_and_center_crop_transform(config, dataset) - elif transform_name == "poverty_train": - return initialize_poverty_train_transform() - else: - raise ValueError(f"{transform_name} not recognized") - - -def initialize_text_transform(config: DictConfig, tokenizer): - assert config.max_token_length is not None - - def transform_text(text): - tokens = tokenizer( - text, - padding="max_length", - truncation=True, - max_length=config.max_token_length, - return_tensors="pt", - ) - if tokenizer.name_or_path == "bert-base-uncased": - x = torch.stack( - ( - tokens["input_ids"], - tokens["attention_mask"], - tokens["token_type_ids"], - ), - dim=2, - ) - elif tokenizer.name_or_path == "distilbert-base-uncased": - x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) - x = torch.squeeze(x, dim=0) # First shape dim is always 1 - return x - - transform = transforms.Lambda(lambda x: transform_text(x)) - return [transform] - - -def initialize_image_base_transform( - config: DictConfig, dataset: wilds.datasets.wilds_dataset.WILDSDataset -): - transform_steps = [] - if dataset.original_resolution is not None and min( - dataset.original_resolution - ) != max(dataset.original_resolution): - crop_size = min(dataset.original_resolution) - transform_steps.append(transforms.CenterCrop(crop_size)) - if config.target_resolution is not None and config.dataset != "fmow": - transform_steps.append(transforms.Resize(config.target_resolution)) - transform_steps += [ - transforms.ToTensor(), - transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), - ] - return transform_steps - - -def initialize_image_resize_and_center_crop_transform( - config: DictConfig, dataset: wilds.datasets.wilds_dataset.WILDSDataset -): - """ - Resizes the image to a slightly larger square then crops the center. - """ - assert dataset.original_resolution is not None - assert config.resize_scale is not None - scaled_resolution = tuple( - int(res * config.resize_scale) for res in dataset.original_resolution - ) - if config.target_resolution is not None: - target_resolution = config.target_resolution - else: - target_resolution = dataset.original_resolution - transforms_list = [ - transforms.Resize(scaled_resolution), - transforms.CenterCrop(target_resolution), - transforms.ToTensor(), - transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), - ] - return transforms_list - - -def initialize_poverty_train_transform(): - transforms_ls = [ - transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1), - transforms.ToTensor(), - ] - rgb_transform = transforms.Compose(transforms_ls) - - def transform_rgb(img): - # bgr to rgb and back to bgr - img[:3] = rgb_transform(img[:3][[2, 1, 0]])[[2, 1, 0]] - return img - - transform = transforms.Lambda(lambda x: transform_rgb(x)) - return [transform]