Skip to content

Commit

Permalink
Add option to run misfit_preprocessor with selected observations
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Mar 14, 2024
1 parent 9bdea67 commit c930ed5
Show file tree
Hide file tree
Showing 16 changed files with 638 additions and 72 deletions.
52 changes: 29 additions & 23 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from fnmatch import fnmatch
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -30,6 +31,7 @@
)
from typing_extensions import Self

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
from . import misfit_preprocessor
from .event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
Expand Down Expand Up @@ -136,14 +138,6 @@ def __next__(self) -> T:
return result


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
misfit_preprocess: bool = False
min_required_realizations: int = 2


class TempStorage(UserDict): # type: ignore
def __getitem__(self, key: str) -> npt.NDArray[np.double]:
value: Union[npt.NDArray[np.double], xr.DataArray] = self.data[key]
Expand Down Expand Up @@ -252,14 +246,23 @@ def _get_obs_and_measure_data(
)


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
matches = []
for pattern in patterns:
matches.extend([val for val in input_list if fnmatch(val, pattern)])
return list(set(matches))


def _load_observations_and_responses(
ensemble: Ensemble,
alpha: float,
std_cutoff: float,
global_std_scaling: float,
iens_ative_index: npt.NDArray[np.int_],
selected_observations: Iterable[str],
misfit_process: bool,
auto_scale_observations: Optional[List[ObservationGroups]],
) -> Tuple[
npt.NDArray[np.float_],
Tuple[
Expand Down Expand Up @@ -288,10 +291,13 @@ def _load_observations_and_responses(
ens_mean_mask = abs(observations - ens_mean) <= alpha * (ens_std + scaled_errors)
obs_mask = np.logical_and(ens_mean_mask, ens_std_mask)

if misfit_process:
scaling[obs_mask] *= misfit_preprocessor.main(
S[obs_mask], scaled_errors[obs_mask]
)
if auto_scale_observations:
for input_group in auto_scale_observations:
group = _expand_wildcards(obs_keys, input_group)
obs_group_mask = np.isin(obs_keys, group) & obs_mask
scaling[obs_group_mask] *= misfit_preprocessor.main(
S[obs_group_mask], scaled_errors[obs_group_mask]
)

update_snapshot = []
for (
Expand Down Expand Up @@ -458,7 +464,7 @@ def analysis_ES(
source_ensemble: Ensemble,
target_ensemble: Ensemble,
progress_callback: Callable[[AnalysisEvent], None],
misfit_process: bool,
auto_scale_observations: Optional[List[ObservationGroups]],
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

Expand All @@ -484,7 +490,7 @@ def adaptive_localization_progress_callback(
global_scaling,
iens_active_index,
observations,
misfit_process,
auto_scale_observations,
)
num_obs = len(observation_values)

Expand Down Expand Up @@ -592,7 +598,7 @@ def analysis_IES(
target_ensemble: Ensemble,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
misfit_preprocessor: bool,
auto_scale_observations: List[ObservationGroups],
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
Expand All @@ -619,7 +625,7 @@ def analysis_IES(
1.0,
iens_active_index,
observations,
misfit_preprocessor,
auto_scale_observations,
)

smoother_snapshot.update_step_snapshots = update_snapshot
Expand Down Expand Up @@ -745,10 +751,10 @@ def _write_update_report(path: Path, snapshot: SmootherSnapshot, run_id: str) ->


def _assert_has_enough_realizations(
ens_mask: npt.NDArray[np.bool_], analysis_config: UpdateSettings
ens_mask: npt.NDArray[np.bool_], min_required_realizations: int
) -> None:
active_realizations = ens_mask.sum()
if active_realizations < analysis_config.min_required_realizations:
if active_realizations < min_required_realizations:
raise ErtAnalysisError(
f"There are {active_realizations} active realisations left, which is "
"less than the minimum specified - stopping assimilation.",
Expand Down Expand Up @@ -790,7 +796,7 @@ def smoother_update(
analysis_config = UpdateSettings() if analysis_config is None else analysis_config
es_settings = ESSettings() if es_settings is None else es_settings
ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, analysis_config)
_assert_has_enough_realizations(ens_mask, analysis_config.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -812,7 +818,7 @@ def smoother_update(
prior_storage,
posterior_storage,
progress_callback,
analysis_config.misfit_preprocess,
analysis_config.auto_scale_observations,
)

if log_path is not None:
Expand Down Expand Up @@ -847,7 +853,7 @@ def iterative_smoother_update(
rng = np.random.default_rng()

ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, update_settings)
_assert_has_enough_realizations(ens_mask, update_settings.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -869,7 +875,7 @@ def iterative_smoother_update(
target_ensemble=posterior_storage,
sies_smoother=sies_smoother,
progress_callback=progress_callback,
misfit_preprocessor=update_settings.misfit_preprocess,
auto_scale_observations=update_settings.auto_scale_observations,
sies_step_length=sies_step_length,
initial_mask=initial_mask,
)
Expand Down
24 changes: 3 additions & 21 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from ert.analysis._es_update import UpdateSettings
from ert.cli import (
ENSEMBLE_EXPERIMENT_MODE,
ENSEMBLE_SMOOTHER_MODE,
Expand All @@ -14,7 +13,8 @@
ITERATIVE_ENSEMBLE_SMOOTHER_MODE,
TEST_RUN_MODE,
)
from ert.config import ConfigWarning, ErtConfig, HookRuntime
from ert.config import ConfigWarning, ErtConfig
from ert.config.analysis_config import UpdateSettings
from ert.run_models import (
BaseRunModel,
EnsembleExperiment,
Expand All @@ -35,23 +35,13 @@
from ert.validation import ActiveRange

if TYPE_CHECKING:
from typing import List

import numpy.typing as npt

from ert.config import Workflow
from ert.namespace import Namespace
from ert.storage import Storage


def _misfit_preprocessor(workflows: List[Workflow]) -> bool:
for workflow in workflows:
for job, _ in workflow:
if job.name == "MISFIT_PREPROCESSOR":
return True
return False


def create_model(
config: ErtConfig,
storage: Storage,
Expand All @@ -65,15 +55,7 @@ def create_model(
"ensemble_size": config.model_config.num_realizations,
},
)
ert_analysis_config = config.analysis_config
update_settings = UpdateSettings(
std_cutoff=ert_analysis_config.std_cutoff,
alpha=ert_analysis_config.enkf_alpha,
misfit_preprocess=_misfit_preprocessor(
config.hooked_workflows[HookRuntime.PRE_FIRST_UPDATE]
),
min_required_realizations=ert_analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings

if args.mode == TEST_RUN_MODE:
return _setup_single_test_run(config, storage, args)
Expand Down
40 changes: 35 additions & 5 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from math import ceil
from os.path import realpath
from pathlib import Path
Expand Down Expand Up @@ -43,6 +46,12 @@ def __init__(
self._min_realization = min_realization

options: Dict[str, Dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}}
observation_settings: Dict[str, Any] = {
"alpha": alpha,
"std_cutoff": std_cutoff,
"auto_scale_observations": [],
"min_required_realizations": min_realization,
}
analysis_set_var = [] if analysis_set_var is None else analysis_set_var
inversion_str_map: Final = {
"STD_ENKF": {
Expand All @@ -64,6 +73,19 @@ def __init__(
all_errors = []

for module_name, var_name, value in analysis_set_var:
if module_name == "OBSERVATIONS":
if var_name == "AUTO_SCALE":
observation_settings["auto_scale_observations"].append(
value.split(",")
)
else:
all_errors.append(
ConfigValidationError(
f"Unknown variable: {var_name} for: ANALYSIS_SET_VAR OBSERVATIONS {var_name}"
"Valid options: AUTO_SCALE"
)
)
continue
if var_name in deprecated_keys:
errors.append(var_name)
continue
Expand Down Expand Up @@ -110,6 +132,7 @@ def __init__(
try:
self.es_module = ESSettings(**options["STD_ENKF"])
self.ies_module = IESSettings(**options["IES_ENKF"])
self.observation_settings = UpdateSettings(**observation_settings)
except ValidationError as err:
for error in err.errors():
error["loc"] = tuple(
Expand Down Expand Up @@ -226,7 +249,6 @@ def __repr__(self) -> str:
return (
"AnalysisConfig("
f"alpha={self._alpha}, "
f"std_cutoff={self._std_cutoff}, "
f"stop_long_running={self._stop_long_running}, "
f"max_runtime={self._max_runtime}, "
f"min_realization={self._min_realization}, "
Expand All @@ -247,10 +269,7 @@ def __eq__(self, other: object) -> bool:
if self.stop_long_running != other.stop_long_running:
return False

if self.std_cutoff != other.std_cutoff:
return False

if self.enkf_alpha != other.enkf_alpha:
if self.observation_settings != other.observation_settings:
return False

if self.ies_module != other.ies_module:
Expand All @@ -265,3 +284,14 @@ def __eq__(self, other: object) -> bool:
if self.minimum_required_realizations != other.minimum_required_realizations:
return False
return True


ObservationGroups = List[str]


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
auto_scale_observations: List[ObservationGroups] = field(default_factory=list)
min_required_realizations: int = 2
8 changes: 1 addition & 7 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from qtpy.QtWidgets import QApplication, QMessageBox

from ert.analysis import ErtAnalysisError, smoother_update
from ert.analysis._es_update import UpdateSettings
from ert.analysis.event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
from ert.enkf_main import EnKFMain, _seed_sequence
from ert.gui.ertnotifier import ErtNotifier
Expand Down Expand Up @@ -48,12 +47,7 @@ def run(self):
error: Optional[str] = None
config = self._ert.ert_config
rng = np.random.default_rng(_seed_sequence(config.random_seed))
update_settings = UpdateSettings(
std_cutoff=config.analysis_config.std_cutoff,
alpha=config.analysis_config.enkf_alpha,
misfit_preprocess=False,
min_required_realizations=config.analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings
try:
smoother_update(
self._source_fs,
Expand Down
9 changes: 1 addition & 8 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ert.shared.version import __version__
from ert.storage import Ensemble

from .analysis._es_update import UpdateSettings
from .enkf_main import EnKFMain, ensemble_context

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,19 +91,13 @@ def smoother_update(
) -> SmootherSnapshot:
if rng is None:
rng = np.random.default_rng()
analysis_config = UpdateSettings(
std_cutoff=self.config.analysis_config.std_cutoff,
alpha=self.config.analysis_config.enkf_alpha,
misfit_preprocess=misfit_process,
min_required_realizations=self.config.analysis_config.minimum_required_realizations,
)
update_snapshot = smoother_update(
prior_storage,
posterior_storage,
run_id,
observations,
parameters,
analysis_config,
self.config.analysis_config.observation_settings,
self.config.analysis_config.es_module,
rng,
progress_callback,
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESRunArguments
from ert.storage import Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.run_models.run_arguments import SIESRunArguments
from ert.storage import Ensemble, Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import IESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESMDARunArguments
from ert.storage import Ensemble, Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
5 changes: 5 additions & 0 deletions src/ert/shared/hook_implementations/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .disable_parameters import DisableParametersUpdate
from .export_misfit_data import ExportMisfitDataJob
from .export_runpath import ExportRunpathJob
from .misfit_preprocessor import MisfitPreprocessor


@hook_implementation
Expand All @@ -18,3 +19,7 @@ def legacy_ertscript_workflow(config):

workflow = config.add_workflow(DisableParametersUpdate, "DISABLE_PARAMETERS")
workflow.description = DisableParametersUpdate.__doc__

workflow = config.add_workflow(MisfitPreprocessor, "MISFIT_PREPROCESSOR")
workflow.description = MisfitPreprocessor.__doc__
workflow.category = "observations.correlation"
Loading

0 comments on commit c930ed5

Please sign in to comment.