diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 25e65e59cedc..cf77d110acab 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -25,7 +25,7 @@ import "flwr/proto/error.proto"; message Task { Node producer = 1; Node consumer = 2; - string created_at = 3; + double created_at = 3; string delivered_at = 4; double pushed_at = 5; double ttl = 6; diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 87014f436cf7..e5acbe0cc9d0 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -172,6 +172,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> and out_meta.reply_to_message == in_meta.message_id and out_meta.group_id == in_meta.group_id and out_meta.message_type == in_meta.message_type + and out_meta.created_at > in_meta.created_at ): return True return False diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index e3f6487421cc..2a510b291c49 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -15,6 +15,7 @@ """Client-side message handler tests.""" +import time import unittest import uuid from copy import copy @@ -169,7 +170,18 @@ def test_client_without_get_properties() -> None: ) assert actual_msg.content == expected_msg.content - assert actual_msg.metadata == expected_msg.metadata + # metadata.created_at will differ so let's exclude it from checks + attrs = actual_msg.metadata.__annotations__ + attrs_keys = list(attrs.keys()) + attrs_keys.remove("_created_at") + # metadata.created_at will differ so let's exclude it from checks + for attr in attrs_keys: + assert getattr(actual_msg.metadata, attr) == getattr( + expected_msg.metadata, attr + ) + + # Ensure the message created last has a higher timestamp + assert actual_msg.metadata.created_at < expected_msg.metadata.created_at def test_client_with_get_properties() -> None: @@ -222,7 +234,17 @@ def test_client_with_get_properties() -> None: ) assert actual_msg.content == expected_msg.content - assert actual_msg.metadata == expected_msg.metadata + attrs = actual_msg.metadata.__annotations__ + attrs_keys = list(attrs.keys()) + attrs_keys.remove("_created_at") + # metadata.created_at will differ so let's exclude it from checks + for attr in attrs_keys: + assert getattr(actual_msg.metadata, attr) == getattr( + expected_msg.metadata, attr + ) + + # Ensure the message created last has a higher timestamp + assert actual_msg.metadata.created_at < expected_msg.metadata.created_at class TestMessageValidation(unittest.TestCase): @@ -241,6 +263,11 @@ def setUp(self) -> None: ttl=DEFAULT_TTL, message_type="mock", ) + # We need to set created_at in this way + # since this `self.in_metadata` is used for tests + # without it ever being part of a Message + self.in_metadata.created_at = time.time() + self.valid_out_metadata = Metadata( run_id=123, message_id="", @@ -281,6 +308,10 @@ def test_invalid_message_run_id(self) -> None: value = 999 elif isinstance(value, str): value = "999" + elif isinstance(value, float): + if attr == "_created_at": + # make it be in 1h the past + value = value - 3600 setattr(invalid_metadata, attr, value) # Add to list invalid_metadata_list.append(invalid_metadata) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 25607179764d..6e0ab9149828 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -16,6 +16,7 @@ from __future__ import annotations +import time from dataclasses import dataclass from .record import RecordSet @@ -62,6 +63,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _ttl: float _message_type: str _partition_id: int | None + _created_at: float # Unix timestamp (in seconds) to be set upon message creation def __init__( # pylint: disable=too-many-arguments self, @@ -125,6 +127,16 @@ def group_id(self, value: str) -> None: """Set group_id.""" self._group_id = value + @property + def created_at(self) -> float: + """Unix timestamp when the message was created.""" + return self._created_at + + @created_at.setter + def created_at(self, value: float) -> None: + """Set creation timestamp for this messages.""" + self._created_at = value + @property def ttl(self) -> float: """Time-to-live for this message.""" @@ -214,6 +226,9 @@ def __init__( ) -> None: self._metadata = metadata + # Set message creation timestamp + self._metadata.created_at = time.time() + if not (content is None) ^ (error is None): raise ValueError("Either `content` or `error` must be set, but not both.") diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index ba628bb3235a..cd5a7d72272f 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -14,7 +14,7 @@ # ============================================================================== """Message tests.""" - +import time from contextlib import ExitStack from typing import Any, Callable @@ -62,12 +62,16 @@ def test_message_creation( if context: stack.enter_context(context) - _ = Message( + current_time = time.time() + message = Message( metadata=metadata, content=None if content_fn is None else content_fn(maker), error=None if error_fn is None else error_fn(0), ) + assert message.metadata.created_at > current_time + assert message.metadata.created_at < time.time() + def create_message_with_content() -> Message: """Create a Message with content.""" diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 6c7a077d2f9f..84932b806aff 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -575,6 +575,7 @@ def message_to_taskins(message: Message) -> TaskIns: task=Task( producer=Node(node_id=0, anonymous=True), # Assume driver node consumer=Node(node_id=md.dst_node_id, anonymous=False), + created_at=md.created_at, ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, @@ -601,7 +602,7 @@ def message_from_taskins(taskins: TaskIns) -> Message: ) # Construct Message - return Message( + message = Message( metadata=metadata, content=( recordset_from_proto(taskins.task.recordset) @@ -614,6 +615,8 @@ def message_from_taskins(taskins: TaskIns) -> Message: else None ), ) + message.metadata.created_at = taskins.task.created_at + return message def message_to_taskres(message: Message) -> TaskRes: @@ -626,6 +629,7 @@ def message_to_taskres(message: Message) -> TaskRes: task=Task( producer=Node(node_id=md.src_node_id, anonymous=False), consumer=Node(node_id=0, anonymous=True), # Assume driver node + created_at=md.created_at, ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, @@ -652,7 +656,7 @@ def message_from_taskres(taskres: TaskRes) -> Message: ) # Construct the Message - return Message( + message = Message( metadata=metadata, content=( recordset_from_proto(taskres.task.recordset) @@ -665,3 +669,5 @@ def message_from_taskres(taskres: TaskRes) -> Message: else None ), ) + message.metadata.created_at = taskres.task.created_at + return message diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 3546f01efded..5f6e9e7be583 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 8f0549ceddc9..455791ac9e6e 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -30,7 +30,7 @@ class Task(google.protobuf.message.Message): def producer(self) -> flwr.proto.node_pb2.Node: ... @property def consumer(self) -> flwr.proto.node_pb2.Node: ... - created_at: typing.Text + created_at: builtins.float delivered_at: typing.Text pushed_at: builtins.float ttl: builtins.float @@ -45,7 +45,7 @@ class Task(google.protobuf.message.Message): *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., - created_at: typing.Text = ..., + created_at: builtins.float = ..., delivered_at: typing.Text = ..., pushed_at: builtins.float = ..., ttl: builtins.float = ..., diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 99ba50d3e2d1..7fdc07d620f2 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -132,6 +132,13 @@ def _send_receive_recordset( ttl=DEFAULT_TTL, ), ) + + # This would normally be recorded upon common.Message creation + # but this compatibility stack doesn't create Messages, + # so we need to inject `created_at` manually (needed for + # taskins validation by server.utils.validator) + task_ins.task.created_at = time.time() + push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 task_ins_list=[task_ins] ) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index cba4ab98a6d5..6fc57707ac36 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -18,7 +18,6 @@ import os import threading import time -from datetime import datetime from logging import ERROR from typing import Dict, List, Optional, Set, Tuple from uuid import UUID, uuid4 @@ -52,13 +51,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id and created_at + # Create task_id task_id = uuid4() - created_at: datetime = now() # Store TaskIns task_ins.task_id = str(task_id) - task_ins.task.created_at = created_at.isoformat() with self.lock: self.task_ins_store[task_id] = task_ins @@ -113,13 +110,11 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id and created_at + # Create task_id task_id = uuid4() - created_at: datetime = now() # Store TaskRes task_res.task_id = str(task_id) - task_res.task.created_at = created_at.isoformat() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index e1c1215000b9..6996d51d2a9b 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -19,7 +19,6 @@ import re import sqlite3 import time -from datetime import datetime from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -59,7 +58,7 @@ producer_node_id INTEGER, consumer_anonymous BOOLEAN, consumer_node_id INTEGER, - created_at TEXT, + created_at REAL, delivered_at TEXT, pushed_at REAL, ttl REAL, @@ -80,7 +79,7 @@ producer_node_id INTEGER, consumer_anonymous BOOLEAN, consumer_node_id INTEGER, - created_at TEXT, + created_at REAL, delivered_at TEXT, pushed_at REAL, ttl REAL, @@ -195,13 +194,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id and created_at + # Create task_id task_id = uuid4() - created_at: datetime = now() # Store TaskIns task_ins.task_id = str(task_id) - task_ins.task.created_at = created_at.isoformat() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -330,11 +327,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Create task_id task_id = uuid4() - created_at: datetime = now() # Store TaskIns task_res.task_id = str(task_id) - task_res.task.created_at = created_at.isoformat() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 6ab511d3f847..1757cfac4255 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -74,7 +74,7 @@ def test_store_task_ins_one(self) -> None: consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) - assert task_ins.task.created_at == "" # pylint: disable=no-member + assert task_ins.task.created_at < time.time() # pylint: disable=no-member assert task_ins.task.delivered_at == "" # pylint: disable=no-member # Execute @@ -91,12 +91,9 @@ def test_store_task_ins_one(self) -> None: actual_task = actual_task_ins.task - assert actual_task.created_at != "" assert actual_task.delivered_at != "" - assert datetime.fromisoformat(actual_task.created_at) > datetime( - 2020, 1, 1, tzinfo=timezone.utc - ) + assert actual_task.created_at < actual_task.pushed_at assert datetime.fromisoformat(actual_task.delivered_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc ) @@ -439,6 +436,7 @@ def create_task_ins( task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, + created_at=time.time(), ), ) task.task.pushed_at = time.time() @@ -463,6 +461,7 @@ def create_task_res( task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, + created_at=time.time(), ), ) task_res.task.pushed_at = time.time() diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index d8b287b0f674..c0b0ec85761c 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -32,8 +32,13 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("`task` does not set field `task`") # Created/delivered/TTL/Pushed - if tasks_ins_res.task.created_at != "": - validation_errors.append("`created_at` must be an empty str") + if ( + tasks_ins_res.task.created_at < 1711497600.0 + ): # unix timestamp of 27 March 2024 00h:00m:00s UTC + validation_errors.append( + "`created_at` must be a float that records the unix timestamp " + "in seconds when the message was created." + ) if tasks_ins_res.task.delivered_at != "": validation_errors.append("`delivered_at` must be an empty str") if tasks_ins_res.task.ttl <= 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index c896af998bea..61fe094c23d4 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -99,6 +99,7 @@ def create_task_ins( task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, + created_at=time.time(), ), ) @@ -123,6 +124,7 @@ def create_task_res( task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, + created_at=time.time(), ), )