diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index aba3726017fd..7ac339aa43c8 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -51,6 +51,64 @@ jobs: short_sha: ${{ steps.upload.outputs.SHORT_SHA }} dir: ${{ steps.upload.outputs.DIR }} + superexec: + runs-on: ubuntu-22.04 + timeout-minutes: 10 + needs: wheel + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + directory: [e2e-bare-auth] + connection: [secure, insecure] + engine: [deployment-engine, simulation-engine] + authentication: [no-auth, client-auth] + exclude: + - connection: insecure + authentication: client-auth + name: | + SuperExec / + Python ${{ matrix.python-version }} / + ${{ matrix.connection }} / + ${{ matrix.authentication }} / + ${{ matrix.engine }} + defaults: + run: + working-directory: e2e/${{ matrix.directory }} + steps: + - uses: actions/checkout@v4 + - name: Bootstrap + uses: ./.github/actions/bootstrap + with: + python-version: ${{ matrix.python-version }} + poetry-skip: 'true' + - name: Install Flower from repo + if: ${{ github.repository != 'adap/flower' || github.event.pull_request.head.repo.fork || github.actor == 'dependabot[bot]' }} + working-directory: ./ + run: | + if [[ "${{ matrix.engine }}" == "simulation-engine" ]]; then + python -m pip install ".[simulation]" + else + python -m pip install . + fi + - name: Download and install Flower wheel from artifact store + if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} + run: | + # Define base URL for wheel file + WHEEL_URL="https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }}" + if [[ "${{ matrix.engine }}" == "simulation-engine" ]]; then + python -m pip install "flwr[simulation] @ ${WHEEL_URL}" + else + python -m pip install "${WHEEL_URL}" + fi + - name: > + Run SuperExec test / + ${{ matrix.connection }} / + ${{ matrix.authentication }} / + ${{ matrix.engine }} + working-directory: e2e/${{ matrix.directory }} + run: ./../test_superexec.sh "${{ matrix.connection }}" "${{ matrix.authentication}}" "${{ matrix.engine }}" + frameworks: runs-on: ubuntu-22.04 timeout-minutes: 10 diff --git a/benchmarks/flowertune-llm/README.md b/benchmarks/flowertune-llm/README.md index 45cd8a828a89..cab9b9156514 100644 --- a/benchmarks/flowertune-llm/README.md +++ b/benchmarks/flowertune-llm/README.md @@ -13,13 +13,13 @@ As the first step, please register for a Flower account on [flower.ai/login](htt Then, create a new Python environment and install Flower. > [!TIP] -> We recommend using `pyenv` with the `virtualenv` plugin to create your environment. Other managers, such as Conda, will likely work as well. Check the [documentation](https://flower.ai/docs/framework/how-to-install-flower.html) for alternative ways to install Flower. +> We recommend using `pyenv` with the `virtualenv` plugin to create your environment with Python >= 3.10.0. Other managers, such as Conda, will likely work as well. Check the [documentation](https://flower.ai/docs/framework/how-to-install-flower.html) for alternative ways to install Flower. ```shell pip install flwr ``` -In the new environment, create a new Flower project using the `FlowerTune` template. You will be prompted for a name to give to your project, your username, and for your choice of LLM challenge: +In the new environment, create a new Flower project using the `FlowerTune` template. You will be prompted for a name to give to your app/project, your username, and for your choice of LLM challenge: ```shell flwr new --framework=FlowerTune ``` @@ -64,5 +64,5 @@ following the `README.md` in [`evaluation`](https://github.com/adap/flower/tree/ > [!NOTE] -> If you have any questions about running FlowerTune LLM challenges or evaluation, please feel free to make posts at [Flower Discuss](https://discuss.flower.ai) forum, +> If you have any questions about running FlowerTune LLM challenges or evaluation, please feel free to make posts at our dedicated [FlowerTune Category](https://discuss.flower.ai/c/flowertune-llm-leaderboard/) on [Flower Discuss](https://discuss.flower.ai) forum, or join our [Slack channel](https://flower.ai/join-slack/) to ask questions in the `#flowertune-llm-leaderboard` channel. diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index 070655550fa1..d6b51fc84ad6 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -3,14 +3,7 @@ Flower Datasets Flower Datasets (``flwr-datasets``) is a library that enables the quick and easy creation of datasets for federated learning/analytics/evaluation. It enables heterogeneity (non-iidness) simulation and division of datasets with the preexisting notion of IDs. The library was created by the ``Flower Labs`` team that also created `Flower `_ : A Friendly Federated Learning Framework. -.. raw:: html - - - - +Try out an interactive demo to generate code and visualize heterogeneous divisions at the :ref:`bottom of this page`. Flower Datasets Framework ------------------------- @@ -142,7 +135,6 @@ What makes Flower Datasets stand out from other libraries? * New custom partitioning schemes (``Partitioner`` subclasses) integrated with the whole ecosystem. - Join the Flower Community ------------------------- @@ -153,3 +145,16 @@ The Flower Community is growing quickly - we're a friendly group of researchers, :shadow: Join us on Slack + +.. _demo: +Demo +---- + +.. raw:: html + + + + diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py index 350383f344e7..d114ccbda02f 100644 --- a/datasets/flwr_datasets/partitioner/pathological_partitioner.py +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -225,7 +225,7 @@ def _determine_partition_id_to_unique_labels(self) -> None: if self._class_assignment_mode == "first-deterministic": # if self._first_class_deterministic_assignment: for partition_id in range(self._num_partitions): - label = partition_id % num_unique_classes + label = self._unique_labels[partition_id % num_unique_classes] self._partition_id_to_unique_labels[partition_id].append(label) while ( diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py index 18707a56bd98..5a3b13bb1436 100644 --- a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from parameterized import parameterized +from parameterized import parameterized, parameterized_class import datasets from datasets import Dataset @@ -26,7 +26,10 @@ def _dummy_dataset_setup( - num_samples: int, partition_by: str, num_unique_classes: int + num_samples: int, + partition_by: str, + num_unique_classes: int, + string_partition_by: bool = False, ) -> Dataset: """Create a dummy dataset for testing.""" data = { @@ -35,6 +38,8 @@ def _dummy_dataset_setup( )[:num_samples], "features": np.random.randn(num_samples), } + if string_partition_by: + data[partition_by] = data[partition_by].astype(str) return Dataset.from_dict(data) @@ -51,6 +56,7 @@ def _dummy_heterogeneous_dataset_setup( return Dataset.from_dict(data) +@parameterized_class(("string_partition_by",), [(False,), (True,)]) class TestClassConstrainedPartitioner(unittest.TestCase): """Unit tests for PathologicalPartitioner.""" @@ -94,7 +100,8 @@ def test_first_class_deterministic_assignment(self) -> None: Test if all the classes are used (which has to be the case, given num_partitions >= than the number of unique classes). """ - dataset = _dummy_dataset_setup(100, "labels", 10) + partition_by = "labels" + dataset = _dummy_dataset_setup(100, partition_by, 10) partitioner = PathologicalPartitioner( num_partitions=10, partition_by="labels", @@ -103,7 +110,12 @@ def test_first_class_deterministic_assignment(self) -> None: ) partitioner.dataset = dataset partitioner.load_partition(0) - expected_classes = set(range(10)) + expected_classes = set( + range(10) + # pylint: disable=unsubscriptable-object + if isinstance(dataset[partition_by][0], int) + else [str(i) for i in range(10)] + ) actual_classes = set() for pid in range(10): partition = partitioner.load_partition(pid) @@ -141,6 +153,9 @@ def test_deterministic_class_assignment( for i in range(num_classes_per_partition) ] ) + # pylint: disable=unsubscriptable-object + if isinstance(dataset["labels"][0], str): + expected_labels = [str(label) for label in expected_labels] actual_labels = sorted(np.unique(partition["labels"])) self.assertTrue( np.array_equal(expected_labels, actual_labels), @@ -166,6 +181,9 @@ def test_too_many_partitions_for_a_class( "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), "features": np.random.randn(num_samples // 2), } + # pylint: disable=unsubscriptable-object + if isinstance(dataset_1["labels"][0], str): + data["labels"] = data["labels"].astype(str) dataset_2 = Dataset.from_dict(data) dataset = datasets.concatenate_datasets([dataset_1, dataset_2]) diff --git a/e2e/e2e-bare-auth/certificate.conf b/e2e/e2e-bare-auth/certificate.conf index ea97fcbb700d..04a2ed388174 100644 --- a/e2e/e2e-bare-auth/certificate.conf +++ b/e2e/e2e-bare-auth/certificate.conf @@ -18,3 +18,4 @@ subjectAltName = @alt_names DNS.1 = localhost IP.1 = ::1 IP.2 = 127.0.0.1 +IP.3 = 0.0.0.0 diff --git a/e2e/test_superexec.sh b/e2e/test_superexec.sh new file mode 100755 index 000000000000..ae79128c6ac1 --- /dev/null +++ b/e2e/test_superexec.sh @@ -0,0 +1,122 @@ +#!/bin/bash +set -e + +# Set connectivity parameters +case "$1" in + secure) + ./generate.sh + server_arg='--ssl-ca-certfile ../certificates/ca.crt + --ssl-certfile ../certificates/server.pem + --ssl-keyfile ../certificates/server.key' + client_arg='--root-certificates ../certificates/ca.crt' + # For $superexec_arg, note special ordering of single- and double-quotes + superexec_arg='--executor-config 'root-certificates=\"../certificates/ca.crt\"'' + superexec_arg="$server_arg $superexec_arg" + ;; + insecure) + server_arg='--insecure' + client_arg=$server_arg + superexec_arg=$server_arg + ;; +esac + +# Set authentication parameters +case "$2" in + client-auth) + server_auth='--auth-list-public-keys ../keys/client_public_keys.csv + --auth-superlink-private-key ../keys/server_credentials + --auth-superlink-public-key ../keys/server_credentials.pub' + client_auth_1='--auth-supernode-private-key ../keys/client_credentials_1 + --auth-supernode-public-key ../keys/client_credentials_1.pub' + client_auth_2='--auth-supernode-private-key ../keys/client_credentials_2 + --auth-supernode-public-key ../keys/client_credentials_2.pub' + server_address='127.0.0.1:9092' + ;; + *) + server_auth='' + client_auth_1='' + client_auth_2='' + server_address='127.0.0.1:9092' + ;; +esac + +# Set engine +case "$3" in + deployment-engine) + superexec_engine_arg='--executor flwr.superexec.deployment:executor' + ;; + simulation-engine) + superexec_engine_arg='--executor flwr.superexec.simulation:executor + --executor-config 'num-supernodes=10'' + ;; +esac + + +# Create and install Flower app +flwr new e2e-tmp-test --framework numpy --username flwrlabs +cd e2e-tmp-test +# Remove flwr dependency from `pyproject.toml`. Seems necessary so that it does +# not override the wheel dependency +if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS (Darwin) system + sed -i '' '/flwr\[simulation\]/d' pyproject.toml +else + # Non-macOS system (Linux) + sed -i '/flwr\[simulation\]/d' pyproject.toml +fi +pip install -e . --no-deps + +# Check if the first argument is 'insecure' +if [ "$1" == "insecure" ]; then + # If $1 is 'insecure', append the first line + echo -e $"\n[tool.flwr.federations.superexec]\naddress = \"127.0.0.1:9093\"\ninsecure = true" >> pyproject.toml +else + # Otherwise, append the second line + echo -e $"\n[tool.flwr.federations.superexec]\naddress = \"127.0.0.1:9093\"\nroot-certificates = \"../certificates/ca.crt\"" >> pyproject.toml +fi + +timeout 2m flower-superlink $server_arg $server_auth & +sl_pid=$! +sleep 2 + +timeout 2m flower-supernode ./ $client_arg \ + --superlink $server_address $client_auth_1 \ + --node-config "partition-id=0 num-partitions=2" --max-retries 0 & +cl1_pid=$! +sleep 2 + +timeout 2m flower-supernode ./ $client_arg \ + --superlink $server_address $client_auth_2 \ + --node-config "partition-id=1 num-partitions=2" --max-retries 0 & +cl2_pid=$! +sleep 2 + +timeout 2m flower-superexec $superexec_arg $superexec_engine_arg 2>&1 | tee flwr_output.log & +se_pid=$(pgrep -f "flower-superexec") +sleep 2 + +timeout 1m flwr run --run-config num-server-rounds=1 ../e2e-tmp-test superexec + +# Initialize a flag to track if training is successful +found_success=false +timeout=120 # Timeout after 120 seconds +elapsed=0 + +# Check for "Success" in a loop with a timeout +while [ "$found_success" = false ] && [ $elapsed -lt $timeout ]; do + if grep -q "Run finished" flwr_output.log; then + echo "Training worked correctly!" + found_success=true + kill $cl1_pid; kill $cl2_pid; sleep 1; kill $sl_pid; kill $se_pid; + else + echo "Waiting for training ... ($elapsed seconds elapsed)" + fi + # Sleep for a short period and increment the elapsed time + sleep 2 + elapsed=$((elapsed + 2)) +done + +if [ "$found_success" = false ]; then + echo "Training had an issue and timed out." + kill $cl1_pid; kill $cl2_pid; kill $sl_pid; kill $se_pid; +fi diff --git a/glossary/flower-datasets.mdx b/glossary/flower-datasets.mdx new file mode 100644 index 000000000000..24537dfe223b --- /dev/null +++ b/glossary/flower-datasets.mdx @@ -0,0 +1,27 @@ +--- +title: "Flower Datasets" +description: "Flower Datasets is a library that enables the creation of datasets for federated learning by partitioning centralized datasets to exhibit heterogeneity or using naturally partitioned datasets." +date: "2024-05-24" +author: + name: "Adam Narożniak" + position: "ML Engineer at Flower Labs" + website: "https://discuss.flower.ai/u/adam.narozniak/summary" +related: + - text: "Flower Datasets documentation" + link: "https://flower.ai/docs/datasets/" + - text: "Flower Datasets GitHub page" + link: "https://github.com/adap/flower/tree/main/datasets" +--- + +Flower Datasets is a library that enables the creation of datasets for federated learning/analytics/evaluation by partitioning centralized datasets to exhibit heterogeneity or using naturally partitioned datasets. It was created by the Flower Labs team, which also created Flower - a Friendly Federated Learning Framework. + +The key features include: +* downloading datasets (HuggingFace `datasets` are used under the hood), +* partitioning (simulate different levels of heterogeneity by using one of the implemented partitioning schemes or create your own), +* creating centralized datasets (easily utilize centralized versions of the datasets), +* reproducibility (repeat the experiments with the same results), +* visualization (display the created partitions), +* ML agnostic (easy integration with all popular ML frameworks). + + +It is a supplementary library to Flower, with which it integrates easily. diff --git a/pyproject.toml b/pyproject.toml index 81c1369f6552..87059cf5c867 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ protobuf = "^4.25.2" cryptography = "^42.0.4" pycryptodome = "^3.18.0" iterators = "^0.0.2" -typer = { version = "^0.9.0", extras = ["all"] } +typer = "^0.12.5" tomli = "^2.0.1" tomli-w = "^1.0.0" pathspec = "^0.12.1" diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto index 9ee4a9572d92..939e97cf46e3 100644 --- a/src/proto/flwr/proto/recordset.proto +++ b/src/proto/flwr/proto/recordset.proto @@ -17,15 +17,9 @@ syntax = "proto3"; package flwr.proto; -message Int { - oneof int { - sint64 sint64 = 1; - uint64 uint64 = 2; - } -} - message DoubleList { repeated double vals = 1; } -message IntList { repeated Int vals = 1; } +message SintList { repeated sint64 vals = 1; } +message UintList { repeated uint64 vals = 1; } message BoolList { repeated bool vals = 1; } message StringList { repeated string vals = 1; } message BytesList { repeated bytes vals = 1; } @@ -46,7 +40,8 @@ message MetricsRecordValue { // List types DoubleList double_list = 21; - IntList int_list = 22; + SintList sint_list = 22; + UintList uint_list = 23; } } @@ -62,7 +57,8 @@ message ConfigsRecordValue { // List types DoubleList double_list = 21; - IntList int_list = 22; + SintList sint_list = 22; + UintList uint_list = 23; BoolList bool_list = 24; StringList string_list = 25; BytesList bytes_list = 26; diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 27f759a71713..a029b926423f 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -16,12 +16,13 @@ import base64 +import inspect import threading import unittest from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Union +from typing import Optional, Union, get_args import grpc @@ -47,6 +48,7 @@ PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.fleet_pb2_grpc import FleetServicer from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 @@ -437,6 +439,20 @@ def test_without_servicer(self) -> None: assert self._servicer.received_client_metadata() is None + def test_fleet_requests_included(self) -> None: + """Test if all Fleet requests are included in the authentication mode.""" + # Prepare + requests = get_args(Request) + rpc_names = {req.__qualname__.removesuffix("Request") for req in requests} + expected_rpc_names = { + name + for name, ref in inspect.getmembers(FleetServicer) + if inspect.isfunction(ref) + } + + # Assert + assert expected_rpc_names == rpc_names + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 3dce14c14956..69ea29d5b7b3 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -24,10 +24,14 @@ from flwr.common import log from flwr.common.constant import ( + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY, + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY, GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY, + GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY, + GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY, GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, ) -from flwr.common.version import package_version +from flwr.common.version import package_name, package_version from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -62,9 +66,16 @@ def _send_and_receive( self, request: GrpcMessage, response_type: type[T], **kwargs: Any ) -> T: # Serialize request + req_cls = request.__class__ container_req = MessageContainer( - metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version}, - grpc_message_name=request.__class__.__qualname__, + metadata={ + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY: package_name, + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY: package_version, + GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version, + GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY: req_cls.__module__, + GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY: req_cls.__qualname__, + }, + grpc_message_name=req_cls.__qualname__, grpc_message_content=request.SerializeToString(), ) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index eabe324f41c5..ffd58478aa48 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -60,8 +60,6 @@ # IDs RUN_ID_NUM_BYTES = 8 NODE_ID_NUM_BYTES = 8 -GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" -GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" # Constants for FAB APP_DIR = "apps" @@ -72,8 +70,13 @@ PARTITION_ID_KEY = "partition-id" NUM_PARTITIONS_KEY = "num-partitions" -GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" +# Constants for keys in `metadata` of `MessageContainer` in `grpc-adapter` +GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY = "flower-package-name" +GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY = "flower-package-version" +GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" # Deprecated GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" +GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY = "grpc-message-module" +GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY = "grpc-message-qualname" class MessageType: diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 5cb5a87c49df..54790992b40d 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -33,12 +33,12 @@ from flwr.proto.recordset_pb2 import BoolList, BytesList from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue -from flwr.proto.recordset_pb2 import DoubleList, Int, IntList +from flwr.proto.recordset_pb2 import DoubleList from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet -from flwr.proto.recordset_pb2 import StringList +from flwr.proto.recordset_pb2 import SintList, StringList, UintList from flwr.proto.run_pb2 import Run as ProtoRun from flwr.proto.task_pb2 import Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ( @@ -340,6 +340,7 @@ def metrics_from_proto(proto: Any) -> typing.Metrics: # === Scalar messages === +INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1 def scalar_to_proto(scalar: typing.Scalar) -> Scalar: @@ -354,9 +355,10 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: return Scalar(double=scalar) if isinstance(scalar, int): - if scalar >= 0: - return Scalar(uint64=scalar) # Use uint64 for non-negative integers - return Scalar(sint64=scalar) # Use sint64 for negative integers + # Use uint64 for integers larger than the maximum value of sint64 + if scalar > INT64_MAX_VALUE: + return Scalar(uint64=scalar) + return Scalar(sint64=scalar) if isinstance(scalar, str): return Scalar(string=scalar) @@ -378,14 +380,14 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: _type_to_field: dict[type, str] = { float: "double", - int: "int", + int: "sint64", bool: "bool", str: "string", bytes: "bytes", } _list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = { float: (DoubleList, "double_list"), - int: (IntList, "int_list"), + int: (SintList, "sint_list"), bool: (BoolList, "bool_list"), str: (StringList, "string_list"), bytes: (BytesList, "bytes_list"), @@ -393,17 +395,9 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: T = TypeVar("T") -def int_to_proto(value: int) -> Int: - """Serialize a int to `Int`.""" - if value >= 0: - return Int(uint64=value) - return Int(sint64=value) - - -def int_from_proto(value_proto: Int) -> int: - """Deserialize a int from `Int`.""" - fld = cast(str, value_proto.WhichOneof("int")) - return cast(int, getattr(value_proto, fld)) +def _is_uint64(value: Any) -> bool: + """Check if a value is uint64.""" + return isinstance(value, int) and value > INT64_MAX_VALUE def _record_value_to_proto( @@ -419,15 +413,16 @@ def _record_value_to_proto( # Note: `isinstance(False, int) == True`. if isinstance(value, t): fld = _type_to_field[t] - if t is int: - fld = "uint64" if cast(int, value) >= 0 else "sint64" + if t is int and _is_uint64(value): + fld = "uint64" arg[fld] = value return proto_class(**arg) # List if isinstance(value, list) and all(isinstance(item, t) for item in value): list_class, fld = _list_type_to_class_and_field[t] - if t is int: - value = [int_to_proto(v) for v in value] + # Use UintList if any element is of type `uint64`. + if t is int and any(_is_uint64(v) for v in value): + list_class, fld = UintList, "uint_list" arg[fld] = list_class(vals=value) return proto_class(**arg) # Invalid types @@ -442,8 +437,6 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any: value_field = cast(str, value_proto.WhichOneof("value")) if value_field.endswith("list"): value = list(getattr(value_proto, value_field).vals) - if value_field == "int_list": - value = [int_from_proto(v) for v in value] else: value = getattr(value_proto, value_field) return value diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 4887f804b8c2..19e9889158a0 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -170,7 +170,7 @@ def get_str(self, length: Optional[int] = None) -> str: length = self.rng.randint(1, 10) return "".join(self.rng.choices(char_pool, k=length)) - def get_value(self, dtype: type[T]) -> T: + def get_value(self, dtype: Union[type[T], str]) -> T: """Create a value of a given type.""" ret: Any = None if dtype == bool: @@ -178,11 +178,13 @@ def get_value(self, dtype: type[T]) -> T: elif dtype == str: ret = self.get_str(self.rng.randint(10, 100)) elif dtype == int: - ret = self.rng.randint(-1 << 63, (1 << 64) - 1) + ret = self.rng.randint(-1 << 63, (1 << 63) - 1) elif dtype == float: ret = (self.rng.random() - 0.5) * (2.0 ** self.rng.randint(0, 50)) elif dtype == bytes: ret = self.randbytes(self.rng.randint(10, 100)) + elif dtype == "uint": + ret = self.rng.randint(0, (1 << 64) - 1) else: raise NotImplementedError(f"Unsupported dtype: {dtype}") return cast(T, ret) @@ -317,6 +319,7 @@ def test_metrics_record_serialization_deserialization() -> None: maker = RecordMaker() original = maker.metrics_record() original["uint64"] = (1 << 63) + 321 + original["list of uint64"] = [maker.get_value("uint") for _ in range(30)] # Execute proto = metrics_record_to_proto(original) @@ -333,6 +336,7 @@ def test_configs_record_serialization_deserialization() -> None: maker = RecordMaker() original = maker.configs_record() original["uint64"] = (1 << 63) + 101 + original["list of uint64"] = [maker.get_value("uint") for _ in range(100)] # Execute proto = configs_record_to_proto(original) diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py index 89999f5b818c..6b169f869ab4 100644 --- a/src/py/flwr/proto/recordset_pb2.py +++ b/src/py/flwr/proto/recordset_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"0\n\x03Int\x12\x10\n\x06sint64\x18\x01 \x01(\x12H\x00\x12\x10\n\x06uint64\x18\x02 \x01(\x04H\x00\x42\x05\n\x03int\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"(\n\x07IntList\x12\x1d\n\x04vals\x18\x01 \x03(\x0b\x32\x0f.flwr.proto.Int\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\xab\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x10\n\x06uint64\x18\x03 \x01(\x04H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12\'\n\x08int_list\x18\x16 \x01(\x0b\x32\x13.flwr.proto.IntListH\x00\x42\x07\n\x05value\"\xe5\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x10\n\x06uint64\x18\x03 \x01(\x04H\x00\x12\x0e\n\x04\x62ool\x18\x04 \x01(\x08H\x00\x12\x10\n\x06string\x18\x05 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x06 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12\'\n\x08int_list\x18\x16 \x01(\x0b\x32\x13.flwr.proto.IntListH\x00\x12)\n\tbool_list\x18\x18 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x19 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x1a \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x18\n\x08SintList\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08UintList\x12\x0c\n\x04vals\x18\x01 \x03(\x04\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\xd8\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x10\n\x06uint64\x18\x03 \x01(\x04H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12)\n\tsint_list\x18\x16 \x01(\x0b\x32\x14.flwr.proto.SintListH\x00\x12)\n\tuint_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.UintListH\x00\x42\x07\n\x05value\"\x92\x03\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x10\n\x06uint64\x18\x03 \x01(\x04H\x00\x12\x0e\n\x04\x62ool\x18\x04 \x01(\x08H\x00\x12\x10\n\x06string\x18\x05 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x06 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12)\n\tsint_list\x18\x16 \x01(\x0b\x32\x14.flwr.proto.SintListH\x00\x12)\n\tuint_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.UintListH\x00\x12)\n\tbool_list\x18\x18 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x19 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x1a \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,40 +31,40 @@ _globals['_RECORDSET_METRICSENTRY']._serialized_options = b'8\001' _globals['_RECORDSET_CONFIGSENTRY']._options = None _globals['_RECORDSET_CONFIGSENTRY']._serialized_options = b'8\001' - _globals['_INT']._serialized_start=42 - _globals['_INT']._serialized_end=90 - _globals['_DOUBLELIST']._serialized_start=92 - _globals['_DOUBLELIST']._serialized_end=118 - _globals['_INTLIST']._serialized_start=120 - _globals['_INTLIST']._serialized_end=160 - _globals['_BOOLLIST']._serialized_start=162 - _globals['_BOOLLIST']._serialized_end=186 - _globals['_STRINGLIST']._serialized_start=188 - _globals['_STRINGLIST']._serialized_end=214 - _globals['_BYTESLIST']._serialized_start=216 - _globals['_BYTESLIST']._serialized_end=241 - _globals['_ARRAY']._serialized_start=243 - _globals['_ARRAY']._serialized_end=309 - _globals['_METRICSRECORDVALUE']._serialized_start=312 - _globals['_METRICSRECORDVALUE']._serialized_end=483 - _globals['_CONFIGSRECORDVALUE']._serialized_start=486 - _globals['_CONFIGSRECORDVALUE']._serialized_end=843 - _globals['_PARAMETERSRECORD']._serialized_start=845 - _globals['_PARAMETERSRECORD']._serialized_end=922 - _globals['_METRICSRECORD']._serialized_start=925 - _globals['_METRICSRECORD']._serialized_end=1068 - _globals['_METRICSRECORD_DATAENTRY']._serialized_start=993 - _globals['_METRICSRECORD_DATAENTRY']._serialized_end=1068 - _globals['_CONFIGSRECORD']._serialized_start=1071 - _globals['_CONFIGSRECORD']._serialized_end=1214 - _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1139 - _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1214 - _globals['_RECORDSET']._serialized_start=1217 - _globals['_RECORDSET']._serialized_end=1624 - _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1395 - _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1474 - _globals['_RECORDSET_METRICSENTRY']._serialized_start=1476 - _globals['_RECORDSET_METRICSENTRY']._serialized_end=1549 - _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1551 - _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1624 + _globals['_DOUBLELIST']._serialized_start=42 + _globals['_DOUBLELIST']._serialized_end=68 + _globals['_SINTLIST']._serialized_start=70 + _globals['_SINTLIST']._serialized_end=94 + _globals['_UINTLIST']._serialized_start=96 + _globals['_UINTLIST']._serialized_end=120 + _globals['_BOOLLIST']._serialized_start=122 + _globals['_BOOLLIST']._serialized_end=146 + _globals['_STRINGLIST']._serialized_start=148 + _globals['_STRINGLIST']._serialized_end=174 + _globals['_BYTESLIST']._serialized_start=176 + _globals['_BYTESLIST']._serialized_end=201 + _globals['_ARRAY']._serialized_start=203 + _globals['_ARRAY']._serialized_end=269 + _globals['_METRICSRECORDVALUE']._serialized_start=272 + _globals['_METRICSRECORDVALUE']._serialized_end=488 + _globals['_CONFIGSRECORDVALUE']._serialized_start=491 + _globals['_CONFIGSRECORDVALUE']._serialized_end=893 + _globals['_PARAMETERSRECORD']._serialized_start=895 + _globals['_PARAMETERSRECORD']._serialized_end=972 + _globals['_METRICSRECORD']._serialized_start=975 + _globals['_METRICSRECORD']._serialized_end=1118 + _globals['_METRICSRECORD_DATAENTRY']._serialized_start=1043 + _globals['_METRICSRECORD_DATAENTRY']._serialized_end=1118 + _globals['_CONFIGSRECORD']._serialized_start=1121 + _globals['_CONFIGSRECORD']._serialized_end=1264 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1189 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1264 + _globals['_RECORDSET']._serialized_start=1267 + _globals['_RECORDSET']._serialized_end=1674 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1445 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1524 + _globals['_RECORDSET_METRICSENTRY']._serialized_start=1526 + _globals['_RECORDSET_METRICSENTRY']._serialized_end=1599 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1601 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1674 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi index bbd8d78f87c2..91d17e3e6473 100644 --- a/src/py/flwr/proto/recordset_pb2.pyi +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -11,45 +11,41 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class Int(google.protobuf.message.Message): +class DoubleList(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - SINT64_FIELD_NUMBER: builtins.int - UINT64_FIELD_NUMBER: builtins.int - sint64: builtins.int - uint64: builtins.int + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... def __init__(self, *, - sint64: builtins.int = ..., - uint64: builtins.int = ..., + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["int",b"int","sint64",b"sint64","uint64",b"uint64"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["int",b"int","sint64",b"sint64","uint64",b"uint64"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["int",b"int"]) -> typing.Optional[typing_extensions.Literal["sint64","uint64"]]: ... -global___Int = Int + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList -class DoubleList(google.protobuf.message.Message): +class SintList(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor VALS_FIELD_NUMBER: builtins.int @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__(self, *, - vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___DoubleList = DoubleList +global___SintList = SintList -class IntList(google.protobuf.message.Message): +class UintList(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor VALS_FIELD_NUMBER: builtins.int @property - def vals(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Int]: ... + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__(self, *, - vals: typing.Optional[typing.Iterable[global___Int]] = ..., + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___IntList = IntList +global___UintList = UintList class BoolList(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -114,7 +110,8 @@ class MetricsRecordValue(google.protobuf.message.Message): SINT64_FIELD_NUMBER: builtins.int UINT64_FIELD_NUMBER: builtins.int DOUBLE_LIST_FIELD_NUMBER: builtins.int - INT_LIST_FIELD_NUMBER: builtins.int + SINT_LIST_FIELD_NUMBER: builtins.int + UINT_LIST_FIELD_NUMBER: builtins.int double: builtins.float """Single element""" @@ -125,18 +122,21 @@ class MetricsRecordValue(google.protobuf.message.Message): """List types""" pass @property - def int_list(self) -> global___IntList: ... + def sint_list(self) -> global___SintList: ... + @property + def uint_list(self) -> global___UintList: ... def __init__(self, *, double: builtins.float = ..., sint64: builtins.int = ..., uint64: builtins.int = ..., double_list: typing.Optional[global___DoubleList] = ..., - int_list: typing.Optional[global___IntList] = ..., + sint_list: typing.Optional[global___SintList] = ..., + uint_list: typing.Optional[global___UintList] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","int_list",b"int_list","sint64",b"sint64","uint64",b"uint64","value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","int_list",b"int_list","sint64",b"sint64","uint64",b"uint64","value",b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","uint64","double_list","int_list"]]: ... + def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint_list",b"sint_list","uint64",b"uint64","uint_list",b"uint_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint_list",b"sint_list","uint64",b"uint64","uint_list",b"uint_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","uint64","double_list","sint_list","uint_list"]]: ... global___MetricsRecordValue = MetricsRecordValue class ConfigsRecordValue(google.protobuf.message.Message): @@ -148,7 +148,8 @@ class ConfigsRecordValue(google.protobuf.message.Message): STRING_FIELD_NUMBER: builtins.int BYTES_FIELD_NUMBER: builtins.int DOUBLE_LIST_FIELD_NUMBER: builtins.int - INT_LIST_FIELD_NUMBER: builtins.int + SINT_LIST_FIELD_NUMBER: builtins.int + UINT_LIST_FIELD_NUMBER: builtins.int BOOL_LIST_FIELD_NUMBER: builtins.int STRING_LIST_FIELD_NUMBER: builtins.int BYTES_LIST_FIELD_NUMBER: builtins.int @@ -165,7 +166,9 @@ class ConfigsRecordValue(google.protobuf.message.Message): """List types""" pass @property - def int_list(self) -> global___IntList: ... + def sint_list(self) -> global___SintList: ... + @property + def uint_list(self) -> global___UintList: ... @property def bool_list(self) -> global___BoolList: ... @property @@ -181,14 +184,15 @@ class ConfigsRecordValue(google.protobuf.message.Message): string: typing.Text = ..., bytes: builtins.bytes = ..., double_list: typing.Optional[global___DoubleList] = ..., - int_list: typing.Optional[global___IntList] = ..., + sint_list: typing.Optional[global___SintList] = ..., + uint_list: typing.Optional[global___UintList] = ..., bool_list: typing.Optional[global___BoolList] = ..., string_list: typing.Optional[global___StringList] = ..., bytes_list: typing.Optional[global___BytesList] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","int_list",b"int_list","sint64",b"sint64","string",b"string","string_list",b"string_list","uint64",b"uint64","value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","int_list",b"int_list","sint64",b"sint64","string",b"string","string_list",b"string_list","uint64",b"uint64","value",b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","uint64","bool","string","bytes","double_list","int_list","bool_list","string_list","bytes_list"]]: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint_list",b"sint_list","string",b"string","string_list",b"string_list","uint64",b"uint64","uint_list",b"uint_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint_list",b"sint_list","string",b"string","string_list",b"string_list","uint64",b"uint64","uint_list",b"uint_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","uint64","bool","string","bytes","double_list","sint_list","uint_list","bool_list","string_list","bytes_list"]]: ... global___ConfigsRecordValue = ConfigsRecordValue class ParametersRecord(google.protobuf.message.Message): diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index dbfbb236a7e4..75aa6d370511 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -21,7 +21,15 @@ import grpc from google.protobuf.message import Message as GrpcMessage +from flwr.common.constant import ( + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY, + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY, + GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY, + GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY, + GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY, +) from flwr.common.logger import log +from flwr.common.version import package_name, package_version from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611 from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 @@ -52,9 +60,16 @@ def _handle( ) -> MessageContainer: req = request_type.FromString(msg_container.grpc_message_content) res = handler(req) + res_cls = res.__class__ return MessageContainer( - metadata={}, - grpc_message_name=res.__class__.__qualname__, + metadata={ + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY: package_name, + GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY: package_version, + GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version, + GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY: res_cls.__module__, + GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY: res_cls.__qualname__, + }, + grpc_message_name=res_cls.__qualname__, grpc_message_content=res.SerializeToString(), )