Skip to content

Commit

Permalink
Improve everest_to_ert code
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Oct 7, 2024
1 parent 59bc774 commit 9c731a0
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 88 deletions.
13 changes: 13 additions & 0 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
overload,
)

from ert.config.ext_param_config import ExtParamConfig
from ert.field_utils import get_shape

from .field import Field
Expand Down Expand Up @@ -83,6 +84,8 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
gen_kw_list = config_dict.get(ConfigKeys.GEN_KW, [])
surface_list = config_dict.get(ConfigKeys.SURFACE, [])
field_list = config_dict.get(ConfigKeys.FIELD, [])
ext_param_dict = config_dict.get("EXT_PARAM", {})

dims = None
if grid_file_path is not None:
try:
Expand All @@ -106,10 +109,20 @@ def make_field(field_list: List[str]) -> Field:
)
return Field.from_config_list(grid_file_path, dims, field_list)

def make_ext_param(
control_name: str, variables: Union[List[str], Dict[str, List[str]]]
) -> ExtParamConfig:
return ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
+ [SurfaceConfig.from_config_list(s) for s in surface_list]
+ [make_field(f) for f in field_list]
+ [make_ext_param(n, e) for n, e in ext_param_dict.items()]
)

response_configs: List[ResponseConfig] = []
Expand Down
4 changes: 1 addition & 3 deletions src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class ExtParamConfig(ParameterConfig):
If a list of strings is given, the order is preserved.
"""

input_keys: Union[List[str], Dict[str, List[Tuple[str, str]]]] = field(
default_factory=list
)
input_keys: Union[List[str], Dict[str, List[str]]] = field(default_factory=list)
forward_init: bool = False
output_file: str = ""
forward_init_file: str = ""
Expand Down
49 changes: 15 additions & 34 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from ert.config import ErtConfig, ExtParamConfig, GenDataConfig
from ert.config import ErtConfig, ExtParamConfig

from .batch_simulator_context import BatchContext

Expand All @@ -16,8 +26,8 @@ class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
controls: Dict[str, List[str]],
results: List[str],
controls: Iterable[str],
results: Iterable[str],
callback: Optional[Callable[[BatchContext], None]] = None,
):
"""Will create simulator which can be used to run multiple simulations.
Expand Down Expand Up @@ -88,39 +98,10 @@ def callback(*args, **kwargs):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.control_keys = set(controls.keys())
self.control_keys = set(controls)
self.result_keys = set(results)
self.callback = callback

ens_config = self.ert_config.ensemble_config
for control_name, variables in controls.items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

if "gen_data" not in ens_config:
ens_config.addNode(
GenDataConfig(
keys=results,
input_files=[f"{k}" for k in results],
report_steps_list=[None for _ in results],
)
)
else:
existing_gendata = ens_config.response_configs["gen_data"]
existing_keys = existing_gendata.keys
assert isinstance(existing_gendata, GenDataConfig)

for key in results:
if key not in existing_keys:
existing_gendata.keys.append(key)
existing_gendata.input_files.append(f"{key}")
existing_gendata.report_steps_list.append(None)

def _setup_sim(
self,
sim_id: int,
Expand Down
50 changes: 49 additions & 1 deletion src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import json
import logging
import os
from typing import Union
from typing import DefaultDict, Dict, List, Union

import everest
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.config.install_data_config import InstallDataConfig
from everest.config.install_job_config import InstallJobConfig
from everest.config.simulator_config import SimulatorConfig
Expand Down Expand Up @@ -455,6 +459,48 @@ def _extract_seed(ever_config: EverestConfig, ert_config):
ert_config["RANDOM_SEED"] = random_seed


def _extract_results(ever_config: EverestConfig, ert_config):
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
gen_data = ert_config.get("GEN_DATA", [])
for name in objectives_names + constraint_names:
gen_data.append((name, f"RESULT_FILE:{name}"))
ert_config["GEN_DATA"] = gen_data


def _extract_controls(ever_config: EverestConfig, ert_config):
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = collections.defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result)

controls = ever_config.controls or []
ert_config["EXT_PARAM"] = {
control.name: _get_variables(control.variables) for control in controls
}


def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
"""
Takes as input an Everest configuration, the site-config and converts them
Expand All @@ -475,5 +521,7 @@ def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
_extract_model(ever_config, ert_config)
_extract_queue_system(ever_config, ert_config)
_extract_seed(ever_config, ert_config)
_extract_results(ever_config, ert_config)
_extract_controls(ever_config, ert_config)

return ert_config
77 changes: 27 additions & 50 deletions src/everest/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from datetime import datetime
from itertools import count
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple

import numpy as np
from numpy import float64
Expand All @@ -13,89 +13,66 @@
from ert.config import ErtConfig, HookRuntime
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.simulator.everest_to_ert import everest_to_ert_config


class Simulator(BatchSimulator):
"""Everest simulator: BatchSimulator"""

def __init__(self, ever_config: EverestConfig, callback=None):
def __init__(self, ever_config: EverestConfig, callback=None) -> None:
self._ert_config = ErtConfig.with_plugins().from_dict(
config_dict=everest_to_ert_config(
ever_config, site_config=ErtConfig.read_site_config()
)
)
controls_def = self._get_controls_def(ever_config)
results_def = self._get_results_def(ever_config)

super(Simulator, self).__init__(
self._ert_config, controls_def, results_def, callback=callback
self._ert_config,
self._get_controls(ever_config),
self._get_results(ever_config),
callback=callback,
)

self._function_aliases = self._get_aliases(ever_config)
self._experiment_id = None
self._batch = 0
self._cache: Optional[_SimulatorCache] = None
if ever_config.simulator is not None and ever_config.simulator.enable_cache:
self._cache = _SimulatorCache()

@staticmethod
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result) # { name : [ index ]

def _get_controls_def(
self, ever_config: EverestConfig
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
def _get_controls(self, ever_config: EverestConfig) -> List[str]:
controls = ever_config.controls or []
return {
control.name: self._get_variables(control.variables) for control in controls
}
return [control.name for control in controls]

def _get_results_def(self, ever_config: EverestConfig):
self._function_aliases = {
def _get_results(self, ever_config: EverestConfig) -> List[str]:
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]

constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names

def _get_aliases(self, ever_config: EverestConfig) -> Dict[str, str]:
aliases = {
objective.name: objective.alias
for objective in ever_config.objective_functions
if objective.alias is not None
}

constraints = ever_config.output_constraints or []
for constraint in constraints:
if (
constraint.upper_bound is not None
and constraint.lower_bound is not None
):
self._function_aliases[f"{constraint.name}:lower"] = constraint.name
self._function_aliases[f"{constraint.name}:upper"] = constraint.name

objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.name not in self._function_aliases
]
aliases[f"{constraint.name}:lower"] = constraint.name
aliases[f"{constraint.name}:upper"] = constraint.name

constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names
return aliases

def __call__(
self, control_values: NDArray[np.float64], metadata: EvaluatorContext
Expand Down

0 comments on commit 9c731a0

Please sign in to comment.