From 54332449890a23b1d8749fdeef2a92571b9db30a Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Thu, 18 Apr 2024 15:39:18 +0200 Subject: [PATCH] init --- src/ert/cli/main.py | 36 +++- src/ert/cli/monitor.py | 23 +- src/ert/ensemble_evaluator/__init__.py | 3 +- .../ensemble_evaluator/evaluator_tracker.py | 201 +++++++++++++++++- src/ert/ensemble_evaluator/monitor.py | 33 ++- 5 files changed, 281 insertions(+), 15 deletions(-) diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 97a98386c1d..6e3a03870da 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -1,10 +1,12 @@ #!/usr/bin/env python +import asyncio import contextlib import logging import os import sys from typing import Any, TextIO +from _ert.async_utils import new_event_loop from _ert.threading import ErtThread from ert.cli import ( ENSEMBLE_EXPERIMENT_MODE, @@ -19,7 +21,10 @@ from ert.cli.workflow import execute_workflow from ert.config import ErtConfig, QueueSystem from ert.enkf_main import EnKFMain -from ert.ensemble_evaluator import EvaluatorServerConfig, EvaluatorTracker +from ert.ensemble_evaluator import ( + AsyncEvaluatorTracker, + EvaluatorServerConfig, +) from ert.namespace import Namespace from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config @@ -108,9 +113,8 @@ def run_cli(args: Namespace, _: Any = None) -> None: ) thread.start() - tracker = EvaluatorTracker( - model, ee_con_info=evaluator_server_config.get_connection_info() - ) + loop = new_event_loop() + asyncio.set_event_loop(loop) with contextlib.ExitStack() as exit_stack: out: TextIO @@ -121,15 +125,25 @@ def run_cli(args: Namespace, _: Any = None) -> None: else: out = sys.stderr monitor = Monitor(out=out, color_always=args.color_always) - - try: - monitor.monitor(tracker.track()) - except (SystemExit, KeyboardInterrupt, OSError): - # _base_service.py translates CTRL-c to OSError - print("\nKilling simulations...") - tracker.request_termination() + loop.run_until_complete(async_tracking(monitor, model, evaluator_server_config)) thread.join() storage.close() model.reraise_exception(ErtCliError) + + +async def async_tracking( + monitor: Monitor, model, evaluator_server_config: EvaluatorServerConfig +): + tracker = AsyncEvaluatorTracker( + model, ee_con_info=evaluator_server_config.get_connection_info() + ) + try: + await asyncio.gather( + tracker._drain_monitor(), monitor.async_monitor(tracker.track()) + ) + except (SystemExit, KeyboardInterrupt, OSError): + # _base_service.py translates CTRL-c to OSError + print("\nKilling simulations...") + await tracker.request_termination() diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index f051870079e..9f450c02335 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- +import asyncio import sys from datetime import datetime, timedelta -from typing import Dict, Iterator, Optional, TextIO, Tuple, Union +from typing import AsyncIterator, Dict, Iterator, Optional, TextIO, Tuple, Union from tqdm import tqdm @@ -77,6 +78,26 @@ def monitor( self._print_job_errors() return + async def async_monitor( + self, + events: AsyncIterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]], + ) -> None: + self._start_time = datetime.now() + async for event in events: + if isinstance(event, FullSnapshotEvent): + if event.snapshot is not None: + self._snapshots[event.iteration] = event.snapshot + self._progress = event.progress + elif isinstance(event, SnapshotUpdateEvent): + if event.partial_snapshot is not None: + self._snapshots[event.iteration].merge_event(event.partial_snapshot) + self._print_progress(event) + if isinstance(event, EndEvent): + self._print_result(event.failed, event.failed_msg) + self._print_job_errors() + return + await asyncio.sleep(0) + def _print_job_errors(self) -> None: failed_jobs: Dict[Optional[str], int] = {} for snapshot in self._snapshots.values(): diff --git a/src/ert/ensemble_evaluator/__init__.py b/src/ert/ensemble_evaluator/__init__.py index ea093217a94..35f64cb9f1c 100644 --- a/src/ert/ensemble_evaluator/__init__.py +++ b/src/ert/ensemble_evaluator/__init__.py @@ -7,12 +7,13 @@ ) from .config import EvaluatorServerConfig from .evaluator import EnsembleEvaluator -from .evaluator_tracker import EvaluatorTracker +from .evaluator_tracker import AsyncEvaluatorTracker, EvaluatorTracker from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent from .monitor import Monitor from .snapshot import PartialSnapshot, Snapshot __all__ = ( + "AsyncEvaluatorTracker", "EndEvent", "Ensemble", "EnsembleBuilder", diff --git a/src/ert/ensemble_evaluator/evaluator_tracker.py b/src/ert/ensemble_evaluator/evaluator_tracker.py index 56369724efb..34b66027cfd 100644 --- a/src/ert/ensemble_evaluator/evaluator_tracker.py +++ b/src/ert/ensemble_evaluator/evaluator_tracker.py @@ -4,7 +4,7 @@ import logging import queue import time -from typing import TYPE_CHECKING, Dict, Iterator, Union +from typing import TYPE_CHECKING, AsyncIterator, Dict, Iterator, Union from aiohttp import ClientError from websockets.exceptions import ConnectionClosedError @@ -238,3 +238,202 @@ def request_termination(self) -> None: while self._drainer_thread.is_alive(): self._clear_work_queue() time.sleep(1) + + +class AsyncEvaluatorTracker: + DONE = "done" + + def __init__( + self, + model: "BaseRunModel", + ee_con_info: EvaluatorConnectionInfo, + next_ensemble_evaluator_wait_time: int = 5, + ): + self._model = model + self._ee_con_info = ee_con_info + self._next_ensemble_evaluator_wait_time = next_ensemble_evaluator_wait_time + self._iter_snapshot: Dict[int, Snapshot] = {} + + self._work_queue: asyncio.Queue[Union[str, CloudEvent]] = asyncio.Queue() + + self.done_flag = False + self.drainer_started = asyncio.Event() + + async def _drain_monitor(self) -> None: + drainer_logger = logging.getLogger("ert.ensemble_evaluator.drainer") + while not self._model.isFinished(): + try: + drainer_logger.debug("connecting to new monitor...") + with Monitor(self._ee_con_info) as monitor: + drainer_logger.debug("connected") + self.drainer_started.set() + async for event in monitor.async_track(): + if event["type"] in ( + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + ): + await self._work_queue.put(event) + if event.data.get(STATUS) in [ + ENSEMBLE_STATE_STOPPED, + ENSEMBLE_STATE_FAILED, + ]: + drainer_logger.debug( + "observed evaluation stopped event, signal done" + ) + monitor.signal_done() + if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED: + drainer_logger.debug( + "observed evaluation cancelled event, exit drainer" + ) + # Allow track() to emit an EndEvent. + await self._work_queue.put(EvaluatorTracker.DONE) + await self._work_queue.join() + return + elif event["type"] == EVTYPE_EE_TERMINATED: + drainer_logger.debug("got terminator event") + # This sleep needs to be there. Refer to issue #1250: `Authority + # on information about evaluations/experiments` + await asyncio.sleep(self._next_ensemble_evaluator_wait_time) + except (ConnectionRefusedError, ClientError) as e: + if not self._model.isFinished(): + drainer_logger.debug(f"connection refused: {e}") + except ConnectionClosedError as e: + # The monitor connection closed unexpectedly + drainer_logger.debug(f"connection closed error: {e}") + except BaseException: + drainer_logger.exception("unexpected error: ") + # We really don't know what happened... shut down + # the thread and get out of here. The monitor has + # been stopped by the ctx-mgr + await self._work_queue.put(EvaluatorTracker.DONE) + await self._work_queue.join() + return + drainer_logger.debug( + "observed that model was finished, waiting tasks completion..." + ) + # The model has finished, we indicate this by sending a DONE + await self._work_queue.put(EvaluatorTracker.DONE) + await self._work_queue.join() + drainer_logger.debug("tasks complete") + + async def track( + self, + ) -> AsyncIterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]: + await self.drainer_started.wait() + while True: + event = await self._work_queue.get() + if isinstance(event, str): + with contextlib.suppress(GeneratorExit): + # consumers may exit at this point, make sure the last + # task is marked as done + if event == EvaluatorTracker.DONE: + yield EndEvent( + failed=self._model.hasRunFailed(), + failed_msg=self._model.getFailMessage(), + ) + self._work_queue.task_done() + break + if event["type"] == EVTYPE_EE_SNAPSHOT: + iter_ = event.data["iter"] + snapshot = Snapshot(event.data) + self._iter_snapshot[iter_] = snapshot + yield FullSnapshotEvent( + phase_name=self._model.getPhaseName(), + current_phase=self._model.currentPhase(), + total_phases=self._model.phaseCount(), + indeterminate=self._model.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + snapshot=copy.deepcopy(snapshot), + ) + elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE: + iter_ = event.data["iter"] + if iter_ not in self._iter_snapshot: + raise OutOfOrderSnapshotUpdateException( + f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored " + f"snapshot for iter {iter_}" + ) + partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent( + event + ) + self._iter_snapshot[iter_].merge_event(partial) + yield SnapshotUpdateEvent( + phase_name=self._model.getPhaseName(), + current_phase=self._model.currentPhase(), + total_phases=self._model.phaseCount(), + indeterminate=self._model.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + partial_snapshot=partial, + ) + self._work_queue.task_done() + self.done_flag = True + + def is_finished(self) -> bool: + return self.done_flag + + def _progress(self) -> float: + """Fraction of completed iterations over total iterations""" + + if self.is_finished(): + return 1.0 + elif not self._iter_snapshot: + return 0.0 + else: + # Calculate completed realizations + current_iter = max(list(self._iter_snapshot.keys())) + done_reals = 0 + all_reals = self._iter_snapshot[current_iter].reals + if not all_reals: + # Empty ensemble or all realizations deactivated + return 1.0 + for real in all_reals.values(): + if real.status in [ + REALIZATION_STATE_FINISHED, + REALIZATION_STATE_FAILED, + ]: + done_reals += 1 + real_progress = float(done_reals) / len(all_reals) + + return ( + (current_iter + real_progress) / self._model.phaseCount() + if self._model.phaseCount() != 1 + else real_progress + ) + + def _clear_work_queue(self) -> None: + with contextlib.suppress(queue.Empty): + while True: + self._work_queue.get_nowait() + self._work_queue.task_done() + + async def request_termination(self) -> None: + logger = logging.getLogger("ert.ensemble_evaluator.tracker") + # There might be some situations where the + # evaluation is finished or the evaluation + # is yet to start when calling this function. + # In these cases the monitor is not started + # + # To avoid waiting too long we exit if we are not + # able to connect to the monitor after 2 tries + # + # See issue: https://github.com/equinor/ert/issues/1250 + # + try: + logger.debug("requesting termination...") + await wait_for_evaluator( + base_url=self._ee_con_info.url, + token=self._ee_con_info.token, + cert=self._ee_con_info.cert, + timeout=5, + ) + logger.debug("requested termination") + except ClientError as e: + logger.warning(f"{__name__} - exception {e}") + return + + with Monitor(self._ee_con_info) as monitor: + monitor.signal_cancel() + while not self.done_flag(): + self._clear_work_queue() + await asyncio.sleep(1) diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 16e7870a15c..09a9f47b27f 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,8 +1,15 @@ +import asyncio import logging import pickle import uuid from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, Generator, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Generator, + Optional, +) from cloudevents.conversion import to_json from cloudevents.exceptions import DataUnmarshallerError @@ -103,3 +110,27 @@ def track(self) -> Generator[CloudEvent, None, None]: if event["type"] == identifiers.EVTYPE_EE_TERMINATED: logger.debug(f"monitor-{self._id} client received terminated") break + + async def async_track(self) -> AsyncIterator[CloudEvent]: + with ExitStack() as stack: + duplexer = self._ws_duplexer + if not duplexer: + duplexer = SyncWebsocketDuplexer( + self._ee_con_info.client_uri, + self._ee_con_info.url, + self._ee_con_info.cert, + self._ee_con_info.token, + ) + stack.callback(duplexer.stop) + for message in duplexer.receive(): + try: + event = from_json( + str(message), data_unmarshaller=evaluator_unmarshaller + ) + except DataUnmarshallerError: + event = from_json(str(message), data_unmarshaller=pickle.loads) + await asyncio.sleep(0) + yield event + if event["type"] == identifiers.EVTYPE_EE_TERMINATED: + logger.debug(f"monitor-{self._id} client received terminated") + break