diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 974ca1b..f0f4976 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -361,6 +361,5 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st malicious_traditional_model_update_attack, ] - default_config_list: List[ConfigType] = [traditional_fl] # default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] diff --git a/src/inversefed/data/__init__.py b/src/inversefed/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inversefed/data/data.py b/src/inversefed/data/data.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inversefed/data/data_processing.py b/src/inversefed/data/data_processing.py new file mode 100644 index 0000000..f6e36e3 --- /dev/null +++ b/src/inversefed/data/data_processing.py @@ -0,0 +1,209 @@ +"""Repeatable code parts concerning data loading.""" + + +import torch +import torchvision +import torchvision.transforms as transforms + +import os + +from ..consts import * + +from .data import _build_bsds_sr, _build_bsds_dn +from .loss import Classification, PSNR + + +def construct_dataloaders(dataset, defs, data_path='~/data', shuffle=True, normalize=True): + """Return a dataloader with given dataset and augmentation, normalize data?.""" + path = os.path.expanduser(data_path) + + if dataset == 'CIFAR10': + trainset, validset = _build_cifar10(path, defs.augmentations, normalize) + loss_fn = Classification() + elif dataset == 'CIFAR100': + trainset, validset = _build_cifar100(path, defs.augmentations, normalize) + loss_fn = Classification() + elif dataset == 'MNIST': + trainset, validset = _build_mnist(path, defs.augmentations, normalize) + loss_fn = Classification() + elif dataset == 'MNIST_GRAY': + trainset, validset = _build_mnist_gray(path, defs.augmentations, normalize) + loss_fn = Classification() + elif dataset == 'ImageNet': + trainset, validset = _build_imagenet(path, defs.augmentations, normalize) + loss_fn = Classification() + elif dataset == 'BSDS-SR': + trainset, validset = _build_bsds_sr(path, defs.augmentations, normalize, upscale_factor=3, RGB=True) + loss_fn = PSNR() + elif dataset == 'BSDS-DN': + trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=False) + loss_fn = PSNR() + elif dataset == 'BSDS-RGB': + trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=True) + loss_fn = PSNR() + + if MULTITHREAD_DATAPROCESSING: + num_workers = min(torch.get_num_threads(), MULTITHREAD_DATAPROCESSING) if torch.get_num_threads() > 1 else 0 + else: + num_workers = 0 + + trainloader = torch.utils.data.DataLoader(trainset, batch_size=min(defs.batch_size, len(trainset)), + shuffle=shuffle, drop_last=True, num_workers=num_workers, pin_memory=PIN_MEMORY) + validloader = torch.utils.data.DataLoader(validset, batch_size=min(defs.batch_size, len(trainset)), + shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=PIN_MEMORY) + + return loss_fn, trainloader, validloader + + +def _build_cifar10(data_path, augmentations=True, normalize=True): + """Define CIFAR-10 with everything considered.""" + # Load data + trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.ToTensor()) + validset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.ToTensor()) + + if cifar10_mean is None: + data_mean, data_std = _get_meanstd(trainset) + else: + data_mean, data_std = cifar10_mean, cifar10_std + + # Organize preprocessing + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) + if augmentations: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transform]) + trainset.transform = transform_train + else: + trainset.transform = transform + validset.transform = transform + + return trainset, validset + +def _build_cifar100(data_path, augmentations=True, normalize=True): + """Define CIFAR-100 with everything considered.""" + # Load data + trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transforms.ToTensor()) + validset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transforms.ToTensor()) + + if cifar100_mean is None: + data_mean, data_std = _get_meanstd(trainset) + else: + data_mean, data_std = cifar100_mean, cifar100_std + + # Organize preprocessing + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) + if augmentations: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transform]) + trainset.transform = transform_train + else: + trainset.transform = transform + validset.transform = transform + + return trainset, validset + + +def _build_mnist(data_path, augmentations=True, normalize=True): + """Define MNIST with everything considered.""" + # Load data + trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) + validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor()) + + if mnist_mean is None: + cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0) + data_mean = (torch.mean(cc, dim=0).item(),) + data_std = (torch.std(cc, dim=0).item(),) + else: + data_mean, data_std = mnist_mean, mnist_std + + # Organize preprocessing + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) + if augmentations: + transform_train = transforms.Compose([ + transforms.RandomCrop(28, padding=4), + transforms.RandomHorizontalFlip(), + transform]) + trainset.transform = transform_train + else: + trainset.transform = transform + validset.transform = transform + + return trainset, validset + +def _build_mnist_gray(data_path, augmentations=True, normalize=True): + """Define MNIST with everything considered.""" + # Load data + trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) + validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor()) + + if mnist_mean is None: + cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0) + data_mean = (torch.mean(cc, dim=0).item(),) + data_std = (torch.std(cc, dim=0).item(),) + else: + data_mean, data_std = mnist_mean, mnist_std + + # Organize preprocessing + transform = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) + if augmentations: + transform_train = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), + transforms.RandomCrop(28, padding=4), + transforms.RandomHorizontalFlip(), + transform]) + trainset.transform = transform_train + else: + trainset.transform = transform + validset.transform = transform + + return trainset, validset + + +def _build_imagenet(data_path, augmentations=True, normalize=True): + """Define ImageNet with everything considered.""" + # Load data + trainset = torchvision.datasets.ImageNet(root=data_path, split='train', transform=transforms.ToTensor()) + validset = torchvision.datasets.ImageNet(root=data_path, split='val', transform=transforms.ToTensor()) + + if imagenet_mean is None: + data_mean, data_std = _get_meanstd(trainset) + else: + data_mean, data_std = imagenet_mean, imagenet_std + + # Organize preprocessing + transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)]) + if augmentations: + transform_train = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)]) + trainset.transform = transform_train + else: + trainset.transform = transform + validset.transform = transform + + return trainset, validset + + +def _get_meanstd(dataset): + cc = torch.cat([trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1) + data_mean = torch.mean(cc, dim=1).tolist() + data_std = torch.std(cc, dim=1).tolist() + return data_mean, data_std \ No newline at end of file diff --git a/src/inversefed/data/datasets.py b/src/inversefed/data/datasets.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inversefed/data/loss.py b/src/inversefed/data/loss.py new file mode 100644 index 0000000..f43ce93 --- /dev/null +++ b/src/inversefed/data/loss.py @@ -0,0 +1,114 @@ +"""Define various loss functions and bundle them with appropriate metrics.""" + +import torch +import numpy as np + + +class Loss: + """Abstract class, containing necessary methods. + + Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model + containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations + of the actual metric that is targeted. + + """ + + def __init__(self): + """Init.""" + pass + + def __call__(self, reference, argmin): + """Return l(x, y).""" + raise NotImplementedError() + return value, name, format + + def metric(self, reference, argmin): + """The actually sought metric.""" + raise NotImplementedError() + return value, name, format + + +class PSNR(Loss): + """A classical MSE target. + + The minimized criterion is MSE Loss, the actual metric is average PSNR. + """ + + def __init__(self): + """Init with torch MSE.""" + self.loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') + + def __call__(self, x=None, y=None): + """Return l(x, y).""" + name = 'MSE' + format = '.6f' + if x is None: + return name, format + else: + value = 0.5 * self.loss_fn(x, y) + return value, name, format + + def metric(self, x=None, y=None): + """The actually sought metric.""" + name = 'avg PSNR' + format = '.3f' + if x is None: + return name, format + else: + value = self.psnr_compute(x, y) + return value, name, format + + @staticmethod + def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0): + """Standard PSNR.""" + def get_psnr(img_in, img_ref): + mse = ((img_in - img_ref)**2).mean() + if mse > 0 and torch.isfinite(mse): + return (10 * torch.log10(factor**2 / mse)).item() + elif not torch.isfinite(mse): + return float('nan') + else: + return float('inf') + + if batched: + psnr = get_psnr(img_batch.detach(), ref_batch) + else: + [B, C, m, n] = img_batch.shape + psnrs = [] + for sample in range(B): + psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) + psnr = np.mean(psnrs) + + return psnr + + +class Classification(Loss): + """A classical NLL loss for classification. Evaluation has the softmax baked in. + + The minimized criterion is cross entropy, the actual metric is total accuracy. + """ + + def __init__(self): + """Init with torch MSE.""" + self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='mean') + + def __call__(self, x=None, y=None): + """Return l(x, y).""" + name = 'CrossEntropy' + format = '1.5f' + if x is None: + return name, format + else: + value = self.loss_fn(x, y) + return value, name, format + + def metric(self, x=None, y=None): + """The actually sought metric.""" + name = 'Accuracy' + format = '6.2%' + if x is None: + return name, format + else: + value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0] + return value.detach(), name, format \ No newline at end of file diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 3e8b50e..2f85d65 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -13,8 +13,6 @@ import yolo from utils.types import ConfigType -from inversefed.reconstruction_algorithms import loss_steps - class ModelUtils: def __init__(self, device: torch.device, config: ConfigType) -> None: self.device = device @@ -197,6 +195,7 @@ def train_classification( print("here, applying softmax") output = nn.functional.log_softmax(output, dim=1) # type: ignore if kwargs.get("gia", False): + from inversefed.reconstruction_algorithms import loss_steps # Sum the loss and create gradient graph like in loss_steps # Use modified loss_steps function that returns loss model.eval()