diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt index 7c66612eb2a5..2d0e65da3615 100644 --- a/examples/llm-flowertune/requirements.txt +++ b/examples/llm-flowertune/requirements.txt @@ -7,3 +7,4 @@ scipy==1.11.2 peft==0.4.0 fschat[model_worker,webui]==0.2.35 transformers==4.38.1 +hf_transfer==0.1.8 diff --git a/pyproject.toml b/pyproject.toml index 0cbcf1c877d9..8faf16136ae8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ flower-supernode = "flwr.client:run_supernode" flower-client-app = "flwr.client:run_client_app" flower-server-app = "flwr.server:run_server_app" flower-simulation = "flwr.simulation.run_simulation:run_simulation_from_cli" +flwr-clientapp = "flwr.client.supernode:flwr_clientapp" [tool.poetry.dependencies] python = "^3.8" diff --git a/src/proto/flwr/proto/clientappio.proto b/src/proto/flwr/proto/clientappio.proto new file mode 100644 index 000000000000..d73ed086f40d --- /dev/null +++ b/src/proto/flwr/proto/clientappio.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package flwr.proto; + +import "flwr/proto/fab.proto"; +import "flwr/proto/run.proto"; +import "flwr/proto/message.proto"; + +service ClientAppIo { + // Get Message, Context, and Run + rpc PullClientAppInputs(PullClientAppInputsRequest) + returns (PullClientAppInputsResponse) {} + + // Send updated Message and Context + rpc PushClientAppOutputs(PushClientAppOutputsRequest) + returns (PushClientAppOutputsResponse) {} +} + +enum ClientAppOutputCode { + SUCCESS = 0; + DEADLINE_EXCEEDED = 1; + UNKNOWN_ERROR = 2; +} +message ClientAppOutputStatus { + ClientAppOutputCode code = 1; + string message = 2; +} + +message PullClientAppInputsRequest { sint64 token = 1; } +message PullClientAppInputsResponse { + Message message = 1; + Context context = 2; + Run run = 3; +} +message PushClientAppOutputsRequest { + sint64 token = 1; + Message message = 2; + Context context = 3; +} +message PushClientAppOutputsResponse { ClientAppOutputStatus status = 1; } diff --git a/src/proto/flwr/proto/message.proto b/src/proto/flwr/proto/message.proto new file mode 100644 index 000000000000..d568522e761e --- /dev/null +++ b/src/proto/flwr/proto/message.proto @@ -0,0 +1,46 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +import "flwr/proto/error.proto"; +import "flwr/proto/recordset.proto"; +import "flwr/proto/transport.proto"; + +message Message { + Metadata metadata = 1; + RecordSet content = 2; + Error error = 3; +} + +message Context { + sint64 node_id = 1; + map node_config = 2; + RecordSet state = 3; + map run_config = 4; +} + +message Metadata { + sint64 run_id = 1; + string message_id = 2; + sint64 src_node_id = 3; + sint64 dst_node_id = 4; + string reply_to_message = 5; + string group_id = 6; + double ttl = 7; + string message_type = 8; +} diff --git a/src/py/flwr/client/supernode/__init__.py b/src/py/flwr/client/supernode/__init__.py index bc505f693875..128d0286d625 100644 --- a/src/py/flwr/client/supernode/__init__.py +++ b/src/py/flwr/client/supernode/__init__.py @@ -15,10 +15,12 @@ """Flower SuperNode.""" +from .app import flwr_clientapp as flwr_clientapp from .app import run_client_app as run_client_app from .app import run_supernode as run_supernode __all__ = [ + "flwr_clientapp", "run_client_app", "run_supernode", ] diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 861ccbe34ece..5840b57c0ab6 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -120,6 +120,31 @@ def run_client_app() -> None: register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) +def flwr_clientapp() -> None: + """Run process-isolated Flower ClientApp.""" + log(INFO, "Starting Flower ClientApp") + + parser = argparse.ArgumentParser( + description="Run a Flower ClientApp", + ) + parser.add_argument( + "--address", + help="Address of SuperNode ClientAppIo gRPC servicer", + ) + parser.add_argument( + "--token", + help="Unique token generated by SuperNode for each ClientApp execution", + ) + args = parser.parse_args() + log( + DEBUG, + "Staring isolated `ClientApp` connected to SuperNode ClientAppIo at %s " + "with the token %s", + args.address, + args.token, + ) + + def _warn_deprecated_server_arg(args: argparse.Namespace) -> None: """Warn about the deprecated argument `--server`.""" if args.server != ADDRESS_FLEET_API_GRPC_RERE: diff --git a/src/py/flwr/proto/clientappio_pb2.py b/src/py/flwr/proto/clientappio_pb2.py new file mode 100644 index 000000000000..2234e3c2a8af --- /dev/null +++ b/src/py/flwr/proto/clientappio_pb2.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/clientappio.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 +from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/clientappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x18\x66lwr/proto/message.proto\"W\n\x15\x43lientAppOutputStatus\x12-\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1f.flwr.proto.ClientAppOutputCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"+\n\x1aPullClientAppInputsRequest\x12\r\n\x05token\x18\x01 \x01(\x12\"\x87\x01\n\x1bPullClientAppInputsResponse\x12$\n\x07message\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Run\"x\n\x1bPushClientAppOutputsRequest\x12\r\n\x05token\x18\x01 \x01(\x12\x12$\n\x07message\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x03 \x01(\x0b\x32\x13.flwr.proto.Context\"Q\n\x1cPushClientAppOutputsResponse\x12\x31\n\x06status\x18\x01 \x01(\x0b\x32!.flwr.proto.ClientAppOutputStatus*L\n\x13\x43lientAppOutputCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44\x45\x41\x44LINE_EXCEEDED\x10\x01\x12\x11\n\rUNKNOWN_ERROR\x10\x02\x32\xe4\x01\n\x0b\x43lientAppIo\x12h\n\x13PullClientAppInputs\x12&.flwr.proto.PullClientAppInputsRequest\x1a\'.flwr.proto.PullClientAppInputsResponse\"\x00\x12k\n\x14PushClientAppOutputs\x12\'.flwr.proto.PushClientAppOutputsRequest\x1a(.flwr.proto.PushClientAppOutputsResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.clientappio_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_CLIENTAPPOUTPUTCODE']._serialized_start=591 + _globals['_CLIENTAPPOUTPUTCODE']._serialized_end=667 + _globals['_CLIENTAPPOUTPUTSTATUS']._serialized_start=114 + _globals['_CLIENTAPPOUTPUTSTATUS']._serialized_end=201 + _globals['_PULLCLIENTAPPINPUTSREQUEST']._serialized_start=203 + _globals['_PULLCLIENTAPPINPUTSREQUEST']._serialized_end=246 + _globals['_PULLCLIENTAPPINPUTSRESPONSE']._serialized_start=249 + _globals['_PULLCLIENTAPPINPUTSRESPONSE']._serialized_end=384 + _globals['_PUSHCLIENTAPPOUTPUTSREQUEST']._serialized_start=386 + _globals['_PUSHCLIENTAPPOUTPUTSREQUEST']._serialized_end=506 + _globals['_PUSHCLIENTAPPOUTPUTSRESPONSE']._serialized_start=508 + _globals['_PUSHCLIENTAPPOUTPUTSRESPONSE']._serialized_end=589 + _globals['_CLIENTAPPIO']._serialized_start=670 + _globals['_CLIENTAPPIO']._serialized_end=898 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/clientappio_pb2.pyi b/src/py/flwr/proto/clientappio_pb2.pyi new file mode 100644 index 000000000000..31c9dc4c6d14 --- /dev/null +++ b/src/py/flwr/proto/clientappio_pb2.pyi @@ -0,0 +1,110 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import flwr.proto.message_pb2 +import flwr.proto.run_pb2 +import google.protobuf.descriptor +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _ClientAppOutputCode: + ValueType = typing.NewType('ValueType', builtins.int) + V: typing_extensions.TypeAlias = ValueType +class _ClientAppOutputCodeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_ClientAppOutputCode.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + SUCCESS: _ClientAppOutputCode.ValueType # 0 + DEADLINE_EXCEEDED: _ClientAppOutputCode.ValueType # 1 + UNKNOWN_ERROR: _ClientAppOutputCode.ValueType # 2 +class ClientAppOutputCode(_ClientAppOutputCode, metaclass=_ClientAppOutputCodeEnumTypeWrapper): + pass + +SUCCESS: ClientAppOutputCode.ValueType # 0 +DEADLINE_EXCEEDED: ClientAppOutputCode.ValueType # 1 +UNKNOWN_ERROR: ClientAppOutputCode.ValueType # 2 +global___ClientAppOutputCode = ClientAppOutputCode + + +class ClientAppOutputStatus(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + CODE_FIELD_NUMBER: builtins.int + MESSAGE_FIELD_NUMBER: builtins.int + code: global___ClientAppOutputCode.ValueType + message: typing.Text + def __init__(self, + *, + code: global___ClientAppOutputCode.ValueType = ..., + message: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["code",b"code","message",b"message"]) -> None: ... +global___ClientAppOutputStatus = ClientAppOutputStatus + +class PullClientAppInputsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + TOKEN_FIELD_NUMBER: builtins.int + token: builtins.int + def __init__(self, + *, + token: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["token",b"token"]) -> None: ... +global___PullClientAppInputsRequest = PullClientAppInputsRequest + +class PullClientAppInputsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + MESSAGE_FIELD_NUMBER: builtins.int + CONTEXT_FIELD_NUMBER: builtins.int + RUN_FIELD_NUMBER: builtins.int + @property + def message(self) -> flwr.proto.message_pb2.Message: ... + @property + def context(self) -> flwr.proto.message_pb2.Context: ... + @property + def run(self) -> flwr.proto.run_pb2.Run: ... + def __init__(self, + *, + message: typing.Optional[flwr.proto.message_pb2.Message] = ..., + context: typing.Optional[flwr.proto.message_pb2.Context] = ..., + run: typing.Optional[flwr.proto.run_pb2.Run] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["context",b"context","message",b"message","run",b"run"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["context",b"context","message",b"message","run",b"run"]) -> None: ... +global___PullClientAppInputsResponse = PullClientAppInputsResponse + +class PushClientAppOutputsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + TOKEN_FIELD_NUMBER: builtins.int + MESSAGE_FIELD_NUMBER: builtins.int + CONTEXT_FIELD_NUMBER: builtins.int + token: builtins.int + @property + def message(self) -> flwr.proto.message_pb2.Message: ... + @property + def context(self) -> flwr.proto.message_pb2.Context: ... + def __init__(self, + *, + token: builtins.int = ..., + message: typing.Optional[flwr.proto.message_pb2.Message] = ..., + context: typing.Optional[flwr.proto.message_pb2.Context] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["context",b"context","message",b"message"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["context",b"context","message",b"message","token",b"token"]) -> None: ... +global___PushClientAppOutputsRequest = PushClientAppOutputsRequest + +class PushClientAppOutputsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + STATUS_FIELD_NUMBER: builtins.int + @property + def status(self) -> global___ClientAppOutputStatus: ... + def __init__(self, + *, + status: typing.Optional[global___ClientAppOutputStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["status",b"status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["status",b"status"]) -> None: ... +global___PushClientAppOutputsResponse = PushClientAppOutputsResponse diff --git a/src/py/flwr/proto/clientappio_pb2_grpc.py b/src/py/flwr/proto/clientappio_pb2_grpc.py new file mode 100644 index 000000000000..b244ef4a5b1d --- /dev/null +++ b/src/py/flwr/proto/clientappio_pb2_grpc.py @@ -0,0 +1,101 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from flwr.proto import clientappio_pb2 as flwr_dot_proto_dot_clientappio__pb2 + + +class ClientAppIoStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.PullClientAppInputs = channel.unary_unary( + '/flwr.proto.ClientAppIo/PullClientAppInputs', + request_serializer=flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsResponse.FromString, + ) + self.PushClientAppOutputs = channel.unary_unary( + '/flwr.proto.ClientAppIo/PushClientAppOutputs', + request_serializer=flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsResponse.FromString, + ) + + +class ClientAppIoServicer(object): + """Missing associated documentation comment in .proto file.""" + + def PullClientAppInputs(self, request, context): + """Get Message, Context, and Run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PushClientAppOutputs(self, request, context): + """Send updated Message and Context + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ClientAppIoServicer_to_server(servicer, server): + rpc_method_handlers = { + 'PullClientAppInputs': grpc.unary_unary_rpc_method_handler( + servicer.PullClientAppInputs, + request_deserializer=flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsRequest.FromString, + response_serializer=flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsResponse.SerializeToString, + ), + 'PushClientAppOutputs': grpc.unary_unary_rpc_method_handler( + servicer.PushClientAppOutputs, + request_deserializer=flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsRequest.FromString, + response_serializer=flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'flwr.proto.ClientAppIo', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ClientAppIo(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def PullClientAppInputs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.ClientAppIo/PullClientAppInputs', + flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsRequest.SerializeToString, + flwr_dot_proto_dot_clientappio__pb2.PullClientAppInputsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PushClientAppOutputs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.ClientAppIo/PushClientAppOutputs', + flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsRequest.SerializeToString, + flwr_dot_proto_dot_clientappio__pb2.PushClientAppOutputsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/clientappio_pb2_grpc.pyi b/src/py/flwr/proto/clientappio_pb2_grpc.pyi new file mode 100644 index 000000000000..4503e11f17ae --- /dev/null +++ b/src/py/flwr/proto/clientappio_pb2_grpc.pyi @@ -0,0 +1,40 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import abc +import flwr.proto.clientappio_pb2 +import grpc + +class ClientAppIoStub: + def __init__(self, channel: grpc.Channel) -> None: ... + PullClientAppInputs: grpc.UnaryUnaryMultiCallable[ + flwr.proto.clientappio_pb2.PullClientAppInputsRequest, + flwr.proto.clientappio_pb2.PullClientAppInputsResponse] + """Get Message, Context, and Run""" + + PushClientAppOutputs: grpc.UnaryUnaryMultiCallable[ + flwr.proto.clientappio_pb2.PushClientAppOutputsRequest, + flwr.proto.clientappio_pb2.PushClientAppOutputsResponse] + """Send updated Message and Context""" + + +class ClientAppIoServicer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def PullClientAppInputs(self, + request: flwr.proto.clientappio_pb2.PullClientAppInputsRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.clientappio_pb2.PullClientAppInputsResponse: + """Get Message, Context, and Run""" + pass + + @abc.abstractmethod + def PushClientAppOutputs(self, + request: flwr.proto.clientappio_pb2.PushClientAppOutputsRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.clientappio_pb2.PushClientAppOutputsResponse: + """Send updated Message and Context""" + pass + + +def add_ClientAppIoServicer_to_server(servicer: ClientAppIoServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/message_pb2.py b/src/py/flwr/proto/message_pb2.py new file mode 100644 index 000000000000..1dfa5656ea79 --- /dev/null +++ b/src/py/flwr/proto/message_pb2.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/message.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 +from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 +from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xa7\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\tb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.message_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_CONTEXT_NODECONFIGENTRY']._options = None + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_options = b'8\001' + _globals['_CONTEXT_RUNCONFIGENTRY']._options = None + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_options = b'8\001' + _globals['_MESSAGE']._serialized_start=120 + _globals['_MESSAGE']._serialized_end=243 + _globals['_CONTEXT']._serialized_start=246 + _globals['_CONTEXT']._serialized_end=565 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_start=426 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_end=495 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565 + _globals['_METADATA']._serialized_start=568 + _globals['_METADATA']._serialized_end=735 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/message_pb2.pyi b/src/py/flwr/proto/message_pb2.pyi new file mode 100644 index 000000000000..68b98430a59a --- /dev/null +++ b/src/py/flwr/proto/message_pb2.pyi @@ -0,0 +1,122 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import flwr.proto.error_pb2 +import flwr.proto.recordset_pb2 +import flwr.proto.transport_pb2 +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class Message(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + METADATA_FIELD_NUMBER: builtins.int + CONTENT_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int + @property + def metadata(self) -> global___Metadata: ... + @property + def content(self) -> flwr.proto.recordset_pb2.RecordSet: ... + @property + def error(self) -> flwr.proto.error_pb2.Error: ... + def __init__(self, + *, + metadata: typing.Optional[global___Metadata] = ..., + content: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., + error: typing.Optional[flwr.proto.error_pb2.Error] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["content",b"content","error",b"error","metadata",b"metadata"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["content",b"content","error",b"error","metadata",b"metadata"]) -> None: ... +global___Message = Message + +class Context(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class NodeConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class RunConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + NODE_ID_FIELD_NUMBER: builtins.int + NODE_CONFIG_FIELD_NUMBER: builtins.int + STATE_FIELD_NUMBER: builtins.int + RUN_CONFIG_FIELD_NUMBER: builtins.int + node_id: builtins.int + @property + def node_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + @property + def state(self) -> flwr.proto.recordset_pb2.RecordSet: ... + @property + def run_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + def __init__(self, + *, + node_id: builtins.int = ..., + node_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + state: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., + run_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["state",b"state"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","state",b"state"]) -> None: ... +global___Context = Context + +class Metadata(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + MESSAGE_ID_FIELD_NUMBER: builtins.int + SRC_NODE_ID_FIELD_NUMBER: builtins.int + DST_NODE_ID_FIELD_NUMBER: builtins.int + REPLY_TO_MESSAGE_FIELD_NUMBER: builtins.int + GROUP_ID_FIELD_NUMBER: builtins.int + TTL_FIELD_NUMBER: builtins.int + MESSAGE_TYPE_FIELD_NUMBER: builtins.int + run_id: builtins.int + message_id: typing.Text + src_node_id: builtins.int + dst_node_id: builtins.int + reply_to_message: typing.Text + group_id: typing.Text + ttl: builtins.float + message_type: typing.Text + def __init__(self, + *, + run_id: builtins.int = ..., + message_id: typing.Text = ..., + src_node_id: builtins.int = ..., + dst_node_id: builtins.int = ..., + reply_to_message: typing.Text = ..., + group_id: typing.Text = ..., + ttl: builtins.float = ..., + message_type: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ... +global___Metadata = Metadata diff --git a/src/py/flwr/proto/message_pb2_grpc.py b/src/py/flwr/proto/message_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/message_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/message_pb2_grpc.pyi b/src/py/flwr/proto/message_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/message_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/server/superlink/ffs/__init__.py b/src/py/flwr/server/superlink/ffs/__init__.py new file mode 100644 index 000000000000..0273d2a630e1 --- /dev/null +++ b/src/py/flwr/server/superlink/ffs/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower File Storage for large objects.""" + + +from .disk_ffs import DiskFfs as DiskFfs +from .ffs import Ffs as Ffs + +__all__ = [ + "DiskFfs", + "Ffs", +] diff --git a/src/py/flwr/server/superlink/ffs/disk_ffs.py b/src/py/flwr/server/superlink/ffs/disk_ffs.py new file mode 100644 index 000000000000..5331af500464 --- /dev/null +++ b/src/py/flwr/server/superlink/ffs/disk_ffs.py @@ -0,0 +1,104 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Disk based Flower File Storage.""" + +import hashlib +import json +from pathlib import Path +from typing import Dict, List, Tuple + +from flwr.server.superlink.ffs.ffs import Ffs + + +class DiskFfs(Ffs): # pylint: disable=R0904 + """Disk-based Flower File Storage interface for large objects.""" + + def __init__(self, base_dir: str) -> None: + """Create a new DiskFfs instance. + + Parameters + ---------- + base_dir : str + The base directory to store the objects. + """ + self.base_dir = Path(base_dir) + + def put(self, content: bytes, meta: Dict[str, str]) -> str: + """Store bytes and metadata and return key (hash of content). + + Parameters + ---------- + content : bytes + The content to be stored. + meta : Dict[str, str] + The metadata to be stored. + + Returns + ------- + key : str + The key (sha256hex hash) of the content. + """ + content_hash = hashlib.sha256(content).hexdigest() + + self.base_dir.mkdir(exist_ok=True, parents=True) + (self.base_dir / content_hash).write_bytes(content) + (self.base_dir / f"{content_hash}.META").write_text(json.dumps(meta)) + + return content_hash + + def get(self, key: str) -> Tuple[bytes, Dict[str, str]]: + """Return tuple containing the object content and metadata. + + Parameters + ---------- + key : str + The sha256hex hash of the object to be retrieved. + + Returns + ------- + Tuple[bytes, Dict[str, str]] + A tuple containing the object content and metadata. + """ + content = (self.base_dir / key).read_bytes() + meta = json.loads((self.base_dir / f"{key}.META").read_text()) + + return content, meta + + def delete(self, key: str) -> None: + """Delete object with hash. + + Parameters + ---------- + key : str + The sha256hex hash of the object to be deleted. + """ + (self.base_dir / key).unlink() + (self.base_dir / f"{key}.META").unlink() + + def list(self) -> List[str]: + """List all keys. + + Return all available keys in this `Ffs` instance. + This can be combined with, for example, + the `delete` method to delete objects. + + Returns + ------- + List[str] + A list of all available keys. + """ + return [ + item.name for item in self.base_dir.iterdir() if not item.suffix == ".META" + ] diff --git a/src/py/flwr/server/superlink/ffs/ffs.py b/src/py/flwr/server/superlink/ffs/ffs.py new file mode 100644 index 000000000000..622988141c9d --- /dev/null +++ b/src/py/flwr/server/superlink/ffs/ffs.py @@ -0,0 +1,79 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class for Flower File Storage interface.""" + + +import abc +from typing import Dict, List, Tuple + + +class Ffs(abc.ABC): # pylint: disable=R0904 + """Abstract Flower File Storage interface for large objects.""" + + @abc.abstractmethod + def put(self, content: bytes, meta: Dict[str, str]) -> str: + """Store bytes and metadata and return sha256hex hash of data as str. + + Parameters + ---------- + content : bytes + The content to be stored. + meta : Dict[str, str] + The metadata to be stored. + + Returns + ------- + key : str + The key (sha256hex hash) of the content. + """ + + @abc.abstractmethod + def get(self, key: str) -> Tuple[bytes, Dict[str, str]]: + """Return tuple containing the object content and metadata. + + Parameters + ---------- + key : str + The key (sha256hex hash) of the object to be retrieved. + + Returns + ------- + Tuple[bytes, Dict[str, str]] + A tuple containing the object content and metadata. + """ + + @abc.abstractmethod + def delete(self, key: str) -> None: + """Delete object with hash. + + Parameters + ---------- + key : str + The key (sha256hex hash) of the object to be deleted. + """ + + @abc.abstractmethod + def list(self) -> List[str]: + """List keys of all stored objects. + + Return all available keys in this `Ffs` instance. + This can be combined with, for example, + the `delete` method to delete objects. + + Returns + ------- + List[str] + A list of all available keys. + """ diff --git a/src/py/flwr/server/superlink/ffs/ffs_test.py b/src/py/flwr/server/superlink/ffs/ffs_test.py new file mode 100644 index 000000000000..3b25ac7b206a --- /dev/null +++ b/src/py/flwr/server/superlink/ffs/ffs_test.py @@ -0,0 +1,150 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests all Ffs implemenations have to conform to.""" +# pylint: disable=invalid-name, disable=R0904 + +import hashlib +import json +import os +import tempfile +import unittest +from abc import abstractmethod +from typing import Dict + +from flwr.server.superlink.ffs import DiskFfs, Ffs + + +class FfsTest(unittest.TestCase): + """Test all ffs implementations.""" + + # This is to True in each child class + __test__ = False + + tmp_dir: tempfile.TemporaryDirectory # type: ignore + + @abstractmethod + def ffs_factory(self) -> Ffs: + """Provide Ffs implementation to test.""" + raise NotImplementedError() + + def test_put(self) -> None: + """Test if object can be stored.""" + # Prepare + ffs: Ffs = self.ffs_factory() + content = b"content" + hash_expected = hashlib.sha256(content).hexdigest() + + # Execute + hash_actual = ffs.put(b"content", {"meta": "data"}) + + # Assert + assert isinstance(hash_actual, str) + assert len(hash_actual) == 64 + assert hash_actual == hash_expected + + # Check if file was created + assert {hash_expected, f"{hash_expected}.META"} == set( + os.listdir(self.tmp_dir.name) + ) + + def test_get(self) -> None: + """Test if object can be retrieved.""" + # Prepare + ffs: Ffs = self.ffs_factory() + content_expected = b"content" + hash_expected = hashlib.sha256(content_expected).hexdigest() + meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + + with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: + file.write(content_expected) + + with open( + os.path.join(self.tmp_dir.name, f"{hash_expected}.META"), + "w", + encoding="utf-8", + ) as file: + json.dump(meta_expected, file) + + # Execute + content_actual, meta_actual = ffs.get(hash_expected) + + # Assert + assert content_actual == content_expected + assert meta_actual == meta_expected + + def test_delete(self) -> None: + """Test if object can be deleted.""" + # Prepare + ffs: Ffs = self.ffs_factory() + content_expected = b"content" + hash_expected = hashlib.sha256(content_expected).hexdigest() + meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + + with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: + file.write(content_expected) + + with open( + os.path.join(self.tmp_dir.name, f"{hash_expected}.META"), + "w", + encoding="utf-8", + ) as file: + json.dump(meta_expected, file) + + # Execute + ffs.delete(hash_expected) + + # Assert + assert set() == set(os.listdir(self.tmp_dir.name)) + + def test_list(self) -> None: + """Test if object hashes can be listed.""" + # Prepare + ffs: Ffs = self.ffs_factory() + content_expected = b"content" + hash_expected = hashlib.sha256(content_expected).hexdigest() + meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + + with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: + file.write(content_expected) + + with open( + os.path.join(self.tmp_dir.name, f"{hash_expected}.META"), + "w", + encoding="utf-8", + ) as file: + json.dump(meta_expected, file) + + # Execute + hashes = ffs.list() + + # Assert + assert {hash_expected} == set(hashes) + + +class DiskFfsTest(FfsTest, unittest.TestCase): + """Test DiskFfs implementation.""" + + __test__ = True + + def ffs_factory(self) -> DiskFfs: + """Return SqliteState with file-based database.""" + # pylint: disable-next=consider-using-with,attribute-defined-outside-init + self.tmp_dir = tempfile.TemporaryDirectory() + ffs = DiskFfs(self.tmp_dir.name) + return ffs + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 6ba7daa0efea..6f9127304f25 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 11 + assert len(PROTO_FILES) == 13