Skip to content

Commit

Permalink
Malicious attack: Gradient inversion (#137)
Browse files Browse the repository at this point in the history
* initial push, need to debug

* first commit - testing grpc

* initial commit

* adding support for 10 image dataset

* added gia for fl

* debug commit

* debugging training

* working now but accuracy is very low

* rewrote using loss steps for much better performance

* gia for flStatic - need to add support for multiple attackers

* fl static - need to optimize for attacker to track num rounds

* fixed PR changes

* fixed logging

* migrated attack code to basenode

---------

Co-authored-by: Abhishek Singh <[email protected]>
  • Loading branch information
photonshi and tremblerz authored Nov 15, 2024
1 parent 8e3b4ac commit 5e960ab
Show file tree
Hide file tree
Showing 30 changed files with 3,085 additions and 186 deletions.
448 changes: 289 additions & 159 deletions src/algos/base_class.py

Large diffs are not rendered by default.

78 changes: 75 additions & 3 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
from algos.attack_bad_weights import BadWeightsAttack
from algos.attack_sign_flip import SignFlipAttack

from utils.gias import gia_main

import pickle

class FedAvgClient(BaseClient):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
self.config = config
self.random_params = self.model.state_dict()

def local_test(self, **kwargs: Any) -> Tuple[float, float, float]:
"""
Expand Down Expand Up @@ -68,6 +73,16 @@ def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]:
# move the model to cpu before sending
for key in message["model"].keys():
message["model"][key] = message["model"][key].to("cpu")

# assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found"
if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# also stream image and labels
message["images"] = self.images.to("cpu")
message["labels"] = self.labels.to("cpu")

message["random_params"] = self.random_params
for key in message["random_params"].keys():
message["random_params"][key] = message["random_params"][key].to("cpu")

return message # type: ignore

Expand Down Expand Up @@ -155,16 +170,73 @@ def test(self, **kwargs: Any) -> Tuple[float, float, float]:
self.stats["test_loss"], self.stats["test_acc"], self.stats["test_time"] = test_loss, test_acc, time_taken
return test_loss, test_acc, time_taken

def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""):
reprs = self.comm_utils.all_gather()

with open(dump_file_name, "wb") as f:
pickle.dump(reprs, f)

# Handle GIA-specific logic
if "gia" in self.config:
print("Server Running GIA attack")
base_params = [key for key, _ in self.model.named_parameters()]
print(base_params)

for rep in reprs:
client_id = rep["sender"]
assert "images" in rep and "labels" in rep, "Images and labels not found in representation"
model_state_dict = rep["model"]

# Extract relevant model parameters
model_params = OrderedDict(
(key, value) for key, value in model_state_dict.items()
if key in base_params
)

random_params = rep["random_params"]
random_params = OrderedDict(
(key, value) for key, value in random_params.items()
if key in base_params
)

# Store parameters based on attack start and end rounds
if round == attack_start_round:
self.params_s[client_id - 1] = model_params
elif round == attack_end_round:
self.params_t[client_id - 1] = model_params
images = rep["images"]
labels = rep["labels"]

# Launch GIA attack
p_s, p_t = self.params_s[client_id - 1], self.params_t[client_id - 1]
gia_main(p_s, p_t, base_params, self.model, labels, images, client_id)

avg_wts = self.aggregate(reprs)
self.set_representation(avg_wts)


def receive_and_aggregate(self):
reprs = self.comm_utils.all_gather()
avg_wts = self.aggregate(reprs)
self.set_representation(avg_wts)

def single_round(self):
def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1):
"""
Runs the whole training procedure
Runs the whole training procedure.
Parameters:
round (int): Current round of training.
attack_start_round (int): The starting round to initiate the attack.
attack_end_round (int): The last round for the attack to be performed.
"""
self.receive_and_aggregate()

# Normal training when outside the attack range

if round < attack_start_round or round > attack_end_round:
self.receive_and_aggregate()
else:
self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name)


def run_protocol(self):
print(f"Client {self.node_id} ready to start training")
Expand Down
Loading

0 comments on commit 5e960ab

Please sign in to comment.