Skip to content

Commit

Permalink
fix the subprocess thing almost
Browse files Browse the repository at this point in the history
  • Loading branch information
valentin-krasontovitsch committed Aug 25, 2023
1 parent 66a13c0 commit 350b9ca
Show file tree
Hide file tree
Showing 20 changed files with 245 additions and 142 deletions.
75 changes: 46 additions & 29 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,42 @@
import logging
import time
from pathlib import Path
from typing import Callable, Iterable, Tuple
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Tuple

from ert.config import EnsembleConfig, ParameterConfig, SummaryConfig
from ert.config import ParameterConfig, ResponseConfig, SummaryConfig
from ert.run_arg import RunArg

from .load_status import LoadResult, LoadStatus
from .realization_state import RealizationState

CallbackArgs = Tuple[RunArg, EnsembleConfig]
Callback = Callable[[RunArg, EnsembleConfig], LoadResult]
if TYPE_CHECKING:
from ert.storage import EnsembleAccessor

CallbackArgs = Tuple[RunArg, Mapping[str, ResponseConfig]]
Callback = Callable[[RunArg, Mapping[str, ResponseConfig]], LoadResult]

logger = logging.getLogger(__name__)


def _read_parameters(
run_arg: RunArg, parameter_configuration: Iterable[ParameterConfig]
runpath: str,
iens: int,
parameter_configurations: Iterable[ParameterConfig],
ensemble_storage: EnsembleAccessor,
) -> LoadResult:
result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
error_msg = ""
for config_node in parameter_configuration:
for config_node in parameter_configurations:
if not config_node.forward_init:
continue
try:
start_time = time.perf_counter()
logger.info(f"Starting to load parameter: {config_node.name}")
ds = config_node.read_from_runpath(Path(run_arg.runpath), run_arg.iens)
run_arg.ensemble_storage.save_parameters(config_node.name, run_arg.iens, ds)
logger.info(
print(f"Starting to load parameter: {config_node.name}")
ds = config_node.read_from_runpath(Path(runpath), iens)
ensemble_storage.save_parameters(config_node.name, iens, ds)
print(
f"Saved {config_node.name} to storage",
extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"},
{"Time": f"{(time.perf_counter() - start_time):.4f}s"},
)
except ValueError as err:
error_msg += str(err)
Expand All @@ -41,20 +47,22 @@ def _read_parameters(


def _write_responses_to_storage(
ens_config: EnsembleConfig, run_arg: RunArg
response_configs: Mapping[str, ResponseConfig],
runpath: str,
iens: int,
ensemble_storage: EnsembleAccessor,
) -> LoadResult:
errors = []
for config in ens_config.response_configs.values():
if isinstance(config, SummaryConfig):
for config in response_configs.values():
if isinstance(config, SummaryConfig) and not config.keys:
# Nothing to load, should not be handled here, should never be
# added in the first place
if not config.keys:
continue
continue
try:
start_time = time.perf_counter()
logger.info(f"Starting to load response: {config.name}")
ds = config.read_from_file(run_arg.runpath, run_arg.iens)
run_arg.ensemble_storage.save_response(config.name, ds, run_arg.iens)
ds = config.read_from_file(runpath, iens)
ensemble_storage.save_response(config.name, ds, iens)
logger.info(
f"Saved {config.name} to storage",
extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"},
Expand All @@ -68,7 +76,8 @@ def _write_responses_to_storage(

def forward_model_ok(
run_arg: RunArg,
ens_conf: EnsembleConfig,
response_configs: Mapping[str, ResponseConfig],
update_state_map: bool = True,
) -> LoadResult:
parameters_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
Expand All @@ -77,15 +86,22 @@ def forward_model_ok(
# handles parameters
if run_arg.itr == 0:
parameters_result = _read_parameters(
run_arg,
run_arg.runpath,
run_arg.iens,
run_arg.ensemble_storage.experiment.parameter_configuration.values(),
run_arg.ensemble_storage,
)

if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL:
response_result = _write_responses_to_storage(ens_conf, run_arg)
response_result = _write_responses_to_storage(
response_configs,
run_arg.runpath,
run_arg.iens,
run_arg.ensemble_storage,
)

except Exception as err:
logging.exception(f"Failed to load results for realization {run_arg.iens}")
except BaseException as err: # pylint: disable=broad-exception-caught
logger.exception(f"Failed to load results for realization {run_arg.iens}")
parameters_result = LoadResult(
LoadStatus.LOAD_FAILURE,
"Failed to load results for realization "
Expand All @@ -96,15 +112,16 @@ def forward_model_ok(
if response_result.status != LoadStatus.LOAD_SUCCESSFUL:
final_result = response_result

run_arg.ensemble_storage.state_map[run_arg.iens] = (
RealizationState.HAS_DATA
if final_result.status == LoadStatus.LOAD_SUCCESSFUL
else RealizationState.LOAD_FAILURE
)
if update_state_map:
run_arg.ensemble_storage.state_map[run_arg.iens] = (
RealizationState.HAS_DATA
if final_result.status == LoadStatus.LOAD_SUCCESSFUL
else RealizationState.LOAD_FAILURE
)

return final_result


def forward_model_exit(run_arg: RunArg, _: EnsembleConfig) -> LoadResult:
def forward_model_exit(run_arg: RunArg, _: Mapping[str, ResponseConfig]) -> LoadResult:
run_arg.ensemble_storage.state_map[run_arg.iens] = RealizationState.LOAD_FAILURE
return LoadResult(None, "")
2 changes: 2 additions & 0 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .parsing import ConfigValidationError, ConfigWarning
from .queue_config import QueueConfig
from .queue_system import QueueSystem
from .response_config import ResponseConfig
from .summary_config import SummaryConfig
from .summary_observation import SummaryObservation
from .surface_config import SurfaceConfig
Expand Down Expand Up @@ -50,6 +51,7 @@
"ModelConfig",
"ParameterConfig",
"PriorDict",
"ResponseConfig",
"QueueConfig",
"QueueSystem",
"SummaryConfig",
Expand Down
5 changes: 3 additions & 2 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_abs_path(file: Optional[str]) -> Optional[str]:

class EnsembleConfig:
@staticmethod
def _load_refcase(refcase_file: Optional[str]) -> Optional[EclSum]:
def load_refcase(refcase_file: Optional[str]) -> Optional[EclSum]:
if refcase_file is None:
return None

Expand Down Expand Up @@ -101,9 +101,10 @@ def __init__( # noqa: 501 pylint: disable=too-many-arguments, too-many-branches

self._grid_file = _get_abs_path(grid_file)
self._refcase_file = _get_abs_path(ref_case_file)
self.refcase: Optional[EclSum] = self._load_refcase(self._refcase_file)
self.refcase: Optional[EclSum] = self.load_refcase(self._refcase_file)
self.parameter_configs: Dict[str, ParameterConfig] = {}
self.response_configs: Dict[str, ResponseConfig] = {}
self._ecl_base = ecl_base.replace("%d", "<IENS>") if ecl_base else None

for gene_data in _gen_data_list:
self.addNode(self.gen_data_node(gene_data))
Expand Down
2 changes: 2 additions & 0 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def from_dict(cls, config_dict) -> Self:
if errors:
raise ConfigValidationError.from_collected(errors)

# transform all values in dict to regular types instead of context types

try:
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
except ConfigValidationError as err:
Expand Down
9 changes: 7 additions & 2 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import filecmp
import logging
import math
import os
import shutil
from dataclasses import dataclass
from hashlib import sha256
Expand Down Expand Up @@ -163,7 +165,7 @@ def transform(self, array: npt.ArrayLike) -> npt.NDArray[np.float64]:
def _values_from_file(
realization: int, name_format: str, keys: List[str]
) -> npt.NDArray[np.double]:
file_name = name_format % realization
file_name = name_format % realization # noqa: S001
df = pd.read_csv(file_name, delim_whitespace=True, header=None)
# This means we have a key: value mapping in the
# file otherwise it is just a list of values
Expand Down Expand Up @@ -258,7 +260,10 @@ def save_experiment_data(self, experiment_path: Path) -> None:
self.template_file_path = Path(
experiment_path / incoming_template_file_path.name
)
shutil.copyfile(incoming_template_file_path, self.template_file_path)
if not os.path.exists(self.template_file_path) or not filecmp.cmp(
incoming_template_file_path, self.template_file_path
):
shutil.copyfile(incoming_template_file_path, self.template_file_path)


@dataclass
Expand Down
17 changes: 16 additions & 1 deletion src/ert/config/parsing/context_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, TypeVar, Union, no_type_check
from typing import List, Tuple, Type, TypeVar, Union, no_type_check

from .file_context_token import FileContextToken

Expand Down Expand Up @@ -63,6 +63,21 @@ class ContextString(str):
def from_token(cls, token: FileContextToken) -> "ContextString":
return cls(val=str(token), token=token, keyword_token=token)

# def __getnewargs__(self):
# return (self.__str__(), self.token, self.keyword_token)

def __reduce__(self) -> Tuple[Type[str], Tuple[str]]:
return (str, (self.__str__(),))

# def __getstate__(self):
# print("I'm being pickled")
# print(f"{self.__dict__}")
# return self.__dict__

# def __setstate__(self, d):
# print("I'm being unpickled with these values: " + repr(d))
# self.__dict__ = d

def __new__(
cls, val: str, token: FileContextToken, keyword_token: FileContextToken
) -> "ContextString":
Expand Down
18 changes: 17 additions & 1 deletion src/ert/config/parsing/file_context_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, cast
from typing import List, Tuple, Type, cast

from lark import Token

Expand Down Expand Up @@ -28,6 +28,22 @@ def __new__(cls, token: Token, filename: str) -> "FileContextToken":
inst_fct.filename = filename
return inst_fct

def __reduce__(
self,
) -> Tuple[Type["FileContextToken"], Tuple[Token, str]]:
# return (str, (self.__str__(),))
token = Token(
self.type,
self.value,
self.start_pos,
self.line,
self.column,
self.end_line,
self.end_column,
self.end_pos,
)
return (self.__class__, (token, self.filename))

def __repr__(self) -> str:
return f"{self.value!r}"

Expand Down
10 changes: 9 additions & 1 deletion src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from dataclasses import dataclass
from fnmatch import fnmatch
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple, Type

import numpy as np
import xarray as xr
Expand All @@ -25,6 +25,14 @@ class SummaryConfig(ResponseConfig):
keys: List[str]
refcase: Optional[EclSum] = None

def __reduce__(
self,
) -> Tuple[Type["SummaryConfig"], Tuple[str, str, List[str], None]]:
# we have to pass `name`, which is an argument to ResponseConfig's __new__ -
# this is very implicit and hidden and was hard to find. TODO can we make this
# more explicit?
return (self.__class__, (self.name, self.input_file, self.keys, None))

def __eq__(self, other: object) -> bool:
if not isinstance(other, SummaryConfig):
return False
Expand Down
7 changes: 4 additions & 3 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Expand All @@ -32,15 +33,15 @@

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.config import AnalysisConfig, EnsembleConfig, QueueConfig
from ert.config import AnalysisConfig, QueueConfig, ResponseConfig
from ert.run_arg import RunArg

from ..config import EvaluatorServerConfig
from ._realization import Realization

MsgType = TypeVar("MsgType", CloudEvent, _ert_com_protocol.DispatcherMessage)

CONCURRENT_INTERNALIZATION = 10
CONCURRENT_INTERNALIZATION = 20

logger = logging.getLogger(__name__)
event_logger = logging.getLogger("ert.event_log")
Expand Down Expand Up @@ -138,7 +139,7 @@ def setup_timeout_callback(
cloudevent_unary_send: Callable[[MsgType], Awaitable[None]],
event_generator: Callable[[str, Optional[int]], MsgType],
) -> Tuple[Callback, asyncio.Task[None]]:
def on_timeout(run_args: RunArg, _: EnsembleConfig) -> LoadResult:
def on_timeout(run_args: RunArg, _: Mapping[str, ResponseConfig]) -> LoadResult:
timeout_queue.put_nowait(
event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, run_args.iens)
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/job_queue/job_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .queue import JobQueue


CONCURRENT_INTERNALIZATION = 1
CONCURRENT_INTERNALIZATION = 10


# TODO: there's no need for this class, all the behavior belongs in the queue
Expand Down
Loading

0 comments on commit 350b9ca

Please sign in to comment.