Skip to content

Commit

Permalink
Rename task_* to message_* (#2944)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Feb 14, 2024
1 parent cbbb813 commit 314492c
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 109 deletions.
4 changes: 2 additions & 2 deletions src/py/flwr/client/clientapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import (
handle_legacy_message_from_tasktype,
handle_legacy_message_from_msgtype,
)
from flwr.client.mod.utils import make_ffn
from flwr.client.typing import ClientFn, Mod
Expand Down Expand Up @@ -63,7 +63,7 @@ def ffn(
message: Message,
context: Context,
) -> Message: # pylint: disable=invalid-name
out_message = handle_legacy_message_from_tasktype(
out_message = handle_legacy_message_from_msgtype(
client_fn=client_fn, message=message, context=context
)
return out_message
Expand Down
40 changes: 20 additions & 20 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from flwr.common import serde
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
MESSAGE_TYPE_EVALUATE,
MESSAGE_TYPE_FIT,
MESSAGE_TYPE_GET_PARAMETERS,
MESSAGE_TYPE_GET_PROPERTIES,
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log
Expand Down Expand Up @@ -133,33 +133,33 @@ def receive() -> Message:

# ServerMessage proto --> *Ins --> RecordSet
field = proto.WhichOneof("msg")
task_type = ""
message_type = ""
if field == "get_properties_ins":
recordset = compat.getpropertiesins_to_recordset(
serde.get_properties_ins_from_proto(proto.get_properties_ins)
)
task_type = TASK_TYPE_GET_PROPERTIES
message_type = MESSAGE_TYPE_GET_PROPERTIES
elif field == "get_parameters_ins":
recordset = compat.getparametersins_to_recordset(
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
)
task_type = TASK_TYPE_GET_PARAMETERS
message_type = MESSAGE_TYPE_GET_PARAMETERS
elif field == "fit_ins":
recordset = compat.fitins_to_recordset(
serde.fit_ins_from_proto(proto.fit_ins), False
)
task_type = TASK_TYPE_FIT
message_type = MESSAGE_TYPE_FIT
elif field == "evaluate_ins":
recordset = compat.evaluateins_to_recordset(
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
)
task_type = TASK_TYPE_EVALUATE
message_type = MESSAGE_TYPE_EVALUATE
elif field == "reconnect_ins":
recordset = RecordSet()
recordset.set_configs(
"config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds})
)
task_type = "reconnect"
message_type = "reconnect"
else:
raise ValueError(
"Unsupported instruction in ServerMessage, "
Expand All @@ -170,44 +170,44 @@ def receive() -> Message:
return Message(
metadata=Metadata(
run_id=0,
task_id=str(uuid.uuid4()),
message_id=str(uuid.uuid4()),
group_id="",
ttl="",
node_id=0,
task_type=task_type,
message_type=message_type,
),
content=recordset,
)

def send(message: Message) -> None:
# Retrieve RecordSet and task_type
# Retrieve RecordSet and message_type
recordset = message.content
task_type = message.metadata.task_type
message_type = message.metadata.message_type

# RecordSet --> *Res --> *Res proto -> ClientMessage proto
if task_type == TASK_TYPE_GET_PROPERTIES:
if message_type == MESSAGE_TYPE_GET_PROPERTIES:
getpropres = compat.recordset_to_getpropertiesres(recordset)
msg_proto = ClientMessage(
get_properties_res=serde.get_properties_res_to_proto(getpropres)
)
elif task_type == TASK_TYPE_GET_PARAMETERS:
elif message_type == MESSAGE_TYPE_GET_PARAMETERS:
getparamres = compat.recordset_to_getparametersres(recordset, False)
msg_proto = ClientMessage(
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
)
elif task_type == TASK_TYPE_FIT:
elif message_type == MESSAGE_TYPE_FIT:
fitres = compat.recordset_to_fitres(recordset, False)
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
elif task_type == TASK_TYPE_EVALUATE:
elif message_type == MESSAGE_TYPE_EVALUATE:
evalres = compat.recordset_to_evaluateres(recordset)
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
elif task_type == "reconnect":
elif message_type == "reconnect":
reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"])
msg_proto = ClientMessage(
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
else:
raise ValueError(f"Invalid task type: {task_type}")
raise ValueError(f"Invalid task type: {message_type}")

# Send ClientMessage proto
return queue.put(msg_proto, block=False)
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flwr.common import recordset_compat as compat
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import TASK_TYPE_GET_PROPERTIES
from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES
from flwr.common.message import Message, Metadata
from flwr.common.recordset import RecordSet
from flwr.common.typing import Code, GetPropertiesRes, Status
Expand All @@ -46,11 +46,11 @@
MESSAGE_GET_PROPERTIES = Message(
metadata=Metadata(
run_id=0,
task_id="",
message_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=compat.getpropertiesres_to_recordset(
GetPropertiesRes(Status(Code.OK, ""), {})
Expand All @@ -59,11 +59,11 @@
MESSAGE_DISCONNECT = Message(
metadata=Metadata(
run_id=0,
task_id="",
message_id="",
group_id="",
node_id=0,
ttl="",
task_type="reconnect",
message_type="reconnect",
),
content=RecordSet(configs={"config": ConfigsRecord({"reason": 0})}),
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def run_client() -> int:
message = receive()

messages_received += 1
if message.metadata.task_type == "reconnect": # type: ignore
if message.metadata.message_type == "reconnect": # type: ignore
send(MESSAGE_DISCONNECT)
break

Expand Down
32 changes: 16 additions & 16 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from flwr.client.typing import ClientFn
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
MESSAGE_TYPE_EVALUATE,
MESSAGE_TYPE_FIT,
MESSAGE_TYPE_GET_PARAMETERS,
MESSAGE_TYPE_GET_PROPERTIES,
)
from flwr.common.context import Context
from flwr.common.message import Message, Metadata
Expand Down Expand Up @@ -75,7 +75,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
sleep_duration : int
Number of seconds that the client should disconnect from the server.
"""
if message.metadata.task_type == "reconnect":
if message.metadata.message_type == "reconnect":
# Retrieve ReconnectIns from recordset
recordset = message.content
seconds = cast(int, recordset.get_configs("config")["seconds"])
Expand All @@ -90,11 +90,11 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
out_message = Message(
metadata=Metadata(
run_id=0,
task_id="",
message_id="",
group_id="",
node_id=0,
ttl="",
task_type="reconnect",
message_type="reconnect",
),
content=recordset,
)
Expand All @@ -105,25 +105,25 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
return None, 0


def handle_legacy_message_from_tasktype(
def handle_legacy_message_from_msgtype(
client_fn: ClientFn, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
client = client_fn("-1")

client.set_context(context)

task_type = message.metadata.task_type
message_type = message.metadata.message_type

# Handle GetPropertiesIns
if task_type == TASK_TYPE_GET_PROPERTIES:
if message_type == MESSAGE_TYPE_GET_PROPERTIES:
get_properties_res = maybe_call_get_properties(
client=client,
get_properties_ins=recordset_to_getpropertiesins(message.content),
)
out_recordset = getpropertiesres_to_recordset(get_properties_res)
# Handle GetParametersIns
elif task_type == TASK_TYPE_GET_PARAMETERS:
elif message_type == MESSAGE_TYPE_GET_PARAMETERS:
get_parameters_res = maybe_call_get_parameters(
client=client,
get_parameters_ins=recordset_to_getparametersins(message.content),
Expand All @@ -132,31 +132,31 @@ def handle_legacy_message_from_tasktype(
get_parameters_res, keep_input=False
)
# Handle FitIns
elif task_type == TASK_TYPE_FIT:
elif message_type == MESSAGE_TYPE_FIT:
fit_res = maybe_call_fit(
client=client,
fit_ins=recordset_to_fitins(message.content, keep_input=True),
)
out_recordset = fitres_to_recordset(fit_res, keep_input=False)
# Handle EvaluateIns
elif task_type == TASK_TYPE_EVALUATE:
elif message_type == MESSAGE_TYPE_EVALUATE:
evaluate_res = maybe_call_evaluate(
client=client,
evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
)
out_recordset = evaluateres_to_recordset(evaluate_res)
else:
raise ValueError(f"Invalid task type: {task_type}")
raise ValueError(f"Invalid task type: {message_type}")

# Return Message
out_message = Message(
metadata=Metadata(
run_id=0,
task_id="",
message_id="",
group_id="",
node_id=0,
ttl="",
task_type=task_type,
message_type=message_type,
),
content=out_recordset,
)
Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
)
from flwr.common import recordset_compat as compat
from flwr.common import typing
from flwr.common.constant import TASK_TYPE_GET_PROPERTIES
from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES
from flwr.common.context import Context
from flwr.common.message import Message, Metadata
from flwr.common.recordset import RecordSet

from .message_handler import handle_legacy_message_from_tasktype
from .message_handler import handle_legacy_message_from_msgtype


class ClientWithoutProps(Client):
Expand Down Expand Up @@ -122,17 +122,17 @@ def test_client_without_get_properties() -> None:
message = Message(
metadata=Metadata(
run_id=0,
task_id=str(uuid.uuid4()),
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=recordset,
)

# Execute
actual_msg = handle_legacy_message_from_tasktype(
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(state=RecordSet()),
Expand All @@ -150,7 +150,7 @@ def test_client_without_get_properties() -> None:
expected_msg = Message(message.metadata, expected_rs)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.task_type == expected_msg.metadata.task_type
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type


def test_client_with_get_properties() -> None:
Expand All @@ -161,17 +161,17 @@ def test_client_with_get_properties() -> None:
message = Message(
metadata=Metadata(
run_id=0,
task_id=str(uuid.uuid4()),
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=recordset,
)

# Execute
actual_msg = handle_legacy_message_from_tasktype(
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(state=RecordSet()),
Expand All @@ -189,4 +189,4 @@ def test_client_with_get_properties() -> None:
expected_msg = Message(message.metadata, expected_rs)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.task_type == expected_msg.metadata.task_type
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type
8 changes: 4 additions & 4 deletions src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flwr.common import ndarray_to_bytes, parameters_to_ndarrays
from flwr.common import recordset_compat as compat
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import TASK_TYPE_FIT
from flwr.common.constant import MESSAGE_TYPE_FIT
from flwr.common.context import Context
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
Expand Down Expand Up @@ -168,7 +168,7 @@ def secaggplus_mod(
) -> Message:
"""Handle incoming message and return results, following the SecAgg+ protocol."""
# Ignore non-fit messages
if msg.metadata.task_type != TASK_TYPE_FIT:
if msg.metadata.message_type != MESSAGE_TYPE_FIT:
return call_next(msg, ctxt)

# Retrieve local state
Expand Down Expand Up @@ -209,11 +209,11 @@ def secaggplus_mod(
return Message(
metadata=Metadata(
run_id=0,
task_id="",
message_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_FIT,
message_type=MESSAGE_TYPE_FIT,
),
content=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}),
)
Expand Down
Loading

0 comments on commit 314492c

Please sign in to comment.