Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Mar 20, 2024
1 parent 83f4aeb commit 5dbb153
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
1 change: 1 addition & 0 deletions baselines/flanders/flanders/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_dataset(path_to_data: Path, cid: str, partition: str, transform=None):
return TorchVisionFL(path_to_data, transform=transform)


# pylint: disable=too-many-arguments, too-many-locals
def get_dataloader(
path_to_data: str,
cid: str,
Expand Down
22 changes: 12 additions & 10 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def main(cfg: DictConfig) -> None:
# 0. Set random seed
seed = cfg.seed
np.random.seed(seed)
np.random.set_state(np.random.RandomState(seed).get_state())
np.random.set_state(
np.random.RandomState(seed).get_state()
) # pylint: disable=no-member
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand Down Expand Up @@ -104,7 +106,7 @@ def main(cfg: DictConfig) -> None:

# 3. Define your clients
# pylint: disable=no-else-return
def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name):
def client_fn(cid: str, dataset_name: str = dataset_name):
client = clients[dataset_name][0]
if dataset_name in ["mnist", "fmnist", "cifar", "cifar100"]:
return client(cid, fed_dir)
Expand Down Expand Up @@ -242,10 +244,10 @@ def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name):
rounds, test_loss = zip(*history.losses_centralized)
_, test_accuracy = zip(*history.metrics_centralized["accuracy"])
_, test_auc = zip(*history.metrics_centralized["auc"])
_, tp = zip(*history.metrics_centralized["TP"])
_, tn = zip(*history.metrics_centralized["TN"])
_, fp = zip(*history.metrics_centralized["FP"])
_, fn = zip(*history.metrics_centralized["FN"])
_, truep = zip(*history.metrics_centralized["TP"])
_, truen = zip(*history.metrics_centralized["TN"])
_, falsep = zip(*history.metrics_centralized["FP"])
_, falsen = zip(*history.metrics_centralized["FN"])

path_to_save = [os.path.join(save_path, "results.csv"), "outputs/all_results.csv"]

Expand All @@ -256,10 +258,10 @@ def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name):
"loss": test_loss,
"accuracy": test_accuracy,
"auc": test_auc,
"TP": tp,
"TN": tn,
"FP": fp,
"FN": fn,
"TP": truep,
"TN": truen,
"FP": falsep,
"FN": falsen,
"attack_fn": [attack_fn for _ in range(len(rounds))],
"dataset_name": [dataset_name for _ in range(len(rounds))],
"num_malicious": [num_malicious for _ in range(len(rounds))],
Expand Down
2 changes: 1 addition & 1 deletion baselines/flanders/flanders/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MnistNet(nn.Module):
"""Neural network for MNIST classification."""

def __init__(self):
super(MnistNet, self).__init__()
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)

Expand Down
2 changes: 1 addition & 1 deletion baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def fit_round(
# the server simulates an attacker that controls a fraction of the clients
if self.attack_fn is not None and server_round > self.warmup_rounds:
log(INFO, "Applying attack function")
results, others = self.attack_fn(
results, _ = self.attack_fn(
ordered_results,
clients_state,
omniscent=self.omniscent,
Expand Down
4 changes: 2 additions & 2 deletions baselines/flanders/flanders/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def aggregate_fit(
parameters_aggregated = ndarrays_to_parameters(
self.aggregate_fn(weights_results, **aggregate_parameters)
)
except ValueError as e:
log(WARNING, f"Error in aggregate_fn: {e}")
except ValueError as err:
log(WARNING, "Error in aggregate_fn: %s", err)
parameters_aggregated = ndarrays_to_parameters(
aggregate(weights_results)
)
Expand Down

0 comments on commit 5dbb153

Please sign in to comment.