diff --git a/openeo/extra/job_management/__init__.py b/openeo/extra/job_management/__init__.py index 43e1a5c2f..78a07d123 100644 --- a/openeo/extra/job_management/__init__.py +++ b/openeo/extra/job_management/__init__.py @@ -32,12 +32,16 @@ from requests.adapters import HTTPAdapter, Retry from openeo import BatchJob, Connection +from openeo.extra.job_management._thread_worker import ( _JobManagerWorkerThreadPool, + _JobStartTask) + from openeo.internal.processes.parse import ( Parameter, Process, parse_remote_process_definition, ) from openeo.rest import OpenEoApiError +from openeo.rest.auth.auth import BearerAuth from openeo.util import LazyLoadCache, deep_get, repr_truncate, rfc3339 _log = logging.getLogger(__name__) @@ -105,6 +109,7 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame: """ ... + def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs): raise NotImplementedError("No 'start_job' callable provided") @@ -186,6 +191,7 @@ def start_job( # Expected columns in the job DB dataframes. # TODO: make this part of public API when settled? + # TODO: move non official statuses to seperate column (not_started, queued_for_start) _COLUMN_REQUIREMENTS: Mapping[str, _ColumnProperties] = { "id": _ColumnProperties(dtype="str"), "backend_name": _ColumnProperties(dtype="str"), @@ -222,6 +228,7 @@ def __init__( datetime.timedelta(seconds=cancel_running_job_after) if cancel_running_job_after is not None else None ) self._thread = None + self._worker_pool = None def add_backend( self, @@ -358,6 +365,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas _log.info(f"Resuming `run_jobs` from existing {job_db}") self._stop_thread = False + self._worker_pool = _JobManagerWorkerThreadPool() def run_loop(): @@ -365,14 +373,19 @@ def run_loop(): stats = collections.defaultdict(int) while ( - sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0 + sum( + job_db.count_by_status( + statuses=["not_started", "created", "queued", "queued_for_start", "running"] + ).values() + ) + > 0 and not self._stop_thread ): - self._job_update_loop(job_db=job_db, start_job=start_job) + self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 + # Show current stats and sleep _log.info(f"Job status histogram: {job_db.count_by_status()}. Run stats: {dict(stats)}") - # Do sequence of micro-sleeps to allow for quick thread exit for _ in range(int(max(1, self.poll_sleep))): time.sleep(1) if self._stop_thread: @@ -391,6 +404,8 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ + self._worker_pool.shutdown() + if self._thread is not None: self._stop_thread = True if timeout_seconds is _UNSET: @@ -493,7 +508,16 @@ def run_jobs( # TODO: support user-provided `stats` stats = collections.defaultdict(int) - while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0: + self._worker_pool = _JobManagerWorkerThreadPool() + + while ( + sum( + job_db.count_by_status( + statuses=["not_started", "created", "queued_for_start", "queued", "running"] + ).values() + ) + > 0 + ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -502,6 +526,10 @@ def run_jobs( time.sleep(self.poll_sleep) stats["sleep"] += 1 + + # TODO; run post process after shutdown once more to ensure completion? + self._worker_pool.shutdown() + return stats def _job_update_loop( @@ -524,7 +552,7 @@ def _job_update_loop( not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy() if len(not_started) > 0: # Check number of jobs running at each backend - running = job_db.get_by_status(statuses=["created", "queued", "running"]) + running = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]) stats["job_db get_by_status"] += 1 per_backend = running.groupby("backend_name").size().to_dict() _log.info(f"Running per backend: {per_backend}") @@ -541,7 +569,9 @@ def _job_update_loop( stats["job_db persist"] += 1 total_added += 1 - # Act on jobs + self._process_threadworker_updates(self._worker_pool, job_db, stats) + + # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: self.on_job_done(job, row) @@ -551,7 +581,6 @@ def _job_update_loop( for job, row in jobs_cancel: self.on_job_cancel(job, row) - def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None): """Helper method for launching jobs @@ -598,6 +627,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df.loc[i, "start_time"] = rfc3339.utcnow() if job: df.loc[i, "id"] = job.job_id + _log.info(f"Job created: {job.job_id}") with ignore_connection_errors(context="get status"): status = job.status() stats["job get status"] += 1 @@ -605,19 +635,84 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No if status == "created": # start job if not yet done by callback try: - job.start() - stats["job start"] += 1 - df.loc[i, "status"] = job.status() - stats["job get status"] += 1 + job_con = job.connection + task = _JobStartTask( + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, + job_id=job.job_id, + ) + _log.info(f"Submitting task {task} to thread pool") + self._worker_pool.submit_task(task) + + stats["job_queued_for_start"] += 1 + df.loc[i, "status"] = "queued_for_start" except OpenEoApiError as e: - _log.error(e) - df.loc[i, "status"] = "start_failed" - stats["job start error"] += 1 + _log.info(f"Failed submitting task {task} to thread pool with error: {e}") + df.loc[i, "status"] = "queued_for_start_failed" + stats["job queued for start failed"] += 1 else: # TODO: what is this "skipping" about actually? df.loc[i, "status"] = "skipped" stats["start_job skipped"] += 1 + def _process_threadworker_updates( + self, + worker_pool: _JobManagerWorkerThreadPool, + job_db: JobDatabaseInterface, + stats: dict + ) -> None: + """Processes asynchronous job updates from worker threads and applies them to the job database and statistics. + + This wrapper function is responsible for: + 1. Collecting completed results from the worker thread pool + 2. applying database updates for each job result + 3. applying statistics updates + 4. Handles errors with comprehensive logging + + :param worker_pool: + Thread pool instance managing the asynchronous job operations. + Should provide a `process_futures()` method returning completed job results. + + :param job_db: + Job database implementing the :py:class:`JobDatabaseInterface` interface. + Used to persist job status updates and metadata. + Must support the `_update_row(job_id: str, updates: dict)` method. + + :param stats: + Dictionary tracking operational statistics that will be updated in-place. + Expected to handle string keys with integer values. + Statistics will be updated with counts from completed job results. + + :return: + None: All updates are applied in-place to the job_db and stats parameters. +. + """ + results = worker_pool.process_futures() + stats_updates = collections.defaultdict(int) + + for result in results: + try: + # Handle job database updates + if result.db_update: + _log.debug(f"Processing update for job {result.job_id}") + job_db._update_row(job_id=result.job_id, updates=result.db_update) + + # Aggregate statistics updates + if result.stats_update: + for key, count in result.stats_update.items(): + stats_updates[key] += int(count) + + + except Exception as e: + _log.error( + f"Failed aggregating the updates for update for job {result.job_id}: {str(e)}") + + # Apply all stat updates + for key, count in stats_updates.items(): + stats[key] = stats.get(key, 0) + count + + + def on_job_done(self, job: BatchJob, row): """ Handles jobs that have finished. Can be overridden to provide custom behaviour. @@ -673,7 +768,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row): try: # Ensure running start time is valid job_running_start_time = rfc3339.parse_datetime(row.get("running_start_time"), with_timezone=True) - + # Parse the current time into a datetime object with timezone info current_time = rfc3339.parse_datetime(rfc3339.utcnow(), with_timezone=True) @@ -681,12 +776,11 @@ def _cancel_prolonged_job(self, job: BatchJob, row): elapsed = current_time - job_running_start_time if elapsed > self._cancel_running_job_after: - _log.info( f"Cancelling long-running job {job.job_id} (after {elapsed}, running since {job_running_start_time})" ) job.stop() - + except Exception as e: _log.error(f"Unexpected error while handling job {job.job_id}: {e}") @@ -715,7 +809,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = """ stats = stats if stats is not None else collections.defaultdict(int) - active = job_db.get_by_status(statuses=["created", "queued", "running"]).copy() + active = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]).copy() jobs_done = [] jobs_error = [] @@ -749,7 +843,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = stats["job canceled"] += 1 jobs_cancel.append((the_job, active.loc[i])) - if previous_status in {"created", "queued"} and new_status == "running": + if previous_status in {"created", "queued", "queued_for_start"} and new_status == "running": stats["job started running"] += 1 active.loc[i, "running_start_time"] = rfc3339.utcnow() @@ -782,7 +876,6 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = return jobs_done, jobs_error, jobs_cancel - def _format_usage_stat(job_metadata: dict, field: str) -> str: value = deep_get(job_metadata, "usage", field, "value", default=0) unit = deep_get(job_metadata, "usage", field, "unit", default="") @@ -877,6 +970,55 @@ def _merge_into_df(self, df: pd.DataFrame): else: self._df = df + def _update_row(self, job_id: str, updates: dict): + """ + Propagates dataframe updates provided in a dictionary to the row relevant for said job_id. + + :param job_id: a job_id. + :param updates: a dictionary containing status updates. + + :return: DataFrame with jobs filtered by status. + """ + if self._df is None: + raise ValueError("Job database not initialized") + + # Create boolean mask for target row + mask = self._df["id"] == job_id + match_count = mask.sum() + + # Handle row identification issues + #TODO: make this more robust, e.g. falling back on the row index? + if match_count == 0: + _log.error(f"Job {job_id!r} not found in database") + return + if match_count > 1: + _log.error(f"Duplicate job ID {job_id!r} found in database") + return + + # Get valid columns + valid_columns = set(self._df.columns) + filtered_updates = {} + + # Validate update keys s + for key, value in updates.items(): + if key in valid_columns: + filtered_updates[key] = value + else: + _log.warning(f"Ignoring invalid column {key!r} in update for job {job_id}") + + # Bulk update + if not filtered_updates: + return + try: + # Update all columns in a single operation + self._df.loc[mask, list(filtered_updates.keys())] = list(filtered_updates.values()) + self.persist(self._df) + except Exception as e: + _log.error(f"Failed to persist row update for job {job_id}: {e}") + + + + class CsvJobDatabase(FullDataFrameJobDatabase): """ @@ -932,6 +1074,8 @@ def persist(self, df: pd.DataFrame): self.path.parent.mkdir(parents=True, exist_ok=True) self.df.to_csv(self.path, index=False) + + class ParquetJobDatabase(FullDataFrameJobDatabase): """ diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py new file mode 100644 index 000000000..8b26e88c3 --- /dev/null +++ b/openeo/extra/job_management/_thread_worker.py @@ -0,0 +1,193 @@ +import concurrent.futures +import logging +from dataclasses import dataclass, field +from typing import Optional, Any, List, Dict, Tuple +import openeo +from abc import ABC, abstractmethod + + +_log = logging.getLogger(__name__) + +@dataclass +class _TaskResult: + """ + Container for the result of a task execution. + Used to communicate the outcome of job-related tasks. + + :param job_id: + The ID of the job this result is associated with. + + :param db_update: + Optional dictionary describing updates to apply to a job database, + such as status changes. Defaults to an empty dict. + + :param stats_update: + Optional dictionary capturing statistical counters or metrics, + e.g., number of successful starts or errors. Defaults to an empty dict. + """ + job_id: str # Mandatory + db_update: Dict[str, Any] = field(default_factory=dict) # Optional + stats_update: Dict[str, int] = field(default_factory=dict) # Optional + +class Task(ABC): + """ + Abstract base class for asynchronous tasks. + + A task encapsulates a unit of work, typically executed asynchronously, + and returns a `_TaskResult` with job-related metadata and updates. + + Implementations must override the `execute` method to define the task logic. + """ + + @abstractmethod + def execute(self) -> _TaskResult: + """Execute the task and return a raw result""" + pass + +@dataclass +class _JobStartTask(Task): + """ + Task for starting a backend job asynchronously. + + Connects to an OpenEO backend using the provided URL and optional token, + retrieves the specified job, and attempts to start it. + + Usage example: + + .. code-block:: python + + task = _JobStartTask( + job_id="1234", + root_url="https://openeo.test", + bearer_token="secret" + ) + result = task.execute() + + :param job_id: + Identifier of the job to start on the backend. + + :param root_url: + The root URL of the OpenEO backend to connect to. + + :param bearer_token: + Optional Bearer token used for authentication. + + :raises ValueError: + If any of the input parameters are invalid (e.g., empty strings). + """ + job_id: str + root_url: str + bearer_token: Optional[str] + + + def __post_init__(self) -> None: + # Validation remains unchanged + if not isinstance(self.root_url, str) or not self.root_url.strip(): + raise ValueError(f"root_url must be a non-empty string, got {self.root_url!r}") + if self.bearer_token is not None and (not isinstance(self.bearer_token, str) or not self.bearer_token.strip()): + raise ValueError(f"bearer_token must be a non-empty string or None, got {self.bearer_token!r}") + if not isinstance(self.job_id, str) or not self.job_id.strip(): + raise ValueError(f"job_id must be a non-empty string, got {self.job_id!r}") + + def execute(self) -> _TaskResult: + """ + Executes the job start process using the OpenEO connection. + + Authenticates if a bearer token is provided, retrieves the job by ID, + and attempts to start it. + + :returns: + A `_TaskResult` with status and statistics metadata, indicating + success or failure of the job start. + """ + try: + conn = openeo.connect(self.root_url) + if self.bearer_token: + conn.authenticate_bearer_token(self.bearer_token) + job = conn.job(self.job_id) + job.start() + _log.info(f"Job {self.job_id} started successfully") + return _TaskResult( + job_id=self.job_id, + db_update={"status": "queued"}, + stats_update={"job start": 1}, + ) + except Exception as e: + _log.error(f"Failed to start job {self.job_id}: {e}") + return _TaskResult( + job_id=self.job_id, + db_update={"status": "start_failed"}, + stats_update={"start_job error": 1}) + +class _JobManagerWorkerThreadPool: + """ + Thread pool-based worker that manages the execution of asynchronous tasks. + + Internally wraps a `ThreadPoolExecutor` and manages submission, + tracking, and result processing of tasks. + + :param max_workers: + Maximum number of concurrent threads to use for execution. + Defaults to 2. + """ + def __init__(self, max_workers: int = 2): + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] + + def submit_task(self, task: Task) -> None: + """ + Submit a task to the thread pool executor. + + Tasks are scheduled for asynchronous execution and tracked + internally to allow later processing of their results. + + :param task: + An instance of `Task` to be executed. + """ + future = self._executor.submit(task.execute) + self._future_task_pairs.append((future, task)) # Track pairs + + def process_futures(self) -> List[ _TaskResult]: + """ + Process and retrieve results from completed tasks. + + This method checks which futures have finished without blocking, + collects their results. + + :returns: + A list of `_TaskResult` objects from completed tasks. + """ + results = [] + to_keep = [] + + # Use timeout=0 to avoid blocking and check for completed futures + done, _ = concurrent.futures.wait( + [f for f, _ in self._future_task_pairs], timeout=0, + return_when=concurrent.futures.FIRST_COMPLETED + ) + + # Process completed futures and their tasks + for future, task in self._future_task_pairs: + if future in done: + try: + result = future.result() + + except Exception as e: + + _log.exception(f"Error processing task: {e}") + result = _TaskResult( + job_id=task.job_id, + db_update={"status": "start_failed"}, + stats_update={"start_job error": 1}) + + results.append(result) + else: + to_keep.append((future, task)) + + self._future_task_pairs = to_keep + return results + + def shutdown(self) -> None: + """Shuts down the thread pool gracefully.""" + _log.info("Shutting down thread pool") + self._executor.shutdown(wait=True) diff --git a/scl_native.tif b/scl_native.tif new file mode 100644 index 000000000..6f9264a20 Binary files /dev/null and b/scl_native.tif differ diff --git a/scl_resampled.tif b/scl_resampled.tif new file mode 100644 index 000000000..a3e26fd12 Binary files /dev/null and b/scl_resampled.tif differ diff --git a/tests/extra/job_management/test_job_management.py b/tests/extra/job_management/test_job_management.py index 46108f0fa..fbb41c1b6 100644 --- a/tests/extra/job_management/test_job_management.py +++ b/tests/extra/job_management/test_job_management.py @@ -6,7 +6,7 @@ import threading from pathlib import Path from time import sleep -from typing import Callable, Union +from typing import Union from unittest import mock import dirty_equals @@ -25,6 +25,9 @@ import pytest import requests import shapely.geometry +import collections +import time + import openeo import openeo.extra.job_management @@ -38,10 +41,17 @@ create_job_db, get_job_db, ) + from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities from openeo.util import rfc3339 from openeo.utils.version import ComparableVersion +from openeo.extra.job_management._thread_worker import ( + Task, + _TaskResult, + _JobManagerWorkerThreadPool, +) + @pytest.fixture def con(requests_mock) -> openeo.Connection: @@ -80,6 +90,26 @@ def sleep_mock(): with mock.patch("time.sleep") as sleep: yield sleep +class DummyTask(Task): + """ + A Task that simply sleeps and then returns a predetermined _TaskResult. + """ + def __init__(self, job_id, db_update, stats_update, delay=0.0): + self.job_id = job_id + self._db_update = db_update or {} + self._stats_update = stats_update or {} + self._delay = delay + + def execute(self) -> _TaskResult: + if self._delay: + time.sleep(self._delay) + return _TaskResult( + job_id=self.job_id, + db_update=self._db_update, + stats_update=self._stats_update, + ) + + class TestMultiBackendJobManager: @@ -93,6 +123,7 @@ def job_manager(self, job_manager_root_dir, dummy_backend_foo, dummy_backend_bar manager.add_backend("foo", connection=dummy_backend_foo.connection) manager.add_backend("bar", connection=dummy_backend_bar.connection) return manager + @staticmethod def _create_year_job(row, connection, **kwargs): @@ -466,6 +497,7 @@ def start_job(row, connection_provider, connection, **kwargs): ("job-2018", "finished", "foo"), ] + @httpretty.activate(allow_net_connect=False, verbose=True) @pytest.mark.parametrize("http_error_status", [502, 503, 504]) def test_resilient_backend_reports_error_when_max_retries_exceeded(self, tmp_path, http_error_status, sleep_mock): @@ -590,14 +622,13 @@ def get_status(job_id, current_status): time_machine.move_to(create_time) job_db_path = tmp_path / "jobs.csv" + # Mock sleep() to not actually sleep, but skip one hour at a time with mock.patch.object(openeo.extra.job_management.time, "sleep", new=lambda s: time_machine.shift(60 * 60)): job_manager.run_jobs(df=df, start_job=self._create_year_job, job_db=job_db_path) final_df = CsvJobDatabase(job_db_path).read() - assert final_df.iloc[0].to_dict() == dirty_equals.IsPartialDict( - id="job-2024", status=expected_status, running_start_time="2024-09-01T10:00:00Z" - ) + assert dirty_equals.IsPartialDict(id="job-2024", status=expected_status) == final_df.iloc[0].to_dict() assert dummy_backend_foo.batch_jobs == { "job-2024": { @@ -644,10 +675,11 @@ def test_status_logging(self, tmp_path, job_manager, job_manager_root_dir, sleep run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job) assert run_stats == dirty_equals.IsPartialDict({"start_job call": 5, "job finished": 5}) - needle = re.compile(r"Job status histogram:.*'queued': 4.*Run stats:.*'start_job call': 4") + needle = re.compile(r"Job status histogram:.*'finished': 5.*Run stats:.*'job_queued_for_start': 5") assert needle.search(caplog.text) + @pytest.mark.parametrize( ["create_time", "start_time", "running_start_time", "end_time", "end_status", "cancel_after_seconds"], [ @@ -720,6 +752,49 @@ def get_status(job_id, current_status): filled_running_start_time = final_df.iloc[0]["running_start_time"] assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime) + + + + def test_process_threadworker_updates(self, job_manager, tmp_path): + + csv_path = tmp_path / "jobs.csv" + df = pd.DataFrame([ + {"id": "job-1", "status": "created"}, + {"id": "job-2", "status": "created"}, + ]) + job_db = CsvJobDatabase(csv_path).initialize_from_df(df) + + pool = _JobManagerWorkerThreadPool(max_workers=2) + + # Submit two dummy tasks with different delays and updates + t1 = DummyTask("job-1", {"status": "done"}, {"a": 1}, delay=0.05) + t2 = DummyTask("job-2", {"status": "failed"}, {"b": 2}, delay=0.1) + pool.submit_task(t1) + pool.submit_task(t2) + + # Wait for all futures to be done + # We access the internal list of (future, task) pairs to check .done() + start = time.time() + timeout = 2.0 + while time.time() - start < timeout: + pairs = list(pool._future_task_pairs) + if all(future.done() for future, _ in pairs): + break + time.sleep(0.01) + else: + pytest.skip("Tasks did not complete within timeout") + + # Now invoke the real update loop + stats = collections.defaultdict(int) + job_manager._process_threadworker_updates(pool, job_db, stats) + + # Check that the in-memory database was updated + df = job_db.df + assert df.loc[df.id == "job-1", "status"].iloc[0] == "done" + assert df.loc[df.id == "job-2", "status"].iloc[0] == "failed" + + # And that our stats were aggregated + assert stats == {"a": 1, "b": 2} JOB_DB_DF_BASICS = pd.DataFrame( @@ -835,6 +910,62 @@ def test_count_by_status(self, tmp_path, db_class): "running": 2, } + @pytest.mark.parametrize("db_class", [CsvJobDatabase, ParquetJobDatabase]) + def test_update_existing_row(self, tmp_path, db_class): + path = tmp_path / "jobs.db" + df = pd.DataFrame({"id": ["job-123"], "status": ["created"], "costs": [0.0]}) + db = db_class(path).initialize_from_df(df) + + db._update_row("job-123", {"status": "queued", "costs": 42.5}) + updated = db.read() + + assert updated.loc[0, "status"] == "queued" + assert updated.loc[0, "costs"] == 42.5 + + @pytest.mark.parametrize("db_class", [CsvJobDatabase, ParquetJobDatabase]) + def test_update_unknown_job_id(self, tmp_path, db_class, caplog): + path = tmp_path / "jobs.db" + df = pd.DataFrame({"id": ["job-123"], "status": ["created"]}) + db = db_class(path).initialize_from_df(df) + + db._update_row("nonexistent-job", {"status": "queued"}) + + assert "not found in database" in caplog.text + # Ensure no updates happened + assert db.read().loc[0, "status"] == "created" + + @pytest.mark.parametrize("db_class", [CsvJobDatabase, ParquetJobDatabase]) + def test_update_duplicate_job_id(self, tmp_path, db_class, caplog): + path = tmp_path / "jobs.db" + df = pd.DataFrame({"id": ["job-123", "job-123"], "status": ["created", "created"]}) + db = db_class(path).initialize_from_df(df) + + db._update_row("job-123", {"status": "queued"}) + + assert "Duplicate job ID" in caplog.text + assert set(db.read()["status"]) == {"created"} + + @pytest.mark.parametrize("db_class", [CsvJobDatabase, ParquetJobDatabase]) + def test_update_with_invalid_column(self, tmp_path, db_class, caplog): + path = tmp_path / "jobs.db" + df = pd.DataFrame({"id": ["job-123"], "status": ["created"]}) + db = db_class(path).initialize_from_df(df) + + db._update_row("job-123", {"not_a_column": "value", "status": "finished"}) + + assert "Ignoring invalid column 'not_a_column'" in caplog.text + assert db.read().loc[0, "status"] == "finished" + + @pytest.mark.parametrize("db_class", [CsvJobDatabase, ParquetJobDatabase]) + def test_update_with_no_valid_fields(self, tmp_path, db_class): + path = tmp_path / "jobs.db" + df = pd.DataFrame({"id": ["job-123"], "status": ["created"]}) + db = db_class(path).initialize_from_df(df) + + db._update_row("job-123", {"invalid_field": "value"}) + + assert db.read().loc[0, "status"] == "created" + class TestCsvJobDatabase: @@ -1761,3 +1892,4 @@ def test_with_job_manager_parameter_column_map( "description": "Process 'increment' (namespace https://remote.test/increment.json) with {'data': 5, 'increment': 200}", }, } + diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py new file mode 100644 index 000000000..3c531edb2 --- /dev/null +++ b/tests/extra/job_management/test_thread_worker.py @@ -0,0 +1,170 @@ +import time +import pytest +import pandas as pd +import requests + +# Import the refactored classes and helper functions from your codebase. +# Adjust the import paths as needed. +from openeo.extra.job_management._thread_worker import ( + _JobManagerWorkerThreadPool, + _JobStartTask) + +# --- Fixtures and Helpers --- + +@pytest.fixture +def worker_pool(): + """Fixture for creating and cleaning up a worker thread pool.""" + pool = _JobManagerWorkerThreadPool(max_workers=2) + yield pool + pool.shutdown() + +@pytest.fixture +def sample_dataframe(): + """Creates a pandas DataFrame for job tracking.""" + df = pd.DataFrame([ + {"id": "job-123", "status": "queued_for_start", "other_field": "foo"}, + {"id": "job-456", "status": "queued_for_start", "other_field": "bar"}, + {"id": "job-789", "status": "other", "other_field": "baz"} + ]) + return df + +@pytest.fixture +def initial_stats(): + """Returns a dictionary with initial stats counters.""" + return {"job start": 0, "job start failed": 0} + +@pytest.fixture +def successful_backend_mock(requests_mock): + """ + Returns a helper to set up a successful backend. + Mocks a version check, job start, and job status check. + """ + def _setup(root_url: str, job_id: str, status: str = "queued"): + # Backend version check + requests_mock.get(root_url, json={"api_version": "1.1.0"}) + # Job start: assume that the job start endpoint returns a JSON response (simulate the backend behavior) + requests_mock.post(f"{root_url}/jobs/{job_id}/results", json={"job_id": job_id, "status": status}, status_code=202) + # Job status check + requests_mock.get(f"{root_url}/jobs/{job_id}", json={"job_id": job_id, "status": status}) + return _setup + +@pytest.fixture +def valid_task(): + """Fixture to create a valid _JobStartTask instance.""" + return _JobStartTask( + root_url="https://foo.test", + bearer_token="test-token", + job_id="test-job-123" + ) + +import time + +def wait_for_results(worker_pool, timeout=3.0, interval=0.1): + """ + Wait for the worker pool to return results, with timeout safety. + Raises: + TimeoutError if no results are available within timeout. + """ + start = time.time() + while time.time() - start < timeout: + results = worker_pool.process_futures() + if results: + return results + time.sleep(interval) + raise TimeoutError(f"Timed out after {timeout}s waiting for worker pool results.") + +# --- Tests for the Worker Thread Pool and Futures Postprocessing --- + +class TestJobManagerWorkerThreadPool: + def test_worker_thread_lifecycle(self, worker_pool): + """Test that the worker thread pool starts and shuts down as expected.""" + + # Before shutdown, the executor should be active + assert not worker_pool._executor._shutdown + worker_pool.shutdown() + assert worker_pool._executor._shutdown + + + def test_submit_and_process_successful_task( + self, worker_pool, valid_task, successful_backend_mock, requests_mock + ): + """Test successful submission and processing of a task.""" + # Setup successful backend responses for the valid task. + successful_backend_mock(valid_task.root_url, valid_task.job_id) + worker_pool.submit_task(valid_task) + + # Wait for the task to complete + results = wait_for_results(worker_pool) + + # Unpack and assert + for result in results: + # Check that we updated the DB to "queued" + assert result.db_update == {"status": "queued"} + + # Check that the stats_update reflects one "job start" + assert result.stats_update == {"job start": 1} + + + def test_network_failure_in_task(self, worker_pool, valid_task, requests_mock): + """Test that a task encountering a network failure returns a failed result.""" + # Simulate a connection error + requests_mock.get(valid_task.root_url, + exc=requests.exceptions.ConnectionError("Backend down")) + worker_pool.submit_task(valid_task) + + results = wait_for_results(worker_pool) + + for result in results: + # On failure we set status to "start_failed" + assert result.db_update == {"status": "start_failed"} + + # And we increment the "start_job error" counter + assert result.stats_update == {"start_job error": 1} + + + def test_mixed_success_and_failure_tasks( + self, worker_pool, requests_mock, successful_backend_mock + ): + """Test processing multiple tasks with mixed outcomes.""" + # Success case + task_success = _JobStartTask( + root_url="https://foo.test", + bearer_token="token", + job_id="job-ok" + ) + successful_backend_mock(task_success.root_url, task_success.job_id) + + # Failure case + task_fail = _JobStartTask( + root_url="https://bar.test", + bearer_token="token", + job_id="job-fail" + ) + requests_mock.get(task_fail.root_url, + exc=requests.exceptions.ConnectionError("Network error")) + + worker_pool.submit_task(task_success) + worker_pool.submit_task(task_fail) + + results = wait_for_results(worker_pool) + + # Verify each outcome by job_id + for result in results: + if result.job_id == "job-ok": + assert result.db_update == {"status": "queued"} + assert result.stats_update == {"job start": 1} + elif result.job_id == "job-fail": + assert result.db_update == {"status": "start_failed"} + assert result.stats_update == {"start_job error": 1} + else: + pytest.skip(f"Unexpected task {result.job_id}") + + def test_worker_pool_bookkeeping(self, worker_pool, valid_task, successful_backend_mock, requests_mock): + """Ensure that processed futures are removed from the pool's internal tracking.""" + successful_backend_mock(valid_task.root_url, valid_task.job_id) + worker_pool.submit_task(valid_task) + results = wait_for_results(worker_pool) + + # Assuming your refactoring clears out futures after processing, + # the internal list (or maps) should be empty. + assert len(worker_pool._future_task_pairs) == 0 \ No newline at end of file