Skip to content

Commit

Permalink
refactor(framework) Support uint64 in Scalar and ConfigsRecordValue (#…
Browse files Browse the repository at this point in the history
…4243)

Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
mohammadnaseri and panh99 authored Sep 26, 2024
1 parent e6fb60d commit 7c38f2c
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 87 deletions.
27 changes: 18 additions & 9 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ syntax = "proto3";

package flwr.proto;

message Int {
oneof int {
sint64 sint64 = 1;
uint64 uint64 = 2;
}
}

message DoubleList { repeated double vals = 1; }
message Sint64List { repeated sint64 vals = 1; }
message IntList { repeated Int vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }
Expand All @@ -35,10 +42,11 @@ message MetricsRecordValue {
// Single element
double double = 1;
sint64 sint64 = 2;
uint64 uint64 = 3;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
IntList int_list = 22;
}
}

Expand All @@ -47,16 +55,17 @@ message ConfigsRecordValue {
// Single element
double double = 1;
sint64 sint64 = 2;
bool bool = 3;
string string = 4;
bytes bytes = 5;
uint64 uint64 = 3;
bool bool = 4;
string string = 5;
bytes bytes = 6;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
BoolList bool_list = 23;
StringList string_list = 24;
BytesList bytes_list = 25;
IntList int_list = 22;
BoolList bool_list = 24;
StringList string_list = 25;
BytesList bytes_list = 26;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/transport.proto
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ message Scalar {
// int32 int32 = 3;
// int64 int64 = 4;
// uint32 uint32 = 5;
// uint64 uint64 = 6;
uint64 uint64 = 6;
// sint32 sint32 = 7;
sint64 sint64 = 8;
// fixed32 fixed32 = 9;
Expand Down
42 changes: 32 additions & 10 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from flwr.proto.recordset_pb2 import BoolList, BytesList
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.recordset_pb2 import DoubleList
from flwr.proto.recordset_pb2 import DoubleList, Int, IntList
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.recordset_pb2 import StringList
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import (
Expand Down Expand Up @@ -354,7 +354,9 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
return Scalar(double=scalar)

if isinstance(scalar, int):
return Scalar(sint64=scalar)
if scalar >= 0:
return Scalar(uint64=scalar) # Use uint64 for non-negative integers
return Scalar(sint64=scalar) # Use sint64 for negative integers

if isinstance(scalar, str):
return Scalar(string=scalar)
Expand All @@ -374,23 +376,36 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
# === Record messages ===


_type_to_field = {
_type_to_field: dict[type, str] = {
float: "double",
int: "sint64",
int: "int",
bool: "bool",
str: "string",
bytes: "bytes",
}
_list_type_to_class_and_field = {
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
float: (DoubleList, "double_list"),
int: (Sint64List, "sint64_list"),
int: (IntList, "int_list"),
bool: (BoolList, "bool_list"),
str: (StringList, "string_list"),
bytes: (BytesList, "bytes_list"),
}
T = TypeVar("T")


def int_to_proto(value: int) -> Int:
"""Serialize a int to `Int`."""
if value >= 0:
return Int(uint64=value)
return Int(sint64=value)


def int_from_proto(value_proto: Int) -> int:
"""Deserialize a int from `Int`."""
fld = cast(str, value_proto.WhichOneof("int"))
return cast(int, getattr(value_proto, fld))


def _record_value_to_proto(
value: Any, allowed_types: list[type], proto_class: type[T]
) -> T:
Expand All @@ -403,12 +418,17 @@ def _record_value_to_proto(
# Single element
# Note: `isinstance(False, int) == True`.
if isinstance(value, t):
arg[_type_to_field[t]] = value
fld = _type_to_field[t]
if t is int:
fld = "uint64" if cast(int, value) >= 0 else "sint64"
arg[fld] = value
return proto_class(**arg)
# List
if isinstance(value, list) and all(isinstance(item, t) for item in value):
list_class, field_name = _list_type_to_class_and_field[t]
arg[field_name] = list_class(vals=value)
list_class, fld = _list_type_to_class_and_field[t]
if t is int:
value = [int_to_proto(v) for v in value]
arg[fld] = list_class(vals=value)
return proto_class(**arg)
# Invalid types
raise TypeError(
Expand All @@ -422,6 +442,8 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
value = list(getattr(value_proto, value_field).vals)
if value_field == "int_list":
value = [int_from_proto(v) for v in value]
else:
value = getattr(value_proto, value_field)
return value
Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
def test_serialisation_deserialisation() -> None:
"""Test if the np.ndarray is identical after (de-)serialization."""
# Prepare
scalars = [True, b"bytestr", 3.14, 9000, "Hello"]
scalars = [True, b"bytestr", 3.14, 9000, "Hello", (1 << 63) + 1]

for scalar in scalars:
# Execute
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_value(self, dtype: type[T]) -> T:
elif dtype == str:
ret = self.get_str(self.rng.randint(10, 100))
elif dtype == int:
ret = self.rng.randint(-1 << 30, 1 << 30)
ret = self.rng.randint(-1 << 63, (1 << 64) - 1)
elif dtype == float:
ret = (self.rng.random() - 0.5) * (2.0 ** self.rng.randint(0, 50))
elif dtype == bytes:
Expand Down Expand Up @@ -316,6 +316,7 @@ def test_metrics_record_serialization_deserialization() -> None:
# Prepare
maker = RecordMaker()
original = maker.metrics_record()
original["uint64"] = (1 << 63) + 321

# Execute
proto = metrics_record_to_proto(original)
Expand All @@ -331,6 +332,7 @@ def test_configs_record_serialization_deserialization() -> None:
# Prepare
maker = RecordMaker()
original = maker.configs_record()
original["uint64"] = (1 << 63) + 101

# Execute
proto = configs_record_to_proto(original)
Expand Down
72 changes: 37 additions & 35 deletions src/py/flwr/proto/recordset_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7c38f2c

Please sign in to comment.