Skip to content

Commit

Permalink
Add option to run misfit_preprocessor with selected observations
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Feb 27, 2024
1 parent 5faa55b commit 8f24a80
Show file tree
Hide file tree
Showing 14 changed files with 553 additions and 70 deletions.
46 changes: 26 additions & 20 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import UserDict
from dataclasses import dataclass, field
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -31,6 +32,7 @@

from ert.config import Field, GenKwConfig, SurfaceConfig

from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
from . import misfit_preprocessor
from .event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
Expand Down Expand Up @@ -128,14 +130,6 @@ def __next__(self) -> T:
return result


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
misfit_preprocess: bool = False
min_required_realizations: int = 2


class TempStorage(UserDict): # type: ignore
def __getitem__(self, key: str) -> npt.NDArray[np.double]:
value: Union[npt.NDArray[np.double], xr.DataArray] = self.data[key]
Expand Down Expand Up @@ -350,14 +344,23 @@ def _get_obs_and_measure_data(
)


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
matches = []
for pattern in patterns:
matches.extend([val for val in input_list if fnmatch(val, pattern)])
return list(set(matches))


def _load_observations_and_responses(
source_fs: EnsembleReader,
alpha: float,
std_cutoff: float,
global_std_scaling: float,
iens_ative_index: npt.NDArray[np.int_],
selected_observations: List[Observation],
misfit_process: bool,
misfit_process: Optional[List[List[str]]],
update_step_name: str,
) -> Tuple[
npt.NDArray[np.float_],
Expand Down Expand Up @@ -388,9 +391,12 @@ def _load_observations_and_responses(
obs_mask = np.logical_and(ens_mean_mask, ens_std_mask)

if misfit_process:
scaling[obs_mask] *= misfit_preprocessor.main(
S[obs_mask], scaled_errors[obs_mask]
)
for input_group in misfit_process:
group = _expand_wildcards(obs_keys, input_group)
obs_group_mask = np.isin(obs_keys, group) & obs_mask
scaling[obs_group_mask] *= misfit_preprocessor.main(
S[obs_group_mask], scaled_errors[obs_group_mask]
)

update_snapshot = []
for (
Expand Down Expand Up @@ -608,7 +614,7 @@ def analysis_ES(
source_fs: EnsembleReader,
target_fs: EnsembleAccessor,
progress_callback: Callable[[AnalysisEvent], None],
misfit_process: bool,
misfit_process: Optional[List[List[str]]],
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

Expand Down Expand Up @@ -777,7 +783,7 @@ def analysis_IES(
target_fs: EnsembleAccessor,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
misfit_preprocessor: bool,
misfit_preprocessor: List[List[str]],
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
Expand Down Expand Up @@ -933,10 +939,10 @@ def _write_update_report(path: Path, snapshot: SmootherSnapshot, run_id: str) ->


def _assert_has_enough_realizations(
ens_mask: npt.NDArray[np.bool_], analysis_config: UpdateSettings
ens_mask: npt.NDArray[np.bool_], min_required_realizations: int
) -> None:
active_realizations = ens_mask.sum()
if active_realizations < analysis_config.min_required_realizations:
if active_realizations < min_required_realizations:
raise ErtAnalysisError(
f"There are {active_realizations} active realisations left, which is "
"less than the minimum specified - stopping assimilation.",
Expand Down Expand Up @@ -977,7 +983,7 @@ def smoother_update(
analysis_config = UpdateSettings() if analysis_config is None else analysis_config
es_settings = ESSettings() if es_settings is None else es_settings
ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, analysis_config)
_assert_has_enough_realizations(ens_mask, analysis_config.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -998,7 +1004,7 @@ def smoother_update(
prior_storage,
posterior_storage,
progress_callback,
analysis_config.misfit_preprocess,
analysis_config.auto_scale,
)

if log_path is not None:
Expand Down Expand Up @@ -1037,7 +1043,7 @@ def iterative_smoother_update(
)

ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, update_settings)
_assert_has_enough_realizations(ens_mask, update_settings.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -1058,7 +1064,7 @@ def iterative_smoother_update(
target_fs=posterior_storage,
sies_smoother=sies_smoother,
progress_callback=progress_callback,
misfit_preprocessor=update_settings.misfit_preprocess,
misfit_preprocessor=update_settings.auto_scale,
sies_step_length=sies_step_length,
initial_mask=initial_mask,
)
Expand Down
24 changes: 3 additions & 21 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from ert.analysis._es_update import UpdateSettings
from ert.cli import (
ENSEMBLE_EXPERIMENT_MODE,
ENSEMBLE_SMOOTHER_MODE,
Expand All @@ -14,7 +13,8 @@
ITERATIVE_ENSEMBLE_SMOOTHER_MODE,
TEST_RUN_MODE,
)
from ert.config import ConfigWarning, ErtConfig, HookRuntime
from ert.config import ConfigWarning, ErtConfig
from ert.config.analysis_config import UpdateSettings
from ert.run_models import (
BaseRunModel,
EnsembleExperiment,
Expand All @@ -35,23 +35,13 @@
from ert.validation import ActiveRange

if TYPE_CHECKING:
from typing import List

import numpy.typing as npt

from ert.config import Workflow
from ert.namespace import Namespace
from ert.storage import StorageAccessor


def _misfit_preprocessor(workflows: List[Workflow]) -> bool:
for workflow in workflows:
for job, _ in workflow:
if job.name == "MISFIT_PREPROCESSOR":
return True
return False


def create_model(
config: ErtConfig,
storage: StorageAccessor,
Expand All @@ -65,15 +55,7 @@ def create_model(
"ensemble_size": config.model_config.num_realizations,
},
)
ert_analysis_config = config.analysis_config
update_settings = UpdateSettings(
std_cutoff=ert_analysis_config.std_cutoff,
alpha=ert_analysis_config.enkf_alpha,
misfit_preprocess=_misfit_preprocessor(
config.hooked_workflows[HookRuntime.PRE_FIRST_UPDATE]
),
min_required_realizations=ert_analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings

if args.mode == TEST_RUN_MODE:
return _setup_single_test_run(config, storage, args)
Expand Down
34 changes: 29 additions & 5 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from math import ceil
from os.path import realpath
from pathlib import Path
Expand Down Expand Up @@ -43,6 +46,12 @@ def __init__(
self._min_realization = min_realization

options: Dict[str, Dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}}
observation_settings: Dict[str, Any] = {
"alpha": alpha,
"std_cutoff": std_cutoff,
"auto_scale": [],
"min_required_realizations": min_realization,
}
analysis_set_var = [] if analysis_set_var is None else analysis_set_var
inversion_str_map: Final = {
"STD_ENKF": {
Expand All @@ -64,6 +73,16 @@ def __init__(
all_errors = []

for module_name, var_name, value in analysis_set_var:
if module_name == "OBSERVATIONS":
if var_name == "AUTO_SCALE":
observation_settings["auto_scale"].append(value.split(","))
else:
all_errors.append(
ConfigValidationError(
f"Unknown module: {module_name} with variable: {var_name}"
)
)
continue
if var_name in deprecated_keys:
errors.append(var_name)
continue
Expand Down Expand Up @@ -110,6 +129,7 @@ def __init__(
try:
self.es_module = ESSettings(**options["STD_ENKF"])
self.ies_module = IESSettings(**options["IES_ENKF"])
self.observation_settings = UpdateSettings(**observation_settings)
except ValidationError as err:
for error in err.errors():
error["loc"] = tuple(
Expand Down Expand Up @@ -226,7 +246,6 @@ def __repr__(self) -> str:
return (
"AnalysisConfig("
f"alpha={self._alpha}, "
f"std_cutoff={self._std_cutoff}, "
f"stop_long_running={self._stop_long_running}, "
f"max_runtime={self._max_runtime}, "
f"min_realization={self._min_realization}, "
Expand All @@ -247,10 +266,7 @@ def __eq__(self, other: object) -> bool:
if self.stop_long_running != other.stop_long_running:
return False

if self.std_cutoff != other.std_cutoff:
return False

if self.enkf_alpha != other.enkf_alpha:
if self.observation_settings != other.observation_settings:
return False

if self.ies_module != other.ies_module:
Expand All @@ -265,3 +281,11 @@ def __eq__(self, other: object) -> bool:
if self.minimum_required_realizations != other.minimum_required_realizations:
return False
return True


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
auto_scale: List[List[str]] = field(default_factory=list)
min_required_realizations: int = 2
20 changes: 19 additions & 1 deletion src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,29 @@ def _workflows_from_dict(
filename = path.basename(work[0]) if len(work) == 1 else work[1]
try:
existed = filename in workflows
workflows[filename] = Workflow.from_file(
workflow = Workflow.from_file(
work[0],
substitution_list,
workflow_jobs,
)
for job, args in workflow:
if job.name == "MISFIT_PREPROCESSOR":
message = (
(
f"MISFIT_PREPROCESSOR is removed, use SCALE_OBSERVATIONS"
f"option instead for: {filename}, "
f"example: SCALE_OBSERVATIONS * -- all observations"
),
)
if args:
# This means the user has configured a config file to the workflow
# so we can assume they have customized the obs groups
message += (
"example: SCALE_OBSERVATIONS 'obs_*' -- all observations starting with obs_"
"Add multiple lines of SCALE_OBSERVATIONS to set up multiple groups"
)
errors.append(ErrorInfo(message=message).set_context(work[0]))
workflows[filename] = workflow
if existed:
ConfigWarning.ert_context_warn(
f"Workflow {filename!r} was added twice", work[0]
Expand Down
8 changes: 1 addition & 7 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from qtpy.QtWidgets import QApplication, QMessageBox

from ert.analysis import ErtAnalysisError, smoother_update
from ert.analysis._es_update import UpdateSettings
from ert.analysis.event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
from ert.enkf_main import EnKFMain, _seed_sequence
from ert.gui.ertnotifier import ErtNotifier
Expand Down Expand Up @@ -48,12 +47,7 @@ def run(self):
error: Optional[str] = None
config = self._ert.ert_config
rng = np.random.default_rng(_seed_sequence(config.random_seed))
update_settings = UpdateSettings(
std_cutoff=config.analysis_config.std_cutoff,
alpha=config.analysis_config.enkf_alpha,
misfit_preprocess=False,
min_required_realizations=config.analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings
try:
smoother_update(
self._source_fs,
Expand Down
9 changes: 1 addition & 8 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ert.shared.version import __version__
from ert.storage import EnsembleReader

from .analysis._es_update import UpdateSettings
from .enkf_main import EnKFMain, ensemble_context

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,18 +79,12 @@ def smoother_update(
) -> SmootherSnapshot:
if rng is None:
rng = np.random.default_rng()
analysis_config = UpdateSettings(
std_cutoff=self.config.analysis_config.std_cutoff,
alpha=self.config.analysis_config.enkf_alpha,
misfit_preprocess=misfit_process,
min_required_realizations=self.config.analysis_config.minimum_required_realizations,
)
update_snapshot = smoother_update(
prior_storage,
posterior_storage,
run_id,
self._enkf_main.update_configuration,
analysis_config,
self.config.analysis_config.observation_settings,
self.config.analysis_config.es_module,
rng,
progress_callback,
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESRunArguments
from ert.storage import StorageAccessor

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.run_models.run_arguments import SIESRunArguments
from ert.storage import EnsembleAccessor, StorageAccessor

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import IESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESMDARunArguments
from ert.storage import EnsembleAccessor, StorageAccessor

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ert.analysis import ErtAnalysisError, UpdateConfiguration, smoother_update
from ert.analysis._es_update import (
TempStorage,
UpdateSettings,
_create_temporary_parameter_storage,
)
from ert.analysis.configuration import UpdateStep
from ert.cli import ENSEMBLE_SMOOTHER_MODE
from ert.config import AnalysisConfig, ErtConfig, GenDataConfig, GenKwConfig
from ert.config.analysis_config import UpdateSettings
from ert.config.analysis_module import ESSettings
from ert.storage import open_storage
from ert.storage.realization_storage_state import RealizationStorageState
Expand Down
Loading

0 comments on commit 8f24a80

Please sign in to comment.