From 8cb84a437d81ae7a1185ff9ccb208bf3fe8fe2cd Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 7 Nov 2024 16:23:13 +0000 Subject: [PATCH] feat(framework) Add first implementation of `NodeState` (#4439) Co-authored-by: Heng Pan Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 16 +++-- src/py/flwr/client/nodestate/__init__.py | 25 +++++++ .../client/nodestate/in_memory_nodestate.py | 38 ++++++++++ src/py/flwr/client/nodestate/nodestate.py | 30 ++++++++ .../client/nodestate/nodestate_factory.py | 37 ++++++++++ .../flwr/client/nodestate/nodestate_test.py | 69 +++++++++++++++++++ 6 files changed, 209 insertions(+), 6 deletions(-) create mode 100644 src/py/flwr/client/nodestate/__init__.py create mode 100644 src/py/flwr/client/nodestate/in_memory_nodestate.py create mode 100644 src/py/flwr/client/nodestate/nodestate.py create mode 100644 src/py/flwr/client/nodestate/nodestate_factory.py create mode 100644 src/py/flwr/client/nodestate/nodestate_test.py diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e803eaf88864..1f22e2a8376c 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -32,6 +32,7 @@ from flwr.cli.install import install_from_fab from flwr.client.client import Client from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.nodestate.nodestate_factory import NodeStateFactory from flwr.client.typing import ClientFnExt from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event from flwr.common.address import parse_address @@ -365,6 +366,8 @@ def _on_backoff(retry_state: RetryState) -> None: # DeprecatedRunInfoStore gets initialized when the first connection is established run_info_store: Optional[DeprecatedRunInfoStore] = None + state_factory = NodeStateFactory() + state = state_factory.state() runs: dict[int, Run] = {} @@ -396,13 +399,14 @@ def _on_backoff(retry_state: RetryState) -> None: ) else: # Call create_node fn to register node - node_id: Optional[int] = ( # pylint: disable=assignment-from-none - create_node() - ) # pylint: disable=not-callable - if node_id is None: - raise ValueError("Node registration failed") + # and store node_id in state + if (node_id := create_node()) is None: + raise ValueError( + "Failed to register SuperNode with the SuperLink" + ) + state.set_node_id(node_id) run_info_store = DeprecatedRunInfoStore( - node_id=node_id, + node_id=state.get_node_id(), node_config=node_config, ) diff --git a/src/py/flwr/client/nodestate/__init__.py b/src/py/flwr/client/nodestate/__init__.py new file mode 100644 index 000000000000..a207e1d81948 --- /dev/null +++ b/src/py/flwr/client/nodestate/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower NodeState.""" + +from .in_memory_nodestate import InMemoryNodeState as InMemoryNodeState +from .nodestate import NodeState as NodeState +from .nodestate_factory import NodeStateFactory as NodeStateFactory + +__all__ = [ + "InMemoryNodeState", + "NodeState", + "NodeStateFactory", +] diff --git a/src/py/flwr/client/nodestate/in_memory_nodestate.py b/src/py/flwr/client/nodestate/in_memory_nodestate.py new file mode 100644 index 000000000000..fd88f3af28c7 --- /dev/null +++ b/src/py/flwr/client/nodestate/in_memory_nodestate.py @@ -0,0 +1,38 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""In-memory NodeState implementation.""" + + +from typing import Optional + +from flwr.client.nodestate.nodestate import NodeState + + +class InMemoryNodeState(NodeState): + """In-memory NodeState implementation.""" + + def __init__(self) -> None: + # Store node_id + self.node_id: Optional[int] = None + + def set_node_id(self, node_id: Optional[int]) -> None: + """Set the node ID.""" + self.node_id = node_id + + def get_node_id(self) -> int: + """Get the node ID.""" + if self.node_id is None: + raise ValueError("Node ID not set") + return self.node_id diff --git a/src/py/flwr/client/nodestate/nodestate.py b/src/py/flwr/client/nodestate/nodestate.py new file mode 100644 index 000000000000..6ae30f49fcc1 --- /dev/null +++ b/src/py/flwr/client/nodestate/nodestate.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class NodeState.""" + +import abc +from typing import Optional + + +class NodeState(abc.ABC): + """Abstract NodeState.""" + + @abc.abstractmethod + def set_node_id(self, node_id: Optional[int]) -> None: + """Set the node ID.""" + + @abc.abstractmethod + def get_node_id(self) -> int: + """Get the node ID.""" diff --git a/src/py/flwr/client/nodestate/nodestate_factory.py b/src/py/flwr/client/nodestate/nodestate_factory.py new file mode 100644 index 000000000000..3d52f0272bd4 --- /dev/null +++ b/src/py/flwr/client/nodestate/nodestate_factory.py @@ -0,0 +1,37 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Factory class that creates NodeState instances.""" + +import threading +from typing import Optional + +from .in_memory_nodestate import InMemoryNodeState +from .nodestate import NodeState + + +class NodeStateFactory: + """Factory class that creates NodeState instances.""" + + def __init__(self) -> None: + self.state_instance: Optional[NodeState] = None + self.lock = threading.RLock() + + def state(self) -> NodeState: + """Return a State instance and create it, if necessary.""" + # Lock access to NodeStateFactory to prevent returning different instances + with self.lock: + if self.state_instance is None: + self.state_instance = InMemoryNodeState() + return self.state_instance diff --git a/src/py/flwr/client/nodestate/nodestate_test.py b/src/py/flwr/client/nodestate/nodestate_test.py new file mode 100644 index 000000000000..f7088b1f8ac6 --- /dev/null +++ b/src/py/flwr/client/nodestate/nodestate_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests all NodeState implementations have to conform to.""" + +import unittest +from abc import abstractmethod + +from flwr.client.nodestate import InMemoryNodeState, NodeState + + +class StateTest(unittest.TestCase): + """Test all state implementations.""" + + # This is to True in each child class + __test__ = False + + @abstractmethod + def state_factory(self) -> NodeState: + """Provide state implementation to test.""" + raise NotImplementedError() + + def test_get_set_node_id(self) -> None: + """Test set_node_id.""" + # Prepare + state: NodeState = self.state_factory() + node_id = 123 + + # Execute + state.set_node_id(node_id) + + retrieved_node_id = state.get_node_id() + + # Assert + assert node_id == retrieved_node_id + + def test_get_node_id_fails(self) -> None: + """Test get_node_id fails correctly if node_id is not set.""" + # Prepare + state: NodeState = self.state_factory() + + # Execute and assert + with self.assertRaises(ValueError): + state.get_node_id() + + +class InMemoryStateTest(StateTest): + """Test InMemoryState implementation.""" + + __test__ = True + + def state_factory(self) -> NodeState: + """Return InMemoryState.""" + return InMemoryNodeState() + + +if __name__ == "__main__": + unittest.main(verbosity=2)