diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 78ee9bde..c69ddb6f 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -24,6 +24,7 @@ get_dataset, non_iid_balanced, balanced_subset, + gia_client_dataset, CacheDataset, TransformDataset, CorruptDataset, @@ -37,6 +38,7 @@ ) from utils.types import ConfigType from utils.dropout_utils import NodeDropout +from utils.gias import gia_main import torchvision.transforms as T # type: ignore import os @@ -119,6 +121,9 @@ def __init__( dropout_rng = random.Random(dropout_seed) self.dropout = NodeDropout(self.node_id, config["dropout_dicts"], dropout_rng) + if "gia" in config and self.node_id in config["gia_attackers"]: + self.gia_attacker = True + self.log_memory = config.get("log_memory", False) self.stats : Dict[str, int | float | List[int]] = {} @@ -205,13 +210,15 @@ def set_model_parameters(self, config: Dict[str, Any]) -> None: optim = torch.optim.SGD else: raise ValueError(f"Unknown optimizer: {optim_name}.") + # if "gia" in config: + # print("setting optim to gia") + # optim = torch.optim.SGD num_classes = self.dset_obj.num_cls num_channels = self.dset_obj.num_channels self.model = self.model_utils.get_model( - config["model"], - self.dset, - self.device, - self.device_ids, + model_name=config["model"], + dset=self.dset, + device=self.device, num_classes=num_classes, num_channels=num_channels, pretrained=config.get("pretrained", False), @@ -271,6 +278,11 @@ def get_model_weights(self) -> Dict[str, Tensor]: """ message = {"sender": self.node_id, "round": self.round, "model": self.model.state_dict()} + if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'): + # also stream image and labels + message["images"] = self.images + message["labels"] = self.labels + # Move to CPU before sending for key in message["model"].keys(): message["model"][key] = message["model"][key].to("cpu") @@ -442,6 +454,17 @@ def __init__( super().__init__(config, comm_utils) self.server_node = 0 self.set_parameters(config) + if "gia" in config: + if int(self.node_id) in self.config["gia_attackers"]: + self.gia_attacker = True + self.params_s = dict() + self.params_t = dict() + # Track neighbor updates with a dictionary mapping neighbor_id to their updates + self.neighbor_updates = defaultdict(list) + # Track which neighbors we've already attacked + self.attacked_neighbors = set() + + self.base_params = [key for key, _ in self.model.named_parameters()] def set_parameters(self, config: Dict[str, Any]) -> None: """ @@ -457,175 +480,216 @@ def set_parameters(self, config: Dict[str, Any]) -> None: self.set_shared_exp_parameters(config) self.set_data_parameters(config) + # after setting data loaders, save client dataset + # TODO verify this .data and .labels fields are correct + if "gia" in config: + # Extract data and labels + train_data = torch.stack([data[0] for data in self.train_dset]) + train_labels = torch.tensor([data[1] for data in self.train_dset]) + + self.log_utils.log_gia_image(train_data, + train_labels, + self.node_id) + def set_data_parameters(self, config: ConfigType) -> None: # Train set and test set from original dataset train_dset = self.dset_obj.train_dset test_dset = self.dset_obj.test_dset - # print("num train", len(train_dset)) - # print("num test", len(test_dset)) - - if config.get("test_samples_per_class", None) is not None: - test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"]) - - samples_per_user = config["samples_per_user"] - batch_size: int = config["batch_size"] # type: ignore - print(f"samples per user: {samples_per_user}, batch size: {batch_size}") - - # Support user specific dataset - if isinstance(config["dset"], dict): - - def is_same_dest(dset): - # Consider all variations of cifar10 as the same dataset - # To avoid having exactly same original dataset (without - # considering transformation) on multiple users - if self.dset == "cifar10" or self.dset.startswith("cifar10_"): - return dset == "cifar10" or dset.startswith("cifar10_") - else: - return dset == self.dset - - users_with_same_dset = sorted( - [int(k) for k, v in config["dset"].items() if is_same_dest(v)] + # Handle GIA case first, before any other modifications + if "gia" in config: + # Select 10 random labels and exactly one image per label for both train and test + train_dset, test_dset, classes, train_indices = gia_client_dataset( + train_dset, test_dset, num_labels=10 ) + + assert len(train_dset) == 10, "GIA should have exactly 10 samples in train set" + assert len(test_dset) == 10, "GIA should have exactly 10 samples in test set" + + # Store the images and labels in tensors, matching the format from your example + self.images = [] + self.labels = [] + + # Collect images and labels in order + for idx in range(len(train_dset)): + img, label = train_dset[idx] + self.images.append(img) + self.labels.append(torch.tensor([label])) + + # Stack/concatenate into final tensors + self.images = torch.stack(self.images) # Shape: [10, C, H, W] + self.labels = torch.cat(self.labels) # Shape: [10] + + # Set up the dataloaders with batch_size equal to dataset size for single-pass training + self.classes_of_interest = classes + self.train_indices = train_indices + self.train_dset = train_dset + self.dloader = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False) + self._test_loader = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False) + print("Using GIA data setup") + print(self.labels) else: - users_with_same_dset = list(range(1, config["num_users"] + 1)) - user_idx = users_with_same_dset.index(self.node_id) - - cls_prior = None - # If iid, each user has random samples from the whole dataset (no - # overlap between users) - if config["train_label_distribution"] == "iid": - indices = np.random.permutation(len(train_dset)) - train_indices = indices[ - user_idx * samples_per_user : (user_idx + 1) * samples_per_user - ] - train_dset = Subset(train_dset, train_indices) - classes = list(set([train_dset[i][1] for i in range(len(train_dset))])) - # If non_iid, each user get random samples from its support classes - # (mulitple users might have same images) - elif config["train_label_distribution"] == "support": - classes = config["support"][str(self.node_id)] - support_classes_dataset, indices = filter_by_class(train_dset, classes) - train_dset, sel_indices = random_samples( - support_classes_dataset, samples_per_user - ) - train_indices = [indices[i] for i in sel_indices] - elif config["train_label_distribution"].endswith("non_iid"): - alpha = config.get("alpha_data", 0.4) - if config["train_label_distribution"] == "inter_domain_non_iid": - # Hack to get the same class prior for all users with the same dataset - # While keeping the same random state for all users - if isinstance(config["dset"], dict) and isinstance( - config["dset"], dict - ): - cls_priors = [] - dsets = list(config["dset"].values()) - for _ in dsets: - n_cls = self.dset_obj.num_cls - cls_priors.append( - np.random.dirichlet( - alpha=[alpha] * n_cls, size=len(users_with_same_dset) + if config.get("test_samples_per_class", None) is not None: + test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"]) + + samples_per_user = config["samples_per_user"] + batch_size: int = config["batch_size"] # type: ignore + print(f"samples per user: {samples_per_user}, batch size: {batch_size}") + + # Support user specific dataset + if isinstance(config["dset"], dict): + + def is_same_dest(dset): + # Consider all variations of cifar10 as the same dataset + # To avoid having exactly same original dataset (without + # considering transformation) on multiple users + if self.dset == "cifar10" or self.dset.startswith("cifar10_"): + return dset == "cifar10" or dset.startswith("cifar10_") + else: + return dset == self.dset + + users_with_same_dset = sorted( + [int(k) for k, v in config["dset"].items() if is_same_dest(v)] + ) + else: + users_with_same_dset = list(range(1, config["num_users"] + 1)) + user_idx = users_with_same_dset.index(self.node_id) + + cls_prior = None + # If iid, each user has random samples from the whole dataset (no + # overlap between users) + if config["train_label_distribution"] == "iid": + indices = np.random.permutation(len(train_dset)) + train_indices = indices[ + user_idx * samples_per_user : (user_idx + 1) * samples_per_user + ] + train_dset = Subset(train_dset, train_indices) + classes = list(set([train_dset[i][1] for i in range(len(train_dset))])) + # If non_iid, each user get random samples from its support classes + # (mulitple users might have same images) + elif config["train_label_distribution"] == "support": + classes = config["support"][str(self.node_id)] + support_classes_dataset, indices = filter_by_class(train_dset, classes) + train_dset, sel_indices = random_samples( + support_classes_dataset, samples_per_user + ) + train_indices = [indices[i] for i in sel_indices] + elif config["train_label_distribution"].endswith("non_iid"): + alpha = config.get("alpha_data", 0.4) + if config["train_label_distribution"] == "inter_domain_non_iid": + # Hack to get the same class prior for all users with the same dataset + # While keeping the same random state for all users + if isinstance(config["dset"], dict) and isinstance( + config["dset"], dict + ): + cls_priors = [] + dsets = list(config["dset"].values()) + for _ in dsets: + n_cls = self.dset_obj.num_cls + cls_priors.append( + np.random.dirichlet( + alpha=[alpha] * n_cls, size=len(users_with_same_dset) + ) ) - ) - cls_prior = cls_priors[dsets.index(self.dset)] - train_y, train_idx_split, cls_prior = non_iid_balanced( - self.dset_obj, - len(users_with_same_dset), - samples_per_user, - alpha, - cls_priors=cls_prior, - is_train=True, - ) - train_indices = train_idx_split[self.node_id - 1] - train_dset = Subset(train_dset, train_indices) - classes = np.unique(train_y[user_idx]).tolist() - # One plot per dataset - # if user_idx == 0: - # print("using non_iid_balanced", alpha) - # self.plot_utils.plot_training_distribution(train_y, - # self.dset, users_with_same_dset) - elif config["train_label_distribution"] == "shard": - raise NotImplementedError - # classes_per_user = config["shards"]["classes_per_user"] - # samples_per_shard = samples_per_user // classes_per_user - # train_dset = build_shards_dataset(train_dset, samples_per_shard, - # classes_per_user, self.node_id) - else: - raise ValueError( - "Unknown train label distribution: {}.".format( - config["train_label_distribution"] + cls_prior = cls_priors[dsets.index(self.dset)] + train_y, train_idx_split, cls_prior = non_iid_balanced( + self.dset_obj, + len(users_with_same_dset), + samples_per_user, + alpha, + cls_priors=cls_prior, + is_train=True, + ) + train_indices = train_idx_split[self.node_id - 1] + train_dset = Subset(train_dset, train_indices) + classes = np.unique(train_y[user_idx]).tolist() + # One plot per dataset + # if user_idx == 0: + # print("using non_iid_balanced", alpha) + # self.plot_utils.plot_training_distribution(train_y, + # self.dset, users_with_same_dset) + elif config["train_label_distribution"] == "shard": + raise NotImplementedError + # classes_per_user = config["shards"]["classes_per_user"] + # samples_per_shard = samples_per_user // classes_per_user + # train_dset = build_shards_dataset(train_dset, samples_per_shard, + # classes_per_user, self.node_id) + else: + raise ValueError( + "Unknown train label distribution: {}.".format( + config["train_label_distribution"] + ) ) - ) - if self.dset.startswith("domainnet"): - train_transform = T.Compose( - [ - T.RandomResizedCrop(32, scale=(0.75, 1)), - T.RandomHorizontalFlip(), - # T.ToTensor() - ] - ) + if self.dset.startswith("domainnet"): + train_transform = T.Compose( + [ + T.RandomResizedCrop(32, scale=(0.75, 1)), + T.RandomHorizontalFlip(), + # T.ToTensor() + ] + ) - # Cache before transform to preserve transform randomness - train_dset = TransformDataset(CacheDataset(train_dset), train_transform) + # Cache before transform to preserve transform randomness + train_dset = TransformDataset(CacheDataset(train_dset), train_transform) - if config.get("malicious_type", None) == "corrupt_data": - corruption_fn_name = config.get("corruption_fn", "gaussian_noise") - severity = config.get("corrupt_severity", 1) - train_dset = CorruptDataset(CacheDataset(train_dset), corruption_fn_name, severity) - print("created train dataset with corruption function: ", corruption_fn_name) + if config.get("malicious_type", None) == "corrupt_data": + corruption_fn_name = config.get("corruption_fn", "gaussian_noise") + severity = config.get("corrupt_severity", 1) + train_dset = CorruptDataset(CacheDataset(train_dset), corruption_fn_name, severity) + print("created train dataset with corruption function: ", corruption_fn_name) - self.classes_of_interest = classes + self.classes_of_interest = classes - val_prop = config.get("validation_prop", 0) - val_dset = None - if val_prop > 0: - val_size = int(val_prop * len(train_dset)) - train_size = len(train_dset) - val_size - train_dset, val_dset = torch.utils.data.random_split( - train_dset, [train_size, val_size] - ) - # self.val_dloader = DataLoader(val_dset, batch_size=batch_size*len(self.device_ids), - # shuffle=True) - self.val_dloader = DataLoader(val_dset, batch_size=batch_size, shuffle=True) - - assert isinstance(train_dset, torch.utils.data.Dataset), "train_dset must be a Dataset" - self.train_indices = train_indices - self.train_dset = train_dset - self.dloader = DataLoader(train_dset, batch_size=batch_size, shuffle=True) # type: ignore - - if config["test_label_distribution"] == "iid": - pass - # If non_iid, each users ge the whole test set for each of its - # support classes - elif config["test_label_distribution"] == "support": - classes = config["support"][str(self.node_id)] - test_dset, _ = filter_by_class(test_dset, classes) - elif config["test_label_distribution"] == "non_iid": - - test_y, test_idx_split, _ = non_iid_balanced( - self.dset_obj, - len(users_with_same_dset), - config["test_samples_per_user"], - is_train=False, - ) + val_prop = config.get("validation_prop", 0) + val_dset = None + if val_prop > 0: + val_size = int(val_prop * len(train_dset)) + train_size = len(train_dset) - val_size + train_dset, val_dset = torch.utils.data.random_split( + train_dset, [train_size, val_size] + ) + # self.val_dloader = DataLoader(val_dset, batch_size=batch_size*len(self.device_ids), + # shuffle=True) + self.val_dloader = DataLoader(val_dset, batch_size=batch_size, shuffle=True) + + assert isinstance(train_dset, torch.utils.data.Dataset), "train_dset must be a Dataset" + self.train_indices = train_indices + self.train_dset = train_dset + self.dloader = DataLoader(train_dset, batch_size=batch_size, shuffle=True) # type: ignore + + if config["test_label_distribution"] == "iid": + pass + # If non_iid, each users ge the whole test set for each of its + # support classes + elif config["test_label_distribution"] == "support": + classes = config["support"][str(self.node_id)] + test_dset, _ = filter_by_class(test_dset, classes) + elif config["test_label_distribution"] == "non_iid": + + test_y, test_idx_split, _ = non_iid_balanced( + self.dset_obj, + len(users_with_same_dset), + config["test_samples_per_user"], + is_train=False, + ) - train_indices = test_idx_split[self.node_id - 1] - test_dset = Subset(test_dset, train_indices) - else: - raise ValueError( - "Unknown test label distribution: {}.".format( - config["test_label_distribution"] + train_indices = test_idx_split[self.node_id - 1] + test_dset = Subset(test_dset, train_indices) + else: + raise ValueError( + "Unknown test label distribution: {}.".format( + config["test_label_distribution"] + ) ) - ) - if self.dset.startswith("domainnet"): - test_dset = CacheDataset(test_dset) + if self.dset.startswith("domainnet"): + test_dset = CacheDataset(test_dset) - self._test_loader = DataLoader(test_dset, batch_size=batch_size) - # TODO: fix print_data_summary - # self.print_data_summary(train_dset, test_dset, val_dset=val_dset) + self._test_loader = DataLoader(test_dset, batch_size=batch_size) + # TODO: fix print_data_summary + # self.print_data_summary(train_dset, test_dset, val_dset=val_dset) def local_train(self, round: int, epochs: int = 1, **kwargs: Any) -> Tuple[float, float, float]: """ @@ -639,7 +703,7 @@ def local_train(self, round: int, epochs: int = 1, **kwargs: Any) -> Tuple[float avg_loss, avg_acc = 0, 0 for _ in range(epochs): tr_loss, tr_acc = self.model_utils.train( - self.model, self.optim, self.dloader, self.loss_fn, self.device, malicious_type=self.config.get("malicious_type", "normal"), config=self.config, + self.model, self.optim, self.dloader, self.loss_fn, self.device, malicious_type=self.config.get("malicious_type", "normal"), config=self.config, node_id=self.node_id, gia=self.config.get("gia", False) ) avg_loss += tr_loss avg_acc += tr_acc @@ -693,6 +757,67 @@ def receive_and_aggregate(self): assert "model" in repr, "Model not found in the received message" self.set_model_weights(repr["model"]) + def receive_attack_and_aggregate(self, neighbors: List[int], round: int, num_neighbors: int) -> None: + """ + Receives updates, launches GIA attack when second update is seen from a neighbor + """ + print("CLIENT RECEIVING ATTACK AND AGGREGATING") + if self.is_working: + # Receive the model updates from the neighbors + model_updates = self.comm_utils.receive(node_ids=neighbors) + assert len(model_updates) == num_neighbors + + for neighbor_info in model_updates: + neighbor_id = neighbor_info["sender"] + neighbor_model = neighbor_info["model"] + neighbor_model = OrderedDict( + (key, value) for key, value in neighbor_model.items() + if key in self.base_params + ) + + neighbor_images = neighbor_info["images"] + neighbor_labels = neighbor_info["labels"] + + # Store this update + self.neighbor_updates[neighbor_id].append({ + "model": neighbor_model, + "images": neighbor_images, + "labels": neighbor_labels + }) + + # Check if we have 2 updates from this neighbor and haven't attacked them yet + if len(self.neighbor_updates[neighbor_id]) == 2 and neighbor_id not in self.attacked_neighbors: + print(f"Client {self.node_id} attacking {neighbor_id}!") + + # Get the two parameter sets for the attack + p_s = self.neighbor_updates[neighbor_id][0]["model"] + p_t = self.neighbor_updates[neighbor_id][1]["model"] + + # Launch the attack + if result := gia_main(p_s, + p_t, + self.base_params, + self.model, + neighbor_labels, + neighbor_images, + self.node_id): + output, stats = result + + # log output and stats as image + self.log_utils.log_gia_image(output, neighbor_labels, neighbor_id, label=f"round_{round}_reconstruction") + self.log_utils.log_summary(f"round {round} gia targeting {neighbor_id} stats: {stats}") + else: + self.log_utils.log_summary(f"Client {self.node_id} failed to attack {neighbor_id} in round {round}!") + print(f"Client {self.node_id} failed to attack {neighbor_id}!") + continue + + # Mark this neighbor as attacked + self.attacked_neighbors.add(neighbor_id) + + # Optionally, clear the stored updates to save memory + del self.neighbor_updates[neighbor_id] + + self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) def receive_pushed_and_aggregate(self, remove_multi = True) -> None: model_updates = self.comm_utils.receive_pushed() @@ -714,6 +839,7 @@ def receive_pushed_and_aggregate(self, remove_multi = True) -> None: else: print("No one pushed model updates for this round.") + def run_protocol(self) -> None: raise NotImplementedError @@ -776,8 +902,11 @@ def set_data_parameters(self, config: Dict[str, Any]) -> None: """Add docstring here""" test_dset = self.dset_obj.test_dset batch_size = config["batch_size"] - self._test_loader = DataLoader(test_dset, batch_size=batch_size) - + if "gia" not in config: + self._test_loader = DataLoader(test_dset, batch_size=batch_size) + else: + _, test_data, labels, indices = gia_client_dataset(self.dset_obj.train_dset, test_dset) + self._test_loader = DataLoader(test_data, batch_size=10) def aggregate( self, representation_list: List[OrderedDict[str, Any]], **kwargs: Any ) -> OrderedDict[str, Tensor]: @@ -801,7 +930,6 @@ def get_model(self, **kwargs: Any) -> Any: def run_protocol(self) -> None: raise NotImplementedError - class CommProtocol(object): """ Communication protocol tags for the server and users @@ -839,6 +967,7 @@ def __init__( keys = self.model_utils.get_last_layer_keys(self.get_model_weights()) self.model_keys_to_ignore.extend(keys) + def local_test(self, **kwargs: Any) -> Tuple[float, float]: """ Test the model locally, not to be used in the traditional FedAvg @@ -1013,6 +1142,8 @@ def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None: self.set_model_weights(agg_wts) def receive_and_aggregate(self, neighbors: List[int]) -> None: + if hasattr(self, "gia_attacker"): + self.receive_attack_and_aggregate(neighbors, it, len(neighbors)) if self.streaming_aggregation: self.receive_and_aggregate_streaming(neighbors) else: @@ -1022,7 +1153,6 @@ def receive_and_aggregate(self, neighbors: List[int]) -> None: # Aggregate the representations self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) - def get_collaborator_weights( self, reprs_dict: Dict[int, OrderedDict[int, Tensor]] ) -> Dict[int, float]: diff --git a/src/algos/fl.py b/src/algos/fl.py index 5bf68637..320bb413 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -11,12 +11,17 @@ from algos.attack_bad_weights import BadWeightsAttack from algos.attack_sign_flip import SignFlipAttack +from utils.gias import gia_main + +import pickle + class FedAvgClient(BaseClient): def __init__( self, config: Dict[str, Any], comm_utils: CommunicationManager ) -> None: super().__init__(config, comm_utils) self.config = config + self.random_params = self.model.state_dict() def local_test(self, **kwargs: Any) -> Tuple[float, float, float]: """ @@ -68,6 +73,16 @@ def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]: # move the model to cpu before sending for key in message["model"].keys(): message["model"][key] = message["model"][key].to("cpu") + + # assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found" + if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'): + # also stream image and labels + message["images"] = self.images.to("cpu") + message["labels"] = self.labels.to("cpu") + + message["random_params"] = self.random_params + for key in message["random_params"].keys(): + message["random_params"][key] = message["random_params"][key].to("cpu") return message # type: ignore @@ -155,16 +170,73 @@ def test(self, **kwargs: Any) -> Tuple[float, float, float]: self.stats["test_loss"], self.stats["test_acc"], self.stats["test_time"] = test_loss, test_acc, time_taken return test_loss, test_acc, time_taken + def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""): + reprs = self.comm_utils.all_gather() + + with open(dump_file_name, "wb") as f: + pickle.dump(reprs, f) + + # Handle GIA-specific logic + if "gia" in self.config: + print("Server Running GIA attack") + base_params = [key for key, _ in self.model.named_parameters()] + print(base_params) + + for rep in reprs: + client_id = rep["sender"] + assert "images" in rep and "labels" in rep, "Images and labels not found in representation" + model_state_dict = rep["model"] + + # Extract relevant model parameters + model_params = OrderedDict( + (key, value) for key, value in model_state_dict.items() + if key in base_params + ) + + random_params = rep["random_params"] + random_params = OrderedDict( + (key, value) for key, value in random_params.items() + if key in base_params + ) + + # Store parameters based on attack start and end rounds + if round == attack_start_round: + self.params_s[client_id - 1] = model_params + elif round == attack_end_round: + self.params_t[client_id - 1] = model_params + images = rep["images"] + labels = rep["labels"] + + # Launch GIA attack + p_s, p_t = self.params_s[client_id - 1], self.params_t[client_id - 1] + gia_main(p_s, p_t, base_params, self.model, labels, images, client_id) + + avg_wts = self.aggregate(reprs) + self.set_representation(avg_wts) + + def receive_and_aggregate(self): reprs = self.comm_utils.all_gather() avg_wts = self.aggregate(reprs) self.set_representation(avg_wts) - def single_round(self): + def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1): """ - Runs the whole training procedure + Runs the whole training procedure. + + Parameters: + round (int): Current round of training. + attack_start_round (int): The starting round to initiate the attack. + attack_end_round (int): The last round for the attack to be performed. """ - self.receive_and_aggregate() + + # Normal training when outside the attack range + + if round < attack_start_round or round > attack_end_round: + self.receive_and_aggregate() + else: + self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name) + def run_protocol(self): print(f"Client {self.node_id} ready to start training") diff --git a/src/algos/fl_inversionAttack.py b/src/algos/fl_inversionAttack.py new file mode 100644 index 00000000..17e6d95b --- /dev/null +++ b/src/algos/fl_inversionAttack.py @@ -0,0 +1,460 @@ +import numpy as np +import networkx as nx +import matplotlib.pyplot as plt +import scipy as sp +from typing import Any, Dict, List +import torch +from fractions import Fraction +import random + +from utils.communication.comm_utils import CommunicationManager +from utils.log_utils import LogUtils +from algos.fl import FedAvgClient, FedAvgServer +from algos.fl_static import FedStaticNode, FedStaticServer + +import inversefed + +def LaplacianGossipMatrix(G): + max_degree = max([G.degree(node) for node in G.nodes()]) + W = np.eye(G.number_of_nodes()) - 1/max_degree * nx.laplacian_matrix(G).toarray() + return W + +def get_non_attackers_neighbors(G, attackers): + """ + G : networkx graph + attackers : list of the nodes considered as attackers + returns : non repetetive list of the neighbors of the attackers + """ + return sorted(set(n for attacker in attackers for n in G.neighbors(attacker)).difference(set(attackers))) + +def GLS(X, y, cov): + """ + Returns the generalized least squares estimator b, such as + Xb = y + e + e being a noise of covariance matrix cov + """ + X_n, X_m = X.shape + y_m = len(y) + s_n = len(cov) + assert s_n == X_n, "Dimension mismatch" + try: + inv_cov = np.linalg.inv(cov) + except Exception as e: + print("WARNING : The covariance matrix is not invertible, using pseudo inverse instead") + inv_cov = np.linalg.pinv(cov) + return np.linalg.inv(X.T@inv_cov@X)@ X.T@inv_cov@y + +class ReconstructOptim(): + def __init__(self, G, n_iter, attackers, gossip_matrix = LaplacianGossipMatrix, targets_only = False): + """ + A class to reconstruct the intial values used in a decentralized parallel gd algorithm + This class depends only on the graph and the attack parameters n_iter and attackers + It doesn't depend on the actual updates of one particular execution + G: networkx graph, we require the nodes to be indexed from 0 to n-1 + n_iter: number of gossip iterations n_iter >= 1 + attackers: indices of the attacker nodes + gossip_matrix: function that returns the gossip matrix of the graph + + same script as https://github.com/AbdellahElmrini/decAttack/tree/master + """ + self.G = G + self.n_iter = n_iter + self.attackers = attackers + self.n_attackers = len(attackers) + self.W = gossip_matrix(self.G) + self.Wt = torch.tensor(self.W, dtype = torch.float64) + self.build_knowledge_matrix_dec() + + def build_knowledge_matrix_dec(self, centralized=False): + """ + Building a simplified knowledge matrix including only the targets as unknowns + This matrix encodes the system of equations that the attackers receive during the learning + We assume that the n_a attackers appear in the beginning of the gossip matrix + returns : + knowledge_matrix : A matrix of shape m * n, where m = self.n_iter*len(neighbors), n = number of targets + """ + if not centralized: + W = self.W + att_matrix = [] + n_targets = len(self.W) - self.n_attackers + for neighbor in get_non_attackers_neighbors(self.G, self.attackers): + att_matrix.append(np.eye(1,n_targets,neighbor-self.n_attackers)[0]) # Shifting the index of the neighbor to start from 0 + + pW_TT = np.identity(n_targets) + + for _ in range(1, self.n_iter): + pW_TT = W[self.n_attackers:,self.n_attackers: ] @ pW_TT + np.identity((n_targets)) + for neighbor in get_non_attackers_neighbors(self.G, self.attackers): + att_matrix.append(pW_TT[neighbor-self.n_attackers]) # Assuming this neighbor is not an attacker + + self.target_knowledge_matrix = np.array(att_matrix) + return self.target_knowledge_matrix + else: + # Simplify for centralized FL: no gossip matrix, direct aggregation from clients + n_targets = len(self.W) - self.n_attackers # Number of clients (non-attackers) + + att_matrix = [] + for client in range(n_targets): + att_matrix.append(np.eye(1, n_targets, client)[0]) # Identity matrix for each client + + self.target_knowledge_matrix = np.array(att_matrix) + return self.target_knowledge_matrix + def build_cov_target_only(self, sigma): # NewName : Build_covariance_matrix + """ + Function to build the covariance matrix of the system of equations received by the attackers + The number of columns corresponds to the number of targets in the system + See the pseudo code at algorithm 6 in the report + return : + cov : a matrix of size m * m, where m = self.n_iter*len(neighbors) + """ + W = self.W + W_TT = W[self.n_attackers:, self.n_attackers:] + neighbors = get_non_attackers_neighbors(self.G, self.attackers) + + m = self.n_iter*len(neighbors) + + cov = np.zeros((m,m)) + # We iteratively fill this matrix line by line in a triangular fashion (as it is a symetric matrix) + i = 0 + + while i < m: + for it1 in range(self.n_iter): + for neighbor1 in neighbors: + j = it1*len(neighbors) + for it2 in range(it1, self.n_iter): + for neighbor2 in neighbors: + s=0 + for t in range(it1+1): + s+=np.linalg.matrix_power(W_TT,it1+it2-2*t)[neighbor1, neighbor2] + cov[i,j] = sigma**2 * s + cov[j,i] = cov[i,j] + j += 1 + i+=1 + return cov + + + + def reconstruct_GLS_target_only(self, v, X_A, sigma): + """ + Function to reconstruct the inital gradients from the values received by the attackers after self.n_iter iterations. + This method uses GLS estimator + v (nd.array) : vector containing the values received by the attackers (in the order defined by the gossip) + sigma : (float) : variance + returns : + x_hat : a vector of shape n * v.shape[1], where n is the number of nodes + """ + cov = self.build_cov_target_only(sigma) + n_targets = len(self.W) - self.n_attackers + neighbors = np.array(get_non_attackers_neighbors(self.G, self.attackers)) + n_neighbors = len(neighbors) + v = v[self.n_attackers:] # v[:self.n_attackers] are the attacker sent updates which are the same as X_A[:self.n_attackers] + d = v[0].shape[0] + W_TA = self.Wt[self.n_attackers:, :self.n_attackers] + W_TT = self.Wt[self.n_attackers:, self.n_attackers:] + pW_TT = np.identity(n_targets, dtype = np.float64) + new_v = [] + B_t = np.zeros((n_targets, d), dtype = np.float64) + for it in range(self.n_iter): + X_A_t = X_A[it*self.n_attackers:(it+1)*self.n_attackers] + pW_TT = W_TT @ pW_TT + np.identity((n_targets), dtype = np.float64) + theta_T_t = v[it*n_neighbors:(it+1)*n_neighbors] + new_v.extend(theta_T_t-B_t[neighbors-self.n_attackers]) + B_t = W_TT @ B_t + W_TA @ X_A_t + v = np.array(new_v) + try: + return GLS(self.target_knowledge_matrix, v, cov) + except Exception as e: + print(e) + print("Building the knowledge matrix failed") + raise + + def reconstruct_LS_target_only(self, v, X_A): + """ + Function to reconstruct the inital gradients from the values received by the attackers after self.n_iter iterations. + This method uses a Least Squares estimator + v (nd.array) : vector containing the values received by the attackers (in the order defined by the gossip) + v looks like (X_A^0, \theta_T^{0+), X_A^1, \theta_T^{1+), ..., X_A^T, \theta_T^{T+)} + where X_A^t are the attacker sent updates at iteration t and \theta_T^{t+)} are the target sent updates at iteration t + X_A (nd.array) : vector of size n_a*self.n_iter, containing the attacker sent updates at each iteration + returns : + x_hat : a vector of shape n_target * v.shape[1], where n_target is the number of target nodes + """ + # Prepossessing v to adapt to the target only knowledge matrix + + n_targets = len(self.W) - self.n_attackers + neighbors = np.array(get_non_attackers_neighbors(self.G, self.attackers)) + n_neighbors = len(neighbors) + v = v[self.n_attackers:] # v[:self.n_attackers] are the attacker sent updates which are the same as X_A[:self.n_attackers] + d = v[0].shape[0] + W_TA = self.Wt[self.n_attackers:, :self.n_attackers] + W_TT = self.Wt[self.n_attackers:, self.n_attackers:] + #pW_TT = np.identity(n_targets, dtype = np.float32) + new_v = [] + B_t = np.zeros((n_targets, d), dtype = np.float64) + for it in range(self.n_iter): + X_A_t = X_A[it*self.n_attackers:(it+1)*self.n_attackers] + #pW_TT = W_TT @ pW_TT + np.identity((n_targets), dtype = np.float64) + theta_T_t = v[it*n_neighbors:(it+1)*n_neighbors] + new_v.extend(theta_T_t-B_t[neighbors-self.n_attackers]) + + B_t = W_TT @ B_t + W_TA @ X_A_t + + v = torch.stack(new_v).numpy() + + try: + return np.linalg.lstsq(self.target_knowledge_matrix, v)[0] + except Exception as e: + print(e) + print("Building the knowledge matrix failed") + raise + +class GradientInversionFedAvgClient(FedAvgClient): + """ + Implements ground truth for evaluating inversion attack + """ + def __init__(self, config: Dict[str, Any], node_id: int, comm: CommunicationManager, log: LogUtils): + super(GradientInversionFedAvgClient, self).__init__(config, node_id, comm, log) + # get ground truth and labels for evaluation + self.ground_truth, self.labels = self.extract_ground_truth(num_images=config["num_images"]) # set reconstruction number + + # TODO somehow get the server to access the ground truth and labels for evaluation + self.comm_utils.send(0, [self.ground_truth, self.labels]) + + def extract_ground_truth(self, num_images=10): + """ + Randomly extract a batch of ground truth images and labels from self.dloader for gradient inversion attacks. + + Args: + num_images (int): Number of images to extract. + + Returns: + ground_truth (torch.Tensor): Tensor containing the extracted ground truth images. + labels (torch.Tensor): Tensor containing the corresponding labels. + """ + # Convert the dataset to a list of (image, label) tuples + data = list(self.dloader.dataset) + + # Randomly sample `num_images` images and labels + sampled_data = random.sample(data, num_images) + + # Separate images and labels + ground_truth = [img for img, label in sampled_data] + labels = [torch.as_tensor((label,)) for img, label in sampled_data] + + # Stack into tensors + ground_truth = torch.stack(ground_truth) + labels = torch.cat(labels) + + return ground_truth, labels + +class GradientInversionFedAvgServer(FedAvgServer): + """ + implements gradient inversion attack to reconstruct training images from other nodes + """ + def __init__(self, config: Dict[str, Any], comm: CommunicationManager, log: LogUtils): + super(GradientInversionFedAvgServer, self).__init__(config, comm, log) + + #TODO somehow obtain the client's ground truth and labels for evaluation + self.ground_truth, self.labels = self.obtain_ground_truth() # should be one list per client + + + def obtain_ground_truth(self): + """ + Obtain the ground truth images and labels from the clients for evaluation. + """ + ground_truth, labels = [], [] + client_list = self.comm_utils.receive([i for i in range(self.num_users)]) + # TODO 1) sort the received items + # TODO 2) add tag to indicate we are receiving dummy data + for i in range(len(client_list)): + ground_truth_i, labels_i = client_list[i][:10], client_list[i][10:] + ground_truth.append(ground_truth_i) + labels.append(labels_i) + return ground_truth, labels + + def inverting_gradients_attack(self): + """ + Setup the inversion attack for the server. + + Based on reconstruction from weight script: + https://github.com/JonasGeiping/invertinggradients/blob/1157b61c6704df42c497ab9eb074c75da5204334/Recovery%20from%20Weight%20Updates.ipynb + """ + setup = inversefed.utils.system_startup() + if self.dset == "cifar10": + # TODO figure out whehether we actually have the dm and ds values in our codebase + dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None] + ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None] + + # extract input parameters (this should be the averaged server-side params after a round of FedAVG) + input_params_s = self.comm_utils.all_gather() #[clinet1 param, client2 param, ...] + self.single_round() + input_params_t = self.comm_utils.all_gather() + + # get the param difference for each client [client1 param diff, client2 param diff, ...] + param_diffs = [] + + # Loop over each client's parameters (assumes input_params_s and input_params_t are lists of lists) + for client_params_s, client_params_t in zip(input_params_s, input_params_t): + client_param_diff = [ + param_t - param_s # element-wise difference of the tensors + for param_s, param_t in zip(client_params_s, client_params_t) + ] + param_diffs.append(client_param_diff) + + assert len(param_diffs) == self.num_users == self.ground_truth, "Number of clients does not match number of param differences" + config = dict(signed=True, + boxed=True, + cost_fn='sim', + indices='def', + weights='equal', + lr=0.1, + optim='adam', + restarts=1, + max_iterations=8_000, + total_variation=1e-6, + init='randn', + filter='none', + lr_decay=True, + scoring_choice='loss') + + for client_i in range(self.num_users): + # TODO assume that client i correspond to order of received params + ground_truth_i, labels_i, params_i = self.ground_truth[client_i], self.labels[client_i], param_diffs[client_i] + + local_steps = 1 # number of local steps for client training + local_lr = self.config["model_lr"] # learning rate for client training + use_updates = False + rec_machine = inversefed.FedAvgReconstructor(self.model, (dm, ds), local_steps, local_lr, config, + use_updates=use_updates) + output, stats = rec_machine.reconstruct(params_i, labels_i, img_shape=(3, 32, 32)) # TODO verify img_shape and change it based on dataset + test_mse = (output.detach() - ground_truth_i).pow(2).mean() + feat_mse = (self.model(output.detach())- self.model(ground_truth_i)).pow(2).mean() + test_psnr = inversefed.metrics.psnr(output, ground_truth_i, factor=1/ds) + + # optional plotting: + # plot(output) + # plt.title(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} " + # f"| PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |"); + return output, test_mse, test_psnr, feat_mse + def run_protocol(self): + """ + basically a carbon copy of fl.py's run protocol. Except attack is launched at the end + """ + self.log_utils.log_console("Starting clients federated averaging") + start_epochs = self.config.get("start_epochs", 0) + total_epochs = self.config["epochs"] + for round in range(start_epochs, total_epochs): + + if round == total_epochs - 1: + self.log_utils.log_console("Launching inversion attack") + output, test_mse, test_psnr, feat_mse = self.inverting_gradients_attack() + self.log_utils.log_console("Inversion attack complete") + self.log_utils.log_summary( + f"Round {round} inversion attack complete. Test MSE: {test_mse}, Test PSNR: {test_psnr}, Feature MSE: {feat_mse}" + ) + # TODO somehow save output? + + self.log_utils.log_console("Starting round {}".format(round)) + self.log_utils.log_summary("Starting round {}".format(round)) + self.single_round() + self.log_utils.log_console("Server testing the model") + loss, acc, time_taken = self.test() + self.log_utils.log_tb(f"test_acc/clients", acc, round) + self.log_utils.log_tb(f"test_loss/clients", loss, round) + self.log_utils.log_console( + "Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format( + round, acc, loss, time_taken + ) + ) + # self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) + self.log_utils.log_console("Round {} complete".format(round)) + self.log_utils.log_summary( + "Round {} complete".format( + round, + ) + ) +class GradientInversionFedStaticServer(FedStaticServer): + """ + implements gradient inversion attack to reconstruct training images from other nodes + can handle colluding neighbors + + base on method proposed by https://github.com/AbdellahElmrini/decAttack/tree/master' + + reconstruction method uses InvertingGradients by Jonas Geiping: https://github.com/JonasGeiping/invertinggradients + + The order of stacked params is just the keys of attacker / collaborator IDs in ascending order + """ + def __init__(self, config: Dict[str, Any], G: nx.Graph): + # construct graph + # TODO need to recheck this instantiation depend on graph implementation + # TODO keep copy of weights at 0th round (when everyone finishes training + self.G = G + self.neighbors = [i for i in range(self.num_users) if i != self.node_id] # for the server, neighbors are all the clients + self.attackers = [self.node_id] # for the server, the attacker is itself + self.end_round = config["rounds"] + + def get_model_parameters(self, ids_list: List[int]): + """ + returns stacked model parameters + modeled after Decentralized.get_model_params in decAttack codebase + + TODO verify the actual params getting sent + """ + param_from_collaborators = self.comm_utils.receive(ids_list) + params = [[] for p in range(self.model.parameters())] + + for i in range(len(self.neighbors)): + neighbor_id = self.neighbors[i] + for j, param in enumerate(param_from_collaborators[neighbor_id]): + params[j].append(param) + + + for j in range(len(params)): + params[j] = torch.stack(params[j]) + + return params + + def get_node_weights(self): + """ + helper function that obtains the param updates of attackers + uses commProtocol to get the params from the nodes + + TODO double check that neighbors include attacking nodes as well + """ + + # Issue for FL where server is the attacker: attacker gradeint is the averaged gradients from neighbors + return self.get_model_parameters(self.neighbors), self.get_model_parameters(self.attackers) + + + def launch_attack(self): + """ + Main function for performing inversion attack when the server is the attacker. + This should happen after running FedAVG for a single round. + """ + # Build reconstruction class: + R = ReconstructOptim(self.G, n_iter=1, attackers=self.attackers) + + # Initial parameters (before aggregation) + neighbor_params0, attacker_params0 = self.get_node_weights() + + # Run a single round of FedAVG to update server's representation + self.single_round() + + # Get the updated parameters after aggregation + neighbor_params_i, attacker_params_i = self.get_node_weights() + + # Collect the difference in parameters for attack + sent_params = [] + attacker_params = [] + + # In centralized FL, server is the attacker + for i in range(len(neighbor_params0)): # Loop over all clients + # Calculate the difference between the initial and updated parameters for neighbors (clients) + sent_params.append(torch.cat([(neighbor_params_i[j][i] - neighbor_params0[j][i]).flatten().detach() for j in range(self.n_params)]).cpu()) + + # For the server (attacker), compute the difference between its initial and updated parameters + attacker_params.append(torch.cat([(attacker_params_i[j] - attacker_params0[j]).flatten().detach() for j in range(self.n_params)]).cpu()) + + # Use the collected parameters to reconstruct the images + x_hat = R.reconstruct_LS_target_only(sent_params, attacker_params) + diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 59d7baa7..9a476093 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -2,12 +2,15 @@ Module for FedStaticClient and FedStaticServer in Federated Learning. """ from typing import Any, Dict, OrderedDict, List +from collections import OrderedDict, defaultdict + from utils.communication.comm_utils import CommunicationManager import torch from algos.base_class import BaseFedAvgClient from algos.topologies.collections import select_topology +from utils.gias import gia_main class FedStaticNode(BaseFedAvgClient): """ @@ -25,7 +28,6 @@ def get_neighbors(self) -> List[int]: """ Returns a list of neighbours for the client. """ - neighbors = self.topology.sample_neighbours(self.num_collaborators) self.stats["neighbors"] = neighbors @@ -43,6 +45,7 @@ def run_protocol(self) -> None: ) total_rounds = self.config["rounds"] epochs_per_round = self.config.get("epochs_per_round", 1) + for it in range(start_round, total_rounds): self.round_init() @@ -52,6 +55,7 @@ def run_protocol(self) -> None: ) self.local_round_done() # Collect the representations from all other nodes from the server + neighbors = self.get_neighbors() # TODO: Log the neighbors self.receive_and_aggregate(neighbors) @@ -61,8 +65,6 @@ def run_protocol(self) -> None: self.round_finalize() - - class FedStaticServer(BaseFedAvgClient): """ Federated Static Server Class. It does not do anything. diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 2f65fdf3..3d859009 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -31,14 +31,25 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st traditional_fl: ConfigType = { # Collaboration setup "algo": "fedavg", - "rounds": 2, - + "rounds": 5, # Model parameters "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, } +test_fl_inversion: ConfigType = { + # Collaboration setup + "algo": "fedavg", + "rounds": 5, + "optimizer": "sgd", + # Model parameters + "model": "resnet10", + "model_lr": 3e-4, + # "batch_size": 256, + "gia": True, +} + fedweight: ConfigType = { "algo": "fedweight", "num_rep": 1, @@ -192,9 +203,10 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore - "rounds": 20, + "rounds": 3, # Model parameters + "optimizer": "sgd", # TODO comment out for real training "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, @@ -350,5 +362,4 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st malicious_traditional_model_update_attack, ] - default_config_list: List[ConfigType] = [traditional_fl] diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 5ea4d69c..99a33fea 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -316,7 +316,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "exp_keys": [], } -num_users = 9 +num_users = 4 dropout_dict = { "distribution_dict": { # leave dict empty to disable dropout @@ -335,6 +335,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # for swift or fedavgpush, just modify the algo_configs list # for swift, synchronous should preferable be False gpu_ids = [2, 3, 5, 6] + grpc_system_config: ConfigType = { "exp_id": "static", "num_users": num_users, @@ -365,6 +366,26 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): } } +grpc_system_config_gia: ConfigType = { + "exp_id": "static", + "num_users": num_users, + "num_collaborators": NUM_COLLABORATORS, + "comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node + "dset": CIFAR10_DSET, + "dump_dir": DUMP_DIR, + "dpath": CIAR10_DPATH, + "seed": 2, + "device_ids": get_device_ids(num_users, gpu_ids), + # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore + "samples_per_user": 50000 // num_users, # distributed equally + "train_label_distribution": "iid", + "test_label_distribution": "iid", + "exp_keys": [], + "dropout_dicts": dropout_dicts, + "gia":True, + "gia_attackers":[1] +} current_config = grpc_system_config # current_config = mpi_system_config diff --git a/src/inversefed/__init__.py b/src/inversefed/__init__.py new file mode 100644 index 00000000..0dacb9c7 --- /dev/null +++ b/src/inversefed/__init__.py @@ -0,0 +1,20 @@ +"""Library of routines.""" + +from inversefed import nn +from inversefed.nn import construct_model, MetaMonkey + +from inversefed.data import construct_dataloaders +from inversefed.training import train +from inversefed import utils + +from .optimization_strategy import training_strategy + + +from .reconstruction_algorithms import GradientReconstructor, FedAvgReconstructor + +from .options import options +from inversefed import metrics + +__all__ = ['train', 'construct_dataloaders', 'construct_model', 'MetaMonkey', + 'training_strategy', 'nn', 'utils', 'options', + 'metrics', 'GradientReconstructor', 'FedAvgReconstructor'] diff --git a/src/inversefed/consts.py b/src/inversefed/consts.py new file mode 100644 index 00000000..0154c38d --- /dev/null +++ b/src/inversefed/consts.py @@ -0,0 +1,16 @@ +"""Setup constants, ymmv.""" + +PIN_MEMORY = True +NON_BLOCKING = False +BENCHMARK = True +MULTITHREAD_DATAPROCESSING = 4 + + +cifar10_mean = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822] +cifar10_std = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324] +cifar100_mean = [0.5071598291397095, 0.4866936206817627, 0.44120192527770996] +cifar100_std = [0.2673342823982239, 0.2564384639263153, 0.2761504650115967] +mnist_mean = (0.13066373765468597,) +mnist_std = (0.30810782313346863,) +imagenet_mean = [0.485, 0.456, 0.406] +imagenet_std = [0.229, 0.224, 0.225] diff --git a/src/inversefed/medianfilt.py b/src/inversefed/medianfilt.py new file mode 100644 index 00000000..f0039da0 --- /dev/null +++ b/src/inversefed/medianfilt.py @@ -0,0 +1,54 @@ +"""This is code for median pooling from https://gist.github.com/rwightman. + +https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 +""" +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _quadruple + + +class MedianPool2d(nn.Module): + """Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + + def __init__(self, kernel_size=3, stride=1, padding=0, same=True): + """Initialize with kernel_size, stride, padding.""" + super().__init__() + self.k = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _quadruple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + # using existing pytorch functions and tensor ops so that we get autograd, + # would likely be more efficient to implement from scratch at C/Cuda level + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/src/inversefed/metrics.py b/src/inversefed/metrics.py new file mode 100644 index 00000000..c227dc42 --- /dev/null +++ b/src/inversefed/metrics.py @@ -0,0 +1,106 @@ +"""This is code based on https://sudomake.ai/inception-score-explained/.""" +import torch +import torchvision + +from collections import defaultdict + +class InceptionScore(torch.nn.Module): + """Class that manages and returns the inception score of images.""" + + def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)): + """Initialize with setup and target inception batch size.""" + super().__init__() + self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False) + self.model = torchvision.models.inception_v3(pretrained=True).to(**setup) + self.model.eval() + self.batch_size = batch_size + + def forward(self, image_batch): + """Image batch should have dimensions BCHW and should be normalized. + + B should be divisible by self.batch_size. + """ + B, C, H, W = image_batch.shape + batches = B // self.batch_size + scores = [] + for batch in range(batches): + input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size]) + scores.append(self.model(input)) + prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1) + entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx)) + return entropy.sum() + + +def psnr(img_batch, ref_batch, batched=False, factor=1.0): + """Standard PSNR.""" + def get_psnr(img_in, img_ref): + mse = ((img_in - img_ref)**2).mean() + if mse > 0 and torch.isfinite(mse): + return (10 * torch.log10(factor**2 / mse)) + elif not torch.isfinite(mse): + return img_batch.new_tensor(float('nan')) + else: + return img_batch.new_tensor(float('inf')) + + if batched: + psnr = get_psnr(img_batch.detach(), ref_batch) + else: + [B, C, m, n] = img_batch.shape + psnrs = [] + for sample in range(B): + psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) + psnr = torch.stack(psnrs, dim=0).mean() + + return psnr.item() + + +def total_variation(x): + """Anisotropic TV.""" + dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + return dx + dy + + + +def activation_errors(model, x1, x2): + """Compute activation-level error metrics for every module in the network.""" + model.eval() + + device = next(model.parameters()).device + + hooks = [] + data = defaultdict(dict) + inputs = torch.cat((x1, x2), dim=0) + separator = x1.shape[0] + + def check_activations(self, input, output): + module_name = str(*[name for name, mod in model.named_modules() if self is mod]) + try: + layer_inputs = input[0].detach() + residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2) + se_error = residual.sum() + mse_error = residual.mean() + sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(), + layer_inputs[separator:].flatten(), + dim=0, eps=1e-8).detach() + data['se'][module_name] = se_error.item() + data['mse'][module_name] = mse_error.item() + data['sim'][module_name] = sim.item() + except (KeyboardInterrupt, SystemExit): + raise + except AttributeError: + pass + + for name, module in model.named_modules(): + hooks.append(module.register_forward_hook(check_activations)) + + try: + outputs = model(inputs.to(device)) + for hook in hooks: + hook.remove() + except Exception as e: + for hook in hooks: + hook.remove() + raise + + return data diff --git a/src/inversefed/nn/README.md b/src/inversefed/nn/README.md new file mode 100644 index 00000000..99211199 --- /dev/null +++ b/src/inversefed/nn/README.md @@ -0,0 +1 @@ +# Models and modules are implemented here \ No newline at end of file diff --git a/src/inversefed/nn/__init__.py b/src/inversefed/nn/__init__.py new file mode 100644 index 00000000..2ba9a68c --- /dev/null +++ b/src/inversefed/nn/__init__.py @@ -0,0 +1,6 @@ +"""Experimental modules and unexperimental model hooks.""" + +from .models import construct_model +from .modules import MetaMonkey + +__all__ = ['construct_model', 'MetaMonkey'] diff --git a/src/inversefed/nn/densenet.py b/src/inversefed/nn/densenet.py new file mode 100644 index 00000000..97f5fa16 --- /dev/null +++ b/src/inversefed/nn/densenet.py @@ -0,0 +1,92 @@ +"""DenseNet in PyTorch.""" +"""Adaptation we did with ******.""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class _Bottleneck(nn.Module): + def __init__(self, in_planes, growth_rate): + super().__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(4 * growth_rate) + self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = self.conv2(F.relu(self.bn2(out))) + out = torch.cat([out, x], 1) + return out + + +class _Transition(nn.Module): + def __init__(self, in_planes, out_planes): + super().__init__() + self.bn = nn.BatchNorm2d(in_planes) + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) + + def forward(self, x): + out = self.conv(F.relu(self.bn(x))) + out = F.avg_pool2d(out, 2) + return out + + +class _DenseNet(nn.Module): + def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): + super().__init__() + self.growth_rate = growth_rate + + num_planes = 2 * growth_rate + self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) + + self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) + num_planes += nblocks[0] * growth_rate + out_planes = int(math.floor(num_planes * reduction)) + self.trans1 = _Transition(num_planes, out_planes) + num_planes = out_planes + + self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) + num_planes += nblocks[1] * growth_rate + out_planes = int(math.floor(num_planes * reduction)) + self.trans2 = _Transition(num_planes, out_planes) + num_planes = out_planes + + self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) + num_planes += nblocks[2] * growth_rate + out_planes = int(math.floor(num_planes * reduction)) + # self.trans3 = Transition(num_planes, out_planes) + # num_planes = out_planes + + # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) + # num_planes += nblocks[3]*growth_rate + + self.bn = nn.BatchNorm2d(num_planes) + num_planes = 132 * growth_rate // 12 * 2 * 2 + self.linear = nn.Linear(num_planes, num_classes) + + def _make_dense_layers(self, block, in_planes, nblock): + layers = [] + for i in range(nblock): + layers.append(block(in_planes, self.growth_rate)) + in_planes += self.growth_rate + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.trans1(self.dense1(out)) + out = self.trans2(self.dense2(out)) + out = self.dense3(out) + # out = self.trans3(self.dense3(out)) + # out = self.dense4(out) + out = F.avg_pool2d(F.relu(self.bn(out)), 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def densenet_cifar(num_classes=10): + """Instantiate the smallest DenseNet.""" + return _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) diff --git a/src/inversefed/nn/models.py b/src/inversefed/nn/models.py new file mode 100644 index 00000000..b8909867 --- /dev/null +++ b/src/inversefed/nn/models.py @@ -0,0 +1,333 @@ +"""Define basic models and translate some torchvision stuff.""" +"""Stuff from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py.""" +import torch +import torchvision +import torch.nn as nn + +from torchvision.models.resnet import Bottleneck +from .revnet import iRevNet +from .densenet import _DenseNet, _Bottleneck + +from collections import OrderedDict +import numpy as np +from ..utils import set_random_seed + + + + +def construct_model(model, num_classes=10, seed=None, num_channels=3, modelkey=None): + """Return various models.""" + if modelkey is None: + if seed is None: + model_init_seed = np.random.randint(0, 2**32 - 10) + else: + model_init_seed = seed + else: + model_init_seed = modelkey + set_random_seed(model_init_seed) + + if model in ['ConvNet', 'ConvNet64']: + model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) + elif model == 'ConvNet8': + model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) + elif model == 'ConvNet16': + model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) + elif model == 'ConvNet32': + model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) + elif model == 'BeyondInferringMNIST': + model = torch.nn.Sequential(OrderedDict([ + ('conv1', torch.nn.Conv2d(1, 32, 3, stride=2, padding=1)), + ('relu0', torch.nn.LeakyReLU()), + ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)), + ('relu1', torch.nn.LeakyReLU()), + ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)), + ('relu2', torch.nn.LeakyReLU()), + ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)), + ('relu3', torch.nn.LeakyReLU()), + ('flatt', torch.nn.Flatten()), + ('linear0', torch.nn.Linear(12544, 12544)), + ('relu4', torch.nn.LeakyReLU()), + ('linear1', torch.nn.Linear(12544, 10)), + ('softmax', torch.nn.Softmax(dim=1)) + ])) + elif model == 'BeyondInferringCifar': + model = torch.nn.Sequential(OrderedDict([ + ('conv1', torch.nn.Conv2d(3, 32, 3, stride=2, padding=1)), + ('relu0', torch.nn.LeakyReLU()), + ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)), + ('relu1', torch.nn.LeakyReLU()), + ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)), + ('relu2', torch.nn.LeakyReLU()), + ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)), + ('relu3', torch.nn.LeakyReLU()), + ('flatt', torch.nn.Flatten()), + ('linear0', torch.nn.Linear(12544, 12544)), + ('relu4', torch.nn.LeakyReLU()), + ('linear1', torch.nn.Linear(12544, 10)), + ('softmax', torch.nn.Softmax(dim=1)) + ])) + elif model == 'MLP': + width = 1024 + model = torch.nn.Sequential(OrderedDict([ + ('flatten', torch.nn.Flatten()), + ('linear0', torch.nn.Linear(3072, width)), + ('relu0', torch.nn.ReLU()), + ('linear1', torch.nn.Linear(width, width)), + ('relu1', torch.nn.ReLU()), + ('linear2', torch.nn.Linear(width, width)), + ('relu2', torch.nn.ReLU()), + ('linear3', torch.nn.Linear(width, num_classes))])) + elif model == 'TwoLP': + width = 2048 + model = torch.nn.Sequential(OrderedDict([ + ('flatten', torch.nn.Flatten()), + ('linear0', torch.nn.Linear(3072, width)), + ('relu0', torch.nn.ReLU()), + ('linear3', torch.nn.Linear(width, num_classes))])) + elif model == 'ResNet20': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16) + elif model == 'ResNet20-nostride': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16, + strides=[1, 1, 1, 1]) + elif model == 'ResNet20-10': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 10) + elif model == 'ResNet20-4': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4) + elif model == 'ResNet20-4-unpooled': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4, + pool='max') + elif model == 'ResNet28-10': + model = ResNet(torchvision.models.resnet.BasicBlock, [4, 4, 4], num_classes=num_classes, base_width=16 * 10) + elif model == 'ResNet32': + model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16) + elif model == 'ResNet32-10': + model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16 * 10) + elif model == 'ResNet44': + model = ResNet(torchvision.models.resnet.BasicBlock, [7, 7, 7], num_classes=num_classes, base_width=16) + elif model == 'ResNet56': + model = ResNet(torchvision.models.resnet.BasicBlock, [9, 9, 9], num_classes=num_classes, base_width=16) + elif model == 'ResNet110': + model = ResNet(torchvision.models.resnet.BasicBlock, [18, 18, 18], num_classes=num_classes, base_width=16) + elif model == 'ResNet18': + model = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes, base_width=64) + elif model == 'ResNet34': + model = ResNet(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes, base_width=64) + elif model == 'ResNet50': + model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64) + elif model == 'ResNet50-2': + model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64 * 2) + elif model == 'ResNet101': + model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, base_width=64) + elif model == 'ResNet152': + model = ResNet(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, base_width=64) + elif model == 'MobileNet': + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 1], # cifar adaptation, cf.https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + model = torchvision.models.MobileNetV2(num_classes=num_classes, + inverted_residual_setting=inverted_residual_setting, + width_mult=1.0) + model.features[0] = torchvision.models.mobilenet.ConvBNReLU(num_channels, 32, stride=1) # this is fixed to width=1 + elif model == 'MNASNet': + model = torchvision.models.MNASNet(1.0, num_classes=num_classes, dropout=0.2) + elif model == 'DenseNet121': + model = torchvision.models.DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), + num_init_features=64, bn_size=4, drop_rate=0, num_classes=num_classes, + memory_efficient=False) + elif model == 'DenseNet40': + model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) + elif model == 'DenseNet40-4': + model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12 * 4, num_classes=num_classes) + elif model == 'SRNet3': + model = SRNet(upscale_factor=3, num_channels=num_channels) + elif model == 'SRNet1': + model = SRNet(upscale_factor=1, num_channels=num_channels) + elif model == 'iRevNet': + if num_classes <= 100: + in_shape = [num_channels, 32, 32] # only for cifar right now + model = iRevNet(nBlocks=[18, 18, 18], nStrides=[1, 2, 2], + nChannels=[16, 64, 256], nClasses=num_classes, + init_ds=0, dropout_rate=0.1, affineBN=True, + in_shape=in_shape, mult=4) + else: + in_shape = [3, 224, 224] # only for imagenet + model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], + nChannels=[24, 96, 384, 1536], nClasses=num_classes, + init_ds=2, dropout_rate=0.1, affineBN=True, + in_shape=in_shape, mult=4) + elif model == 'LeNetZhu': + model = LeNetZhu(num_channels=num_channels, num_classes=num_classes) + else: + raise NotImplementedError('Model not implemented.') + + print(f'Model initialized with random key {model_init_seed}.') + return model, model_init_seed + + +class ResNet(torchvision.models.ResNet): + """ResNet generalization for CIFAR thingies.""" + + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, + groups=1, base_width=64, replace_stride_with_dilation=None, + norm_layer=None, strides=[1, 2, 2, 2], pool='avg'): + """Initialize as usual. Layers and strides are scriptable.""" + super(torchvision.models.ResNet, self).__init__() # nn.Module + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False, False] + if len(replace_stride_with_dilation) != 4: + raise ValueError("replace_stride_with_dilation should be None " + "or a 4-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + + self.inplanes = base_width + self.base_width = 64 # Do this to circumvent BasicBlock errors. The value is not actually used. + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.layers = torch.nn.ModuleList() + width = self.inplanes + for idx, layer in enumerate(layers): + self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx])) + width *= 2 + + self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == 'avg' else nn.AdaptiveMaxPool2d((1, 1)) + self.fc = nn.Linear(width // 2 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + for layer in self.layers: + x = layer(x) + + x = self.pool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +class ConvNet(torch.nn.Module): + """ConvNetBN.""" + + def __init__(self, width=32, num_classes=10, num_channels=3): + """Init with width and num classes.""" + super().__init__() + self.model = torch.nn.Sequential(OrderedDict([ + ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)), + ('bn0', torch.nn.BatchNorm2d(1 * width)), + ('relu0', torch.nn.ReLU()), + + ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)), + ('bn1', torch.nn.BatchNorm2d(2 * width)), + ('relu1', torch.nn.ReLU()), + + ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)), + ('bn2', torch.nn.BatchNorm2d(2 * width)), + ('relu2', torch.nn.ReLU()), + + ('conv3', torch.nn.Conv2d(2 * width, 4 * width, kernel_size=3, padding=1)), + ('bn3', torch.nn.BatchNorm2d(4 * width)), + ('relu3', torch.nn.ReLU()), + + ('conv4', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), + ('bn4', torch.nn.BatchNorm2d(4 * width)), + ('relu4', torch.nn.ReLU()), + + ('conv5', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), + ('bn5', torch.nn.BatchNorm2d(4 * width)), + ('relu5', torch.nn.ReLU()), + + ('pool0', torch.nn.MaxPool2d(3)), + + ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), + ('bn6', torch.nn.BatchNorm2d(4 * width)), + ('relu6', torch.nn.ReLU()), + + ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), + ('bn6', torch.nn.BatchNorm2d(4 * width)), + ('relu6', torch.nn.ReLU()), + + ('conv7', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), + ('bn7', torch.nn.BatchNorm2d(4 * width)), + ('relu7', torch.nn.ReLU()), + + ('pool1', torch.nn.MaxPool2d(3)), + ('flatten', torch.nn.Flatten()), + ('linear', torch.nn.Linear(36 * width, num_classes)) + ])) + + def forward(self, input): + return self.model(input) + + +class LeNetZhu(nn.Module): + """LeNet variant from https://github.com/mit-han-lab/dlg/blob/master/models/vision.py.""" + + def __init__(self, num_classes=10, num_channels=3): + """3-Layer sigmoid Conv with large linear layer.""" + super().__init__() + act = nn.Sigmoid + self.body = nn.Sequential( + nn.Conv2d(num_channels, 12, kernel_size=5, padding=5 // 2, stride=2), + act(), + nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2), + act(), + nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1), + act(), + ) + self.fc = nn.Sequential( + nn.Linear(768, num_classes) + ) + for module in self.modules(): + self.weights_init(module) + + @staticmethod + def weights_init(m): + if hasattr(m, "weight"): + m.weight.data.uniform_(-0.5, 0.5) + if hasattr(m, "bias"): + m.bias.data.uniform_(-0.5, 0.5) + + def forward(self, x): + out = self.body(x) + out = out.view(out.size(0), -1) + # print(out.size()) + out = self.fc(out) + return out diff --git a/src/inversefed/nn/modules.py b/src/inversefed/nn/modules.py new file mode 100644 index 00000000..bdc69615 --- /dev/null +++ b/src/inversefed/nn/modules.py @@ -0,0 +1,98 @@ +"""For monkey-patching into meta-learning frameworks.""" +import torch +import torch.nn.functional as F +from collections import OrderedDict +from functools import partial +import warnings + +from ..consts import BENCHMARK +torch.backends.cudnn.benchmark = BENCHMARK + +DEBUG = False # Emit warning messages when patching. Use this to bootstrap new architectures. + +class MetaMonkey(torch.nn.Module): + """Trace a networks and then replace its module calls with functional calls. + + This allows for backpropagation w.r.t to weights for "normal" PyTorch networks. + """ + + def __init__(self, net): + """Init with network.""" + super().__init__() + self.net = net + self.parameters = OrderedDict(net.named_parameters()) + + + def forward(self, inputs, parameters=None): + """Live Patch ... :> ...""" + # If no parameter dictionary is given, everything is normal + if parameters is None: + return self.net(inputs) + + # But if not ... + param_gen = iter(parameters.values()) + method_pile = [] + counter = 0 + + for name, module in self.net.named_modules(): + if isinstance(module, torch.nn.Conv2d): + ext_weight = next(param_gen) + if module.bias is not None: + ext_bias = next(param_gen) + else: + ext_bias = None + + method_pile.append(module.forward) + module.forward = partial(F.conv2d, weight=ext_weight, bias=ext_bias, stride=module.stride, + padding=module.padding, dilation=module.dilation, groups=module.groups) + elif isinstance(module, torch.nn.BatchNorm2d): + if module.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = module.momentum + + if module.training and module.track_running_stats: + if module.num_batches_tracked is not None: + module.num_batches_tracked += 1 + if module.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(module.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = module.momentum + + ext_weight = next(param_gen) + ext_bias = next(param_gen) + method_pile.append(module.forward) + module.forward = partial(F.batch_norm, running_mean=module.running_mean, running_var=module.running_var, + weight=ext_weight, bias=ext_bias, + training=module.training or not module.track_running_stats, + momentum=exponential_average_factor, eps=module.eps) + + elif isinstance(module, torch.nn.Linear): + lin_weights = next(param_gen) + lin_bias = next(param_gen) + method_pile.append(module.forward) + module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias) + + elif next(module.parameters(), None) is None: + # Pass over modules that do not contain parameters + pass + elif isinstance(module, torch.nn.Sequential): + # Pass containers + pass + else: + # Warn for other containers + if DEBUG: + warnings.warn(f'Patching for module {module.__class__} is not implemented.') + + output = self.net(inputs) + + # Undo Patch + for name, module in self.net.named_modules(): + if isinstance(module, torch.nn.modules.conv.Conv2d): + module.forward = method_pile.pop(0) + elif isinstance(module, torch.nn.BatchNorm2d): + module.forward = method_pile.pop(0) + elif isinstance(module, torch.nn.Linear): + module.forward = method_pile.pop(0) + + return output diff --git a/src/inversefed/nn/revnet.py b/src/inversefed/nn/revnet.py new file mode 100644 index 00000000..841a2cb2 --- /dev/null +++ b/src/inversefed/nn/revnet.py @@ -0,0 +1,192 @@ +"""https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/iRevNet.py. + +Code for "i-RevNet: Deep Invertible Networks" +https://openreview.net/pdf?id=HJsjkMb0Z +ICLR, 2018 + + +(c) Joern-Henrik Jacobsen, 2018 +""" + +""" +MIT License + +Copyright (c) 2018 Jörn Jacobsen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .revnet_utils import split, merge, injective_pad, psi + + +class irevnet_block(nn.Module): + """This is an i-revnet block from Jacobsen et al.""" + + def __init__(self, in_ch, out_ch, stride=1, first=False, dropout_rate=0., + affineBN=True, mult=4): + """Build invertible bottleneck block.""" + super(irevnet_block, self).__init__() + self.first = first + self.pad = 2 * out_ch - in_ch + self.stride = stride + self.inj_pad = injective_pad(self.pad) + self.psi = psi(stride) + if self.pad != 0 and stride == 1: + in_ch = out_ch * 2 + print('') + print('| Injective iRevNet |') + print('') + layers = [] + if not first: + layers.append(nn.BatchNorm2d(in_ch // 2, affine=affineBN)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(in_ch // 2, int(out_ch // mult), kernel_size=3, + stride=stride, padding=1, bias=False)) + layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(int(out_ch // mult), int(out_ch // mult), + kernel_size=3, padding=1, bias=False)) + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(int(out_ch // mult), out_ch, kernel_size=3, + padding=1, bias=False)) + self.bottleneck_block = nn.Sequential(*layers) + + def forward(self, x): + """Bijective or injective block forward.""" + if self.pad != 0 and self.stride == 1: + x = merge(x[0], x[1]) + x = self.inj_pad.forward(x) + x1, x2 = split(x) + x = (x1, x2) + x1 = x[0] + x2 = x[1] + Fx2 = self.bottleneck_block(x2) + if self.stride == 2: + x1 = self.psi.forward(x1) + x2 = self.psi.forward(x2) + y1 = Fx2 + x1 + return (x2, y1) + + def inverse(self, x): + """Bijective or injecitve block inverse.""" + x2, y1 = x[0], x[1] + if self.stride == 2: + x2 = self.psi.inverse(x2) + Fx2 = - self.bottleneck_block(x2) + x1 = Fx2 + y1 + if self.stride == 2: + x1 = self.psi.inverse(x1) + if self.pad != 0 and self.stride == 1: + x = merge(x1, x2) + x = self.inj_pad.inverse(x) + x1, x2 = split(x) + x = (x1, x2) + else: + x = (x1, x2) + return x + + +class iRevNet(nn.Module): + """This is an i-revnet from Jacobsen et al.""" + + def __init__(self, nBlocks, nStrides, nClasses, nChannels=None, init_ds=2, + dropout_rate=0., affineBN=True, in_shape=None, mult=4): + """Init with e.g. nBlocks=[18, 18, 18], nStrides = [1, 2, 2].""" + super(iRevNet, self).__init__() + self.ds = in_shape[2] // 2**(nStrides.count(2) + init_ds // 2) + self.init_ds = init_ds + self.in_ch = in_shape[0] * 2**self.init_ds + self.nBlocks = nBlocks + self.first = True + + print('') + print(' == Building iRevNet %d == ' % (sum(nBlocks) * 3 + 1)) + if not nChannels: + nChannels = [self.in_ch // 2, self.in_ch // 2 * 4, + self.in_ch // 2 * 4**2, self.in_ch // 2 * 4**3] + + self.init_psi = psi(self.init_ds) + self.stack = self.irevnet_stack(irevnet_block, nChannels, nBlocks, + nStrides, dropout_rate=dropout_rate, + affineBN=affineBN, in_ch=self.in_ch, + mult=mult) + self.bn1 = nn.BatchNorm2d(nChannels[-1] * 2, momentum=0.9) + self.linear = nn.Linear(nChannels[-1] * 2, nClasses) + + def irevnet_stack(self, _block, nChannels, nBlocks, nStrides, dropout_rate, + affineBN, in_ch, mult): + """Create stack of irevnet blocks.""" + block_list = nn.ModuleList() + strides = [] + channels = [] + for channel, depth, stride in zip(nChannels, nBlocks, nStrides): + strides = strides + ([stride] + [1] * (depth - 1)) + channels = channels + ([channel] * depth) + for channel, stride in zip(channels, strides): + block_list.append(_block(in_ch, channel, stride, + first=self.first, + dropout_rate=dropout_rate, + affineBN=affineBN, mult=mult)) + in_ch = 2 * channel + self.first = False + return block_list + + def forward(self, x, return_bijection=False): + """Irevnet forward.""" + n = self.in_ch // 2 + if self.init_ds != 0: + x = self.init_psi.forward(x) + out = (x[:, :n, :, :], x[:, n:, :, :]) + for block in self.stack: + out = block.forward(out) + out_bij = merge(out[0], out[1]) + out = F.relu(self.bn1(out_bij)) + out = F.avg_pool2d(out, self.ds) + out = out.view(out.size(0), -1) + out = self.linear(out) + if return_bijection: + return out, out_bij + else: + return out + + def inverse(self, out_bij): + """Irevnet inverse.""" + out = split(out_bij) + for i in range(len(self.stack)): + out = self.stack[-1 - i].inverse(out) + out = merge(out[0], out[1]) + if self.init_ds != 0: + x = self.init_psi.inverse(out) + else: + x = out + return x + + +if __name__ == '__main__': + model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], + nChannels=None, nClasses=1000, init_ds=2, + dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], + mult=4) + y = model(torch.randn(1, 3, 224, 224)) + print(y.size()) diff --git a/src/inversefed/nn/revnet_utils.py b/src/inversefed/nn/revnet_utils.py new file mode 100644 index 00000000..aa3eaafb --- /dev/null +++ b/src/inversefed/nn/revnet_utils.py @@ -0,0 +1,132 @@ +"""https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py. + +Code for "i-RevNet: Deep Invertible Networks" +https://openreview.net/pdf?id=HJsjkMb0Z +ICLR, 2018 + + +(c) Joern-Henrik Jacobsen, 2018 +""" + +""" +MIT License + +Copyright (c) 2018 Jörn Jacobsen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import torch.nn as nn + +from torch.nn import Parameter + + +def split(x): + n = int(x.size()[1] / 2) + x1 = x[:, :n, :, :].contiguous() + x2 = x[:, n:, :, :].contiguous() + return x1, x2 + + +def merge(x1, x2): + return torch.cat((x1, x2), 1) + + +class injective_pad(nn.Module): + def __init__(self, pad_size): + super(injective_pad, self).__init__() + self.pad_size = pad_size + self.pad = nn.ZeroPad2d((0, 0, 0, pad_size)) + + def forward(self, x): + x = x.permute(0, 2, 1, 3) + x = self.pad(x) + return x.permute(0, 2, 1, 3) + + def inverse(self, x): + return x[:, :x.size(1) - self.pad_size, :, :] + + +class psi(nn.Module): + def __init__(self, block_size): + super(psi, self).__init__() + self.block_size = block_size + self.block_size_sq = block_size * block_size + + def inverse(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, d_height, d_width, d_depth) = output.size() + s_depth = int(d_depth / self.block_size_sq) + s_width = int(d_width * self.block_size) + s_height = int(d_height * self.block_size) + t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth) + spl = t_1.split(self.block_size, 3) + stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl] + output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth) + output = output.permute(0, 3, 1, 2) + return output.contiguous() + + def forward(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, s_height, s_width, s_depth) = output.size() + d_depth = s_depth * self.block_size_sq + d_height = int(s_height / self.block_size) + t_1 = output.split(self.block_size, 2) + stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] + output = torch.stack(stack, 1) + output = output.permute(0, 2, 1, 3) + output = output.permute(0, 3, 1, 2) + return output.contiguous() + + +class ListModule(object): + def __init__(self, module, prefix, *args): + self.module = module + self.prefix = prefix + self.num_module = 0 + for new_module in args: + self.append(new_module) + + def append(self, new_module): + if not isinstance(new_module, nn.Module): + raise ValueError('Not a Module') + else: + self.module.add_module(self.prefix + str(self.num_module), new_module) + self.num_module += 1 + + def __len__(self): + return self.num_module + + def __getitem__(self, i): + if i < 0 or i >= self.num_module: + raise IndexError('Out of bound') + return getattr(self.module, self.prefix + str(i)) + + +def get_all_params(var, all_params): + if isinstance(var, Parameter): + all_params[id(var)] = var.nelement() + elif hasattr(var, "creator") and var.creator is not None: + if var.creator.previous_functions is not None: + for j in var.creator.previous_functions: + get_all_params(j[0], all_params) + elif hasattr(var, "previous_functions"): + for j in var.previous_functions: + get_all_params(j[0], all_params) diff --git a/src/inversefed/optimization_strategy.py b/src/inversefed/optimization_strategy.py new file mode 100644 index 00000000..2c73483d --- /dev/null +++ b/src/inversefed/optimization_strategy.py @@ -0,0 +1,78 @@ +"""Optimization setups.""" + +from dataclasses import dataclass + + +def training_strategy(strategy, lr=None, epochs=None, dryrun=False): + """Parse training strategy.""" + if strategy == 'conservative': + defs = ConservativeStrategy(lr, epochs, dryrun) + elif strategy == 'adam': + defs = AdamStrategy(lr, epochs, dryrun) + else: + raise ValueError('Unknown training strategy.') + return defs + + +@dataclass +class Strategy: + """Default usual parameters, not intended for parsing.""" + + epochs : int + batch_size : int + optimizer : str + lr : float + scheduler : str + weight_decay : float + validate : int + warmup: bool + dryrun : bool + dropout : float + augmentations : bool + + def __init__(self, lr=None, epochs=None, dryrun=False): + """Defaulted parameters. Apply overwrites from args.""" + if epochs is not None: + self.epochs = epochs + if lr is not None: + self.lr = lr + if dryrun: + self.dryrun = dryrun + self.validate = 10 + +@dataclass +class ConservativeStrategy(Strategy): + """Default usual parameters, defines a config object.""" + + def __init__(self, lr=None, epochs=None, dryrun=False): + """Initialize training hyperparameters.""" + self.lr = 0.1 + self.epochs = 120 + self.batch_size = 128 + self.optimizer = 'SGD' + self.scheduler = 'linear' + self.warmup = False + self.weight_decay : float = 5e-4 + self.dropout = 0.0 + self.augmentations = True + self.dryrun = False + super().__init__(lr=None, epochs=None, dryrun=False) + + +@dataclass +class AdamStrategy(Strategy): + """Start slowly. Use a tame Adam.""" + + def __init__(self, lr=None, epochs=None, dryrun=False): + """Initialize training hyperparameters.""" + self.lr = 1e-3 / 10 + self.epochs = 120 + self.batch_size = 32 + self.optimizer = 'AdamW' + self.scheduler = 'linear' + self.warmup = True + self.weight_decay : float = 5e-4 + self.dropout = 0.0 + self.augmentations = True + self.dryrun = False + super().__init__(lr=None, epochs=None, dryrun=False) diff --git a/src/inversefed/options.py b/src/inversefed/options.py new file mode 100644 index 00000000..b8bc3ebe --- /dev/null +++ b/src/inversefed/options.py @@ -0,0 +1,52 @@ +"""Parser options.""" + +import argparse + +def options(): + """Construct the central argument parser, filled with useful defaults.""" + parser = argparse.ArgumentParser(description='Reconstruct some image from a trained model.') + + # Central: + parser.add_argument('--model', default='ConvNet', type=str, help='Vision model.') + parser.add_argument('--dataset', default='CIFAR10', type=str) + parser.add_argument('--dtype', default='float', type=str, help='Data type used during reconstruction [Not during training!].') + + + parser.add_argument('--trained_model', action='store_true', help='Use a trained model.') + parser.add_argument('--epochs', default=120, type=int, help='If using a trained model, how many epochs was it trained?') + + parser.add_argument('--accumulation', default=0, type=int, help='Accumulation 0 is rec. from gradient, accumulation > 0 is reconstruction from fed. averaging.') + parser.add_argument('--num_images', default=1, type=int, help='How many images should be recovered from the given gradient.') + parser.add_argument('--target_id', default=None, type=int, help='Cifar validation image used for reconstruction.') + parser.add_argument('--label_flip', action='store_true', help='Dishonest server permuting weights in classification layer.') + + # Rec. parameters + parser.add_argument('--optim', default='ours', type=str, help='Use our reconstruction method or the DLG method.') + + parser.add_argument('--restarts', default=1, type=int, help='How many restarts to run.') + parser.add_argument('--cost_fn', default='sim', type=str, help='Choice of cost function.') + parser.add_argument('--indices', default='def', type=str, help='Choice of indices from the parameter list.') + parser.add_argument('--weights', default='equal', type=str, help='Weigh the parameter list differently.') + + parser.add_argument('--optimizer', default='adam', type=str, help='Weigh the parameter list differently.') + parser.add_argument('--signed', action='store_false', help='Do not used signed gradients.') + parser.add_argument('--boxed', action='store_false', help='Do not used box constraints.') + + parser.add_argument('--scoring_choice', default='loss', type=str, help='How to find the best image between all restarts.') + parser.add_argument('--init', default='randn', type=str, help='Choice of image initialization.') + parser.add_argument('--tv', default=1e-4, type=float, help='Weight of TV penalty.') + + + # Files and folders: + parser.add_argument('--save_image', action='store_true', help='Save the output to a file.') + + parser.add_argument('--image_path', default='images/', type=str) + parser.add_argument('--model_path', default='models/', type=str) + parser.add_argument('--table_path', default='tables/', type=str) + parser.add_argument('--data_path', default='~/data', type=str) + + # Debugging: + parser.add_argument('--name', default='iv', type=str, help='Name tag for the result table and model.') + parser.add_argument('--deterministic', action='store_true', help='Disable CUDNN non-determinism.') + parser.add_argument('--dryrun', action='store_true', help='Run everything for just one step to test functionality.') + return parser diff --git a/src/inversefed/reconstruction_algorithms.py b/src/inversefed/reconstruction_algorithms.py new file mode 100644 index 00000000..63965d85 --- /dev/null +++ b/src/inversefed/reconstruction_algorithms.py @@ -0,0 +1,392 @@ +"""Mechanisms for image reconstruction from parameter gradients.""" + +import torch +from collections import defaultdict, OrderedDict +from inversefed.nn import MetaMonkey +from .metrics import total_variation as TV +from .metrics import InceptionScore +from .medianfilt import MedianPool2d +from copy import deepcopy + +import time + +DEFAULT_CONFIG = dict(signed=False, + boxed=True, + cost_fn='sim', + indices='def', + weights='equal', + lr=0.1, + optim='adam', + restarts=1, + max_iterations=4800, + total_variation=1e-1, + init='randn', + filter='none', + lr_decay=True, + scoring_choice='loss') + +def _label_to_onehot(target, num_classes=100): + target = torch.unsqueeze(target, 1) + onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) + onehot_target.scatter_(1, target, 1) + return onehot_target + +def _validate_config(config): + for key in DEFAULT_CONFIG.keys(): + if config.get(key) is None: + config[key] = DEFAULT_CONFIG[key] + for key in config.keys(): + if DEFAULT_CONFIG.get(key) is None: + raise ValueError(f'Deprecated key in config dict: {key}!') + return config + + +class GradientReconstructor(): + """Instantiate a reconstruction algorithm.""" + + def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1): + """Initialize with algorithm setup.""" + self.config = _validate_config(config) + self.model = model + self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) + + self.mean_std = mean_std + self.num_images = num_images + + if self.config['scoring_choice'] == 'inception': + self.inception = InceptionScore(batch_size=1, setup=self.setup) + + self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean') + self.iDLG = True + + def reconstruct(self, input_data, labels, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None): + """Reconstruct image from gradient.""" + start_time = time.time() + if eval: + self.model.eval() + + + stats = defaultdict(list) + x = self._init_images(img_shape) + scores = torch.zeros(self.config['restarts']) + + if labels is None: + if self.num_images == 1 and self.iDLG: + # iDLG trick: + last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1) + labels = last_weight_min.detach().reshape((1,)).requires_grad_(False) + self.reconstruct_label = False + else: + # DLG label recovery + # However this also improves conditioning for some LBFGS cases + self.reconstruct_label = True + + def loss_fn(pred, labels): + labels = torch.nn.functional.softmax(labels, dim=-1) + return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1)) + self.loss_fn = loss_fn + else: + assert labels.shape[0] == self.num_images + self.reconstruct_label = False + + try: + for trial in range(self.config['restarts']): + x_trial, labels = self._run_trial(x[trial], input_data, labels, dryrun=dryrun) + # Finalize + scores[trial] = self._score_trial(x_trial, input_data, labels) + x[trial] = x_trial + if tol is not None and scores[trial] <= tol: + break + if dryrun: + break + except KeyboardInterrupt: + print('Trial procedure manually interruped.') + pass + + # Choose optimal result: + if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']: + x_optimal, stats = self._average_trials(x, labels, input_data, stats) + else: + print('Choosing optimal result ...') + scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores? + optimal_index = torch.argmin(scores) + print(f'Optimal result score: {scores[optimal_index]:2.4f}') + stats['opt'] = scores[optimal_index].item() + x_optimal = x[optimal_index] + + print(f'Total time: {time.time()-start_time}.') + return x_optimal.detach(), stats + + def _init_images(self, img_shape): + if self.config['init'] == 'randn': + return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup) + elif self.config['init'] == 'rand': + return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2 + elif self.config['init'] == 'zeros': + return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup) + else: + raise ValueError() + + def _run_trial(self, x_trial, input_data, labels, dryrun=False): + x_trial.requires_grad = True + if self.reconstruct_label: + output_test = self.model(x_trial) + labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True) + + if self.config['optim'] == 'adam': + optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr']) + elif self.config['optim'] == 'sgd': # actually gd + optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True) + elif self.config['optim'] == 'LBFGS': + optimizer = torch.optim.LBFGS([x_trial, labels]) + else: + raise ValueError() + else: + if self.config['optim'] == 'adam': + optimizer = torch.optim.Adam([x_trial], lr=self.config['lr']) + elif self.config['optim'] == 'sgd': # actually gd + optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True) + elif self.config['optim'] == 'LBFGS': + optimizer = torch.optim.LBFGS([x_trial]) + else: + raise ValueError() + + max_iterations = self.config['max_iterations'] + dm, ds = self.mean_std + if self.config['lr_decay']: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=[max_iterations // 2.667, max_iterations // 1.6, + + max_iterations // 1.142], gamma=0.1) # 3/8 5/8 7/8 + try: + for iteration in range(max_iterations): + closure = self._gradient_closure(optimizer, x_trial, input_data, labels) + rec_loss = optimizer.step(closure) + if self.config['lr_decay']: + scheduler.step() + + with torch.no_grad(): + # Project into image space + if self.config['boxed']: + x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds) + + if (iteration + 1 == max_iterations) or iteration % 500 == 0: + print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.') + + if (iteration + 1) % 500 == 0: + if self.config['filter'] == 'none': + pass + elif self.config['filter'] == 'median': + x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial) + else: + raise ValueError() + + if dryrun: + break + except KeyboardInterrupt: + print(f'Recovery interrupted manually in iteration {iteration}!') + pass + return x_trial.detach(), labels + + def _gradient_closure(self, optimizer, x_trial, input_gradient, label): + + def closure(): + optimizer.zero_grad() + self.model.zero_grad() + loss = self.loss_fn(self.model(x_trial), label) + gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) + rec_loss = reconstruction_costs([gradient], input_gradient, + cost_fn=self.config['cost_fn'], indices=self.config['indices'], + weights=self.config['weights']) + + if self.config['total_variation'] > 0: + rec_loss += self.config['total_variation'] * TV(x_trial) + rec_loss.backward() + if self.config['signed']: + x_trial.grad.sign_() + return rec_loss + return closure + + def _score_trial(self, x_trial, input_gradient, label): + if self.config['scoring_choice'] == 'loss': + self.model.zero_grad() + x_trial.grad = None + loss = self.loss_fn(self.model(x_trial), label) + gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) + return reconstruction_costs([gradient], input_gradient, + cost_fn=self.config['cost_fn'], indices=self.config['indices'], + weights=self.config['weights']) + elif self.config['scoring_choice'] == 'tv': + return TV(x_trial) + elif self.config['scoring_choice'] == 'inception': + # We do not care about diversity here! + return self.inception(x_trial) + elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']: + return 0.0 + else: + raise ValueError() + + def _average_trials(self, x, labels, input_data, stats): + print(f'Computing a combined result via {self.config["scoring_choice"]} ...') + if self.config['scoring_choice'] == 'pixelmedian': + x_optimal, _ = x.median(dim=0, keepdims=False) + elif self.config['scoring_choice'] == 'pixelmean': + x_optimal = x.mean(dim=0, keepdims=False) + + self.model.zero_grad() + if self.reconstruct_label: + labels = self.model(x_optimal).softmax(dim=1) + loss = self.loss_fn(self.model(x_optimal), labels) + gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) + stats['opt'] = reconstruction_costs([gradient], input_data, + cost_fn=self.config['cost_fn'], + indices=self.config['indices'], + weights=self.config['weights']) + print(f'Optimal result score: {stats["opt"]:2.4f}') + return x_optimal, stats + + + +class FedAvgReconstructor(GradientReconstructor): + """Reconstruct an image from weights after n gradient descent steps.""" + + def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4, + config=DEFAULT_CONFIG, num_images=1, use_updates=True, batch_size=0): + """Initialize with model, (mean, std) and config.""" + super().__init__(model, mean_std, config, num_images) + self.local_steps = local_steps + self.local_lr = local_lr + self.use_updates = use_updates + self.batch_size = batch_size + + def _gradient_closure(self, optimizer, x_trial, input_parameters, labels): + def closure(): + optimizer.zero_grad() + self.model.zero_grad() + parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn, + local_steps=self.local_steps, lr=self.local_lr, + use_updates=self.use_updates, + batch_size=self.batch_size) + rec_loss = reconstruction_costs([parameters], input_parameters, + cost_fn=self.config['cost_fn'], indices=self.config['indices'], + weights=self.config['weights']) + + if self.config['total_variation'] > 0: + rec_loss += self.config['total_variation'] * TV(x_trial) + rec_loss.backward() + if self.config['signed']: + x_trial.grad.sign_() + return rec_loss + return closure + + def _score_trial(self, x_trial, input_parameters, labels): + if self.config['scoring_choice'] == 'loss': + self.model.zero_grad() + parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn, + local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates) + return reconstruction_costs([parameters], input_parameters, + cost_fn=self.config['cost_fn'], indices=self.config['indices'], + weights=self.config['weights']) + elif self.config['scoring_choice'] == 'tv': + return TV(x_trial) + elif self.config['scoring_choice'] == 'inception': + # We do not care about diversity here! + return self.inception(x_trial) + + +def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0): + """Take a few gradient descent steps to fit the model to the given input.""" + patched_model = MetaMonkey(model) + if use_updates: + patched_model_origin = deepcopy(patched_model) + for i in range(local_steps): + if batch_size == 0: + outputs = patched_model(inputs, patched_model.parameters) + labels_ = labels + else: + idx = i % (inputs.shape[0] // batch_size) + outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters) + labels_ = labels[idx * batch_size:(idx + 1) * batch_size] + loss = loss_fn(outputs, labels_).sum() + grad = torch.autograd.grad(loss, patched_model.parameters.values(), + retain_graph=True, create_graph=True, only_inputs=True) + + patched_model.parameters = OrderedDict((name, param - lr * grad_part) + for ((name, param), grad_part) + in zip(patched_model.parameters.items(), grad)) + + if use_updates: + patched_model.parameters = OrderedDict((name, param - param_origin) + for ((name, param), (name_origin, param_origin)) + in zip(patched_model.parameters.items(), patched_model_origin.parameters.items())) + return list(patched_model.parameters.values()) + + +def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', weights='equal'): + """Input gradient is given data.""" + if isinstance(indices, list): + pass + elif indices == 'def': + indices = torch.arange(len(input_gradient)) + elif indices == 'batch': + indices = torch.randperm(len(input_gradient))[:8] + elif indices == 'topk-1': + _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4) + elif indices == 'top10': + _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10) + elif indices == 'top50': + _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50) + elif indices in ['first', 'first4']: + indices = torch.arange(0, 4) + elif indices == 'first5': + indices = torch.arange(0, 5) + elif indices == 'first10': + indices = torch.arange(0, 10) + elif indices == 'first50': + indices = torch.arange(0, 50) + elif indices == 'last5': + indices = torch.arange(len(input_gradient))[-5:] + elif indices == 'last10': + indices = torch.arange(len(input_gradient))[-10:] + elif indices == 'last50': + indices = torch.arange(len(input_gradient))[-50:] + else: + raise ValueError() + + ex = input_gradient[0] + if weights == 'linear': + weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) + elif weights == 'exp': + weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) + weights = weights.softmax(dim=0) + weights = weights / weights[0] + else: + weights = input_gradient[0].new_ones(len(input_gradient)) + + total_costs = 0 + for trial_gradient in gradients: + pnorm = [0, 0] + costs = 0 + if indices == 'topk-2': + _, indices = torch.topk(torch.stack([p.norm().detach() for p in trial_gradient], dim=0), 4) + for i in indices: + if cost_fn == 'l2': + costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i] + elif cost_fn == 'l1': + costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i] + elif cost_fn == 'max': + costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i] + elif cost_fn == 'sim': + costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i] + pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i] + pnorm[1] += input_gradient[i].pow(2).sum() * weights[i] + elif cost_fn == 'simlocal': + costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(), + input_gradient[i].flatten(), + 0, 1e-10) * weights[i] + if cost_fn == 'sim': + costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt() + + # Accumulate final costs + total_costs += costs + return total_costs / len(gradients) diff --git a/src/inversefed/training/README.md b/src/inversefed/training/README.md new file mode 100644 index 00000000..5897a788 --- /dev/null +++ b/src/inversefed/training/README.md @@ -0,0 +1 @@ +# Training routines are implemented here \ No newline at end of file diff --git a/src/inversefed/training/__init__.py b/src/inversefed/training/__init__.py new file mode 100644 index 00000000..d1456816 --- /dev/null +++ b/src/inversefed/training/__init__.py @@ -0,0 +1,5 @@ +"""Basic training routines and loss functions.""" + +from .training_routine import train + +__all__ = ['train'] diff --git a/src/inversefed/training/scheduler.py b/src/inversefed/training/scheduler.py new file mode 100644 index 00000000..e5c4a247 --- /dev/null +++ b/src/inversefed/training/scheduler.py @@ -0,0 +1,95 @@ +"""This file is part of https://github.com/ildoonet/pytorch-gradual-warmup-lr. + +MIT License + +Copyright (c) 2019 Ildoo Kim + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" + +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import ReduceLROnPlateau + + +class GradualWarmupScheduler(_LRScheduler): + """Gradually warm-up(increasing) learning rate in optimizer. + + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: target learning rate = base lr * multiplier + total_epoch: target learning rate is reached at total_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + + """ + + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): + """Initialize the warm-up start. + + Usage: + + scheduler_normal = torch.optim.lr_scheduler.MultiStepLR(optimizer) + scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_normal) + """ + self.multiplier = multiplier + if self.multiplier < 1.: + raise ValueError('multiplier should be greater thant or equal to 1.') + self.total_epoch = total_epoch + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer) + + def get_lr(self): + if self.last_epoch > self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + + def step_ReduceLROnPlateau(self, metrics, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning + if self.last_epoch <= self.total_epoch: + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group['lr'] = lr + else: + if epoch is None: + self.after_scheduler.step(metrics, None) + else: + self.after_scheduler.step(metrics, epoch - self.total_epoch) + + def step(self, epoch=None, metrics=None): + if type(self.after_scheduler) != ReduceLROnPlateau: + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.total_epoch) + else: + return super(GradualWarmupScheduler, self).step(epoch) + else: + self.step_ReduceLROnPlateau(metrics, epoch) diff --git a/src/inversefed/training/training_routine.py b/src/inversefed/training/training_routine.py new file mode 100644 index 00000000..8e16db84 --- /dev/null +++ b/src/inversefed/training/training_routine.py @@ -0,0 +1,124 @@ +"""Implement the .train function.""" + +import torch +import numpy as np + +from collections import defaultdict + +from .scheduler import GradualWarmupScheduler + +from ..consts import BENCHMARK, NON_BLOCKING +torch.backends.cudnn.benchmark = BENCHMARK + +def train(model, loss_fn, trainloader, validloader, defs, setup=dict(dtype=torch.float, device=torch.device('cpu'))): + """Run the main interface. Train a network with specifications from the Strategy object.""" + stats = defaultdict(list) + optimizer, scheduler = set_optimizer(model, defs) + + for epoch in range(defs.epochs): + model.train() + step(model, loss_fn, trainloader, optimizer, scheduler, defs, setup, stats) + + if epoch % defs.validate == 0 or epoch == (defs.epochs - 1): + model.eval() + validate(model, loss_fn, validloader, defs, setup, stats) + # Print information about loss and accuracy + print_status(epoch, loss_fn, optimizer, stats) + + if defs.dryrun: + break + if not (np.isfinite(stats['train_losses'][-1])): + print('Loss is NaN/Inf ... terminating early ...') + break + + return stats + +def step(model, loss_fn, dataloader, optimizer, scheduler, defs, setup, stats): + """Step through one epoch.""" + epoch_loss, epoch_metric = 0, 0 + for batch, (inputs, targets) in enumerate(dataloader): + # Prep Mini-Batch + optimizer.zero_grad() + + # Transfer to GPU + inputs = inputs.to(**setup) + targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) + + # Get loss + outputs = model(inputs) + loss, _, _ = loss_fn(outputs, targets) + + + epoch_loss += loss.item() + + loss.backward() + optimizer.step() + + metric, name, _ = loss_fn.metric(outputs, targets) + epoch_metric += metric.item() + + if defs.scheduler == 'cyclic': + scheduler.step() + if defs.dryrun: + break + if defs.scheduler == 'linear': + scheduler.step() + + stats['train_losses'].append(epoch_loss / (batch + 1)) + stats['train_' + name].append(epoch_metric / (batch + 1)) + + +def validate(model, loss_fn, dataloader, defs, setup, stats): + """Validate model effectiveness of val dataset.""" + epoch_loss, epoch_metric = 0, 0 + with torch.no_grad(): + for batch, (inputs, targets) in enumerate(dataloader): + # Transfer to GPU + inputs = inputs.to(**setup) + targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) + + # Get loss and metric + outputs = model(inputs) + loss, _, _ = loss_fn(outputs, targets) + metric, name, _ = loss_fn.metric(outputs, targets) + + epoch_loss += loss.item() + epoch_metric += metric.item() + + if defs.dryrun: + break + + stats['valid_losses'].append(epoch_loss / (batch + 1)) + stats['valid_' + name].append(epoch_metric / (batch + 1)) + +def set_optimizer(model, defs): + """Build model optimizer and scheduler from defs. + + The linear scheduler drops the learning rate in intervals. + # Example: epochs=160 leads to drops at 60, 100, 140. + """ + if defs.optimizer == 'SGD': + optimizer = torch.optim.SGD(model.parameters(), lr=defs.lr, momentum=0.9, + weight_decay=defs.weight_decay, nesterov=True) + elif defs.optimizer == 'AdamW': + optimizer = torch.optim.AdamW(model.parameters(), lr=defs.lr, weight_decay=defs.weight_decay) + + if defs.scheduler == 'linear': + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=[120 // 2.667, 120 // 1.6, + 120 // 1.142], gamma=0.1) + # Scheduler is fixed to 120 epochs so that calls with fewer epochs are equal in lr drops. + + if defs.warmup: + scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=scheduler) + + return optimizer, scheduler + + +def print_status(epoch, loss_fn, optimizer, stats): + """Print basic console printout every defs.validation epochs.""" + current_lr = optimizer.param_groups[0]['lr'] + name, format = loss_fn.metric() + print(f'Epoch: {epoch}| lr: {current_lr:.4f} | ' + f'Train loss is {stats["train_losses"][-1]:6.4f}, Train {name}: {stats["train_" + name][-1]:{format}} | ' + f'Val loss is {stats["valid_losses"][-1]:6.4f}, Val {name}: {stats["valid_" + name][-1]:{format}} |') diff --git a/src/inversefed/utils.py b/src/inversefed/utils.py new file mode 100644 index 00000000..cf9c2d9f --- /dev/null +++ b/src/inversefed/utils.py @@ -0,0 +1,70 @@ +"""Various utilities.""" + +import os +import csv + +import torch +import random +import numpy as np + +import socket +import datetime + + +def system_startup(args=None, defs=None): + """Print useful system information.""" + # Choose GPU device and print status information: + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + setup = dict(device=device, dtype=torch.float) # non_blocking=NON_BLOCKING + print('Currently evaluating -------------------------------:') + print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) + print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.') + if args is not None: + print(args) + if defs is not None: + print(repr(defs)) + if torch.cuda.is_available(): + print(f'GPU : {torch.cuda.get_device_name(device=device)}') + return setup + +def save_to_table(out_dir, name, dryrun, **kwargs): + """Save keys to .csv files. Function adapted from Micah.""" + # Check for file + if not os.path.isdir(out_dir): + os.makedirs(out_dir) + fname = os.path.join(out_dir, f'table_{name}.csv') + fieldnames = list(kwargs.keys()) + + # Read or write header + try: + with open(fname, 'r') as f: + reader = csv.reader(f, delimiter='\t') + header = [line for line in reader][0] + except Exception as e: + print('Creating a new .csv table...') + with open(fname, 'w') as f: + writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames) + writer.writeheader() + if not dryrun: + # Add row for this experiment + with open(fname, 'a') as f: + writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames) + writer.writerow(kwargs) + print('\nResults saved to ' + fname + '.') + else: + print(f'Would save results to {fname}.') + print(f'Would save these keys: {fieldnames}.') + +def set_random_seed(seed=233): + """233 = 144 + 89 is my favorite number.""" + torch.manual_seed(seed + 1) + torch.cuda.manual_seed(seed + 2) + torch.cuda.manual_seed_all(seed + 3) + np.random.seed(seed + 4) + torch.cuda.manual_seed_all(seed + 5) + random.seed(seed + 6) + +def set_deterministic(): + """Switch pytorch into a deterministic computation mode.""" + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index beb68431..c57d02d7 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -1,11 +1,12 @@ from enum import Enum from utils.communication.grpc.main import GRPCCommunication -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Tuple, TYPE_CHECKING # from utils.communication.mpi import MPICommUtils if TYPE_CHECKING: from algos.base_class import BaseNode +import numpy as np class CommunicationType(Enum): MPI = 1 @@ -59,7 +60,7 @@ def send(self, dest: str | int | List[str | int], data: Any, tag: int = 0): else: print(f"Sending data to {dest}") self.comm.send(dest=int(dest), data=data) - + def receive(self, node_ids: List[int]) -> Any: """ Receive data from the specified node diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 7889b8ff..28c3040d 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -365,3 +365,106 @@ def non_iid_balanced( clnt_y = np.asarray(clnt_y) return clnt_y, clnt_idx, cls_priors + +def gia_client_dataset(train_dataset, test_dataset, num_labels=10, n=1): + """ + Select random labels and n random images per selected label from both train and test datasets. + + Args: + train_dataset: Training dataset object with __getitem__ returning (image, label) tuples + test_dataset: Test dataset object with __getitem__ returning (image, label) tuples + num_labels (int): Number of unique labels to select + n (int): Number of images to select per unique label (default is 1) + + Returns: + filtered_train_dataset: Subset of training dataset with n images per selected label + filtered_test_dataset: Subset of test dataset with n images per selected label + selected_labels: List of selected label indices + train_indices: List of indices for the selected training images + """ + + def get_ordered_indices(dataset): + label_to_indices = {i: [] for i in range(num_labels)} + for idx in range(len(dataset)): + label = dataset[idx][1] + if label < num_labels: + label_to_indices[label].append(idx) + + ordered_indices = [] + selected_labels = [] + for label in range(num_labels): + # Shuffle indices for each label to randomize selection + np.random.seed(None) + np.random.shuffle(label_to_indices[label]) + for i in range(n): + if i < len(label_to_indices[label]): + random_idx = label_to_indices[label][i] + ordered_indices.append(random_idx) + selected_labels.append(label) + + return ordered_indices, selected_labels + + # Get ordered indices and selected labels for both datasets + final_train_indices, train_selected_labels = get_ordered_indices(train_dataset) + final_test_indices, test_selected_labels = get_ordered_indices(test_dataset) + + # Create the subsets + filtered_train_dataset = Subset(train_dataset, final_train_indices) + filtered_test_dataset = Subset(test_dataset, final_test_indices) + + # Create selected_labels in ascending order + selected_labels = sorted(set(train_selected_labels)) + + # Verify ordering + for i in range(len(final_train_indices)): + assert filtered_train_dataset[i][1] == train_selected_labels[i], f"Train label at position {i} is not {train_selected_labels[i]}" + for i in range(len(final_test_indices)): + assert filtered_test_dataset[i][1] == test_selected_labels[i], f"Test label at position {i} is not {test_selected_labels[i]}" + + return filtered_train_dataset, filtered_test_dataset, selected_labels, final_train_indices + +def gia_server_testset(test_dataset, num_labels=10, num_images_per_label=4): + """ + Select random labels and exactly four random images per selected label from the test dataset. + + Args: + test_dataset: Test dataset object with __getitem__ returning (image, label) tuples + num_labels (int): Number of unique labels to select + num_images_per_label (int): Number of images to select per label + + Returns: + filtered_test_dataset: Subset of test dataset with four images per selected label + selected_labels: List of selected label indices + test_indices: List of indices for the selected test images + """ + # Get all unique labels from the test dataset + all_labels = list(set([test_dataset[i][1] for i in range(len(test_dataset))])) + + # Randomly select labels + selected_labels = sorted(np.random.choice(all_labels, size=num_labels, replace=False)) + + # Process test dataset + temp_test_dataset, test_all_indices = filter_by_class(test_dataset, selected_labels) + test_label_to_indices = {} + for idx in range(len(temp_test_dataset)): + label = temp_test_dataset[idx][1] + if label not in test_label_to_indices: + test_label_to_indices[label] = [] + test_label_to_indices[label].append(test_all_indices[idx]) + + # Select four random images per label for the test dataset + final_test_indices = [] + for label in selected_labels: + test_label_indices = test_label_to_indices[label] + + # Ensure there are at least 'num_images_per_label' images per label + if len(test_label_indices) >= num_images_per_label: + selected_test_indices = np.random.choice(test_label_indices, size=num_images_per_label, replace=False) + final_test_indices.extend(selected_test_indices) + else: + raise ValueError(f"Not enough images in class {label} to select {num_images_per_label} images.") + + # Create final dataset with exactly four images per label + filtered_test_dataset = Subset(test_dataset, final_test_indices) + + return filtered_test_dataset, selected_labels, final_test_indices \ No newline at end of file diff --git a/src/utils/gias.py b/src/utils/gias.py new file mode 100644 index 00000000..7b914f57 --- /dev/null +++ b/src/utils/gias.py @@ -0,0 +1,128 @@ +# /////////////// Gradient Inversion Helpers /////////////// + +import inversefed +import matplotlib.pyplot as plt + +import torch +import pickle + +# based on InvertingGradients code by Jonas Geiping +# code found in https://github.com/JonasGeiping/invertinggradients/tree/1157b61c6704df42c497ab9eb074c75da5204334 + +def compute_param_delta(param_s, param_t, basic_params): + """ + Generates the input value for reconstruction + Assumes param_s and param_t are from the same client. + + basic_params: list of names present in model params + """ + assert len(param_s) != 0 and len(param_t) != 0, "Empty parameters" + return [(param_t[name].to("cuda") - param_s[name].to("cuda")).detach() for name in basic_params if name in param_s and name in param_t] + +def reconstruct_gradient(param_diff, target_labels, target_images, lr, local_steps, model, client_id=0): + """ + Reconstructs the gradient following the Geiping InvertingGradients technique + """ + print("length of param diff: ", len(param_diff)) + with open(f"param_diff_{client_id}.pkl", "wb") as f: + pickle.dump(param_diff, f) + setup = inversefed.utils.system_startup() + for p in range(len(param_diff)): + param_diff[p] = param_diff[p].to(setup['device']) + # param_diff = param_diff.to(setup['device']) + target_labels = target_labels.to(setup['device']) + target_images = target_images.to(setup['device']) + + mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010] + dm = torch.as_tensor(mean, **setup)[:, None, None] + ds = torch.as_tensor(std, **setup)[:, None, None] + model = model.to(setup['device']) + config = dict(signed=True, + boxed=True, + cost_fn='sim', + indices='def', + weights='equal', + lr=0.1, + optim='adam', + restarts=1, + max_iterations=8_000, + total_variation=1e-6, + init='randn', + filter='none', + lr_decay=True, + scoring_choice='loss') + + assert len(param_diff) == 38 + + rec_machine = inversefed.FedAvgReconstructor(model, (dm, ds), local_steps, lr, config, + use_updates=True, num_images=len(target_labels)) + + output, stats = rec_machine.reconstruct(param_diff, target_labels, img_shape=(3, 32, 32)) + + # compute reconstruction acccuracy + test_mse = (output.detach() - target_images).pow(2).mean() + feat_mse = (model(output.detach())- model(target_images)).pow(2).mean() + test_psnr = inversefed.metrics.psnr(output, target_images, factor=1/ds) + print(f"Client {client_id} Test MSE: {test_mse:.2e}, Test PSNR: {test_psnr:.2f}, Feature MSE: {feat_mse:.2e}") + + grid_plot(output, target_labels, ds, dm, stats, test_mse, feat_mse, test_psnr, save_path=f"gias_output_client_{client_id}.png") + return output, stats + +def grid_plot(tensor, labels, ds, dm, stats, test_mse, feat_mse, test_psnr, save_path=None): + tensor = tensor.clone().detach() + tensor.mul_(ds).add_(dm).clamp_(0, 1) + + fig, axes = plt.subplots(1, 10, figsize=(24, 24)) + for im, l, ax in zip(tensor, labels, axes.flatten()): + ax.imshow(im.permute(1, 2, 0).cpu()) + ax.set_title(l) + ax.axis('off') + plt.title(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} " + f"| PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |"); + + if save_path: + plt.savefig(save_path, bbox_inches='tight') # Save the figure if save_path is provided + # plt.show() # Show the plot after saving + +def gia_main(param_s, param_t, base_params, model, target_labels, target_images, client_id): + """ + Main function for Gradient Inversion Attack + Returns results moved back to their original devices + """ + # Store original devices + model_device = next(model.parameters()).device + target_labels_device = target_labels.device + target_images_device = target_images.device + + # Store original parameter devices + param_s_devices = {name: param_s[name].device for name in base_params if name in param_s} + param_t_devices = {name: param_t[name].device for name in base_params if name in param_t} + + param_diff = compute_param_delta(param_s, param_t, base_params) + + # Check if all elements in para_diff are zero tensors + if all((diff == 0).all() for diff in param_diff): + print("Parameter differences contain only zeros for client ", client_id) + return None # or return an empty list, depending on your needs + + output, stats = reconstruct_gradient(param_diff, target_labels, target_images, 3e-4, 1, model, client_id) + + # Move output back to target_images device (since it's a reconstruction of the images) + if output is not None: + output = output.to(target_images_device) + + # Move model back to original device + model.to(model_device) + + # Move parameters back to their original devices + for name in base_params: + if name in param_s: + param_s[name] = param_s[name].to(param_s_devices[name]) + if name in param_t: + param_t[name] = param_t[name].to(param_t_devices[name]) + + # Move labels and images back to their original devices + target_labels = target_labels.to(target_labels_device) + target_images = target_images.to(target_images_device) + + return output, stats \ No newline at end of file diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index 62f015aa..39d4ee89 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -17,6 +17,7 @@ import pandas as pd from utils.types import ConfigType import json +import matplotlib.pyplot as plt def deprocess(img: torch.Tensor) -> torch.Tensor: @@ -190,6 +191,61 @@ def log_image(self, imgs: torch.Tensor, key: str, iteration: int): save_image(grid_img, f"{self.log_dir}/{iteration}_{key}.png") self.writer.add_image(key, grid_img.numpy(), iteration) + def log_gia_image(self, + data, + target, + node_id, + dm=torch.as_tensor([0.4914, 0.4822, 0.4465])[:, None, None], + ds=torch.as_tensor([0.2023, 0.1994, 0.2010])[:, None, None], + label=None): + """ + Plots a grid of images from `data` with corresponding labels from `target`, and saves the plot. + + Args: + data (torch.Tensor): The data tensor with shape (batch, channels, height, width). + target (torch.Tensor): The target labels tensor with shape (batch,). + node_id (int): The node ID for the client. + dm (torch.Tensor): The mean of the dataset used for normalization, with shape (3, 1, 1). + ds (torch.Tensor): The standard deviation of the dataset used for normalization, with shape (3, 1, 1). + """ + # Move data and target to CPU if they are on a GPU, and detach from the computation graph + data = data.cpu().detach() + target = target.cpu().detach() + + # Normalize and clamp the data to the valid range [0, 1] + data = data.mul(ds).add(dm) + data.clamp_(0, 1) + + # Set up grid size for plotting (e.g., 2 rows of 5 images if batch size is 10) + batch_size = data.size(0) + rows = 1 + cols = batch_size // rows if batch_size % rows == 0 else batch_size + + fig, axes = plt.subplots(rows, cols, figsize=(12, 6)) + axes = axes.flatten() + + # Loop over each image and label in the batch + for i in range(batch_size): + # Convert image to numpy format for plotting + img = data[i].permute(1, 2, 0).numpy() + + # Plot the image and label + axes[i].imshow(img) + axes[i].set_title(f"Label: {target[i].item()}") + axes[i].axis("off") + + plt.tight_layout() + + log_lab = "base" if not label else label + + plt.savefig(f"{self.log_dir}/{node_id}_{log_lab}.png") + plt.close() + + # Log images to TensorBoard + grid_img = make_grid(data, normalize=True, scale_each=True) + self.writer.add_image(f"gia_images_node_{node_id}_{log_lab}", grid_img.numpy(), node_id) + self.writer.add_text(f"gia_labels_node_{node_id}_{log_lab}", str(target.tolist()), node_id) + def log_console(self, msg: str): """ Log a message to the console. diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index fbcabfd6..3e8b50e5 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -13,6 +13,7 @@ import yolo from utils.types import ConfigType +from inversefed.reconstruction_algorithms import loss_steps class ModelUtils: def __init__(self, device: torch.device, config: ConfigType) -> None: @@ -32,7 +33,6 @@ def get_model( model_name: str, dset: str, device: torch.device, - device_ids: List[int], pretrained: bool = False, **kwargs: Any, ): @@ -181,6 +181,7 @@ def train_classification( test_loader: DataLoader[Any] | None = None, **kwargs: Any, ) -> Tuple[float, float]: + model.train() correct = 0 train_loss = 0 for batch_idx, (data, target) in enumerate(dloader): @@ -193,24 +194,71 @@ def train_classification( output = model(data, position=position) if kwargs.get("apply_softmax", False): + print("here, applying softmax") output = nn.functional.log_softmax(output, dim=1) # type: ignore + if kwargs.get("gia", False): + # Sum the loss and create gradient graph like in loss_steps + # Use modified loss_steps function that returns loss + model.eval() + param_updates = loss_steps( + model, + data, + target, + loss_fn=loss_fn, + lr=3e-4, + local_steps=1, + use_updates=True, # Must be True to get parameter differences + batch_size=10 + ) + + # save parameter update for sanity check + # with open(f"param_updates_{node_id}.pkl", "wb") as f: + # pickle.dump(param_updates, f) + model.train() + + # Apply the updates to the model parameters + with torch.no_grad(): + for param, update in zip(model.parameters(), param_updates): + param.data.add_(update) # Directly add the update differences + + # Compute loss for tracking (without gradients since we've already updated) + with torch.no_grad(): + position = kwargs.get("position", 0) + output = model(data, position=position) + if kwargs.get("apply_softmax", False): + output = nn.functional.log_softmax(output, dim=1) + loss = loss_fn(output, target) + train_loss += loss.item() + + else: + # Standard training procedure + optim.zero_grad() + position = kwargs.get("position", 0) + output = model(data, position=position) + + if kwargs.get("apply_softmax", False): + output = nn.functional.log_softmax(output, dim=1) + + loss = loss_fn(output, target) + loss.backward() + optim.step() + train_loss += loss.item() - loss = loss_fn(output, target) - loss.backward() - optim.step() - train_loss += loss.item() - pred = output.argmax(dim=1, keepdim=True) - # view_as() is used to make sure the shape of pred and target are - # the same - if len(target.size()) > 1: - target = target.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() + # Compute accuracy + with torch.no_grad(): + output = model(data, position=position) + pred = output.argmax(dim=1, keepdim=True) + if len(target.size()) > 1: + target = target.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() if test_loader is not None: - # TODO: implement test loader for pascal test_loss, test_acc = self.test(model, test_loader, loss_fn, device) print( - f"Train Loss: {train_loss/(batch_idx+1):.6f} | Train Acc: {correct/((batch_idx+1)*len(data)):.6f} | Test Loss: {test_loss:.6f} | Test Acc: {test_acc:.6f}" + f"Train Loss: {train_loss/(batch_idx+1):.6f} | " + f"Train Acc: {correct/((batch_idx+1)*len(data)):.6f} | " + f"Test Loss: {test_loss:.6f} | " + f"Test Acc: {test_acc:.6f}" ) acc = correct / len(dloader.dataset) @@ -522,4 +570,4 @@ def get_memory_usage(self): """ Get the memory usage """ - return torch.cuda.memory_allocated(self.device) + return torch.cuda.memory_allocated(self.device) \ No newline at end of file