From 62d504dd9e7c9fbce8d99d306cffb529429f69e5 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 14:06:39 +0200 Subject: [PATCH 1/8] feat(framework) Add proto changes for config overrides --- src/proto/flwr/proto/driver.proto | 1 + src/proto/flwr/proto/exec.proto | 5 +++- src/proto/flwr/proto/run.proto | 1 + src/py/flwr/proto/common_pb2.py | 24 +++++++++++++++ src/py/flwr/proto/common_pb2.pyi | 7 +++++ src/py/flwr/proto/common_pb2_grpc.py | 4 +++ src/py/flwr/proto/common_pb2_grpc.pyi | 4 +++ src/py/flwr/proto/driver_pb2.py | 42 +++++++++++++++------------ src/py/flwr/proto/driver_pb2.pyi | 19 +++++++++++- src/py/flwr/proto/exec_pb2.py | 26 ++++++++++------- src/py/flwr/proto/exec_pb2.pyi | 20 ++++++++++++- src/py/flwr/proto/run_pb2.py | 18 +++++++----- src/py/flwr/proto/run_pb2.pyi | 20 ++++++++++++- 13 files changed, 150 insertions(+), 41 deletions(-) create mode 100644 src/py/flwr/proto/common_pb2.py create mode 100644 src/py/flwr/proto/common_pb2.pyi create mode 100644 src/py/flwr/proto/common_pb2_grpc.py create mode 100644 src/py/flwr/proto/common_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index edbd5d91bb5b..77dc52b3258b 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -42,6 +42,7 @@ service Driver { message CreateRunRequest { string fab_id = 1; string fab_version = 2; + map override_config = 3; } message CreateRunResponse { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 8e5f53b02ca8..d0d8dfcbb273 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -25,7 +25,10 @@ service Exec { rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} } -message StartRunRequest { bytes fab_file = 1; } +message StartRunRequest { + bytes fab_file = 1; + map override_config = 2; +} message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } message StreamLogsResponse { string log_output = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 76a7fd91532f..e41748381cab 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -21,6 +21,7 @@ message Run { sint64 run_id = 1; string fab_id = 2; string fab_version = 3; + map override_config = 4; } message GetRunRequest { sint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/proto/common_pb2.py b/src/py/flwr/proto/common_pb2.py new file mode 100644 index 000000000000..8a6430137f05 --- /dev/null +++ b/src/py/flwr/proto/common_pb2.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/common.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.common_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/common_pb2.pyi b/src/py/flwr/proto/common_pb2.pyi new file mode 100644 index 000000000000..e08fa11c2caa --- /dev/null +++ b/src/py/flwr/proto/common_pb2.pyi @@ -0,0 +1,7 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import google.protobuf.descriptor + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor diff --git a/src/py/flwr/proto/common_pb2_grpc.py b/src/py/flwr/proto/common_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/common_pb2_grpc.pyi b/src/py/flwr/proto/common_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index a2458b445563..07975937328d 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -17,29 +17,33 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"\xb9\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=107 - _globals['_CREATERUNREQUEST']._serialized_end=162 - _globals['_CREATERUNRESPONSE']._serialized_start=164 - _globals['_CREATERUNRESPONSE']._serialized_end=199 - _globals['_GETNODESREQUEST']._serialized_start=201 - _globals['_GETNODESREQUEST']._serialized_end=234 - _globals['_GETNODESRESPONSE']._serialized_start=236 - _globals['_GETNODESRESPONSE']._serialized_end=287 - _globals['_PUSHTASKINSREQUEST']._serialized_start=289 - _globals['_PUSHTASKINSREQUEST']._serialized_end=353 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 - _globals['_PULLTASKRESREQUEST']._serialized_start=396 - _globals['_PULLTASKRESREQUEST']._serialized_end=466 - _globals['_PULLTASKRESRESPONSE']._serialized_start=468 - _globals['_PULLTASKRESRESPONSE']._serialized_end=533 - _globals['_DRIVER']._serialized_start=536 - _globals['_DRIVER']._serialized_end=924 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_CREATERUNREQUEST']._serialized_start=108 + _globals['_CREATERUNREQUEST']._serialized_end=293 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=240 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=293 + _globals['_CREATERUNRESPONSE']._serialized_start=295 + _globals['_CREATERUNRESPONSE']._serialized_end=330 + _globals['_GETNODESREQUEST']._serialized_start=332 + _globals['_GETNODESREQUEST']._serialized_end=365 + _globals['_GETNODESRESPONSE']._serialized_start=367 + _globals['_GETNODESRESPONSE']._serialized_end=418 + _globals['_PUSHTASKINSREQUEST']._serialized_start=420 + _globals['_PUSHTASKINSREQUEST']._serialized_end=484 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=486 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=525 + _globals['_PULLTASKRESREQUEST']._serialized_start=527 + _globals['_PULLTASKRESREQUEST']._serialized_end=597 + _globals['_PULLTASKRESRESPONSE']._serialized_start=599 + _globals['_PULLTASKRESRESPONSE']._serialized_end=664 + _globals['_DRIVER']._serialized_start=667 + _globals['_DRIVER']._serialized_end=1055 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 2d8d11fb59a3..95d4c9785ff1 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -16,16 +16,33 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... global___CreateRunRequest = CreateRunRequest class CreateRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 7b037a9454c0..4aee0f4a882f 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -14,21 +14,25 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"#\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"\xa4\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.exec_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_STARTRUNREQUEST']._serialized_start=37 - _globals['_STARTRUNREQUEST']._serialized_end=72 - _globals['_STARTRUNRESPONSE']._serialized_start=74 - _globals['_STARTRUNRESPONSE']._serialized_end=108 - _globals['_STREAMLOGSREQUEST']._serialized_start=110 - _globals['_STREAMLOGSREQUEST']._serialized_end=145 - _globals['_STREAMLOGSRESPONSE']._serialized_start=147 - _globals['_STREAMLOGSRESPONSE']._serialized_end=187 - _globals['_EXEC']._serialized_start=190 - _globals['_EXEC']._serialized_end=350 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST']._serialized_start=38 + _globals['_STARTRUNREQUEST']._serialized_end=202 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=149 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=202 + _globals['_STARTRUNRESPONSE']._serialized_start=204 + _globals['_STARTRUNRESPONSE']._serialized_end=238 + _globals['_STREAMLOGSREQUEST']._serialized_start=240 + _globals['_STREAMLOGSREQUEST']._serialized_end=275 + _globals['_STREAMLOGSRESPONSE']._serialized_start=277 + _globals['_STREAMLOGSRESPONSE']._serialized_end=317 + _globals['_EXEC']._serialized_start=320 + _globals['_EXEC']._serialized_end=480 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 466812808da8..8065fc1de1b4 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,13 +13,30 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class StartRunRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_FILE_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_file: builtins.bytes = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13f06e7169aa..d6531201f647 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -14,17 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\xaf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_RUN']._serialized_start=36 - _globals['_RUN']._serialized_end=94 - _globals['_GETRUNREQUEST']._serialized_start=96 - _globals['_GETRUNREQUEST']._serialized_end=127 - _globals['_GETRUNRESPONSE']._serialized_start=129 - _globals['_GETRUNRESPONSE']._serialized_end=175 + _globals['_RUN_OVERRIDECONFIGENTRY']._options = None + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=37 + _globals['_RUN']._serialized_end=212 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=159 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=212 + _globals['_GETRUNREQUEST']._serialized_start=214 + _globals['_GETRUNREQUEST']._serialized_end=245 + _globals['_GETRUNRESPONSE']._serialized_start=247 + _globals['_GETRUNRESPONSE']._serialized_end=293 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 401d27855a41..3c58c04c1734 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,19 +13,36 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class Run(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int run_id: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, run_id: builtins.int = ..., fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run class GetRunRequest(google.protobuf.message.Message): From 8587e73442c088d6d45122b7f413519a6fe37c94 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 14:46:47 +0200 Subject: [PATCH 2/8] feat(framework) Add override config to Run --- src/py/flwr/client/app.py | 25 +++++++--- .../client/grpc_adapter_client/connection.py | 3 +- src/py/flwr/client/grpc_client/connection.py | 3 +- .../client/grpc_rere_client/connection.py | 12 +++-- src/py/flwr/client/rest_client/connection.py | 14 ++++-- src/py/flwr/server/driver/grpc_driver.py | 1 + .../server/driver/inmemory_driver_test.py | 10 ++-- .../superlink/driver/driver_servicer.py | 6 ++- .../grpc_rere/server_interceptor_test.py | 4 +- .../superlink/fleet/vce/vce_api_test.py | 4 +- .../server/superlink/state/in_memory_state.py | 12 ++++- .../server/superlink/state/sqlite_state.py | 23 +++++++-- src/py/flwr/server/superlink/state/state.py | 9 +++- .../flwr/server/superlink/state/state_test.py | 47 ++++++++++--------- src/py/flwr/simulation/run_simulation.py | 2 +- 15 files changed, 120 insertions(+), 55 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index d2d5a79f32f3..36757da18960 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,6 +19,7 @@ import time from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -29,6 +30,7 @@ from flwr.client.typing import ClientFnExt from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address +from flwr.common.config import get_fused_config from flwr.common.constant import ( MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_ADAPTER, @@ -41,6 +43,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.typing import Run from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -192,6 +195,7 @@ def _start_client_internal( max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, partition_id: Optional[int] = None, + flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -235,9 +239,16 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - partitioni_id: Optional[int] (default: None) + partition_id: Optional[int] (default: None) The data partition index associated with this node. Better suited for prototyping purposes. + flwr_dir: Optional[Path] (default: None) + The path containing installed Flower Apps. + By default, this value is equal to: + + - `$FLWR_HOME/` if `$FLWR_HOME` is defined + - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined + - `$HOME/.flwr/` in all other cases """ if insecure is None: insecure = root_certificates is None @@ -315,8 +326,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) node_state = NodeState(partition_id=partition_id) - # run_id -> (fab_id, fab_version) - run_info: Dict[int, Tuple[str, str]] = {} + run_info: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -371,13 +381,14 @@ def _on_backoff(retry_state: RetryState) -> None: run_info[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = ("", "") + run_info[run_id] = Run(run_id, "", "", {}) # Register context for this run node_state.register_context(run_id=run_id) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) + context.run_config = get_fused_config(run_info[run_id], flwr_dir) # Create an error reply message that will never be used to prevent # the used-before-assignment linting error @@ -388,7 +399,9 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - client_app: ClientApp = load_client_app_fn(*run_info[run_id]) + client_app: ClientApp = load_client_app_fn( + run_info[run_id].fab_id, run_info[run_id].fab_version + ) # Execute ClientApp reply_message = client_app(message=message, context=context) @@ -573,7 +586,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ], ], diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e4e32b3accd0..971b630e470b 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.logger import log from flwr.common.message import Message from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run @contextmanager @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server via GrpcAdapter. diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 8c049861c672..3e9f261c1ecf 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -38,6 +38,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Establish a gRPC connection to a gRPC server. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 34dc0e417383..8062ce28fcc7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -80,7 +81,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -266,7 +267,7 @@ def send(message: Message) -> None: # Cleanup metadata = None - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Call FleetAPI get_run_request = GetRunRequest(run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( @@ -275,7 +276,12 @@ def get_run(run_id: int) -> Tuple[str, str]: ) # Return fab_id and fab_version - return get_run_response.run.fab_id, get_run_response.run.fab_version + return Run( + run_id, + get_run_response.run.fab_id, + get_run_response.run.fab_version, + dict(get_run_response.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index db5bd7eb6770..0efa5731ae51 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -91,7 +92,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -344,16 +345,21 @@ def send(message: Message) -> None: res.results, # pylint: disable=no-member ) - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Construct the request req = GetRunRequest(run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return "", "" + return Run(run_id, "", "", {}) - return res.run.fab_id, res.run.fab_version + return Run( + run_id, + res.run.fab_id, + res.run.fab_version, + dict(res.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e614df659e3f..ae4b3d2519fb 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -206,6 +206,7 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, + override_config=dict(res.run.override_config.items()), ) return self.stub, self._run.run_id diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 0cc1c5a53e13..d0f32e830f7d 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -86,7 +86,10 @@ def setUp(self) -> None: for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( - run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + run_id=61016, + fab_id="mock/mock", + fab_version="v1.0.0", + override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) @@ -98,6 +101,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", ""), MagicMock(state=lambda: state) + state.create_run("", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", ""), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..7f8ded3bdb85 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -69,7 +69,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + dict(request.override_config.items()), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 01499102b7d8..798e71435585 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index df9f2cc96f95..c0bf506fd2b6 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -82,7 +82,9 @@ def register_messages_into_state( ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0") + state.run_ids[run_id] = Run( + run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5a4e4eb0fd9a..bc4bd4478a23 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -275,7 +275,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" return self.public_key_to_node_id.get(client_public_key) - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: @@ -283,7 +288,10 @@ def create_run(self, fab_id: str, fab_version: str) -> int: if run_id not in self.run_ids: self.run_ids[run_id] = Run( - run_id=run_id, fab_id=fab_id, fab_version=fab_version + run_id=run_id, + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, ) return run_id log(ERROR, "Unexpected run creation failure.") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 725f7c2dff4b..49f40653750e 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,6 +18,7 @@ import re import sqlite3 import time +from ast import literal_eval from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -63,7 +64,8 @@ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE, fab_id TEXT, - fab_version TEXT + fab_version TEXT, + overrides TEXT ); """ @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: return node_id return None - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) @@ -622,8 +629,11 @@ def create_run(self, fab_id: str, fab_version: str) -> int: query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);" - self.query(query, (run_id, fab_id, fab_version)) + query = ( + "INSERT INTO run (run_id, fab_id, fab_version, overrides)" + "VALUES (?, ?, ?, ?);" + ) + self.query(query, (run_id, fab_id, fab_version, str(override_config))) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -687,7 +697,10 @@ def get_run(self, run_id: int) -> Optional[Run]: try: row = self.query(query, (run_id,))[0] return Run( - run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"] + run_id=run_id, + fab_id=row["fab_id"], + fab_version=row["fab_version"], + override_config=literal_eval(row["overrides"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..c93f6ba756b8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID from flwr.common.typing import Run @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 373202d5cde6..5f0d23ffc4d8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0") + run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -62,6 +62,7 @@ def test_create_and_get_run(self) -> None: assert run.run_id == run_id assert run.fab_id == "Mock/mock" assert run.fab_version == "v1.0.0" + assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" @@ -90,7 +91,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -125,7 +126,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -199,7 +200,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -214,7 +215,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -228,7 +229,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -242,7 +243,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -259,7 +260,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -302,7 +303,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -323,7 +324,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -335,7 +336,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [] # Execute @@ -352,7 +353,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -368,7 +369,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -390,7 +391,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -405,7 +406,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -422,7 +423,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = 0 # Execute & Assert @@ -441,7 +442,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -460,7 +461,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -475,7 +476,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0") + state.create_run("mock/mock", "v1.0.0", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -489,7 +490,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -507,7 +508,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -608,7 +609,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -627,7 +628,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7c7a412a245b..91805dc5ed7b 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -209,7 +209,7 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "", {}) if run_id: _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) From 51eae29b02ff559313684a36859dd7fb8dd51202 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 14:52:42 +0200 Subject: [PATCH 3/8] Add overrides to Run --- src/py/flwr/common/typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index f51830955679..04d2cf5bbf7f 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -194,3 +194,4 @@ class Run: run_id: int fab_id: str fab_version: str + override_config: Dict[str, str] From 06d5cb09b96b4ac80670b5b264c22a95aa25a5a2 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 14:59:13 +0200 Subject: [PATCH 4/8] Remove non-existing function call --- src/py/flwr/client/app.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 36757da18960..e00dcd4f174a 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -30,7 +30,6 @@ from flwr.client.typing import ClientFnExt from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address -from flwr.common.config import get_fused_config from flwr.common.constant import ( MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_ADAPTER, @@ -388,7 +387,6 @@ def _on_backoff(retry_state: RetryState) -> None: # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) - context.run_config = get_fused_config(run_info[run_id], flwr_dir) # Create an error reply message that will never be used to prevent # the used-before-assignment linting error From b228b5d0814fb2f3af4bccee47d8a08e5d875d26 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 15:03:31 +0200 Subject: [PATCH 5/8] Remove unused argument --- src/py/flwr/client/app.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e00dcd4f174a..1bc02362eacf 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -194,7 +194,6 @@ def _start_client_internal( max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, partition_id: Optional[int] = None, - flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -241,13 +240,6 @@ class `flwr.client.Client` (default: None) partition_id: Optional[int] (default: None) The data partition index associated with this node. Better suited for prototyping purposes. - flwr_dir: Optional[Path] (default: None) - The path containing installed Flower Apps. - By default, this value is equal to: - - - `$FLWR_HOME/` if `$FLWR_HOME` is defined - - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined - - `$HOME/.flwr/` in all other cases """ if insecure is None: insecure = root_certificates is None From 8898eae90f00d9ac11dd766e818b47a6f6f91b87 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 15:04:30 +0200 Subject: [PATCH 6/8] Remove unused import --- src/py/flwr/client/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 1bc02362eacf..646b75e2d0a2 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,7 +19,6 @@ import time from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN -from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec From 21c83a175b54b69a07b5d98a10f0b87ec4f66e1a Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 15:37:47 +0200 Subject: [PATCH 7/8] More consistent naming --- src/py/flwr/server/superlink/state/sqlite_state.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 49f40653750e..1b07b26f75ca 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -62,10 +62,10 @@ SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( - run_id INTEGER UNIQUE, - fab_id TEXT, - fab_version TEXT, - overrides TEXT + run_id INTEGER UNIQUE, + fab_id TEXT, + fab_version TEXT, + override_config TEXT ); """ @@ -630,7 +630,7 @@ def create_run( # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: query = ( - "INSERT INTO run (run_id, fab_id, fab_version, overrides)" + "INSERT INTO run (run_id, fab_id, fab_version, override_config)" "VALUES (?, ?, ?, ?);" ) self.query(query, (run_id, fab_id, fab_version, str(override_config))) @@ -700,7 +700,7 @@ def get_run(self, run_id: int) -> Optional[Run]: run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"], - override_config=literal_eval(row["overrides"]), + override_config=literal_eval(row["override_config"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") From e45ffea15124819e3c9d426f15590b7459e2f603 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:16:24 +0200 Subject: [PATCH 8/8] Use JSON for sqlite state --- src/py/flwr/server/superlink/state/sqlite_state.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 1b07b26f75ca..ea6f349b9f9a 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -15,10 +15,10 @@ """SQLite based implemenation of server state.""" +import json import re import sqlite3 import time -from ast import literal_eval from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -633,7 +633,9 @@ def create_run( "INSERT INTO run (run_id, fab_id, fab_version, override_config)" "VALUES (?, ?, ?, ?);" ) - self.query(query, (run_id, fab_id, fab_version, str(override_config))) + self.query( + query, (run_id, fab_id, fab_version, json.dumps(override_config)) + ) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -700,7 +702,7 @@ def get_run(self, run_id: int) -> Optional[Run]: run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"], - override_config=literal_eval(row["override_config"]), + override_config=json.loads(row["override_config"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.")