Skip to content

Commit

Permalink
Merge pull request #6445 from MetRonnie/flow-nums
Browse files Browse the repository at this point in the history
Ensure `cylc trigger` does not fall back to `flow=none` by default
  • Loading branch information
oliver-sanders authored Oct 24, 2024
2 parents 576170c + 5f8a1f4 commit b702dbb
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 28 deletions.
1 change: 1 addition & 0 deletions changes.d/6445.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure `cylc trigger` does not fall back to `flow=none` when there are no active flows.
30 changes: 22 additions & 8 deletions cylc/flow/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,24 @@
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Optional,
Union
Union,
)

from cylc.flow import LOG
from cylc.flow.exceptions import PlatformLookupError
from cylc.flow.util import deserialise_set
import cylc.flow.flags
from cylc.flow.util import (
deserialise_set,
serialise_set,
)


if TYPE_CHECKING:
from pathlib import Path

from cylc.flow.flow_mgr import FlowNums


Expand Down Expand Up @@ -800,13 +805,22 @@ def select_prev_instances(
)
]

def select_latest_flow_nums(self):
def select_latest_flow_nums(self) -> Optional['FlowNums']:
"""Return a list of the most recent previous flow numbers."""
stmt = rf'''
SELECT flow_nums, MAX(time_created) FROM {self.TABLE_TASK_STATES}
''' # nosec (table name is code constant)
flow_nums_str = list(self.connect().execute(stmt))[0][0]
return deserialise_set(flow_nums_str)
SELECT
flow_nums, MAX(time_created)
FROM
{self.TABLE_TASK_STATES}
WHERE
flow_nums != ?
''' # nosec B608 (table name is code constant)
# Exclude flow=none:
params = [serialise_set()]
flow_nums_str = self.connect().execute(stmt, params).fetchone()[0]
if flow_nums_str:
return deserialise_set(flow_nums_str)
return None

def select_task_outputs(
self, name: str, point: str
Expand Down
17 changes: 9 additions & 8 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,19 +2051,20 @@ def _set_prereqs_tdef(
if self._set_prereqs_itask(itask, prereqs, flow_nums):
self.add_to_pool(itask)

def _get_active_flow_nums(self) -> Set[int]:
"""Return active flow numbers.
def _get_active_flow_nums(self) -> 'FlowNums':
"""Return all active flow numbers.
If there are no active flows (e.g. on restarting a completed workflow)
return the most recent active flows.
Or, if there are no flows in the workflow history (e.g. after
`cylc remove`), return flow=1.
"""
fnums = set()
for itask in self.get_tasks():
fnums.update(itask.flow_nums)
if not fnums:
fnums = self.workflow_db_mgr.pri_dao.select_latest_flow_nums()
return fnums
return (
set().union(*(itask.flow_nums for itask in self.get_tasks()))
or self.workflow_db_mgr.pri_dao.select_latest_flow_nums()
or {1}
)

def remove_tasks(self, items):
"""Remove tasks from the pool (forced by command)."""
Expand Down
15 changes: 10 additions & 5 deletions cylc/flow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)
Expand Down Expand Up @@ -148,17 +149,21 @@ def cli_format(cmd: List[str]):
return ' '.join(cmd)


def serialise_set(flow_nums: set) -> str:
def serialise_set(flow_nums: Optional[set] = None) -> str:
"""Convert set to json, sorted.
For use when a sorted result is needed for consistency.
Example:
>>> serialise_set({'3','2'})
'["2", "3"]'
Examples:
>>> serialise_set({'b', 'a'})
'["a", "b"]'
>>> serialise_set({3, 2})
'[2, 3]'
>>> serialise_set()
'[]'
"""
return json.dumps(sorted(flow_nums))
return json.dumps(sorted(flow_nums or ()))


def deserialise_set(flow_num_str: str) -> set:
Expand Down
68 changes: 61 additions & 7 deletions tests/unit/test_rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import contextlib
import os
import json
from pathlib import Path
import sqlite3
from types import SimpleNamespace
from typing import (
List,
Optional,
Tuple,
)
import unittest
from unittest import mock
from tempfile import mktemp
from types import SimpleNamespace

import pytest

from cylc.flow.exceptions import PlatformLookupError
from cylc.flow.flow_mgr import FlowNums
from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.util import serialise_set


GLOBAL_CONFIG = """
Expand Down Expand Up @@ -93,10 +99,8 @@ def test_select_task_job_sqlite_error(self):
@contextlib.contextmanager
def create_temp_db():
"""Create and tidy a temporary database for testing purposes."""
temp_db = mktemp()
conn = sqlite3.connect(temp_db)
yield (temp_db, conn)
os.remove(temp_db)
conn = sqlite3.connect(':memory:')
yield conn
conn.close() # doesn't raise error on re-invocation


Expand Down Expand Up @@ -175,3 +179,53 @@ def callback(index, row):
match='not defined.*\n.*foo.*\n.*bar'
):
dao.select_task_pool_for_restart(callback)


@pytest.mark.parametrize(
'values, expected',
[
pytest.param(
[
({1, 2}, '2021-01-01T00:00:00'),
({3, 4}, '2021-01-01T00:00:02'),
({5, 6}, '2021-01-01T00:00:01'),
],
{3, 4},
id="basic"
),
pytest.param(
[
({2}, '2021-01-01T00:00:00'),
(set(), '2021-01-01T00:00:01'),
(set(), '2021-01-01T00:00:02'),
],
{2},
id="ignore flow=none"
),
pytest.param(
[
(set(), '2021-01-01T00:00:01'),
(set(), '2021-01-01T00:00:02'),
],
None,
id="all flow=none"
),
],
)
def test_select_latest_flow_nums(
values: List[Tuple[FlowNums, str]], expected: Optional[FlowNums]
):
with CylcWorkflowDAO(':memory:') as dao:
conn = dao.connect()
conn.execute(
"CREATE TABLE task_states (flow_nums TEXT, time_created TEXT)"
)
for (fnums, timestamp) in values:
conn.execute(
"INSERT INTO task_states VALUES ("
f"{json.dumps(serialise_set(fnums))}, {json.dumps(timestamp)}"
")"
)
conn.commit()

assert dao.select_latest_flow_nums() == expected
42 changes: 42 additions & 0 deletions tests/unit/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,45 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


from typing import List
from unittest.mock import Mock

import pytest

from cylc.flow.flow_mgr import FlowNums
from cylc.flow.task_pool import TaskPool


@pytest.mark.parametrize('pool, db_fnums, expected', [
pytest.param(
[{1, 2}, {2, 3}],
{5, 6},
{1, 2, 3},
id="all-active"
),
pytest.param(
[set(), set()],
{5, 6},
{5, 6},
id="from-db"
),
pytest.param(
[set()],
set(),
{1},
id="fallback" # see https://github.com/cylc/cylc-flow/pull/6445
),
])
def test_get_active_flow_nums(
pool: List[FlowNums], db_fnums: FlowNums, expected
):
mock_task_pool = Mock(
get_tasks=lambda: [Mock(flow_nums=fnums) for fnums in pool],
)
mock_task_pool.workflow_db_mgr.pri_dao.select_latest_flow_nums = (
lambda: db_fnums
)

assert TaskPool._get_active_flow_nums(mock_task_pool) == expected

0 comments on commit b702dbb

Please sign in to comment.