Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Mar 19, 2024
1 parent 6b422b7 commit ed4601e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
3 changes: 2 additions & 1 deletion baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def fit(self, num_rounds, timeout):
self.good_clients_idx,
)

metrics_cen = {key: self.confusion_matrix[key] for key in ["TP", "TN", "FP", "FN"]}
for key, val in self.confusion_matrix.items():
metrics_cen[key] = val

log(
INFO,
Expand Down
22 changes: 4 additions & 18 deletions baselines/flanders/flanders/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def aggregate_fit(
anomaly_scores = self.distance_function(
ground_truth, predicted_matrix[:, :, 0]
)
log(DEBUG, "Anomaly scores: %s", anomaly_scores)
log(INFO, "Anomaly scores: %s", anomaly_scores)

log(INFO, "Selecting good clients")
good_clients_idx = sorted(
Expand All @@ -244,13 +244,13 @@ def aggregate_fit(
) # noqa

avg_anomaly_score_gc = np.mean(anomaly_scores[good_clients_idx])
log(DEBUG, "Average anomaly score for good clients", avg_anomaly_score_gc)
log(INFO, "Average anomaly score for good clients: %s", avg_anomaly_score_gc)

avg_anomaly_score_m = np.mean(anomaly_scores[malicious_clients_idx])
log(DEBUG, "Average anomaly score for malicious clients", avg_anomaly_score_m)
log(INFO, "Average anomaly score for malicious clients: %s", avg_anomaly_score_m)

results = np.array(results)[good_clients_idx].tolist()
log(DEBUG, "Good clients: %s", good_clients_idx)
log(INFO, "Good clients: %s", good_clients_idx)

log(INFO, "Applying aggregate_fn")
# Convert results
Expand Down Expand Up @@ -297,20 +297,6 @@ def aggregate_fit(
malicious_clients_idx,
)

def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
parameters_ndarrays = parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics


# pylint: disable=too-many-locals, too-many-arguments, invalid-name
def mar(X, pred_step, alpha=1, beta=1, maxiter=100):
Expand Down

0 comments on commit ed4601e

Please sign in to comment.