Skip to content

Commit

Permalink
update states
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Sep 23, 2024
1 parent 85579df commit 68b3154
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
23 changes: 14 additions & 9 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,15 @@ def get_run(self, run_id: int) -> Optional[Run]:
return self.run_ids[run_id][0]

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the status information for the specified runs."""
"""Retrieve the statuses for the specified runs."""
with self.lock:
return {
run_id: self.run_ids[run_id][1]
for run_id in run_ids
if run_id in self.run_ids
}

def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
with self.lock:
# Check if the run_id exists
Expand All @@ -369,23 +369,28 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
return False

# Check if the status transition is valid
info = self.run_ids[run_id][1]
if not is_valid_transition(info, new_status_info):
status = self.run_ids[run_id][1]
if not is_valid_transition(status, new_status):
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
info.status,
new_status_info.status,
status.status,
new_status.status,
)
return False

# Check if the sub-status is valid
if not has_valid_sub_status(info):
log(ERROR, 'Invalid run status: "%s:%s"', info.status, info.sub_status)
if not has_valid_sub_status(status):
log(
ERROR,
'Invalid run status: "%s:%s"',
status.status,
status.sub_status,
)
return False

# Update the status
self.run_ids[run_id] = (self.run_ids[run_id][0], new_status_info)
self.run_ids[run_id] = (self.run_ids[run_id][0], new_status)
return True

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
Expand Down
24 changes: 12 additions & 12 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
override_config TEXT,
status TEXT,
sub_status TEXT,
reason TEXT
details TEXT
);
"""

Expand Down Expand Up @@ -706,7 +706,7 @@ def create_run(
query = (
"INSERT INTO run "
"(run_id, fab_id, fab_version, fab_hash, override_config, "
"status, sub_status, reason)"
"status, sub_status, details)"
"VALUES (?, ?, ?, ?, ?, ?, ?, ?);"
)
if fab_hash:
Expand Down Expand Up @@ -790,20 +790,20 @@ def get_run(self, run_id: int) -> Optional[Run]:
return None

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the status information for the specified runs."""
"""Retrieve the statuses for the specified runs."""
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
rows = self.query(query, tuple(run_ids))

return {
row["run_id"]: RunStatus(
status=row["status"],
sub_status=row["sub_status"],
details=row["reason"],
details=row["details"],
)
for row in rows
}

def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
query = "SELECT * FROM run WHERE run_id = ?;"
rows = self.query(query, (run_id,))
Expand All @@ -818,14 +818,14 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
status = RunStatus(
status=row["status"],
sub_status=row["sub_status"],
details=row["reason"],
details=row["details"],
)
if not is_valid_transition(status, new_status_info):
if not is_valid_transition(status, new_status):
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
status.status,
new_status_info.status,
new_status.status,
)
return False

Expand All @@ -835,12 +835,12 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
return False

# Update the status
query = "UPDATE run SET status= ?, sub_status = ?, reason = ? "
query = "UPDATE run SET status= ?, sub_status = ?, details = ? "
query += "WHERE run_id = ?;"
data = (
new_status_info.status,
new_status_info.sub_status,
new_status_info.details,
new_status.status,
new_status.sub_status,
new_status.details,
run_id,
)
self.query(query, data)
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_run(self, run_id: int) -> Optional[Run]:

@abc.abstractmethod
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the status information for the specified runs.
"""Retrieve the statuses for the specified runs.
Parameters
----------
Expand All @@ -195,8 +195,8 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
Returns
-------
dict[int, StatusInfo]
A dictionary mapping each valid run ID to its corresponding `StatusInfo`.
dict[int, RunStatus]
A dictionary mapping each valid run ID to its corresponding status.
Notes
-----
Expand All @@ -205,14 +205,14 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""

@abc.abstractmethod
def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool:
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run.
new_status_info : StatusInfo
new_status_info : RunStatus
The new status info to be assigned to the run.
Returns
Expand Down

0 comments on commit 68b3154

Please sign in to comment.