From 13d2ee31423b7cb81da67726e195543c48570019 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 18:15:15 +0000 Subject: [PATCH] ordereddict throughout --- src/py/flwr/common/parametersrecord.py | 10 +++++----- src/py/flwr/common/recordset_test.py | 14 +++++++------- src/py/flwr/common/recordset_utils.py | 6 +++++- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py index bf1bd3ad147b..3d40c0488baa 100644 --- a/src/py/flwr/common/parametersrecord.py +++ b/src/py/flwr/common/parametersrecord.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field -from typing import Dict, List, Optional, OrderedDict +from typing import List, Optional, OrderedDict @dataclass @@ -64,14 +64,14 @@ class ParametersRecord: def __init__( self, - array_dict: Optional[Dict[str, Array]] = None, + array_dict: Optional[OrderedDict[str, Array]] = None, keep_input: bool = False, ) -> None: """Construct a ParametersRecord object. Parameters ---------- - array_dict : Optional[Dict[str, Array]] + array_dict : Optional[OrderedDict[str, Array]] A dictionary that stores serialized array-like or tensor-like objects. keep_input : bool (default: False) A boolean indicating whether parameters should be deleted from the input @@ -87,12 +87,12 @@ def __init__( if array_dict: self.set_parameters(array_dict) - def set_parameters(self, array_dict: Dict[str, Array]) -> None: + def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: """Add parameters to record. Parameters ---------- - array_dict : Dict[str, Array] + array_dict : OrderedDict[str, Array] A dictionary that stores serialized array-like or tensor-like objects. """ if any(not isinstance(k, str) for k in array_dict.keys()): diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 8a87100f48e0..90c06dcdb109 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -15,7 +15,7 @@ """RecordSet tests.""" -from typing import Callable, List, Type, Union +from typing import Callable, List, OrderedDict, Type, Union import numpy as np import pytest @@ -100,9 +100,9 @@ def test_set_parameters_while_keeping_intputs() -> None: """Tests keep_input functionality in ParametersRecord.""" # Adding parameters to a record that doesn't erase entries in the input `array_dict` p_record = ParametersRecord(keep_input=True) - array_dict = { - str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays()) - } + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) p_record.set_parameters(array_dict) # Creating a second parametersrecord passing the same `array_dict` (not erased) @@ -116,9 +116,9 @@ def test_set_parameters_while_keeping_intputs() -> None: def test_set_parameters_with_correct_types() -> None: """Test adding dictionary of Arrays to ParametersRecord.""" p_record = ParametersRecord() - array_dict = { - str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays()) - } + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) p_record.set_parameters(array_dict) diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py index a4252cb972d5..c1e724fa2758 100644 --- a/src/py/flwr/common/recordset_utils.py +++ b/src/py/flwr/common/recordset_utils.py @@ -15,6 +15,8 @@ """RecordSet utilities.""" +from typing import OrderedDict + from .parametersrecord import Array, ParametersRecord from .typing import Parameters @@ -77,7 +79,9 @@ def parameters_to_parametersrecord( else: tensor = parameters.tensors.pop(0) p_record.set_parameters( - {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) ) return p_record