Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add created_at to Metadata #3174

Merged
merged 22 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import "flwr/proto/error.proto";
message Task {
Node producer = 1;
Node consumer = 2;
string created_at = 3;
double created_at = 3;
string delivered_at = 4;
double pushed_at = 5;
double ttl = 6;
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
and out_meta.reply_to_message == in_meta.message_id
and out_meta.group_id == in_meta.group_id
and out_meta.message_type == in_meta.message_type
and out_meta.created_at > in_meta.created_at
):
return True
return False
35 changes: 33 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Client-side message handler tests."""


import time
import unittest
import uuid
from copy import copy
Expand Down Expand Up @@ -169,7 +170,18 @@ def test_client_without_get_properties() -> None:
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata == expected_msg.metadata
# metadata.created_at will differ so let's exclude it from checks
attrs = actual_msg.metadata.__annotations__
attrs_keys = list(attrs.keys())
attrs_keys.remove("_created_at")
# metadata.created_at will differ so let's exclude it from checks
for attr in attrs_keys:
assert getattr(actual_msg.metadata, attr) == getattr(
expected_msg.metadata, attr
)

# Ensure the message created last has a higher timestamp
assert actual_msg.metadata.created_at < expected_msg.metadata.created_at


def test_client_with_get_properties() -> None:
Expand Down Expand Up @@ -222,7 +234,17 @@ def test_client_with_get_properties() -> None:
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata == expected_msg.metadata
attrs = actual_msg.metadata.__annotations__
attrs_keys = list(attrs.keys())
attrs_keys.remove("_created_at")
# metadata.created_at will differ so let's exclude it from checks
for attr in attrs_keys:
assert getattr(actual_msg.metadata, attr) == getattr(
expected_msg.metadata, attr
)

# Ensure the message created last has a higher timestamp
assert actual_msg.metadata.created_at < expected_msg.metadata.created_at


class TestMessageValidation(unittest.TestCase):
Expand All @@ -241,6 +263,11 @@ def setUp(self) -> None:
ttl=DEFAULT_TTL,
message_type="mock",
)
# We need to set created_at in this way
# since this `self.in_metadata` is used for tests
# without it ever being part of a Message
self.in_metadata.created_at = time.time()

self.valid_out_metadata = Metadata(
run_id=123,
message_id="",
Expand Down Expand Up @@ -281,6 +308,10 @@ def test_invalid_message_run_id(self) -> None:
value = 999
elif isinstance(value, str):
value = "999"
elif isinstance(value, float):
if attr == "_created_at":
# make it be in 1h the past
value = value - 3600
setattr(invalid_metadata, attr, value)
# Add to list
invalid_metadata_list.append(invalid_metadata)
Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import time
from dataclasses import dataclass

from .record import RecordSet
Expand Down Expand Up @@ -62,6 +63,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
_ttl: float
_message_type: str
_partition_id: int | None
_created_at: float # Unix timestamp (in seconds) to be set upon message creation.

def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -125,6 +127,16 @@ def group_id(self, value: str) -> None:
"""Set group_id."""
self._group_id = value

@property
def created_at(self) -> float:
"""Unix timestamp when the message was created."""
return self._created_at

@created_at.setter
def created_at(self, value: float) -> None:
"""Set creation timestamp for this messages."""
self._created_at = value

@property
def ttl(self) -> float:
"""Time-to-live for this message."""
Expand Down Expand Up @@ -214,6 +226,9 @@ def __init__(
) -> None:
self._metadata = metadata

# Set message creation timestamp
self._metadata.created_at = time.time()

if not (content is None) ^ (error is None):
raise ValueError("Either `content` or `error` must be set, but not both.")

Expand Down
7 changes: 5 additions & 2 deletions src/py/flwr/common/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Message tests."""


import time
from contextlib import ExitStack
from typing import Any, Callable

Expand Down Expand Up @@ -62,12 +62,15 @@ def test_message_creation(
if context:
stack.enter_context(context)

_ = Message(
message = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)

assert message.metadata.created_at > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could record the timestamp before the message gets created and assert that created_at is greater.

assert message.metadata.created_at < time.time()


def create_message_with_content() -> Message:
"""Create a Message with content."""
Expand Down
11 changes: 9 additions & 2 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def message_to_taskins(message: Message) -> TaskIns:
task=Task(
producer=Node(node_id=0, anonymous=True), # Assume driver node
consumer=Node(node_id=md.dst_node_id, anonymous=False),
created_at=md.created_at,
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
Expand All @@ -601,7 +602,7 @@ def message_from_taskins(taskins: TaskIns) -> Message:
)

# Construct Message
return Message(
message = Message(
metadata=metadata,
content=(
recordset_from_proto(taskins.task.recordset)
Expand All @@ -614,6 +615,8 @@ def message_from_taskins(taskins: TaskIns) -> Message:
else None
),
)
message.metadata.created_at = taskins.task.created_at
return message


def message_to_taskres(message: Message) -> TaskRes:
Expand All @@ -626,6 +629,7 @@ def message_to_taskres(message: Message) -> TaskRes:
task=Task(
producer=Node(node_id=md.src_node_id, anonymous=False),
consumer=Node(node_id=0, anonymous=True), # Assume driver node
created_at=md.created_at,
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
Expand All @@ -652,7 +656,7 @@ def message_from_taskres(taskres: TaskRes) -> Message:
)

# Construct the Message
return Message(
message = Message(
metadata=metadata,
content=(
recordset_from_proto(taskres.task.recordset)
Expand All @@ -665,3 +669,6 @@ def message_from_taskres(taskres: TaskRes) -> Message:
else None
),
)

message.metadata.created_at = taskres.task.created_at
return message
2 changes: 1 addition & 1 deletion src/py/flwr/proto/task_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/py/flwr/proto/task_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Task(google.protobuf.message.Message):
def producer(self) -> flwr.proto.node_pb2.Node: ...
@property
def consumer(self) -> flwr.proto.node_pb2.Node: ...
created_at: typing.Text
created_at: builtins.float
delivered_at: typing.Text
pushed_at: builtins.float
ttl: builtins.float
Expand All @@ -45,7 +45,7 @@ class Task(google.protobuf.message.Message):
*,
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
consumer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
created_at: typing.Text = ...,
created_at: builtins.float = ...,
delivered_at: typing.Text = ...,
pushed_at: builtins.float = ...,
ttl: builtins.float = ...,
Expand Down
7 changes: 7 additions & 0 deletions src/py/flwr/server/compat/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def _send_receive_recordset(
ttl=DEFAULT_TTL,
),
)

# This would normally be recorded upon common.Message creation
# but this compatibility stack doesn't create Messages,
# so we need to inject `created_at` manually (needed for
# taskins validation by server.utils.validator)
task_ins.task.created_at = time.time()

push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
task_ins_list=[task_ins]
)
Expand Down
9 changes: 2 additions & 7 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import os
import threading
from datetime import datetime
from logging import ERROR
from typing import Dict, List, Optional, Set
from uuid import UUID, uuid4
Expand Down Expand Up @@ -50,13 +49,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
log(ERROR, "`run_id` is invalid")
return None

# Create task_id and created_at
# Create task_id
task_id = uuid4()
created_at: datetime = now()

# Store TaskIns
task_ins.task_id = str(task_id)
task_ins.task.created_at = created_at.isoformat()
with self.lock:
self.task_ins_store[task_id] = task_ins

Expand Down Expand Up @@ -111,13 +108,11 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
log(ERROR, "`run_id` is invalid")
return None

# Create task_id and created_at
# Create task_id
task_id = uuid4()
created_at: datetime = now()

# Store TaskRes
task_res.task_id = str(task_id)
task_res.task.created_at = created_at.isoformat()
with self.lock:
self.task_res_store[task_id] = task_res

Expand Down
11 changes: 3 additions & 8 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
import re
import sqlite3
from datetime import datetime
from logging import DEBUG, ERROR
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
from uuid import UUID, uuid4
Expand Down Expand Up @@ -52,7 +51,7 @@
producer_node_id INTEGER,
consumer_anonymous BOOLEAN,
consumer_node_id INTEGER,
created_at TEXT,
created_at REAL,
delivered_at TEXT,
pushed_at REAL,
ttl REAL,
Expand All @@ -73,7 +72,7 @@
producer_node_id INTEGER,
consumer_anonymous BOOLEAN,
consumer_node_id INTEGER,
created_at TEXT,
created_at REAL,
delivered_at TEXT,
pushed_at REAL,
ttl REAL,
Expand Down Expand Up @@ -187,13 +186,11 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
log(ERROR, errors)
return None

# Create task_id and created_at
# Create task_id
task_id = uuid4()
created_at: datetime = now()

# Store TaskIns
task_ins.task_id = str(task_id)
task_ins.task.created_at = created_at.isoformat()
data = (task_ins_to_dict(task_ins),)
columns = ", ".join([f":{key}" for key in data[0]])
query = f"INSERT INTO task_ins VALUES({columns});"
Expand Down Expand Up @@ -322,11 +319,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:

# Create task_id
task_id = uuid4()
created_at: datetime = now()

# Store TaskIns
task_res.task_id = str(task_id)
task_res.task.created_at = created_at.isoformat()
data = (task_res_to_dict(task_res),)
columns = ", ".join([f":{key}" for key in data[0]])
query = f"INSERT INTO task_res VALUES({columns});"
Expand Down
9 changes: 4 additions & 5 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_store_task_ins_one(self) -> None:
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
)

assert task_ins.task.created_at == "" # pylint: disable=no-member
assert task_ins.task.created_at < time.time() # pylint: disable=no-member
assert task_ins.task.delivered_at == "" # pylint: disable=no-member

# Execute
Expand All @@ -90,12 +90,9 @@ def test_store_task_ins_one(self) -> None:

actual_task = actual_task_ins.task

assert actual_task.created_at != ""
assert actual_task.delivered_at != ""

assert datetime.fromisoformat(actual_task.created_at) > datetime(
2020, 1, 1, tzinfo=timezone.utc
)
assert actual_task.created_at < actual_task.pushed_at
assert datetime.fromisoformat(actual_task.delivered_at) > datetime(
2020, 1, 1, tzinfo=timezone.utc
)
Expand Down Expand Up @@ -419,6 +416,7 @@ def create_task_ins(
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
ttl=DEFAULT_TTL,
created_at=time.time(),
),
)
task.task.pushed_at = time.time()
Expand All @@ -443,6 +441,7 @@ def create_task_res(
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
ttl=DEFAULT_TTL,
created_at=time.time(),
),
)
task_res.task.pushed_at = time.time()
Expand Down
9 changes: 7 additions & 2 deletions src/py/flwr/server/utils/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
validation_errors.append("`task` does not set field `task`")

# Created/delivered/TTL/Pushed
if tasks_ins_res.task.created_at != "":
validation_errors.append("`created_at` must be an empty str")
if (
tasks_ins_res.task.created_at < 1711497600.0
): # unix timestamp of 27 March 2024 00h:00m:00s UTC
validation_errors.append(
"`created_at` must be a float that records the unix timestamp "
"in seconds when the message was created."
)
if tasks_ins_res.task.delivered_at != "":
validation_errors.append("`delivered_at` must be an empty str")
if tasks_ins_res.task.ttl <= 0:
Expand Down
Loading