Skip to content

Commit

Permalink
amend sqlitestate
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Sep 17, 2024
1 parent 60ff84b commit 1fcb722
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 24 deletions.
21 changes: 14 additions & 7 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from flwr.common.recordset_compat import getpropertiesins_to_recordset
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.common.typing import Run, RunStatus
from flwr.server.superlink.fleet.vce.vce_api import (
NodeToPartitionMapping,
_register_nodes,
Expand Down Expand Up @@ -113,12 +113,19 @@ def register_messages_into_state(
) -> dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
state.run_ids[run_id] = (
Run(
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
),
RunStatus(
phase="starting",
result="",
reason="",
),
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
Expand Down
14 changes: 10 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,19 @@ def create_run(
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run = Run(
run_id=run_id,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
)
initial_status = RunStatus(
phase="starting",
result="",
reason="",
)
self.run_ids[run_id] = (run, initial_status)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -343,7 +349,7 @@ def get_run(self, run_id: int) -> Optional[Run]:
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return None
return self.run_ids[run_id]
return self.run_ids[run_id][0]

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Get the status of the run with the specified `run_id`."""
Expand Down Expand Up @@ -375,11 +381,11 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:

# Check if the result is valid
if not has_valid_result(status):
log(ERROR, 'Invalid run status: "%s:"', status.phase, status.result)
log(ERROR, 'Invalid run status: "%s:%s"', status.phase, status.result)
return False

# Update the status
self.run_ids[run_id][1] = new_status
self.run_ids[run_id] = (self.run_ids[run_id][0], new_status)
return True

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
Expand Down
83 changes: 70 additions & 13 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import Run, UserConfig
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.utils.validator import validate_task_ins_or_res

from .state import State
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
from .utils import (
generate_rand_int_from_bytes,
has_valid_result,
is_valid_transition,
make_node_unavailable_taskres,
)

SQL_CREATE_TABLE_NODE = """
CREATE TABLE IF NOT EXISTS node(
Expand Down Expand Up @@ -67,7 +72,10 @@
fab_id TEXT,
fab_version TEXT,
fab_hash TEXT,
override_config TEXT
override_config TEXT,
status_phase TEXT,
status_result TEXT,
status_reason TEXT
);
"""

Expand Down Expand Up @@ -634,18 +642,15 @@ def create_run(
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = (
"INSERT INTO run "
"(run_id, fab_id, fab_version, fab_hash, override_config)"
"VALUES (?, ?, ?, ?, ?);"
"(run_id, fab_id, fab_version, fab_hash, override_config, "
"status_phase, status_result, status_reason)"
"VALUES (?, ?, ?, ?, ?, ?, ?, ?);"
)
if fab_hash:
self.query(
query, (run_id, "", "", fab_hash, json.dumps(override_config))
)
else:
self.query(
query,
(run_id, fab_id, fab_version, "", json.dumps(override_config)),
)
fab_id, fab_version = "", ""
data = [run_id, fab_id, fab_version, fab_hash, json.dumps(override_config)]
data += ["starting", "", ""]
self.query(query, tuple(data))
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -719,6 +724,58 @@ def get_run(self, run_id: int) -> Optional[Run]:
log(ERROR, "`run_id` does not exist.")
return None

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Get the status of the run with the specified `run_id`."""
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
rows = self.query(query, tuple(run_ids))

return {
row["run_id"]: RunStatus(
phase=row["status_phase"],
result=row["status_result"],
reason=row["status_reason"],
)
for row in rows
}

def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
query = "SELECT * FROM run WHERE run_id = ?;"
rows = self.query(query, (run_id,))

# Check if the run_id exists
if not rows:
log(ERROR, "`run_id` is invalid")
return False

# Check if the status transition is valid
row = rows[0]
status = RunStatus(
phase=row["status_phase"],
result=row["status_result"],
reason=row["status_reason"],
)
if not is_valid_transition(status, new_status):
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
status.phase,
new_status.phase,
)
return False

# Check if the result is valid
if not has_valid_result(status):
log(ERROR, 'Invalid run status: "%s:%s"', status.phase, status.result)
return False

# Update the status
query = "UPDATE run SET status_phase = ?, status_result = ?, status_reason = ? "
query += "WHERE run_id = ?;"
data = (new_status.phase, new_status.result, new_status.reason, run_id)
self.query(query, data)
return True

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
# Update `online_until` and `ping_interval` for the given `node_id`
Expand Down
59 changes: 59 additions & 0 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
private_key_to_bytes,
public_key_to_bytes,
)
from flwr.common.typing import RunStatus
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
Expand Down Expand Up @@ -62,6 +63,64 @@ def test_create_and_get_run(self) -> None:
assert run.fab_hash == "9f86d08"
assert run.override_config["test_key"] == "test_value"

def test_get_and_update_run_status(self) -> None:
"""Test if get_run_status and update_run_status work correctly."""
# Prepare
state = self.state_factory()
run_id1 = state.create_run(None, None, "9f86d08", {"test_key": "test_value"})
run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"})
state.update_run_status(run_id2, RunStatus("running", "", ""))

# Execute
run_status_dict = state.get_run_status({run_id1, run_id2})
status1 = run_status_dict[run_id1]
status2 = run_status_dict[run_id2]

# Assert
assert status1.phase == "starting"
assert status2.phase == "running"

def test_status_transition_valid(self) -> None:
"""Test valid run status transactions."""
# Prepare
state = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"})

# Execute and assert
status1 = state.get_run_status({run_id})[run_id]
assert state.update_run_status(run_id, RunStatus("running", "", ""))
status2 = state.get_run_status({run_id})[run_id]
assert state.update_run_status(
run_id, RunStatus("finished", "failed", "mock failure")
)
status3 = state.get_run_status({run_id})[run_id]

assert status1.phase == "starting"
assert status2.phase == "running"
assert status3.phase == "finished"

def test_status_transition_invalid(self) -> None:
"""Test invalid run status transitions."""
# Prepare
state = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"})

# Execute and assert
# Cannot transition to "starting" or "finished" from "starting"
assert not state.update_run_status(run_id, RunStatus("starting", "", ""))
assert not state.update_run_status(
run_id, RunStatus("finished", "completed", "")
)
state.update_run_status(run_id, RunStatus("running", "", ""))
# Cannot transition to "starting" or "running" from "running"
assert not state.update_run_status(run_id, RunStatus("starting", "", ""))
assert not state.update_run_status(run_id, RunStatus("running", "", ""))
state.update_run_status(run_id, RunStatus("finished", "completed", ""))
# Cannot transition to any status from "finished"
assert not state.update_run_status(run_id, RunStatus("starting", "", ""))
assert not state.update_run_status(run_id, RunStatus("running", "", ""))
assert not state.update_run_status(run_id, RunStatus("finished", "failed", ""))

def test_get_task_ins_empty(self) -> None:
"""Validate that a new state has no TaskIns."""
# Prepare
Expand Down

0 comments on commit 1fcb722

Please sign in to comment.