Skip to content

Commit

Permalink
fix(framework) Fix aggregate_inplace() in strategy.py (#3936)
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam authored Oct 7, 2024
1 parent 8db79a4 commit dbf8d6d
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"""Aggregation functions for strategy implementations."""
# mypy: disallow_untyped_calls=False

from functools import reduce
from typing import Any, Callable
from functools import partial, reduce
from typing import Any, Callable, Union

import numpy as np

Expand Down Expand Up @@ -52,17 +52,31 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
fit_res.num_examples / num_examples_total for _, fit_res in results
]

def _try_inplace(
x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
) -> NDArray:
return ( # type: ignore[no-any-return]
np_binary_op(x, y, out=x)
if np.can_cast(y, x.dtype, casting="same_kind")
else np_binary_op(x, np.array(y, x.dtype), out=x)
)

# Let's do in-place aggregation
# Get first result, then add up each other
params = [
scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
_try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:]):

for i, (_, fit_res) in enumerate(results[1:], start=1):
res = (
scaling_factors[i + 1] * x
_try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]
params = [
reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
for layer_updates in zip(params, res)
]

return params

Expand Down Expand Up @@ -128,7 +142,7 @@ def aggregate_bulyan(
Parameters
----------
results: List[Tuple[NDArrays, int]]
results: list[tuple[NDArrays, int]]
Weights and number of samples for each of the client.
num_malicious: int
The maximum number of malicious clients.
Expand Down Expand Up @@ -332,7 +346,7 @@ def _aggregate_n_closest_weights(
----------
reference_weights: NDArrays
The weights from which the distances will be computed
results: List[Tuple[NDArrays, int]]
results: list[tuple[NDArrays, int]]
The weights from models
beta_closest: int
The number of the closest distance weights that will be averaged
Expand Down

0 comments on commit dbf8d6d

Please sign in to comment.