Skip to content

Commit

Permalink
extend Run
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Nov 13, 2024
1 parent 8c82df5 commit e0d82c4
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 125 deletions.
5 changes: 5 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ message Run {
string fab_version = 3;
map<string, Scalar> override_config = 4;
string fab_hash = 5;
string pending_at = 6;
string starting_at = 7;
string running_at = 8;
string finished_at = 9;
RunStatus status = 10;
}

message RunStatus {
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _on_backoff(retry_state: RetryState) -> None:
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
runs[run_id] = Run(run_id, "", "", "", {})
runs[run_id] = Run.create_empty(run_id=run_id)

run: Run = runs[run_id]
if get_fab is not None and run.fab_hash:
Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/client/clientapp/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def test_set_inputs(self) -> None:
fab_version="ipsum",
fab_hash="dolor",
override_config=self.maker.user_config(),
pending_at="2021-01-01T00:00:00Z",
starting_at="",
running_at="",
finished_at="",
status=typing.RunStatus(status="pending", sub_status="", details=""),
)
fab = typing.Fab(
hash_str="abc123#$%",
Expand Down
14 changes: 2 additions & 12 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
from flwr.common.typing import Fab, Run
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
Expand Down Expand Up @@ -287,13 +283,7 @@ def get_run(run_id: int) -> Run:
)

# Return fab_id and fab_version
return Run(
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
get_run_response.run.fab_hash,
user_config_from_proto(get_run_response.run.override_config),
)
return run_from_proto(get_run_response.run)

def get_fab(fab_hash: str) -> Fab:
# Call FleetAPI
Expand Down
18 changes: 4 additions & 14 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
from flwr.common.typing import Fab, Run
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
Expand Down Expand Up @@ -361,15 +357,9 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
res.run.fab_hash,
user_config_from_proto(res.run.override_config),
)
return Run.create_empty(run_id)

return run_from_proto(res.run)

def get_fab(fab_hash: str) -> Fab:
# Construct the request
Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,11 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
fab_version=run.fab_version,
fab_hash=run.fab_hash,
override_config=user_config_to_proto(run.override_config),
pending_at=run.pending_at,
starting_at=run.starting_at,
running_at=run.running_at,
finished_at=run.finished_at,
status=run_status_to_proto(run.status),
)
return proto

Expand All @@ -884,6 +889,11 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
fab_version=run_proto.fab_version,
fab_hash=run_proto.fab_hash,
override_config=user_config_from_proto(run_proto.override_config),
pending_at=run_proto.pending_at,
starting_at=run_proto.starting_at,
running_at=run_proto.running_at,
finished_at=run_proto.finished_at,
status=run_status_from_proto(run_proto.status),
)
return run

Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def test_run_serialization_deserialization() -> None:
fab_version="ipsum",
fab_hash="hash",
override_config=maker.user_config(),
pending_at="2021-01-01T00:00:00Z",
starting_at="2021-01-02T23:02:11Z",
running_at="2021-01-03T12:00:50Z",
finished_at="",
status=typing.RunStatus(status="running", sub_status="", details="OK"),
)

# Execute
Expand Down
41 changes: 31 additions & 10 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,44 @@ class ClientMessage:


@dataclass
class Run:
class RunStatus:
"""Run status information."""

status: str
sub_status: str
details: str


@dataclass
class Run: # pylint: disable=too-many-instance-attributes
"""Run details."""

run_id: int
fab_id: str
fab_version: str
fab_hash: str
override_config: UserConfig


@dataclass
class RunStatus:
"""Run status information."""

status: str
sub_status: str
details: str
pending_at: str
starting_at: str
running_at: str
finished_at: str
status: RunStatus

@classmethod
def create_empty(cls, run_id: int = 0) -> "Run":
"""Return an empty Run instance."""
return cls(
run_id=run_id,
fab_id="",
fab_version="",
fab_hash="",
override_config={},
pending_at="",
starting_at="",
running_at="",
finished_at="",
status=RunStatus(status="", sub_status="", details=""),
)


@dataclass
Expand Down
60 changes: 30 additions & 30 deletions src/py/flwr/proto/run_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 18 additions & 1 deletion src/py/flwr/proto/run_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,38 @@ class Run(google.protobuf.message.Message):
FAB_VERSION_FIELD_NUMBER: builtins.int
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
FAB_HASH_FIELD_NUMBER: builtins.int
PENDING_AT_FIELD_NUMBER: builtins.int
STARTING_AT_FIELD_NUMBER: builtins.int
RUNNING_AT_FIELD_NUMBER: builtins.int
FINISHED_AT_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
run_id: builtins.int
fab_id: typing.Text
fab_version: typing.Text
@property
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
fab_hash: typing.Text
pending_at: typing.Text
starting_at: typing.Text
running_at: typing.Text
finished_at: typing.Text
@property
def status(self) -> global___RunStatus: ...
def __init__(self,
*,
run_id: builtins.int = ...,
fab_id: typing.Text = ...,
fab_version: typing.Text = ...,
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
fab_hash: typing.Text = ...,
pending_at: typing.Text = ...,
starting_at: typing.Text = ...,
running_at: typing.Text = ...,
finished_at: typing.Text = ...,
status: typing.Optional[global___RunStatus] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["status",b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","finished_at",b"finished_at","override_config",b"override_config","pending_at",b"pending_at","run_id",b"run_id","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ...
global___Run = Run

class RunStatus(google.protobuf.message.Message):
Expand Down
14 changes: 2 additions & 12 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import (
message_from_taskres,
message_to_taskins,
user_config_from_proto,
)
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
from flwr.common.typing import Run
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
Expand Down Expand Up @@ -119,13 +115,7 @@ def set_run(self, run_id: int) -> None:
res: GetRunResponse = self._stub.GetRun(req)
if not res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
self._run = Run(
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
fab_hash=res.run.fab_hash,
override_config=user_config_from_proto(res.run.override_config),
)
self._run = run_from_proto(res.run)

@property
def run(self) -> Run:
Expand Down
11 changes: 8 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
from unittest.mock import MagicMock, patch
from uuid import uuid4

from flwr.common import ConfigsRecord, RecordSet
from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL
from flwr.common import ConfigsRecord, RecordSet, now
from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL, Status
from flwr.common.message import Error
from flwr.common.serde import (
error_to_proto,
message_from_taskins,
message_to_taskres,
recordset_to_proto,
)
from flwr.common.typing import Run
from flwr.common.typing import Run, RunStatus
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.server.superlink.linkstate import (
InMemoryLinkState,
Expand Down Expand Up @@ -94,6 +94,11 @@ def setUp(self) -> None:
fab_version="v1.0.0",
fab_hash="9f86d08",
override_config={"test_key": "test_value"},
pending_at=now().isoformat(),
starting_at="",
running_at="",
finished_at="",
status=RunStatus(status=Status.PENDING, sub_status="", details=""),
)
state_factory = MagicMock(state=lambda: self.state)
self.driver = InMemoryDriver(state_factory=state_factory)
Expand Down
17 changes: 11 additions & 6 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Metadata,
RecordSet,
Scalar,
now,
)
from flwr.common.constant import Status
from flwr.common.recordset_compat import getpropertiesins_to_recordset
Expand Down Expand Up @@ -122,11 +123,15 @@ def register_messages_into_state(
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
),
RunStatus(
status=Status.PENDING,
sub_status="",
details="",
pending_at=now().isoformat(),
starting_at="",
running_at="",
finished_at="",
status=RunStatus(
status=Status.PENDING,
sub_status="",
details="",
),
),
)
# Artificially add TaskIns to state so they can be processed
Expand Down Expand Up @@ -210,7 +215,7 @@ def start_and_shutdown(
if not app_dir:
app_dir = _autoresolve_app_dir()

run = Run(run_id=1234, fab_id="", fab_version="", fab_hash="", override_config={})
run = Run.create_empty(run_id=1234)

start_vce(
num_supernodes=num_supernodes,
Expand Down
Loading

0 comments on commit e0d82c4

Please sign in to comment.