diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 27f759a71713..a029b926423f 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -16,12 +16,13 @@ import base64 +import inspect import threading import unittest from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Union +from typing import Optional, Union, get_args import grpc @@ -47,6 +48,7 @@ PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.fleet_pb2_grpc import FleetServicer from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 @@ -437,6 +439,20 @@ def test_without_servicer(self) -> None: assert self._servicer.received_client_metadata() is None + def test_fleet_requests_included(self) -> None: + """Test if all Fleet requests are included in the authentication mode.""" + # Prepare + requests = get_args(Request) + rpc_names = {req.__qualname__.removesuffix("Request") for req in requests} + expected_rpc_names = { + name + for name, ref in inspect.getmembers(FleetServicer) + if inspect.isfunction(ref) + } + + # Assert + assert expected_rpc_names == rpc_names + if __name__ == "__main__": unittest.main(verbosity=2)