Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(framework) Fix aggregate_inplace() in strategy.py #3936

Merged
merged 20 commits into from
Oct 7, 2024

Conversation

KarhouTam
Copy link
Contributor

@KarhouTam KarhouTam commented Jul 28, 2024

Issue

I found that aggregate_inplace() actually do its job not inplace.

Description

params = [
scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:]):
res = (
scaling_factors[i + 1] * x
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]

scaling_factors[0] * x, caling_factors[i + 1] * x would introduce additional memory overhead to temporaily store the output. Instead, using api like np.add(x, y, out=x) can let numpy write the output to x directly.

Proposal

Explanation

Add out argument of numpy funcition calls.
Explanation from numpy:

out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

However, out argument requires x and y has the same dtype, which is the most common cases. For dealing the unusual cases, I use a wrapper and set the fallback.

Proof

I use memory_profiler to check the memory usage.

Flower version

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    13    120.7 MiB    120.7 MiB           1   @profile
    14                                         def flwr_aggregate_inplace(results):
    15                                             """Compute in-place weighted average."""
    16                                             # Count total examples
    17    120.7 MiB      0.0 MiB          23       num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
    18                                         
    19                                             # Compute scaling factors for each result
    20    120.7 MiB      0.0 MiB          24       scaling_factors = [
    21    120.7 MiB      0.0 MiB          11           fit_res.num_examples / num_examples_total for _, fit_res in results
    22                                             ]
    23                                         
    24                                             # Let's do in-place aggregation
    25                                             # Get first result, then add up each other
    26    122.4 MiB      0.0 MiB          44       params = [
    27    122.4 MiB      1.7 MiB          21           scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
    28                                             ]
    29    123.2 MiB      0.0 MiB          10       for i, (_, fit_res) in enumerate(results[1:]):
    30    123.2 MiB      0.0 MiB         559           res = (
    31    123.2 MiB      0.3 MiB         180               scaling_factors[i + 1] * x
    32    123.2 MiB      0.0 MiB         189               for x in parameters_to_ndarrays(fit_res.parameters)
    33                                                 )
    34    123.2 MiB      0.6 MiB         207           params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]
    35                                         
    36    123.2 MiB      0.0 MiB           1       return params

Proposed Version

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    39    119.9 MiB    119.9 MiB           1   @profile
    40                                         def my_aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
    41                                             """Compute in-place weighted average."""
    42                                             # Count total examples
    43    119.9 MiB      0.0 MiB          23       num_examples_total = sum(fit_res.num_examples for _, fit_res in results)
    44                                         
    45                                             # Compute scaling factors for each result
    46    119.9 MiB      0.0 MiB          24       scaling_factors = [
    47    119.9 MiB      0.0 MiB          11           fit_res.num_examples / num_examples_total for _, fit_res in results
    48                                             ]
    49                                         
    50    121.2 MiB      0.0 MiB         385       def _try_inplace(
    51    119.9 MiB      0.0 MiB           3           x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
    52    119.9 MiB      0.0 MiB           1       ) -> NDArray:
    53    121.2 MiB      0.0 MiB         380           return (  # type: ignore[no-any-return]
    54    121.2 MiB      0.0 MiB         380               np_binary_op(x, y, out=x)
    55    121.2 MiB      0.0 MiB         380               if np.can_cast(y, x.dtype, casting="same_kind")
    56                                                     else np_binary_op(x, np.array(y, x.dtype), out=x)
    57                                                 )
    58                                         
    59                                             # Let's do in-place aggregation
    60                                             # Get first result, then add up each other
    61                                         
    62    119.9 MiB      0.0 MiB          44       params = [
    63    119.9 MiB      0.0 MiB          20           _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
    64    119.9 MiB      0.0 MiB          21           for x in parameters_to_ndarrays(results[0][1].parameters)
    65                                             ]
    66                                         
    67    121.2 MiB      0.0 MiB          10       for i, (_, fit_res) in enumerate(results[1:], start=1):
    68    121.2 MiB      0.0 MiB         559           res = (
    69    121.2 MiB      0.0 MiB         180               _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
    70    121.2 MiB      1.3 MiB         189               for x in parameters_to_ndarrays(fit_res.parameters)
    71                                                 )
    72    121.2 MiB      0.0 MiB         396           params = [
    73    121.2 MiB      0.0 MiB         180               reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
    74    121.2 MiB      0.0 MiB         189               for layer_updates in zip(params, res)
    75                                                 ]
    76                                         
    77    121.2 MiB      0.0 MiB           1       return params

Checklist

  • Implement proposed change
  • Write tests
  • Update documentation
  • Make CI checks pass
  • Ping maintainers on Slack (channel #contributions)

Any other comments?

@KarhouTam KarhouTam marked this pull request as draft July 28, 2024 02:39
@KarhouTam KarhouTam force-pushed the fix-aggregate-inplace branch from 81ee2a3 to 0a0d60f Compare July 28, 2024 10:08
@KarhouTam KarhouTam marked this pull request as ready for review July 28, 2024 11:45
@jafermarq
Copy link
Contributor

Hi @KarhouTam , thanks for taking a close look at that function used by many of the strategies. What's the ration of memory footprint saved by the change you propose? Is it possible to give an analytical x% memory savings cost?

@KarhouTam
Copy link
Contributor Author

KarhouTam commented Jul 31, 2024

@jafermarq Thanks for the reviewing first.

What's the ration of memory footprint saved by the change you propose? Is it possible to give an analytical x% memory savings cost?

Sure.

TL;DR: Close to 50% memory savings.

The results calculated first from above used input results is

test_parameters = [np.random.randn(512, 10) for _ in range(20)]

# 10 results, each result has 10 parameters
results = [
    (
        None,
        FitRes(
            status=None,
            parameters=Parameters(
                tensors=[ndarray_to_bytes(param) for param in test_parameters],
                tensor_type="numpy.ndarray",
            ),
            num_examples=1,
            metrics=None,
        ),
    )
    for _ in range(10)
]

And we can see that the Flower version uses memory 2.6M and the proposed uses 1.3M, so 50% maybe. For confirmation, I do the test again with different inputs (results used in fedavg_test.py). The Flower's uses 154.4 - 136.1 = 18.3M; and the proposed uses 145.6 - 136.7 = 8.9M.

So I think the saving is close to 50%.

Full Example

First I wrap these functions by memory_profiler.profile:

# =============== Flower Version ==============
@profile
def flwr_aggregate_inplace(results):
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)

    # Compute scaling factors for each result
    scaling_factors = [
        fit_res.num_examples / num_examples_total for _, fit_res in results
    ]

    # Let's do in-place aggregation
    # Get first result, then add up each other
    params = [
        scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
    ]
    for i, (_, fit_res) in enumerate(results[1:]):
        res = (
            scaling_factors[i + 1] * x
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]

    return params

# =============== Proposed Version ==============
@profile
def my_aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for _, fit_res in results)

    # Compute scaling factors for each result
    scaling_factors = [
        fit_res.num_examples / num_examples_total for _, fit_res in results
    ]

    def _try_inplace(
        x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
    ) -> NDArray:
        return (  # type: ignore[no-any-return]
            np_binary_op(x, y, out=x)
            if np.can_cast(y, x.dtype, casting="same_kind")
            else np_binary_op(x, np.array(y, x.dtype), out=x)
        )

    # Let's do in-place aggregation
    # Get first result, then add up each other

    params = [
        _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
        for x in parameters_to_ndarrays(results[0][1].parameters)
    ]

    for i, (_, fit_res) in enumerate(results[1:], start=1):
        res = (
            _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [
            reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
            for layer_updates in zip(params, res)
        ]

    return params

And set the input argument results as

weights0_0 = np.random.randn(100, 64)
weights0_1 = np.random.randn(314, 628, 3)
weights1_0 = np.random.randn(100, 64)
weights1_1 = np.random.randn(314, 628, 3)

results: List[Tuple[ClientProxy, FitRes]] = [
    (
        None,
        FitRes(
            status=Status(code=Code.OK, message="Success"),
            parameters=ndarrays_to_parameters([weights0_0, weights0_1]),
            num_examples=1,
            metrics={},
        ),
    ),
    (
        None,
        FitRes(
            status=Status(code=Code.OK, message="Success"),
            parameters=ndarrays_to_parameters([weights1_0, weights1_1]),
            num_examples=5,
            metrics={},
        ),
    ),
]

Which is the same as one in . And the test funcitons are

import gc
LOOPS = 1
def test_flwr_aggregate_inplace():
    gc.disable()
    start = time.time()
    for _ in range(LOOPS):
        a = flwr_aggregate_inplace(results)
    end = time.time()
    gc.enable()
    print("flwr's aggregate_inplace cost time: ", end - start)


def test_my_aggregate_inplace():
    gc.disable()
    start = time.time()
    for _ in range(LOOPS):
        b = my_aggregate_inplace(results)
    end = time.time()
    gc.enable()
    print("my aggregate_inplace cost time: ", end - start)

Results

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    22    136.1 MiB    136.1 MiB           1   @profile
    23                                         def flwr_aggregate_inplace(results):
    24                                             """Compute in-place weighted average."""
    25                                             # Count total examples
    26    136.1 MiB      0.0 MiB           7       num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
    27                                         
    28                                             # Compute scaling factors for each result
    29    136.1 MiB      0.0 MiB           8       scaling_factors = [
    30    136.1 MiB      0.0 MiB           3           fit_res.num_examples / num_examples_total for _, fit_res in results
    31                                             ]
    32                                         
    33                                             # Let's do in-place aggregation
    34                                             # Get first result, then add up each other
    35    145.4 MiB      0.0 MiB           8       params = [
    36    145.4 MiB      9.2 MiB           3           scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
    37                                             ]
    38    154.4 MiB      0.0 MiB           2       for i, (_, fit_res) in enumerate(results[1:]):
    39    154.4 MiB      0.0 MiB           9           res = (
    40    149.9 MiB      4.6 MiB           2               scaling_factors[i + 1] * x
    41    145.4 MiB      0.0 MiB           3               for x in parameters_to_ndarrays(fit_res.parameters)
    42                                                 )
    43    154.4 MiB      4.5 MiB           5           params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]
    44                                         
    45    154.4 MiB      0.0 MiB           1       return params


flwr's aggregate_inplace cost time:  0.03412580490112305

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    48    136.7 MiB    136.7 MiB           1   @profile
    49                                         def my_aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
    50                                             """Compute in-place weighted average."""
    51                                             # Count total examples
    52    136.7 MiB      0.0 MiB           7       num_examples_total = sum(fit_res.num_examples for _, fit_res in results)
    53                                         
    54                                             # Compute scaling factors for each result
    55    136.7 MiB      0.0 MiB           8       scaling_factors = [
    56    136.7 MiB      0.0 MiB           3           fit_res.num_examples / num_examples_total for _, fit_res in results
    57                                             ]
    58                                         
    59    145.6 MiB      0.0 MiB          11       def _try_inplace(
    60    136.7 MiB      0.0 MiB           3           x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
    61    136.7 MiB      0.0 MiB           1       ) -> NDArray:
    62    145.6 MiB      0.0 MiB           6           return (  # type: ignore[no-any-return]
    63    145.6 MiB      0.0 MiB           6               np_binary_op(x, y, out=x)
    64    145.6 MiB      0.0 MiB           6               if np.can_cast(y, x.dtype, casting="same_kind")
    65                                                     else np_binary_op(x, np.array(y, x.dtype), out=x)
    66                                                 )
    67                                         
    68                                             # Let's do in-place aggregation
    69                                             # Get first result, then add up each other
    70                                         
    71    141.0 MiB      0.0 MiB           8       params = [
    72    141.0 MiB      0.0 MiB           2           _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
    73    141.0 MiB      4.3 MiB           3           for x in parameters_to_ndarrays(results[0][1].parameters)
    74                                             ]
    75                                         
    76    145.6 MiB      0.0 MiB           2       for i, (_, fit_res) in enumerate(results[1:], start=1):
    77    145.6 MiB      0.0 MiB           9           res = (
    78    145.6 MiB      0.0 MiB           2               _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
    79    145.6 MiB      4.6 MiB           3               for x in parameters_to_ndarrays(fit_res.parameters)
    80                                                 )
    81    145.6 MiB      0.0 MiB           8           params = [
    82    145.6 MiB      0.0 MiB           2               reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
    83    145.6 MiB      0.0 MiB           3               for layer_updates in zip(params, res)
    84                                                 ]
    85                                         
    86    145.6 MiB      0.0 MiB           1       return params


my aggregate_inplace cost time:  0.008509635925292969

@jafermarq
Copy link
Contributor

Hi @KarhouTam, this looks super good. I spend some time on Friday looking into it closely. We aim to get it merged for the next release of flower, bear with us just some days. 🙏

panh99
panh99 previously approved these changes Aug 23, 2024
Copy link
Contributor

@panh99 panh99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@danieljanes
Copy link
Member

Thanks for the PR and the helpful description @KarhouTam! Great improvement.

The merge window for Flower 1.11 is pretty much closed (apart from a few scheduled PRs). We'll merge this after the 1.11 release to have it included in Flower 1.12.

@KarhouTam
Copy link
Contributor Author

Conflicts solved. This PR is ready to be merged in anytime. 🤗

@KarhouTam
Copy link
Contributor Author

Hi, @panh99 @danieljanes .
Is this PR still having problems? Because I found that this PR is still not merged but no new questions or reviews have been given.

@panh99
Copy link
Contributor

panh99 commented Oct 6, 2024

Hi, @panh99 @danieljanes . Is this PR still having problems? Because I found that this PR is still not merged but no new questions or reviews have been given.

Hi @KarhouTam ! We are preparing the 1.12.0 release, and your PR will be merged soon, I believe. Thanks for the reminder!

@danieljanes danieljanes enabled auto-merge (squash) October 7, 2024 10:58
@danieljanes
Copy link
Member

danieljanes commented Oct 7, 2024

Thanks for the PR @KarhouTam, approved! This will be included in the upcoming Flower 1.12 release (ETA: end of this week or early next week)

@danieljanes danieljanes merged commit dbf8d6d into adap:main Oct 7, 2024
50 checks passed
@KarhouTam KarhouTam deleted the fix-aggregate-inplace branch October 7, 2024 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants