Skip to content

Commit

Permalink
ordereddict throughout
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 17, 2024
1 parent f66990f commit 13d2ee3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/py/flwr/common/parametersrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from dataclasses import dataclass, field
from typing import Dict, List, Optional, OrderedDict
from typing import List, Optional, OrderedDict


@dataclass
Expand Down Expand Up @@ -64,14 +64,14 @@ class ParametersRecord:

def __init__(
self,
array_dict: Optional[Dict[str, Array]] = None,
array_dict: Optional[OrderedDict[str, Array]] = None,
keep_input: bool = False,
) -> None:
"""Construct a ParametersRecord object.
Parameters
----------
array_dict : Optional[Dict[str, Array]]
array_dict : Optional[OrderedDict[str, Array]]
A dictionary that stores serialized array-like or tensor-like objects.
keep_input : bool (default: False)
A boolean indicating whether parameters should be deleted from the input
Expand All @@ -87,12 +87,12 @@ def __init__(
if array_dict:
self.set_parameters(array_dict)

def set_parameters(self, array_dict: Dict[str, Array]) -> None:
def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
"""Add parameters to record.
Parameters
----------
array_dict : Dict[str, Array]
array_dict : OrderedDict[str, Array]
A dictionary that stores serialized array-like or tensor-like objects.
"""
if any(not isinstance(k, str) for k in array_dict.keys()):
Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""RecordSet tests."""


from typing import Callable, List, Type, Union
from typing import Callable, List, OrderedDict, Type, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -100,9 +100,9 @@ def test_set_parameters_while_keeping_intputs() -> None:
"""Tests keep_input functionality in ParametersRecord."""
# Adding parameters to a record that doesn't erase entries in the input `array_dict`
p_record = ParametersRecord(keep_input=True)
array_dict = {
str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())
}
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
p_record.set_parameters(array_dict)

# Creating a second parametersrecord passing the same `array_dict` (not erased)
Expand All @@ -116,9 +116,9 @@ def test_set_parameters_while_keeping_intputs() -> None:
def test_set_parameters_with_correct_types() -> None:
"""Test adding dictionary of Arrays to ParametersRecord."""
p_record = ParametersRecord()
array_dict = {
str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())
}
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
p_record.set_parameters(array_dict)


Expand Down
6 changes: 5 additions & 1 deletion src/py/flwr/common/recordset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""RecordSet utilities."""


from typing import OrderedDict

from .parametersrecord import Array, ParametersRecord
from .typing import Parameters

Expand Down Expand Up @@ -77,7 +79,9 @@ def parameters_to_parametersrecord(
else:
tensor = parameters.tensors.pop(0)
p_record.set_parameters(
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
OrderedDict(
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
)
)

return p_record

0 comments on commit 13d2ee3

Please sign in to comment.