diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 3778fd4061f9..8ef8e7ebf62a 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -21,7 +21,7 @@ from copy import copy from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -73,6 +73,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 authentication_keys: Optional[ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, + adapter_cls: Optional[Type[FleetStub]] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], @@ -133,7 +134,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 channel.subscribe(on_channel_state_change) # Shared variables for inner functions - stub = FleetStub(channel) + if adapter_cls is None: + adapter_cls = FleetStub + stub = adapter_cls(channel) metadata: Optional[Metadata] = None node: Optional[Node] = None ping_thread: Optional[threading.Thread] = None