Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make EvaluatorTracker async #7702

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions src/ert/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python
import asyncio
import contextlib
import logging
import os
import sys
from typing import Any, TextIO

from _ert.async_utils import new_event_loop
from _ert.threading import ErtThread
from ert.cli import (
ENSEMBLE_EXPERIMENT_MODE,
Expand All @@ -19,7 +21,10 @@
from ert.cli.workflow import execute_workflow
from ert.config import ErtConfig, QueueSystem
from ert.enkf_main import EnKFMain
from ert.ensemble_evaluator import EvaluatorServerConfig, EvaluatorTracker
from ert.ensemble_evaluator import (
AsyncEvaluatorTracker,
EvaluatorServerConfig,
)
from ert.namespace import Namespace
from ert.storage import open_storage
from ert.storage.local_storage import local_storage_set_ert_config
Expand Down Expand Up @@ -108,9 +113,8 @@
)
thread.start()

tracker = EvaluatorTracker(
model, ee_con_info=evaluator_server_config.get_connection_info()
)
loop = new_event_loop()
asyncio.set_event_loop(loop)

with contextlib.ExitStack() as exit_stack:
out: TextIO
Expand All @@ -121,15 +125,25 @@
else:
out = sys.stderr
monitor = Monitor(out=out, color_always=args.color_always)

try:
monitor.monitor(tracker.track())
except (SystemExit, KeyboardInterrupt, OSError):
# _base_service.py translates CTRL-c to OSError
print("\nKilling simulations...")
tracker.request_termination()
loop.run_until_complete(async_tracking(monitor, model, evaluator_server_config))

thread.join()
storage.close()

model.reraise_exception(ErtCliError)


async def async_tracking(

Check failure on line 136 in src/ert/cli/main.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation

Check failure on line 136 in src/ert/cli/main.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation for one or more arguments
monitor: Monitor, model, evaluator_server_config: EvaluatorServerConfig
):
tracker = AsyncEvaluatorTracker(
model, ee_con_info=evaluator_server_config.get_connection_info()
)
try:
await asyncio.gather(
tracker._drain_monitor(), monitor.async_monitor(tracker.track())
)
except (SystemExit, KeyboardInterrupt, OSError):
# _base_service.py translates CTRL-c to OSError
print("\nKilling simulations...")
await tracker.request_termination()
23 changes: 22 additions & 1 deletion src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
import asyncio
import sys
from datetime import datetime, timedelta
from typing import Dict, Iterator, Optional, TextIO, Tuple, Union
from typing import AsyncIterator, Dict, Iterator, Optional, TextIO, Tuple, Union

from tqdm import tqdm

Expand Down Expand Up @@ -77,6 +78,26 @@ def monitor(
self._print_job_errors()
return

async def async_monitor(
self,
events: AsyncIterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]],
) -> None:
self._start_time = datetime.now()
async for event in events:
if isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
self._snapshots[event.iteration] = event.snapshot
self._progress = event.progress
elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
self._snapshots[event.iteration].merge_event(event.partial_snapshot)
self._print_progress(event)
if isinstance(event, EndEvent):
self._print_result(event.failed, event.failed_msg)
self._print_job_errors()
return
await asyncio.sleep(0)

def _print_job_errors(self) -> None:
failed_jobs: Dict[Optional[str], int] = {}
for snapshot in self._snapshots.values():
Expand Down
3 changes: 2 additions & 1 deletion src/ert/ensemble_evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
)
from .config import EvaluatorServerConfig
from .evaluator import EnsembleEvaluator
from .evaluator_tracker import EvaluatorTracker
from .evaluator_tracker import AsyncEvaluatorTracker, EvaluatorTracker
from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent
from .monitor import Monitor
from .snapshot import PartialSnapshot, Snapshot

__all__ = (
"AsyncEvaluatorTracker",
"EndEvent",
"Ensemble",
"EnsembleBuilder",
Expand Down
201 changes: 200 additions & 1 deletion src/ert/ensemble_evaluator/evaluator_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import queue
import time
from typing import TYPE_CHECKING, Dict, Iterator, Union
from typing import TYPE_CHECKING, AsyncIterator, Dict, Iterator, Union

from aiohttp import ClientError
from websockets.exceptions import ConnectionClosedError
Expand Down Expand Up @@ -238,3 +238,202 @@
while self._drainer_thread.is_alive():
self._clear_work_queue()
time.sleep(1)


class AsyncEvaluatorTracker:
DONE = "done"

def __init__(
self,
model: "BaseRunModel",
ee_con_info: EvaluatorConnectionInfo,
next_ensemble_evaluator_wait_time: int = 5,
):
self._model = model
self._ee_con_info = ee_con_info
self._next_ensemble_evaluator_wait_time = next_ensemble_evaluator_wait_time
self._iter_snapshot: Dict[int, Snapshot] = {}

self._work_queue: asyncio.Queue[Union[str, CloudEvent]] = asyncio.Queue()

self.done_flag = False
self.drainer_started = asyncio.Event()

async def _drain_monitor(self) -> None:
drainer_logger = logging.getLogger("ert.ensemble_evaluator.drainer")
while not self._model.isFinished():
try:
drainer_logger.debug("connecting to new monitor...")
with Monitor(self._ee_con_info) as monitor:
drainer_logger.debug("connected")
self.drainer_started.set()
async for event in monitor.async_track():
if event["type"] in (
EVTYPE_EE_SNAPSHOT,
EVTYPE_EE_SNAPSHOT_UPDATE,
):
await self._work_queue.put(event)
if event.data.get(STATUS) in [
ENSEMBLE_STATE_STOPPED,
ENSEMBLE_STATE_FAILED,
]:
drainer_logger.debug(
"observed evaluation stopped event, signal done"
)
monitor.signal_done()
if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED:
drainer_logger.debug(
"observed evaluation cancelled event, exit drainer"
)
# Allow track() to emit an EndEvent.
await self._work_queue.put(EvaluatorTracker.DONE)
await self._work_queue.join()
return
elif event["type"] == EVTYPE_EE_TERMINATED:
drainer_logger.debug("got terminator event")
# This sleep needs to be there. Refer to issue #1250: `Authority
# on information about evaluations/experiments`
await asyncio.sleep(self._next_ensemble_evaluator_wait_time)
except (ConnectionRefusedError, ClientError) as e:
if not self._model.isFinished():
drainer_logger.debug(f"connection refused: {e}")
except ConnectionClosedError as e:
# The monitor connection closed unexpectedly
drainer_logger.debug(f"connection closed error: {e}")
except BaseException:
drainer_logger.exception("unexpected error: ")
# We really don't know what happened... shut down
# the thread and get out of here. The monitor has
# been stopped by the ctx-mgr
await self._work_queue.put(EvaluatorTracker.DONE)
await self._work_queue.join()
return
drainer_logger.debug(
"observed that model was finished, waiting tasks completion..."
)
# The model has finished, we indicate this by sending a DONE
await self._work_queue.put(EvaluatorTracker.DONE)
await self._work_queue.join()
drainer_logger.debug("tasks complete")

async def track(
self,
) -> AsyncIterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]:
await self.drainer_started.wait()
while True:
event = await self._work_queue.get()
if isinstance(event, str):
with contextlib.suppress(GeneratorExit):
# consumers may exit at this point, make sure the last
# task is marked as done
if event == EvaluatorTracker.DONE:
yield EndEvent(
failed=self._model.hasRunFailed(),
failed_msg=self._model.getFailMessage(),
)
self._work_queue.task_done()
break
if event["type"] == EVTYPE_EE_SNAPSHOT:
iter_ = event.data["iter"]
snapshot = Snapshot(event.data)
self._iter_snapshot[iter_] = snapshot
yield FullSnapshotEvent(
phase_name=self._model.getPhaseName(),
current_phase=self._model.currentPhase(),
total_phases=self._model.phaseCount(),
indeterminate=self._model.isIndeterminate(),
progress=self._progress(),
iteration=iter_,
snapshot=copy.deepcopy(snapshot),
)
elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE:
iter_ = event.data["iter"]
if iter_ not in self._iter_snapshot:
raise OutOfOrderSnapshotUpdateException(
f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored "
f"snapshot for iter {iter_}"
)
partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent(
event
)
self._iter_snapshot[iter_].merge_event(partial)
yield SnapshotUpdateEvent(
phase_name=self._model.getPhaseName(),
current_phase=self._model.currentPhase(),
total_phases=self._model.phaseCount(),
indeterminate=self._model.isIndeterminate(),
progress=self._progress(),
iteration=iter_,
partial_snapshot=partial,
)
self._work_queue.task_done()
self.done_flag = True

def is_finished(self) -> bool:
return self.done_flag

def _progress(self) -> float:
"""Fraction of completed iterations over total iterations"""

if self.is_finished():
return 1.0
elif not self._iter_snapshot:
return 0.0
else:
# Calculate completed realizations
current_iter = max(list(self._iter_snapshot.keys()))
done_reals = 0
all_reals = self._iter_snapshot[current_iter].reals
if not all_reals:
# Empty ensemble or all realizations deactivated
return 1.0
for real in all_reals.values():
if real.status in [
REALIZATION_STATE_FINISHED,
REALIZATION_STATE_FAILED,
]:
done_reals += 1
real_progress = float(done_reals) / len(all_reals)

return (
(current_iter + real_progress) / self._model.phaseCount()
if self._model.phaseCount() != 1
else real_progress
)

def _clear_work_queue(self) -> None:
with contextlib.suppress(queue.Empty):
while True:
self._work_queue.get_nowait()
self._work_queue.task_done()

async def request_termination(self) -> None:
logger = logging.getLogger("ert.ensemble_evaluator.tracker")
# There might be some situations where the
# evaluation is finished or the evaluation
# is yet to start when calling this function.
# In these cases the monitor is not started
#
# To avoid waiting too long we exit if we are not
# able to connect to the monitor after 2 tries
#
# See issue: https://github.com/equinor/ert/issues/1250
#
try:
logger.debug("requesting termination...")
await wait_for_evaluator(
base_url=self._ee_con_info.url,
token=self._ee_con_info.token,
cert=self._ee_con_info.cert,
timeout=5,
)
logger.debug("requested termination")
except ClientError as e:
logger.warning(f"{__name__} - exception {e}")
return

with Monitor(self._ee_con_info) as monitor:
monitor.signal_cancel()
while not self.done_flag():

Check failure on line 437 in src/ert/ensemble_evaluator/evaluator_tracker.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

"bool" not callable
self._clear_work_queue()
await asyncio.sleep(1)
33 changes: 32 additions & 1 deletion src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import asyncio
import logging
import pickle
import uuid
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, Generator, Optional
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generator,
Optional,
)

from cloudevents.conversion import to_json
from cloudevents.exceptions import DataUnmarshallerError
Expand Down Expand Up @@ -103,3 +110,27 @@ def track(self) -> Generator[CloudEvent, None, None]:
if event["type"] == identifiers.EVTYPE_EE_TERMINATED:
logger.debug(f"monitor-{self._id} client received terminated")
break

async def async_track(self) -> AsyncIterator[CloudEvent]:
with ExitStack() as stack:
duplexer = self._ws_duplexer
if not duplexer:
duplexer = SyncWebsocketDuplexer(
self._ee_con_info.client_uri,
self._ee_con_info.url,
self._ee_con_info.cert,
self._ee_con_info.token,
)
stack.callback(duplexer.stop)
for message in duplexer.receive():
try:
event = from_json(
str(message), data_unmarshaller=evaluator_unmarshaller
)
except DataUnmarshallerError:
event = from_json(str(message), data_unmarshaller=pickle.loads)
await asyncio.sleep(0)
yield event
if event["type"] == identifiers.EVTYPE_EE_TERMINATED:
logger.debug(f"monitor-{self._id} client received terminated")
break
Loading