Skip to content

Commit

Permalink
feat(framework) Validate node IDs for TaskIns in `LinkState.store_t…
Browse files Browse the repository at this point in the history
…ask_ins` (#4378)
  • Loading branch information
panh99 authored Nov 5, 2024
1 parent 57e49fe commit 87bea16
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 34 deletions.
3 changes: 1 addition & 2 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@
def push_messages(driver: InMemoryDriver, num_nodes: int) -> tuple[Iterable[str], int]:
"""Help push messages to state."""
for _ in range(num_nodes):
driver.state.create_node(ping_interval=PING_MAX_INTERVAL)
node_id = driver.state.create_node(ping_interval=PING_MAX_INTERVAL)
num_messages = 3
node_id = 1
msgs = [
driver.create_message(RecordSet(), "message_type", node_id, "")
for _ in range(num_messages)
Expand Down
19 changes: 18 additions & 1 deletion src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,25 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
return None
# Validate run_id
if task_ins.run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
return None
# Validate source node ID
if task_ins.task.producer.node_id != 0:
log(
ERROR,
"Invalid source node ID for TaskIns: %s",
task_ins.task.producer.node_id,
)
return None
# Validate destination node ID
if not task_ins.task.consumer.anonymous:
if task_ins.task.consumer.node_id not in self.node_ids:
log(
ERROR,
"Invalid destination node ID for TaskIns: %s",
task_ins.task.consumer.node_id,
)
return None

# Create task_id
task_id = uuid4()
Expand Down
79 changes: 54 additions & 25 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,19 @@ def test_get_task_res_empty(self) -> None:
def test_store_task_ins_one(self) -> None:
"""Test store_task_ins."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

assert task_ins.task.created_at < time.time() # pylint: disable=no-member
assert task_ins.task.delivered_at == "" # pylint: disable=no-member

# Execute
state.store_task_ins(task_ins=task_ins)
task_ins_list = state.get_task_ins(node_id=consumer_node_id, limit=10)
task_ins_list = state.get_task_ins(node_id=node_id, limit=10)

# Assert
assert len(task_ins_list) == 1
Expand All @@ -224,20 +224,39 @@ def test_store_task_ins_one(self) -> None:
)
assert actual_task.ttl > 0

def test_store_task_ins_invalid_node_id(self) -> None:
"""Test store_task_ins with invalid node_id."""
# Prepare
state = self.state_factory()
node_id = state.create_node(1e3)
invalid_node_id = 61016 if node_id != 61016 else 61017
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=invalid_node_id, anonymous=False, run_id=run_id
)
task_ins2 = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins2.task.producer.node_id = 61016

# Execute and assert
assert state.store_task_ins(task_ins) is None
assert state.store_task_ins(task_ins2) is None

def test_store_and_delete_tasks(self) -> None:
"""Test delete_tasks."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins_0 = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins_1 = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins_2 = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

# Insert three TaskIns
Expand All @@ -250,11 +269,11 @@ def test_store_and_delete_tasks(self) -> None:
assert task_id_2

# Get TaskIns to mark them delivered
_ = state.get_task_ins(node_id=consumer_node_id, limit=None)
_ = state.get_task_ins(node_id=node_id, limit=None)

# Insert one TaskRes and retrive it to mark it as delivered
task_res_0 = create_task_res(
producer_node_id=consumer_node_id,
producer_node_id=node_id,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
Expand All @@ -265,7 +284,7 @@ def test_store_and_delete_tasks(self) -> None:

# Insert one TaskRes, but don't retrive it
task_res_1: TaskRes = create_task_res(
producer_node_id=consumer_node_id,
producer_node_id=node_id,
anonymous=False,
ancestry=[str(task_id_1)],
run_id=run_id,
Expand Down Expand Up @@ -332,8 +351,11 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None:
"""Store identity TaskIns and fail retrieving it as anonymous."""
# Prepare
state: LinkState = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id)
task_ins = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

# Execute
_ = state.store_task_ins(task_ins)
Expand All @@ -346,12 +368,15 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None:
"""Store identity TaskIns and retrieve it."""
# Prepare
state: LinkState = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id)
task_ins = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

# Execute
task_ins_uuid = state.store_task_ins(task_ins)
task_ins_list = state.get_task_ins(node_id=1, limit=None)
task_ins_list = state.get_task_ins(node_id=node_id, limit=None)

# Assert
assert len(task_ins_list) == 1
Expand All @@ -363,14 +388,17 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None:
"""Fail retrieving delivered task."""
# Prepare
state: LinkState = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id)
task_ins = create_task_ins(
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

# Execute
_ = state.store_task_ins(task_ins)

# 1st get: set to delivered
task_ins_list = state.get_task_ins(node_id=1, limit=None)
task_ins_list = state.get_task_ins(node_id=node_id, limit=None)

assert len(task_ins_list) == 1

Expand Down Expand Up @@ -874,11 +902,11 @@ def test_store_task_res_limit_ttl(self) -> None:
def test_get_task_ins_not_return_expired(self) -> None:
"""Test get_task_ins not to return expired tasks."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins.task.created_at = time.time() - 5
task_ins.task.ttl = 5.0
Expand All @@ -894,11 +922,11 @@ def test_get_task_ins_not_return_expired(self) -> None:
def test_get_task_res_not_return_expired(self) -> None:
"""Test get_task_res not to return TaskRes if its TaskIns is expired."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins.task.created_at = time.time() - 5
task_ins.task.ttl = 5.1
Expand Down Expand Up @@ -948,19 +976,19 @@ def test_get_task_res_return_if_not_expired(self) -> None:
"""Test get_task_res to return TaskRes if its TaskIns exists and is not
expired."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)
task_ins.task.created_at = time.time() - 5
task_ins.task.ttl = 7.1

task_id = state.store_task_ins(task_ins=task_ins)

task_res = create_task_res(
producer_node_id=1,
producer_node_id=node_id,
anonymous=False,
ancestry=[str(task_id)],
run_id=run_id,
Expand All @@ -980,17 +1008,18 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None:
"""Test store_task_res to fail if there is a mismatch between the
consumer_node_id of taskIns and the producer_node_id of taskRes."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
node_id = state.create_node(1e3)
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
consumer_node_id=node_id, anonymous=False, run_id=run_id
)

task_id = state.store_task_ins(task_ins=task_ins)

task_res = create_task_res(
producer_node_id=100, # different than consumer_node_id
# Different than consumer_node_id
producer_node_id=100 if node_id != 100 else 101,
anonymous=False,
ancestry=[str(task_id)],
run_id=run_id,
Expand Down
31 changes: 25 additions & 6 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
if any(errors):
log(ERROR, errors)
return None

# Create task_id
task_id = uuid4()

Expand All @@ -284,16 +283,36 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
)

# Validate run_id
query = "SELECT run_id FROM run WHERE run_id = ?;"
if not self.query(query, (data[0]["run_id"],)):
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
return None
# Validate source node ID
if task_ins.task.producer.node_id != 0:
log(
ERROR,
"Invalid source node ID for TaskIns: %s",
task_ins.task.producer.node_id,
)
return None
# Validate destination node ID
query = "SELECT node_id FROM node WHERE node_id = ?;"
if not task_ins.task.consumer.anonymous:
if not self.query(query, (data[0]["consumer_node_id"],)):
log(
ERROR,
"Invalid destination node ID for TaskIns: %s",
task_ins.task.consumer.node_id,
)
return None

columns = ", ".join([f":{key}" for key in data[0]])
query = f"INSERT INTO task_ins VALUES({columns});"

# Only invalid run_id can trigger IntegrityError.
# This may need to be changed in the future version with more integrity checks.
try:
self.query(query, data)
except sqlite3.IntegrityError:
log(ERROR, "`run` is invalid")
return None
self.query(query, data)

return task_id

Expand Down

0 comments on commit 87bea16

Please sign in to comment.