From 8cc82c44e9c854831ec9b3ad6e722bd4fd7257b9 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sun, 28 Jul 2024 09:03:40 +0800 Subject: [PATCH 1/5] fix --- src/py/flwr/server/strategy/aggregate.py | 30 +++++++++++++++++------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index c668b55eebe6..1be416a07589 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -15,7 +15,7 @@ """Aggregation functions for strategy implementations.""" # mypy: disallow_untyped_calls=False -from functools import reduce +from functools import partial, reduce from typing import Any, Callable, List, Tuple import numpy as np @@ -45,24 +45,36 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: """Compute in-place weighted average.""" # Count total examples - num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) + num_examples_total = sum(fit_res.num_examples for _, fit_res in results) # Compute scaling factors for each result - scaling_factors = [ - fit_res.num_examples / num_examples_total for _, fit_res in results - ] + scaling_factors = np.array( + [fit_res.num_examples / num_examples_total for _, fit_res in results], + dtype=np.float32, + ) + + def _try_inplace( + x: NDArray, y: NDArray, np_binary_op: Callable[[NDArray, NDArray], NDArray] + ): + # `(x, y, out=x)` requires x and y has the same dtype, which will do the inplace-job + return np_binary_op(x, y, out=x) if x.dtype == y.dtype else np_binary_op(x, y) # Let's do in-place aggregation # Get first result, then add up each other + params = [ - scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply) + for x in parameters_to_ndarrays(results[0][1].parameters) ] - for i, (_, fit_res) in enumerate(results[1:]): + for i, (_, fit_res) in enumerate(results[1:], start=1): res = ( - scaling_factors[i + 1] * x + _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply) for x in parameters_to_ndarrays(fit_res.parameters) ) - params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + params = [ + reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates) + for layer_updates in zip(params, res) + ] return params From 51657217d9435943f09be493087a6435de426314 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sun, 28 Jul 2024 10:09:40 +0800 Subject: [PATCH 2/5] mypy fix --- src/py/flwr/server/strategy/aggregate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 1be416a07589..e108f2ece518 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -53,9 +53,7 @@ def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: dtype=np.float32, ) - def _try_inplace( - x: NDArray, y: NDArray, np_binary_op: Callable[[NDArray, NDArray], NDArray] - ): + def _try_inplace(x: NDArray, y: NDArray, np_binary_op: np.ufunc) -> Any: # `(x, y, out=x)` requires x and y has the same dtype, which will do the inplace-job return np_binary_op(x, y, out=x) if x.dtype == y.dtype else np_binary_op(x, y) From 00cd859ae81b903e3188d0a55b358c7c7e2821c4 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sun, 28 Jul 2024 12:34:14 +0800 Subject: [PATCH 3/5] fix output type of `_try_inplace` --- src/py/flwr/server/strategy/aggregate.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index e108f2ece518..b82c5ab4e1d4 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -16,7 +16,7 @@ # mypy: disallow_untyped_calls=False from functools import partial, reduce -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union import numpy as np @@ -48,14 +48,18 @@ def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: num_examples_total = sum(fit_res.num_examples for _, fit_res in results) # Compute scaling factors for each result - scaling_factors = np.array( - [fit_res.num_examples / num_examples_total for _, fit_res in results], - dtype=np.float32, - ) + scaling_factors = [ + fit_res.num_examples / num_examples_total for _, fit_res in results + ] - def _try_inplace(x: NDArray, y: NDArray, np_binary_op: np.ufunc) -> Any: - # `(x, y, out=x)` requires x and y has the same dtype, which will do the inplace-job - return np_binary_op(x, y, out=x) if x.dtype == y.dtype else np_binary_op(x, y) + def _try_inplace( + x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc + ) -> NDArray: + return ( + np_binary_op(x, y, out=x) + if np.can_cast(y, x.dtype, casting="same_kind") + else np_binary_op(x, y) + ) # type: ignore[no-any-return] # Let's do in-place aggregation # Get first result, then add up each other From 0a0d60f235122fe8d8969eaee6ddeafda33043b6 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sun, 28 Jul 2024 18:08:09 +0800 Subject: [PATCH 4/5] fix check --- src/py/flwr/server/strategy/aggregate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index b82c5ab4e1d4..eebad83226af 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -55,11 +55,11 @@ def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: def _try_inplace( x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc ) -> NDArray: - return ( + return ( # type: ignore[no-any-return] np_binary_op(x, y, out=x) if np.can_cast(y, x.dtype, casting="same_kind") else np_binary_op(x, y) - ) # type: ignore[no-any-return] + ) # Let's do in-place aggregation # Get first result, then add up each other From 9a3f1f7bf8802de6e347a3825dc1d84c9df87757 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sun, 28 Jul 2024 19:36:49 +0800 Subject: [PATCH 5/5] update --- src/py/flwr/server/strategy/aggregate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index eebad83226af..4e34ccca4c4f 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -58,7 +58,7 @@ def _try_inplace( return ( # type: ignore[no-any-return] np_binary_op(x, y, out=x) if np.can_cast(y, x.dtype, casting="same_kind") - else np_binary_op(x, y) + else np_binary_op(x, np.array(y, x.dtype), out=x) ) # Let's do in-place aggregation @@ -68,6 +68,7 @@ def _try_inplace( _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply) for x in parameters_to_ndarrays(results[0][1].parameters) ] + for i, (_, fit_res) in enumerate(results[1:], start=1): res = ( _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)