From 33a9d083304db0d57951f579aac9bba11b358707 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 15 Aug 2024 18:52:44 +0200 Subject: [PATCH] Add message_handler --- .../fleet/message_handler/message_handler.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 30865f04d373..80ac1c406756 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -19,7 +19,9 @@ from typing import List, Optional from uuid import UUID -from flwr.common.serde import user_config_to_proto +from flwr.common.serde import fab_to_proto, user_config_to_proto +from flwr.common.typing import Fab +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -40,6 +42,7 @@ Run, ) from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.ffs.ffs import Ffs from flwr.server.superlink.state import State @@ -124,5 +127,14 @@ def get_run( fab_id=run.fab_id, fab_version=run.fab_version, override_config=user_config_to_proto(run.override_config), + fab_hash=run.fab_hash, ) ) + + +def get_fab( + request: GetFabRequest, ffs: Ffs # pylint: disable=W0613 +) -> GetFabResponse: + """Get FAB.""" + fab = Fab(request.hash_str, ffs.get(request.hash_str)[0]) + return GetFabResponse(fab=fab_to_proto(fab))