diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py index 704657601f50..471c85f0b961 100644 --- a/src/py/flwr/common/record/configsrecord.py +++ b/src/py/flwr/common/record/configsrecord.py @@ -15,7 +15,7 @@ """ConfigsRecord.""" -from typing import Dict, Optional, get_args +from typing import Dict, List, Optional, get_args from flwr.common.typing import ConfigsRecordValues, ConfigsScalar @@ -85,3 +85,39 @@ def __init__( self[k] = configs_dict[k] if not keep_input: del configs_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object. + + This function counts booleans as occupying 1 Byte. + """ + + def get_var_bytes(value: ConfigsScalar) -> int: + """Return Bytes of value passed.""" + if isinstance(value, bool): + var_bytes = 1 + elif isinstance(value, (int, float)): + var_bytes = ( + 8 # the profobufing represents int/floats in ConfigRecords as 64bit + ) + if isinstance(value, (str, bytes)): + var_bytes = len(value) + return var_bytes + + num_bytes = 0 + + for k, v in self.items(): + if isinstance(v, List): + if isinstance(v[0], (bytes, str)): + # not all str are of equal length necessarily + # for both the footprint of each element is 1 Byte + num_bytes += int(sum(len(s) for s in v)) # type: ignore + else: + num_bytes += get_var_bytes(v[0]) * len(v) + else: + num_bytes += get_var_bytes(v) + + # We also count the bytes footprint of the keys + num_bytes += len(k) + + return num_bytes diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py index 81b02303421b..2b6e584be390 100644 --- a/src/py/flwr/common/record/metricsrecord.py +++ b/src/py/flwr/common/record/metricsrecord.py @@ -15,7 +15,7 @@ """MetricsRecord.""" -from typing import Dict, Optional, get_args +from typing import Dict, List, Optional, get_args from flwr.common.typing import MetricsRecordValues, MetricsScalar @@ -84,3 +84,19 @@ def __init__( self[k] = metrics_dict[k] if not keep_input: del metrics_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object.""" + num_bytes = 0 + + for k, v in self.items(): + if isinstance(v, List): + # both int and float normally take 4 bytes + # But MetricRecords are mapped to 64bit int/float + # during protobuffing + num_bytes += 8 * len(v) + else: + num_bytes += 8 + # We also count the bytes footprint of the keys + num_bytes += len(k) + return num_bytes diff --git a/src/py/flwr/common/record/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py index 17bf3f608db7..a4a71f751f97 100644 --- a/src/py/flwr/common/record/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -117,3 +117,20 @@ def __init__( self[k] = array_dict[k] if not keep_input: del array_dict[k] + + def count_bytes(self) -> int: + """Return number of Bytes stored in this object. + + Note that a small amount of Bytes might also be included in this counting that + correspond to metadata of the serialized object (e.g. of NumPy array) needed for + deseralization. + """ + num_bytes = 0 + + for k, v in self.items(): + num_bytes += len(v.data) + + # We also count the bytes footprint of the keys + num_bytes += len(k) + + return num_bytes diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py index 9633af7bda6d..e840e5e266e4 100644 --- a/src/py/flwr/common/record/parametersrecord_test.py +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -14,14 +14,26 @@ # ============================================================================== """Unit tests for ParametersRecord and Array.""" - import unittest +from collections import OrderedDict from io import BytesIO +from typing import List import numpy as np +import pytest + +from flwr.common import ndarray_to_bytes from ..constant import SType -from .parametersrecord import Array +from ..typing import NDArray +from .parametersrecord import Array, ParametersRecord + + +def _get_buffer_from_ndarray(array: NDArray) -> bytes: + """Return a bytes buffer froma given NumPy array.""" + buffer = BytesIO() + np.save(buffer, array, allow_pickle=False) + return buffer.getvalue() class TestArray(unittest.TestCase): @@ -31,16 +43,15 @@ def test_numpy_conversion_valid(self) -> None: """Test the numpy method with valid Array instance.""" # Prepare original_array = np.array([1, 2, 3], dtype=np.float32) - buffer = BytesIO() - np.save(buffer, original_array, allow_pickle=False) - buffer.seek(0) + + buffer = _get_buffer_from_ndarray(original_array) # Execute array_instance = Array( dtype=str(original_array.dtype), shape=list(original_array.shape), stype=SType.NUMPY, - data=buffer.read(), + data=buffer, ) converted_array = array_instance.numpy() @@ -60,3 +71,31 @@ def test_numpy_conversion_invalid(self) -> None: # Execute and assert with self.assertRaises(TypeError): array_instance.numpy() + + +@pytest.mark.parametrize( + "shape, dtype", + [ + ([100], "float32"), + ([31, 31], "int8"), + ([31, 153], "bool_"), # bool_ is represented as a whole Byte in NumPy + ], +) +def test_count_bytes(shape: List[int], dtype: str) -> None: + """Test bytes in a ParametersRecord are computed correctly.""" + original_array = np.random.randn(*shape).astype(np.dtype(dtype)) + + buff = ndarray_to_bytes(original_array) + + buffer = _get_buffer_from_ndarray(original_array) + + array_instance = Array( + dtype=str(original_array.dtype), + shape=list(original_array.shape), + stype=SType.NUMPY, + data=buffer, + ) + key_name = "data" + p_record = ParametersRecord(OrderedDict({key_name: array_instance})) + + assert len(buff) + len(key_name) == p_record.count_bytes() diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index bcf5c75a1e02..0e0b149881be 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -359,3 +359,42 @@ def test_set_configs_to_configsrecord_with_incorrect_types( with pytest.raises(TypeError): c_record.update(my_configs) + + +def test_count_bytes_metricsrecord() -> None: + """Test counting bytes in MetricsRecord.""" + data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]} + bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8 + bytes_in_dict += 4 # represnting the keys + + m_record = MetricsRecord() + m_record.update(OrderedDict(data)) + record_bytest_count = m_record.count_bytes() + assert bytes_in_dict == record_bytest_count + + +def test_count_bytes_configsrecord() -> None: + """Test counting bytes in ConfigsRecord.""" + data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]} + bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8 + bytes_in_dict += 4 # represnting the keys + + to_add = { + "aa": True, + "bb": "False", + "cc": bytes(9), + "dd": [True, False, False], + "ee": ["True", "False"], + "ff": [bytes(1), bytes(13), bytes(51)], + } + data = {**data, **to_add} + bytes_in_dict += 1 + 5 + 9 + 3 + (4 + 5) + (1 + 13 + 51) + bytes_in_dict += 12 # represnting the keys + + bytes_in_dict = int(bytes_in_dict) + + c_record = ConfigsRecord() + c_record.update(OrderedDict(data)) + + record_bytest_count = c_record.count_bytes() + assert bytes_in_dict == record_bytest_count