Skip to content

Commit

Permalink
Fix strategy API for aggregate_fit()
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Mar 21, 2024
1 parent 5dbb153 commit 4abf7d4
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 44 deletions.
20 changes: 11 additions & 9 deletions baselines/flanders/flanders/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,18 @@ def fang_attack(
# Set corrupted clients' updates to w_1
results = [
(
proxy,
FitRes(
fitres.status,
parameters=ndarrays_to_parameters(corrupted_params),
num_examples=fitres.num_examples,
metrics=fitres.metrics,
),
(
proxy,
FitRes(
fitres.status,
parameters=ndarrays_to_parameters(corrupted_params),
num_examples=fitres.num_examples,
metrics=fitres.metrics,
),
)
if states[fitres.metrics["cid"]]
else (proxy, fitres)
)
if states[fitres.metrics["cid"]]
else (proxy, fitres)
for proxy, fitres in ordered_results
]

Expand Down
3 changes: 2 additions & 1 deletion baselines/flanders/flanders/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Clients implementation for Flanders."""

from collections import OrderedDict
from pathlib import Path
from typing import Tuple
Expand Down Expand Up @@ -80,7 +81,7 @@ def fit(self, parameters, config):
return (
get_params(self.net),
len(trainloader.dataset),
{"cid": self.cid},
{"cid": self.cid, "malicious": config["malicious"]},
)

def evaluate(self, parameters, config):
Expand Down
4 changes: 2 additions & 2 deletions baselines/flanders/flanders/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def mnist_transformation(img):
class TorchVisionFL(VisionDataset):
"""TorchVision FL class.
Use this class by either passing a path to a torch file (.pt) containing
(data, targets) or pass the data, targets directly instead.
Use this class by either passing a path to a torch file (.pt) containing (data,
targets) or pass the data, targets directly instead.
This is just a trimmed down version of torchvision.datasets.MNIST.
"""
Expand Down
6 changes: 3 additions & 3 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .utils import fmnist_evaluate, l2_norm, mnist_evaluate


# pylint: disable=too-many-locals
# pylint: disable=too-many-locals, too-many-branches, too-many-statements
@hydra.main(config_path="conf", config_name="base", version_base=None)
def main(cfg: DictConfig) -> None:
"""Run the baseline.
Expand All @@ -36,8 +36,8 @@ def main(cfg: DictConfig) -> None:
seed = cfg.seed
np.random.seed(seed)
np.random.set_state(
np.random.RandomState(seed).get_state()
) # pylint: disable=no-member
np.random.RandomState(seed).get_state() # pylint: disable=no-member
)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand Down
1 change: 1 addition & 0 deletions baselines/flanders/flanders/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Models for FLANDERS experiments."""

import itertools

import torch
Expand Down
29 changes: 14 additions & 15 deletions baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,15 @@ def fit_round(
)

# Create dict clients_state to keep track of malicious clients
# and send the information to the clients
clients_state = {}
for _, (proxy, _) in enumerate(client_instructions):
for _, (proxy, ins) in enumerate(client_instructions):
clients_state[proxy.cid] = False
ins.config["malicious"] = False
if proxy.cid in self.malicious_lst:
clients_state[proxy.cid] = True
ins.config["malicious"] = True

# Sort clients states
clients_state = {k: clients_state[k] for k in sorted(clients_state)}
log(
Expand Down Expand Up @@ -326,20 +330,18 @@ def fit_round(
else:
results = ordered_results

# Aggregate training results
log(INFO, "fit_round - Aggregating training results")
good_clients_idx = []
malicious_clients_idx = []
aggregated_result = self.strategy.aggregate_fit(
server_round, results, failures
)
if isinstance(self.strategy, Flanders):
# Aggregate training results
log(INFO, "fit_round - Aggregating training results")
aggregated_result = self.strategy.aggregate_fit(
server_round, results, failures, clients_state
)
(
parameters_aggregated,
metrics_aggregated,
good_clients_idx,
malicious_clients_idx,
) = aggregated_result
parameters_aggregated, metrics_aggregated = aggregated_result
malicious_clients_idx = metrics_aggregated["malicious_clients_idx"]
good_clients_idx = metrics_aggregated["good_clients_idx"]

log(INFO, "Malicious clients: %s", malicious_clients_idx)

log(INFO, "clients_state: %s", clients_state)
Expand Down Expand Up @@ -376,9 +378,6 @@ def fit_round(
else:
# Aggregate training results
log(INFO, "fit_round - Aggregating training results")
aggregated_result = self.strategy.aggregate_fit(
server_round, results, failures
)
parameters_aggregated, metrics_aggregated = aggregated_result

self.clients_state = clients_state
Expand Down
8 changes: 5 additions & 3 deletions baselines/flanders/flanders/strategies/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def configure_fit(
fit_ins_list = [
FitIns(
parameters,
{}
if not self.on_fit_config_fn
else self.on_fit_config_fn(server_round),
(
{}
if not self.on_fit_config_fn
else self.on_fit_config_fn(server_round)
),
)
for _ in range(sample_size)
]
Expand Down
26 changes: 15 additions & 11 deletions baselines/flanders/flanders/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,11 @@ def configure_fit(
fit_ins_list = [
FitIns(
parameters,
{}
if not self.on_fit_config_fn
else self.on_fit_config_fn(server_round),
(
{}
if not self.on_fit_config_fn
else self.on_fit_config_fn(server_round)
),
)
for _ in range(sample_size)
]
Expand All @@ -186,7 +188,6 @@ def aggregate_fit(
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
clients_state: Dict[str, bool],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Apply MAR forecasting to exclude malicious clients from FedAvg.
Expand Down Expand Up @@ -269,7 +270,11 @@ def aggregate_fit(
# Check that self.aggregate_fn has num_malicious parameter
if "num_malicious" in self.aggregate_fn.__code__.co_varnames:
# Count the number of malicious clients in
# good_clients_idx by checking clients_state
# good_clients_idx by checking FitRes
clients_state = {
str(fit_res.metrics["cid"]): fit_res.metrics["malicious"]
for _, fit_res in results
}
num_malicious = sum([clients_state[str(cid)] for cid in good_clients_idx])
log(
INFO,
Expand Down Expand Up @@ -313,12 +318,11 @@ def aggregate_fit(
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return (
parameters_aggregated,
metrics_aggregated,
good_clients_idx,
malicious_clients_idx,
)
# Add good_clients_idx and malicious_clients_idx to metrics_aggregated
metrics_aggregated["good_clients_idx"] = good_clients_idx
metrics_aggregated["malicious_clients_idx"] = malicious_clients_idx

return parameters_aggregated, metrics_aggregated


# pylint: disable=too-many-locals, too-many-arguments, invalid-name
Expand Down

0 comments on commit 4abf7d4

Please sign in to comment.