From dbf8d6d214f5e15702f2e81be7db774d734b95b8 Mon Sep 17 00:00:00 2001 From: Jiahao Tan Date: Mon, 7 Oct 2024 19:03:18 +0800 Subject: [PATCH] fix(framework) Fix `aggregate_inplace()` in `strategy.py` (#3936) --- src/py/flwr/server/strategy/aggregate.py | 30 +++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index d5ee7340f8ea..94beacba0087 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -15,8 +15,8 @@ """Aggregation functions for strategy implementations.""" # mypy: disallow_untyped_calls=False -from functools import reduce -from typing import Any, Callable +from functools import partial, reduce +from typing import Any, Callable, Union import numpy as np @@ -52,17 +52,31 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays: fit_res.num_examples / num_examples_total for _, fit_res in results ] + def _try_inplace( + x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc + ) -> NDArray: + return ( # type: ignore[no-any-return] + np_binary_op(x, y, out=x) + if np.can_cast(y, x.dtype, casting="same_kind") + else np_binary_op(x, np.array(y, x.dtype), out=x) + ) + # Let's do in-place aggregation # Get first result, then add up each other params = [ - scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply) + for x in parameters_to_ndarrays(results[0][1].parameters) ] - for i, (_, fit_res) in enumerate(results[1:]): + + for i, (_, fit_res) in enumerate(results[1:], start=1): res = ( - scaling_factors[i + 1] * x + _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply) for x in parameters_to_ndarrays(fit_res.parameters) ) - params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + params = [ + reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates) + for layer_updates in zip(params, res) + ] return params @@ -128,7 +142,7 @@ def aggregate_bulyan( Parameters ---------- - results: List[Tuple[NDArrays, int]] + results: list[tuple[NDArrays, int]] Weights and number of samples for each of the client. num_malicious: int The maximum number of malicious clients. @@ -332,7 +346,7 @@ def _aggregate_n_closest_weights( ---------- reference_weights: NDArrays The weights from which the distances will be computed - results: List[Tuple[NDArrays, int]] + results: list[tuple[NDArrays, int]] The weights from models beta_closest: int The number of the closest distance weights that will be averaged