From 02e181382df4b8f880bf92f92aafb357de50f161 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 19 Sep 2024 18:17:12 +0100 Subject: [PATCH 1/4] fix(framework:skip) Fix SuperExec log capturing (#4242) --- src/py/flwr/superexec/deployment.py | 2 ++ src/py/flwr/superexec/exec_servicer.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index d60870995d39..331fd817228e 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -167,6 +167,8 @@ def start_run( # Execute the command proc = subprocess.Popen( # pylint: disable=consider-using-with command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, ) log(INFO, "Started run %s", str(run_id)) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 8bf384312fca..ebb12b5ddbd2 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -16,6 +16,7 @@ import select +import sys import threading import time from collections.abc import Generator @@ -95,7 +96,8 @@ def StreamLogs( # pylint: disable=C0103 # is returned at this point and the server ends the stream. if self.runs[request.run_id].proc.poll() is not None: log(INFO, "All logs for run ID `%s` returned", request.run_id) - return + context.set_code(grpc.StatusCode.OK) + context.cancel() time.sleep(1.0) # Sleep briefly to avoid busy waiting @@ -115,7 +117,12 @@ def _capture_logs( ) # Read from std* and append to RunTracker.logs for stream in ready_to_read: - line = stream.readline().rstrip() + # Flush stdout to view output in real time + readline = stream.readline() + sys.stdout.write(readline) + sys.stdout.flush() + # Append to logs + line = readline.rstrip() if line: run.logs.append(f"{line}") From 247cada644c059df4575063d2ed3010a0c6a753b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 20 Sep 2024 16:32:04 +0100 Subject: [PATCH 2/4] fix(examples:skip) Specify `min_fit_clients` in `FedAvg` in the Secure Aggregation example (#4247) --- examples/flower-secure-aggregation/secaggexample/server_app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/flower-secure-aggregation/secaggexample/server_app.py b/examples/flower-secure-aggregation/secaggexample/server_app.py index 0f1b594317fa..0b95d68e4183 100644 --- a/examples/flower-secure-aggregation/secaggexample/server_app.py +++ b/examples/flower-secure-aggregation/secaggexample/server_app.py @@ -40,6 +40,7 @@ def main(driver: Driver, context: Context) -> None: strategy = FedAvg( # Select all available clients fraction_fit=1.0, + min_fit_clients=5, # Disable evaluation in demo fraction_evaluate=(0.0 if is_demo else context.run_config["fraction-evaluate"]), min_available_clients=5, From b6babe95addcbd1380f3894cc3ea71772119b72c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 23 Sep 2024 11:37:28 +0100 Subject: [PATCH 3/4] feat(framework) Add new RPCs to `Control` service (#4241) --- src/proto/flwr/proto/control.proto | 7 +++ src/proto/flwr/proto/run.proto | 20 ++++++ src/py/flwr/proto/control_pb2.py | 6 +- src/py/flwr/proto/control_pb2_grpc.py | 68 ++++++++++++++++++++ src/py/flwr/proto/control_pb2_grpc.pyi | 26 ++++++++ src/py/flwr/proto/run_pb2.py | 32 +++++++--- src/py/flwr/proto/run_pb2.pyi | 86 ++++++++++++++++++++++++++ 7 files changed, 233 insertions(+), 12 deletions(-) diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto index 9747c2e3ea11..8b75c66fccaa 100644 --- a/src/proto/flwr/proto/control.proto +++ b/src/proto/flwr/proto/control.proto @@ -22,4 +22,11 @@ import "flwr/proto/run.proto"; service Control { // Request to create a new run rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} + + // Get the status of a given run + rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {} + + // Update the status of a given run + rpc UpdateRunStatus(UpdateRunStatusRequest) + returns (UpdateRunStatusResponse) {} } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index ada72610e182..2c9bd877f66c 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -28,6 +28,15 @@ message Run { string fab_hash = 5; } +message RunStatus { + // "starting", "running", "finished" + string status = 1; + // "completed", "failed", "stopped" or "" (non-finished) + string sub_status = 2; + // failure details + string details = 3; +} + // CreateRun message CreateRunRequest { string fab_id = 1; @@ -40,3 +49,14 @@ message CreateRunResponse { uint64 run_id = 1; } // GetRun message GetRunRequest { uint64 run_id = 1; } message GetRunResponse { Run run = 1; } + +// UpdateRunStatus +message UpdateRunStatusRequest { + uint64 run_id = 1; + RunStatus run_status = 2; +} +message UpdateRunStatusResponse {} + +// GetRunStatus +message GetRunStatusRequest { repeated uint64 run_ids = 1; } +message GetRunStatusResponse { map run_status_dict = 1; } diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py index 2b8776509d32..eb1c18d8dcff 100644 --- a/src/py/flwr/proto/control_pb2.py +++ b/src/py/flwr/proto/control_pb2.py @@ -15,13 +15,13 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\x88\x02\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CONTROL']._serialized_start=62 - _globals['_CONTROL']._serialized_end=147 + _globals['_CONTROL']._serialized_start=63 + _globals['_CONTROL']._serialized_end=327 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py index 987b9d8d7433..a59f90f15935 100644 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -19,6 +19,16 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) + self.GetRunStatus = channel.unary_unary( + '/flwr.proto.Control/GetRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + ) + self.UpdateRunStatus = channel.unary_unary( + '/flwr.proto.Control/UpdateRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, + ) class ControlServicer(object): @@ -31,6 +41,20 @@ def CreateRun(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRunStatus(self, request, context): + """Get the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateRunStatus(self, request, context): + """Update the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { @@ -39,6 +63,16 @@ def add_ControlServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), + 'GetRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString, + ), + 'UpdateRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.UpdateRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Control', rpc_method_handlers) @@ -65,3 +99,37 @@ def CreateRun(request, flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus', + flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/UpdateRunStatus', + flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi index c38b7f15d125..7817e2b12e31 100644 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -13,6 +13,16 @@ class ControlStub: flwr.proto.run_pb2.CreateRunResponse] """Request to create a new run""" + GetRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunStatusRequest, + flwr.proto.run_pb2.GetRunStatusResponse] + """Get the status of a given run""" + + UpdateRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.UpdateRunStatusRequest, + flwr.proto.run_pb2.UpdateRunStatusResponse] + """Update the status of a given run""" + class ControlServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -23,5 +33,21 @@ class ControlServicer(metaclass=abc.ABCMeta): """Request to create a new run""" pass + @abc.abstractmethod + def GetRunStatus(self, + request: flwr.proto.run_pb2.GetRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunStatusResponse: + """Get the status of a given run""" + pass + + @abc.abstractmethod + def UpdateRunStatus(self, + request: flwr.proto.run_pb2.UpdateRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.UpdateRunStatusResponse: + """Update the status of a given run""" + pass + def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 99ca4df5c44c..d59cc26fbb48 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"&\n\x13GetRunStatusRequest\x12\x0f\n\x07run_ids\x18\x01 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,18 +27,32 @@ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' _globals['_RUN']._serialized_start=87 _globals['_RUN']._serialized_end=300 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNREQUEST']._serialized_start=303 - _globals['_CREATERUNREQUEST']._serialized_end=538 + _globals['_RUNSTATUS']._serialized_start=302 + _globals['_RUNSTATUS']._serialized_end=366 + _globals['_CREATERUNREQUEST']._serialized_start=369 + _globals['_CREATERUNREQUEST']._serialized_end=604 _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNRESPONSE']._serialized_start=540 - _globals['_CREATERUNRESPONSE']._serialized_end=575 - _globals['_GETRUNREQUEST']._serialized_start=577 - _globals['_GETRUNREQUEST']._serialized_end=608 - _globals['_GETRUNRESPONSE']._serialized_start=610 - _globals['_GETRUNRESPONSE']._serialized_end=656 + _globals['_CREATERUNRESPONSE']._serialized_start=606 + _globals['_CREATERUNRESPONSE']._serialized_end=641 + _globals['_GETRUNREQUEST']._serialized_start=643 + _globals['_GETRUNREQUEST']._serialized_end=674 + _globals['_GETRUNRESPONSE']._serialized_start=676 + _globals['_GETRUNRESPONSE']._serialized_end=722 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=724 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=807 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=809 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=834 + _globals['_GETRUNSTATUSREQUEST']._serialized_start=836 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=874 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=877 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1054 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=979 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1054 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 26b69e7eed27..cec90c4d2d4c 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -52,6 +52,29 @@ class Run(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run +class RunStatus(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + STATUS_FIELD_NUMBER: builtins.int + SUB_STATUS_FIELD_NUMBER: builtins.int + DETAILS_FIELD_NUMBER: builtins.int + status: typing.Text + """"starting", "running", "finished" """ + + sub_status: typing.Text + """"completed", "failed", "stopped" or "" (non-finished)""" + + details: typing.Text + """failure details""" + + def __init__(self, + *, + status: typing.Text = ..., + sub_status: typing.Text = ..., + details: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["details",b"details","status",b"status","sub_status",b"sub_status"]) -> None: ... +global___RunStatus = RunStatus + class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -126,3 +149,66 @@ class GetRunResponse(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... global___GetRunResponse = GetRunResponse + +class UpdateRunStatusRequest(google.protobuf.message.Message): + """UpdateRunStatus""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + RUN_STATUS_FIELD_NUMBER: builtins.int + run_id: builtins.int + @property + def run_status(self) -> global___RunStatus: ... + def __init__(self, + *, + run_id: builtins.int = ..., + run_status: typing.Optional[global___RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","run_status",b"run_status"]) -> None: ... +global___UpdateRunStatusRequest = UpdateRunStatusRequest + +class UpdateRunStatusResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + def __init__(self, + ) -> None: ... +global___UpdateRunStatusResponse = UpdateRunStatusResponse + +class GetRunStatusRequest(google.protobuf.message.Message): + """GetRunStatus""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_IDS_FIELD_NUMBER: builtins.int + @property + def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_ids",b"run_ids"]) -> None: ... +global___GetRunStatusRequest = GetRunStatusRequest + +class GetRunStatusResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class RunStatusDictEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.int + @property + def value(self) -> global___RunStatus: ... + def __init__(self, + *, + key: builtins.int = ..., + value: typing.Optional[global___RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + RUN_STATUS_DICT_FIELD_NUMBER: builtins.int + @property + def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___RunStatus]: ... + def __init__(self, + *, + run_status_dict: typing.Optional[typing.Mapping[builtins.int, global___RunStatus]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... +global___GetRunStatusResponse = GetRunStatusResponse From b370b0618fe53d6f01e5486c3269ee09dc0dcc6e Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 23 Sep 2024 13:54:22 +0100 Subject: [PATCH 4/4] feat(framework) Add `flwr log` (#3577) Co-authored-by: Charles Beauville Co-authored-by: jafermarq Co-authored-by: Taner Topal --- src/py/flwr/cli/app.py | 2 + src/py/flwr/cli/log.py | 196 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 src/py/flwr/cli/log.py diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py index 93effea6df98..8baccb4638fc 100644 --- a/src/py/flwr/cli/app.py +++ b/src/py/flwr/cli/app.py @@ -19,6 +19,7 @@ from .build import build from .install import install +from .log import log from .new import new from .run import run @@ -35,6 +36,7 @@ app.command()(run) app.command()(build) app.command()(install) +app.command()(log) typer_click_object = get_command(app) diff --git a/src/py/flwr/cli/log.py b/src/py/flwr/cli/log.py new file mode 100644 index 000000000000..6915de1e00c5 --- /dev/null +++ b/src/py/flwr/cli/log.py @@ -0,0 +1,196 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `log` command.""" + +import sys +import time +from logging import DEBUG, ERROR, INFO +from pathlib import Path +from typing import Annotated, Optional + +import grpc +import typer + +from flwr.cli.config_utils import load_and_validate +from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel +from flwr.common.logger import log as logger + +CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) + + +# pylint: disable=unused-argument +def stream_logs(run_id: int, channel: grpc.Channel, period: int) -> None: + """Stream logs from the beginning of a run with connection refresh.""" + + +# pylint: disable=unused-argument +def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None: + """Print logs from the beginning of a run.""" + + +def on_channel_state_change(channel_connectivity: str) -> None: + """Log channel connectivity.""" + logger(DEBUG, channel_connectivity) + + +def log( + run_id: Annotated[ + int, + typer.Argument(help="The Flower run ID to query"), + ], + app: Annotated[ + Path, + typer.Argument(help="Path of the Flower project to run"), + ] = Path("."), + federation: Annotated[ + Optional[str], + typer.Argument(help="Name of the federation to run the app on"), + ] = None, + stream: Annotated[ + bool, + typer.Option( + "--stream/--show", + help="Flag to stream or print logs from the Flower run", + ), + ] = True, +) -> None: + """Get logs from a Flower project run.""" + typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) + + pyproject_path = app / "pyproject.toml" if app else None + config, errors, warnings = load_and_validate(path=pyproject_path) + + if config is None: + typer.secho( + "Project configuration could not be loaded.\n" + "pyproject.toml is invalid:\n" + + "\n".join([f"- {line}" for line in errors]), + fg=typer.colors.RED, + bold=True, + ) + sys.exit() + + if warnings: + typer.secho( + "Project configuration is missing the following " + "recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]), + fg=typer.colors.RED, + bold=True, + ) + + typer.secho("Success", fg=typer.colors.GREEN) + + federation = federation or config["tool"]["flwr"]["federations"].get("default") + + if federation is None: + typer.secho( + "❌ No federation name was provided and the project's `pyproject.toml` " + "doesn't declare a default federation (with a SuperExec address or an " + "`options.num-supernodes` value).", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + # Validate the federation exists in the configuration + federation_config = config["tool"]["flwr"]["federations"].get(federation) + if federation_config is None: + available_feds = { + fed for fed in config["tool"]["flwr"]["federations"] if fed != "default" + } + typer.secho( + f"❌ There is no `{federation}` federation declared in the " + "`pyproject.toml`.\n The following federations were found:\n\n" + + "\n".join(available_feds), + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + if "address" not in federation_config: + typer.secho( + "❌ `flwr log` currently works with `SuperExec`. Ensure that the correct" + "`SuperExec` address is provided in the `pyproject.toml`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + _log_with_superexec(federation_config, run_id, stream) + + +# pylint: disable-next=too-many-branches +def _log_with_superexec( + federation_config: dict[str, str], + run_id: int, + stream: bool, +) -> None: + insecure_str = federation_config.get("insecure") + if root_certificates := federation_config.get("root-certificates"): + root_certificates_bytes = Path(root_certificates).read_bytes() + if insecure := bool(insecure_str): + typer.secho( + "❌ `root_certificates` were provided but the `insecure` parameter" + "is set to `True`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + else: + root_certificates_bytes = None + if insecure_str is None: + typer.secho( + "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + if not (insecure := bool(insecure_str)): + typer.secho( + "❌ No certificate were given yet `insecure` is set to `False`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + channel = create_channel( + server_address=federation_config["address"], + insecure=insecure, + root_certificates=root_certificates_bytes, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + interceptors=None, + ) + channel.subscribe(on_channel_state_change) + + if stream: + try: + while True: + logger(INFO, "Starting logstream for run_id `%s`", run_id) + stream_logs(run_id, channel, CONN_REFRESH_PERIOD) + time.sleep(2) + logger(DEBUG, "Reconnecting to logstream") + except KeyboardInterrupt: + logger(INFO, "Exiting logstream") + except grpc.RpcError as e: + # pylint: disable=E1101 + if e.code() == grpc.StatusCode.NOT_FOUND: + logger(ERROR, "Invalid run_id `%s`, exiting", run_id) + if e.code() == grpc.StatusCode.CANCELLED: + pass + finally: + channel.close() + else: + logger(INFO, "Printing logstream for run_id `%s`", run_id) + print_logs(run_id, channel, timeout=5)