Skip to content

Commit

Permalink
Merge branch 'main' into handle-flower-callable-exception
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Mar 27, 2024
2 parents bac30f5 + 3acdf47 commit 0152490
Show file tree
Hide file tree
Showing 39 changed files with 241 additions and 109 deletions.
6 changes: 3 additions & 3 deletions examples/app-pytorch/client_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message:
@app.train()
def train(msg: Message, ctx: Context):
print("`train` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.evaluate()
def eval(msg: Message, ctx: Context):
print("`evaluate` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.query()
def query(msg: Message, ctx: Context):
print("`query` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)
3 changes: 2 additions & 1 deletion examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Message,
MessageType,
Metrics,
DEFAULT_TTL,
)
from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres
from flwr.server import Driver, History
Expand Down Expand Up @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand Down
12 changes: 10 additions & 2 deletions examples/app-pytorch/server_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import time

import flwr as fl
from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet
from flwr.common import (
Context,
NDArrays,
Message,
MessageType,
Metrics,
RecordSet,
DEFAULT_TTL,
)
from flwr.server import Driver


Expand All @@ -30,7 +38,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand Down
1 change: 1 addition & 0 deletions examples/llm-flowertune/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ bitsandbytes==0.41.3
scipy==1.11.2
peft==0.4.0
fschat[model_worker,webui]==0.2.35
transformers==4.38.1
5 changes: 5 additions & 0 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "flwr/proto/task.proto";
service Fleet {
rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {}
rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) {}
rpc Ping(PingRequest) returns (PingResponse) {}

// Retrieve one or more tasks, if possible
//
Expand All @@ -43,6 +44,10 @@ message CreateNodeResponse { Node node = 1; }
message DeleteNodeRequest { Node node = 1; }
message DeleteNodeResponse {}

// Ping messages
message PingRequest { Node node = 1; }
message PingResponse { bool success = 1; }

// PullTaskIns messages
message PullTaskInsRequest {
Node node = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ message Task {
Node consumer = 2;
string created_at = 3;
string delivered_at = 4;
string ttl = 5;
double ttl = 5;
repeated string ancestry = 6;
string task_type = 7;
RecordSet recordset = 8;
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def train(message: Message, context: Context) -> Message:
>>> print("ClientApp training running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def evaluate(message: Message, context: Context) -> Message:
>>> print("ClientApp evaluation running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def query(message: Message, context: Context) -> Message:
>>> print("ClientApp query running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError:
>>> print("ClientApp {fn_name} running")
>>> # Create and return an echo reply message
>>> return message.create_reply(
>>> content=message.content(), ttl=""
>>> content=message.content()
>>> )
""",
)
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.common import (
DEFAULT_TTL,
GRPC_MAX_MESSAGE_LENGTH,
ConfigsRecord,
Message,
Expand Down Expand Up @@ -180,7 +181,7 @@ def receive() -> Message:
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=message_type,
),
content=recordset,
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import grpc

from flwr.common import ConfigsRecord, Message, Metadata, RecordSet
from flwr.common import DEFAULT_TTL, ConfigsRecord, Message, Metadata, RecordSet
from flwr.common import recordset_compat as compat
from flwr.common.constant import MessageTypeLegacy
from flwr.common.retry_invoker import RetryInvoker, exponential
Expand All @@ -50,7 +50,7 @@
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=compat.getpropertiesres_to_recordset(
Expand All @@ -65,7 +65,7 @@
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type="reconnect",
),
content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}),
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
out_message = message.create_reply(recordset, ttl="")
out_message = message.create_reply(recordset)
# Return TaskRes and sleep duration
return out_message, sleep_duration

Expand Down Expand Up @@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype(
raise ValueError(f"Invalid message type: {message_type}")

# Return Message
return message.create_reply(out_recordset, ttl="")
return message.create_reply(out_recordset)


def _reconnect(
Expand Down
13 changes: 7 additions & 6 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.common import (
DEFAULT_TTL,
Code,
Context,
EvaluateIns,
Expand Down Expand Up @@ -131,7 +132,7 @@ def test_client_without_get_properties() -> None:
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=recordset,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_client_without_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand All @@ -184,7 +185,7 @@ def test_client_with_get_properties() -> None:
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=recordset,
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_client_with_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand All @@ -237,7 +238,7 @@ def setUp(self) -> None:
dst_node_id=20,
reply_to_message="",
group_id="group1",
ttl="60",
ttl=DEFAULT_TTL,
message_type="mock",
)
self.valid_out_metadata = Metadata(
Expand All @@ -247,7 +248,7 @@ def setUp(self) -> None:
dst_node_id=10,
reply_to_message="qwerty",
group_id="group1",
ttl="60",
ttl=DEFAULT_TTL,
message_type="mock",
)
self.common_content = RecordSet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def secaggplus_mod(

# Return message
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
return msg.create_reply(out_content, ttl="")
return msg.create_reply(out_content)


def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
Expand Down
13 changes: 10 additions & 3 deletions src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
from typing import Callable, Dict, List

from flwr.client.mod import make_ffn
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet
from flwr.common import (
DEFAULT_TTL,
ConfigsRecord,
Context,
Message,
Metadata,
RecordSet,
)
from flwr.common.constant import MessageType
from flwr.common.secure_aggregation.secaggplus_constants import (
RECORD_KEY_CONFIGS,
Expand All @@ -38,7 +45,7 @@ def get_test_handler(
"""."""

def empty_ffn(_msg: Message, _2: Context) -> Message:
return _msg.create_reply(RecordSet(), ttl="")
return _msg.create_reply(RecordSet())

app = make_ffn(empty_ffn, [secaggplus_mod])

Expand All @@ -51,7 +58,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord:
dst_node_id=123,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageType.TRAIN,
),
content=RecordSet(
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flwr.client.typing import ClientAppCallable, Mod
from flwr.common import (
DEFAULT_TTL,
ConfigsRecord,
Context,
Message,
Expand Down Expand Up @@ -84,7 +85,7 @@ def _get_dummy_flower_message() -> Message:
src_node_id=0,
dst_node_id=0,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type="mock",
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .grpc import GRPC_MAX_MESSAGE_LENGTH
from .logger import configure as configure
from .logger import log as log
from .message import DEFAULT_TTL
from .message import Error as Error
from .message import Message as Message
from .message import Metadata as Metadata
Expand Down Expand Up @@ -87,6 +88,7 @@
"Message",
"MessageType",
"MessageTypeLegacy",
"DEFAULT_TTL",
"Metadata",
"Metrics",
"MetricsAggregationFn",
Expand Down
Loading

0 comments on commit 0152490

Please sign in to comment.