-
Notifications
You must be signed in to change notification settings - Fork 942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(framework) Fix aggregate_inplace()
in strategy.py
#3936
Conversation
81ee2a3
to
0a0d60f
Compare
Hi @KarhouTam , thanks for taking a close look at that function used by many of the strategies. What's the ration of memory footprint saved by the change you propose? Is it possible to give an analytical x% memory savings cost? |
@jafermarq Thanks for the reviewing first.
Sure. TL;DR: Close to 50% memory savings. The results calculated first from above used input
And we can see that the Flower version uses memory So I think the saving is close to 50%. Full ExampleFirst I wrap these functions by # =============== Flower Version ==============
@profile
def flwr_aggregate_inplace(results):
"""Compute in-place weighted average."""
# Count total examples
num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
# Compute scaling factors for each result
scaling_factors = [
fit_res.num_examples / num_examples_total for _, fit_res in results
]
# Let's do in-place aggregation
# Get first result, then add up each other
params = [
scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:]):
res = (
scaling_factors[i + 1] * x
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]
return params
# =============== Proposed Version ==============
@profile
def my_aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
"""Compute in-place weighted average."""
# Count total examples
num_examples_total = sum(fit_res.num_examples for _, fit_res in results)
# Compute scaling factors for each result
scaling_factors = [
fit_res.num_examples / num_examples_total for _, fit_res in results
]
def _try_inplace(
x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
) -> NDArray:
return ( # type: ignore[no-any-return]
np_binary_op(x, y, out=x)
if np.can_cast(y, x.dtype, casting="same_kind")
else np_binary_op(x, np.array(y, x.dtype), out=x)
)
# Let's do in-place aggregation
# Get first result, then add up each other
params = [
_try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:], start=1):
res = (
_try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [
reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
for layer_updates in zip(params, res)
]
return params And set the input argument weights0_0 = np.random.randn(100, 64)
weights0_1 = np.random.randn(314, 628, 3)
weights1_0 = np.random.randn(100, 64)
weights1_1 = np.random.randn(314, 628, 3)
results: List[Tuple[ClientProxy, FitRes]] = [
(
None,
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=ndarrays_to_parameters([weights0_0, weights0_1]),
num_examples=1,
metrics={},
),
),
(
None,
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=ndarrays_to_parameters([weights1_0, weights1_1]),
num_examples=5,
metrics={},
),
),
] Which is the same as one in . And the test funcitons are import gc
LOOPS = 1
def test_flwr_aggregate_inplace():
gc.disable()
start = time.time()
for _ in range(LOOPS):
a = flwr_aggregate_inplace(results)
end = time.time()
gc.enable()
print("flwr's aggregate_inplace cost time: ", end - start)
def test_my_aggregate_inplace():
gc.disable()
start = time.time()
for _ in range(LOOPS):
b = my_aggregate_inplace(results)
end = time.time()
gc.enable()
print("my aggregate_inplace cost time: ", end - start) Results
|
Hi @KarhouTam, this looks super good. I spend some time on Friday looking into it closely. We aim to get it merged for the next release of flower, bear with us just some days. 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Thanks for the PR and the helpful description @KarhouTam! Great improvement. The merge window for Flower 1.11 is pretty much closed (apart from a few scheduled PRs). We'll merge this after the 1.11 release to have it included in Flower 1.12. |
# Conflicts: # src/py/flwr/server/strategy/aggregate.py
…fix-aggregate-inplace
Conflicts solved. This PR is ready to be merged in anytime. 🤗 |
Hi, @panh99 @danieljanes . |
Hi @KarhouTam ! We are preparing the 1.12.0 release, and your PR will be merged soon, I believe. Thanks for the reminder! |
Thanks for the PR @KarhouTam, approved! This will be included in the upcoming Flower 1.12 release (ETA: end of this week or early next week) |
Issue
I found that
aggregate_inplace()
actually do its job not inplace.Description
flower/src/py/flwr/server/strategy/aggregate.py
Lines 57 to 65 in ad811b5
scaling_factors[0] * x
,caling_factors[i + 1] * x
would introduce additional memory overhead to temporaily store the output. Instead, using api likenp.add(x, y, out=x)
can letnumpy
write the output tox
directly.Proposal
Explanation
Add
out
argument ofnumpy
funcition calls.Explanation from
numpy
:However,
out
argument requiresx
andy
has the samedtype
, which is the most common cases. For dealing the unusual cases, I use a wrapper and set the fallback.Proof
I use
memory_profiler
to check the memory usage.Flower version
Proposed Version
Checklist
#contributions
)Any other comments?