From 6ef1cd4fed25d109d178ba8df02d6c8282c5949e Mon Sep 17 00:00:00 2001 From: photonshi Date: Mon, 2 Dec 2024 17:28:11 +0000 Subject: [PATCH 1/2] fixed attack invocation in fl --- src/algos/fl.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 569213e..7d586b4 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -229,12 +229,13 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round attack_end_round (int): The last round for the attack to be performed. """ - # Normal training when outside the attack range - - if round < attack_start_round or round > attack_end_round: - self.receive_and_aggregate() + # Determine if the attack should be performed + attack_in_progress = self.gia_attacker and attack_start_round <= round <= attack_end_round + + if attack_in_progress: + self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round) else: - self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name) + self.receive_and_aggregate() def run_protocol(self): From 43a663ad65de861a1c8ddb5c328612cf431e199a Mon Sep 17 00:00:00 2001 From: photonshi Date: Mon, 2 Dec 2024 17:29:50 +0000 Subject: [PATCH 2/2] removed dump file name in fl --- src/algos/fl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 7d586b4..721912f 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -169,12 +169,9 @@ def test(self, **kwargs: Any) -> Tuple[float, float, float]: self.stats["test_loss"], self.stats["test_acc"], self.stats["test_time"] = test_loss, test_acc, time_taken return test_loss, test_acc, time_taken - def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""): + def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int): reprs = self.comm_utils.all_gather() - with open(dump_file_name, "wb") as f: - pickle.dump(reprs, f) - # Handle GIA-specific logic if "gia" in self.config: print("Server Running GIA attack")