Skip to content

Commit

Permalink
Merge branch 'main' into get-federation-options
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Nov 7, 2024
2 parents e916634 + 8cb84a4 commit 93da4be
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 127 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- connection: insecure
authentication: client-auth
name: |
SuperExec /
Exec API /
Python ${{ matrix.python-version }} /
${{ matrix.connection }} /
${{ matrix.authentication }} /
Expand Down
32 changes: 20 additions & 12 deletions e2e/test_exec_api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,19 @@ timeout 2m flower-superlink $combined_args --executor-config "$executor_config"
sl_pid=$(pgrep -f "flower-superlink")
sleep 2

timeout 2m flower-supernode ./ $client_arg \
--superlink $server_address $client_auth_1 \
--node-config "partition-id=0 num-partitions=2" --max-retries 0 &
cl1_pid=$!
sleep 2
if [ "$3" = "deployment-engine" ]; then
timeout 2m flower-supernode ./ $client_arg \
--superlink $server_address $client_auth_1 \
--node-config "partition-id=0 num-partitions=2" --max-retries 0 &
cl1_pid=$!
sleep 2

timeout 2m flower-supernode ./ $client_arg \
--superlink $server_address $client_auth_2 \
--node-config "partition-id=1 num-partitions=2" --max-retries 0 &
cl2_pid=$!
sleep 2
timeout 2m flower-supernode ./ $client_arg \
--superlink $server_address $client_auth_2 \
--node-config "partition-id=1 num-partitions=2" --max-retries 0 &
cl2_pid=$!
sleep 2
fi

timeout 1m flwr run --run-config num-server-rounds=1 ../e2e-tmp-test e2e

Expand All @@ -105,7 +107,10 @@ while [ "$found_success" = false ] && [ $elapsed -lt $timeout ]; do
if grep -q "Run finished" flwr_output.log; then
echo "Training worked correctly!"
found_success=true
kill $cl1_pid; kill $cl2_pid; sleep 1; kill $sl_pid;
if $3 = "deployment-engine"; then
kill $cl1_pid; kill $cl2_pid;
fi
sleep 1; kill $sl_pid;
else
echo "Waiting for training ... ($elapsed seconds elapsed)"
fi
Expand All @@ -116,5 +121,8 @@ done

if [ "$found_success" = false ]; then
echo "Training had an issue and timed out."
kill $cl1_pid; kill $cl2_pid; kill $sl_pid;
if $3 = "deployment-engine"; then
kill $cl1_pid; kill $cl2_pid;
fi
kill $sl_pid;
fi
16 changes: 10 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flwr.cli.install import install_from_fab
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.nodestate.nodestate_factory import NodeStateFactory
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
from flwr.common.address import parse_address
Expand Down Expand Up @@ -365,6 +366,8 @@ def _on_backoff(retry_state: RetryState) -> None:

# DeprecatedRunInfoStore gets initialized when the first connection is established
run_info_store: Optional[DeprecatedRunInfoStore] = None
state_factory = NodeStateFactory()
state = state_factory.state()

runs: dict[int, Run] = {}

Expand Down Expand Up @@ -396,13 +399,14 @@ def _on_backoff(retry_state: RetryState) -> None:
)
else:
# Call create_node fn to register node
node_id: Optional[int] = ( # pylint: disable=assignment-from-none
create_node()
) # pylint: disable=not-callable
if node_id is None:
raise ValueError("Node registration failed")
# and store node_id in state
if (node_id := create_node()) is None:
raise ValueError(
"Failed to register SuperNode with the SuperLink"
)
state.set_node_id(node_id)
run_info_store = DeprecatedRunInfoStore(
node_id=node_id,
node_id=state.get_node_id(),
node_config=node_config,
)

Expand Down
25 changes: 25 additions & 0 deletions src/py/flwr/client/nodestate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower NodeState."""

from .in_memory_nodestate import InMemoryNodeState as InMemoryNodeState
from .nodestate import NodeState as NodeState
from .nodestate_factory import NodeStateFactory as NodeStateFactory

__all__ = [
"InMemoryNodeState",
"NodeState",
"NodeStateFactory",
]
38 changes: 38 additions & 0 deletions src/py/flwr/client/nodestate/in_memory_nodestate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""In-memory NodeState implementation."""


from typing import Optional

from flwr.client.nodestate.nodestate import NodeState


class InMemoryNodeState(NodeState):
"""In-memory NodeState implementation."""

def __init__(self) -> None:
# Store node_id
self.node_id: Optional[int] = None

def set_node_id(self, node_id: Optional[int]) -> None:
"""Set the node ID."""
self.node_id = node_id

def get_node_id(self) -> int:
"""Get the node ID."""
if self.node_id is None:
raise ValueError("Node ID not set")
return self.node_id
30 changes: 30 additions & 0 deletions src/py/flwr/client/nodestate/nodestate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract base class NodeState."""

import abc
from typing import Optional


class NodeState(abc.ABC):
"""Abstract NodeState."""

@abc.abstractmethod
def set_node_id(self, node_id: Optional[int]) -> None:
"""Set the node ID."""

@abc.abstractmethod
def get_node_id(self) -> int:
"""Get the node ID."""
37 changes: 37 additions & 0 deletions src/py/flwr/client/nodestate/nodestate_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory class that creates NodeState instances."""

import threading
from typing import Optional

from .in_memory_nodestate import InMemoryNodeState
from .nodestate import NodeState


class NodeStateFactory:
"""Factory class that creates NodeState instances."""

def __init__(self) -> None:
self.state_instance: Optional[NodeState] = None
self.lock = threading.RLock()

def state(self) -> NodeState:
"""Return a State instance and create it, if necessary."""
# Lock access to NodeStateFactory to prevent returning different instances
with self.lock:
if self.state_instance is None:
self.state_instance = InMemoryNodeState()
return self.state_instance
69 changes: 69 additions & 0 deletions src/py/flwr/client/nodestate/nodestate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests all NodeState implementations have to conform to."""

import unittest
from abc import abstractmethod

from flwr.client.nodestate import InMemoryNodeState, NodeState


class StateTest(unittest.TestCase):
"""Test all state implementations."""

# This is to True in each child class
__test__ = False

@abstractmethod
def state_factory(self) -> NodeState:
"""Provide state implementation to test."""
raise NotImplementedError()

def test_get_set_node_id(self) -> None:
"""Test set_node_id."""
# Prepare
state: NodeState = self.state_factory()
node_id = 123

# Execute
state.set_node_id(node_id)

retrieved_node_id = state.get_node_id()

# Assert
assert node_id == retrieved_node_id

def test_get_node_id_fails(self) -> None:
"""Test get_node_id fails correctly if node_id is not set."""
# Prepare
state: NodeState = self.state_factory()

# Execute and assert
with self.assertRaises(ValueError):
state.get_node_id()


class InMemoryStateTest(StateTest):
"""Test InMemoryState implementation."""

__test__ = True

def state_factory(self) -> NodeState:
"""Return InMemoryState."""
return InMemoryNodeState()


if __name__ == "__main__":
unittest.main(verbosity=2)
1 change: 0 additions & 1 deletion src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class ErrorCode:
UNKNOWN = 0
LOAD_CLIENT_APP_EXCEPTION = 1
CLIENT_APP_RAISED_EXCEPTION = 2
NODE_UNAVAILABLE = 3

def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
Expand Down
16 changes: 0 additions & 16 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
generate_rand_int_from_bytes,
has_valid_sub_status,
is_valid_transition,
make_node_unavailable_taskres,
)


Expand Down Expand Up @@ -257,21 +256,6 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
task_res_list.append(task_res)
replied_task_ids.add(reply_to)

# Check if the node is offline
for task_id in task_ids - replied_task_ids:
task_ins = self.task_ins_store.get(task_id)
if task_ins is None:
continue
node_id = task_ins.task.consumer.node_id
online_until, _ = self.node_ids[node_id]
# Generate a TaskRes containing an error reply if the node is offline.
if online_until < time.time():
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
)
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
task_res_list.append(err_taskres)

# Mark all of them as delivered
delivered_at = now().isoformat()
for task_res in task_res_list:
Expand Down
44 changes: 1 addition & 43 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from uuid import UUID

from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, RecordSet, now
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common.constant import Status, SubStatus
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
private_key_to_bytes,
Expand Down Expand Up @@ -786,48 +786,6 @@ def test_acknowledge_ping(self) -> None:
# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))

def test_node_unavailable_error(self) -> None:
"""Test if get_task_res return TaskRes containing node unavailable error."""
# Prepare
state: LinkState = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
node_id_0 = state.create_node(ping_interval=90)
node_id_1 = state.create_node(ping_interval=30)
# Create and store TaskIns
task_ins_0 = create_task_ins(
consumer_node_id=node_id_0, anonymous=False, run_id=run_id
)
task_ins_1 = create_task_ins(
consumer_node_id=node_id_1, anonymous=False, run_id=run_id
)
task_id_0 = state.store_task_ins(task_ins=task_ins_0)
task_id_1 = state.store_task_ins(task_ins=task_ins_1)
assert task_id_0 is not None and task_id_1 is not None

# Get TaskIns to mark them delivered
state.get_task_ins(node_id=node_id_0, limit=None)

# Create and store TaskRes
task_res_0 = create_task_res(
producer_node_id=node_id_0,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
)
state.store_task_res(task_res_0)

# Execute
current_time = time.time()
task_res_list: list[TaskRes] = []
with patch("time.time", side_effect=lambda: current_time + 50):
task_res_list = state.get_task_res({task_id_0, task_id_1})

# Assert
assert len(task_res_list) == 2
err_taskres = task_res_list[1]
assert err_taskres.task.HasField("error")
assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE

def test_store_task_res_task_ins_expired(self) -> None:
"""Test behavior of store_task_res when the TaskIns it references is expired."""
# Prepare
Expand Down
Loading

0 comments on commit 93da4be

Please sign in to comment.