Skip to content

Commit

Permalink
Merge branch 'main' into heartbeat-client-app
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Mar 28, 2024
2 parents 6690b47 + 531e0e3 commit 482baaa
Show file tree
Hide file tree
Showing 14 changed files with 344 additions and 117 deletions.
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: local
hooks:
- id: format-code
name: Format Code
entry: ./dev/format.sh
language: script
# Ensures the script runs from the repository root:
pass_filenames: false
stages: [commit]

- id: run-tests
name: Run Tests
entry: ./dev/test.sh
language: script
# Ensures the script runs from the repository root:
pass_filenames: false
stages: [commit]
27 changes: 27 additions & 0 deletions doc/source/contributor-tutorial-get-started-as-a-contributor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,33 @@ Run Linters and Tests

$ ./dev/test.sh

Add a pre-commit hook
~~~~~~~~~~~~~~~~~~~~~

Developers may integrate a pre-commit hook into their workflow utilizing the `pre-commit <https://pre-commit.com/#install>`_ library. The pre-commit hook is configured to execute two primary operations: ``./dev/format.sh`` and ``./dev/test.sh`` scripts.

There are multiple ways developers can use this:

1. Install the pre-commit hook to your local git directory by simply running:

::
$ pre-commit install

- Each ``git commit`` will trigger the execution of formatting and linting/test scripts.
- If in a hurry, bypass the hook using ``--no-verify`` with the ``git commit`` command.
::
$ git commit --no-verify -m "Add new feature"
2. For developers who prefer not to install the hook permanently, it is possible to execute a one-time check prior to committing changes by using the following command:

::

$ pre-commit run --all-files
This executes the formatting and linting checks/tests on all the files without modifying the default behavior of ``git commit``.

Run Github Actions (CI) locally
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
33 changes: 21 additions & 12 deletions examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,19 @@ def main(driver: Driver, context: Context) -> None:
all_replies: List[Message] = []
while True:
replies = driver.pull_messages(message_ids=message_ids)
print(f"Got {len(replies)} results")
for res in replies:
print(f"Got 1 {'result' if res.has_content() else 'error'}")
all_replies += replies
if len(all_replies) == len(message_ids):
break
print("Pulling messages...")
time.sleep(3)

# Collect correct results
# Filter correct results
all_fitres = [
recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies
recordset_to_fitres(msg.content, keep_input=True)
for msg in all_replies
if msg.has_content()
]
print(f"Received {len(all_fitres)} results")

Expand All @@ -128,16 +132,21 @@ def main(driver: Driver, context: Context) -> None:
)
metrics_results.append((fitres.num_examples, fitres.metrics))

# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated
if len(weights_results) > 0:
# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated

# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
else:
print(
f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..."
)

# Slow down the start of the next round
time.sleep(sleep_time)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ check-wheel-contents = "==0.4.0"
GitPython = "==3.1.32"
PyGithub = "==2.1.1"
licensecheck = "==2024"
pre-commit = "==3.5.0"

[tool.isort]
line_length = 88
Expand Down
59 changes: 35 additions & 24 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# ==============================================================================
"""Flower client app."""


import argparse
import sys
import time
from logging import DEBUG, INFO, WARN
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -38,6 +37,7 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.message import Error
from flwr.common.object_ref import load_app, validate
from flwr.common.retry_invoker import RetryInvoker, exponential

Expand Down Expand Up @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp:
# Retrieve context for this run
context = node_state.retrieve_context(run_id=message.metadata.run_id)

# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()
# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
error=Error(code=0, reason="Unknown")
)

# Handle task message
out_message = client_app(message=message, context=context)
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()

# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
reply_message = client_app(message=message, context=context)
# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp raised an exception", exc_info=ex)

# Legacy grpc-bidi
if transport in ["grpc-bidi", None]:
# Raise exception, crash process
raise ex

# Don't update/change NodeState

# Create error message
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
reply_message = message.create_error_reply(
error=Error(code=0, reason=reason)
)

# Send
send(out_message)
log(
INFO,
"[RUN %s, ROUND %s]",
out_message.metadata.run_id,
out_message.metadata.group_id,
)
log(
INFO,
"Sent: %s reply to message %s",
out_message.metadata.message_type,
message.metadata.message_id,
)
send(reply_message)
log(INFO, "Sent reply")

# Unregister node
if delete_node is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_client_without_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl=DEFAULT_TTL,
ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply()
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_client_with_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl=DEFAULT_TTL,
ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply()
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand Down
52 changes: 38 additions & 14 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,33 @@ def _create_reply_metadata(self, ttl: float) -> Metadata:
partition_id=self.metadata.partition_id,
)

def create_error_reply(
self,
error: Error,
ttl: float,
) -> Message:
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
"""Construct a reply message indicating an error happened.
Parameters
----------
error : Error
The error that was encountered.
ttl : float
Time-to-live for this message in seconds.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
"""
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
# Create reply with error
message = Message(metadata=self._create_reply_metadata(ttl), error=error)
message = Message(metadata=self._create_reply_metadata(ttl_), error=error)

if ttl is None:
# Set TTL equal to the remaining time for the received message to expire
ttl = self.metadata.ttl - (
message.metadata.created_at - self.metadata.created_at
)
message.metadata.ttl = ttl

return message

def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
Expand All @@ -327,18 +338,31 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
content : RecordSet
The content for the reply message.
ttl : Optional[float] (default: None)
Time-to-live for this message in seconds. If unset, it will use
the `common.DEFAULT_TTL` value.
Time-to-live for this message in seconds. If unset, it will be set based
on the remaining time for the received message before it expires. This
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
Message
A new `Message` instance representing the reply.
"""
if ttl is None:
ttl = DEFAULT_TTL
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl

return Message(
metadata=self._create_reply_metadata(ttl),
message = Message(
metadata=self._create_reply_metadata(ttl_),
content=content,
)

if ttl is None:
# Set TTL equal to the remaining time for the received message to expire
ttl = self.metadata.ttl - (
message.metadata.created_at - self.metadata.created_at
)
message.metadata.ttl = ttl

return message
52 changes: 49 additions & 3 deletions src/py/flwr/common/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import time
from contextlib import ExitStack
from typing import Any, Callable
from typing import Any, Callable, Optional

import pytest

Expand Down Expand Up @@ -73,17 +73,21 @@ def test_message_creation(
assert message.metadata.created_at < time.time()


def create_message_with_content() -> Message:
def create_message_with_content(ttl: Optional[float] = None) -> Message:
"""Create a Message with content."""
maker = RecordMaker(state=2)
metadata = maker.metadata()
if ttl:
metadata.ttl = ttl
return Message(metadata=metadata, content=RecordSet())


def create_message_with_error() -> Message:
def create_message_with_error(ttl: Optional[float] = None) -> Message:
"""Create a Message with error."""
maker = RecordMaker(state=2)
metadata = maker.metadata()
if ttl:
metadata.ttl = ttl
return Message(metadata=metadata, error=Error(code=1))


Expand Down Expand Up @@ -111,3 +115,45 @@ def test_altering_message(
message.error = Error(code=123)
if message.has_error():
message.content = RecordSet()


@pytest.mark.parametrize(
"message_creation_fn,ttl,reply_ttl",
[
(create_message_with_content, 1e6, None),
(create_message_with_error, 1e6, None),
(create_message_with_content, 1e6, 3600),
(create_message_with_error, 1e6, 3600),
],
)
def test_create_reply(
message_creation_fn: Callable[
[float],
Message,
],
ttl: float,
reply_ttl: Optional[float],
) -> None:
"""Test reply creation from message."""
message: Message = message_creation_fn(ttl)

time.sleep(0.1)

if message.has_error():
dummy_error = Error(code=0, reason="it crashed")
reply_message = message.create_error_reply(dummy_error, ttl=reply_ttl)
else:
reply_message = message.create_reply(content=RecordSet(), ttl=reply_ttl)

# Ensure reply has a higher timestamp
assert message.metadata.created_at < reply_message.metadata.created_at
if reply_ttl:
# Ensure the TTL is the one specify upon reply creation
assert reply_message.metadata.ttl == reply_ttl
else:
# Ensure reply ttl is lower (since it uses remaining time left)
assert message.metadata.ttl > reply_message.metadata.ttl

assert message.metadata.src_node_id == reply_message.metadata.dst_node_id
assert message.metadata.dst_node_id == reply_message.metadata.src_node_id
assert reply_message.metadata.reply_to_message == message.metadata.message_id
Loading

0 comments on commit 482baaa

Please sign in to comment.