Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misfit config so it is possible to configure a subset of observations #7263

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import time
from collections import UserDict, defaultdict
from dataclasses import dataclass
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -29,6 +29,7 @@
)
from typing_extensions import Self

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
from . import misfit_preprocessor
from .event import (
Expand Down Expand Up @@ -95,14 +96,6 @@ def __next__(self) -> T:
return result


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
misfit_preprocess: bool = False
min_required_realizations: int = 2


class TempStorage(UserDict): # type: ignore
def __getitem__(self, key: str) -> npt.NDArray[np.double]:
value: Union[npt.NDArray[np.double], xr.DataArray] = self.data[key]
Expand Down Expand Up @@ -211,14 +204,23 @@ def _get_obs_and_measure_data(
)


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
matches = []
for pattern in patterns:
matches.extend([val for val in input_list if fnmatch(val, pattern)])
return list(set(matches))


def _load_observations_and_responses(
ensemble: Ensemble,
alpha: float,
std_cutoff: float,
global_std_scaling: float,
iens_active_index: npt.NDArray[np.int_],
selected_observations: Iterable[str],
misfit_process: bool,
auto_scale_observations: Optional[List[ObservationGroups]],
) -> Tuple[
npt.NDArray[np.float_],
Tuple[
Expand Down Expand Up @@ -247,10 +249,13 @@ def _load_observations_and_responses(
ens_mean_mask = abs(observations - ens_mean) <= alpha * (ens_std + scaled_errors)
obs_mask = np.logical_and(ens_mean_mask, ens_std_mask)

if misfit_process:
scaling[obs_mask] *= misfit_preprocessor.main(
S[obs_mask], scaled_errors[obs_mask]
)
if auto_scale_observations:
for input_group in auto_scale_observations:
group = _expand_wildcards(obs_keys, input_group)
sondreso marked this conversation as resolved.
Show resolved Hide resolved
obs_group_mask = np.isin(obs_keys, group) & obs_mask
scaling[obs_group_mask] *= misfit_preprocessor.main(
S[obs_group_mask], scaled_errors[obs_group_mask]
)

update_snapshot = []
for (
Expand Down Expand Up @@ -417,7 +422,7 @@ def analysis_ES(
source_ensemble: Ensemble,
target_ensemble: Ensemble,
progress_callback: Callable[[AnalysisEvent], None],
misfit_process: bool,
auto_scale_observations: Optional[List[ObservationGroups]],
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

Expand All @@ -443,7 +448,7 @@ def adaptive_localization_progress_callback(
global_scaling,
iens_active_index,
observations,
misfit_process,
auto_scale_observations,
)
num_obs = len(observation_values)

Expand Down Expand Up @@ -556,7 +561,7 @@ def analysis_IES(
target_ensemble: Ensemble,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
misfit_preprocessor: bool,
auto_scale_observations: List[ObservationGroups],
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
Expand All @@ -583,7 +588,7 @@ def analysis_IES(
1.0,
iens_active_index,
observations,
misfit_preprocessor,
auto_scale_observations,
)

smoother_snapshot.update_step_snapshots = update_snapshot
Expand Down Expand Up @@ -721,10 +726,10 @@ def _write_update_report(


def _assert_has_enough_realizations(
ens_mask: npt.NDArray[np.bool_], analysis_config: UpdateSettings
ens_mask: npt.NDArray[np.bool_], min_required_realizations: int
) -> None:
active_realizations = ens_mask.sum()
if active_realizations < analysis_config.min_required_realizations:
if active_realizations < min_required_realizations:
raise ErtAnalysisError(
f"There are {active_realizations} active realisations left, which is "
"less than the minimum specified - stopping assimilation.",
Expand Down Expand Up @@ -767,7 +772,7 @@ def smoother_update(
analysis_config = UpdateSettings() if analysis_config is None else analysis_config
es_settings = ESSettings() if es_settings is None else es_settings
ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, analysis_config)
_assert_has_enough_realizations(ens_mask, analysis_config.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -790,7 +795,7 @@ def smoother_update(
prior_storage,
posterior_storage,
progress_callback,
analysis_config.misfit_preprocess,
analysis_config.auto_scale_observations,
)
except Exception as e:
raise e
Expand Down Expand Up @@ -828,7 +833,7 @@ def iterative_smoother_update(
rng = np.random.default_rng()

ens_mask = prior_storage.get_realization_mask_with_responses()
_assert_has_enough_realizations(ens_mask, update_settings)
_assert_has_enough_realizations(ens_mask, update_settings.min_required_realizations)

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
Expand All @@ -851,7 +856,7 @@ def iterative_smoother_update(
target_ensemble=posterior_storage,
sies_smoother=sies_smoother,
progress_callback=progress_callback,
misfit_preprocessor=update_settings.misfit_preprocess,
auto_scale_observations=update_settings.auto_scale_observations,
sies_step_length=sies_step_length,
initial_mask=initial_mask,
)
Expand Down
24 changes: 3 additions & 21 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from ert.analysis._es_update import UpdateSettings
from ert.cli import (
ENSEMBLE_EXPERIMENT_MODE,
ENSEMBLE_SMOOTHER_MODE,
Expand All @@ -14,7 +13,8 @@
ITERATIVE_ENSEMBLE_SMOOTHER_MODE,
TEST_RUN_MODE,
)
from ert.config import ConfigWarning, ErtConfig, HookRuntime
from ert.config import ConfigWarning, ErtConfig
from ert.config.analysis_config import UpdateSettings
from ert.run_models import (
BaseRunModel,
EnsembleExperiment,
Expand All @@ -35,23 +35,13 @@
from ert.validation import ActiveRange

if TYPE_CHECKING:
from typing import List

import numpy.typing as npt

from ert.config import Workflow
from ert.namespace import Namespace
from ert.storage import Storage


def _misfit_preprocessor(workflows: List[Workflow]) -> bool:
for workflow in workflows:
for job, _ in workflow:
if job.name == "MISFIT_PREPROCESSOR":
return True
return False


def create_model(
config: ErtConfig,
storage: Storage,
Expand All @@ -65,15 +55,7 @@ def create_model(
"ensemble_size": config.model_config.num_realizations,
},
)
ert_analysis_config = config.analysis_config
update_settings = UpdateSettings(
std_cutoff=ert_analysis_config.std_cutoff,
alpha=ert_analysis_config.enkf_alpha,
misfit_preprocess=_misfit_preprocessor(
config.hooked_workflows[HookRuntime.PRE_FIRST_UPDATE]
),
min_required_realizations=ert_analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings

if args.mode == TEST_RUN_MODE:
return _setup_single_test_run(config, storage, args)
Expand Down
47 changes: 41 additions & 6 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from math import ceil
from os.path import realpath
from pathlib import Path
Expand All @@ -19,6 +22,15 @@
logger = logging.getLogger(__name__)

DEFAULT_ANALYSIS_MODE = AnalysisMode.ENSEMBLE_SMOOTHER
ObservationGroups = List[str]


@dataclass
class UpdateSettings:
std_cutoff: float = 1e-6
alpha: float = 3.0
auto_scale_observations: List[ObservationGroups] = field(default_factory=list)
min_required_realizations: int = 2


class AnalysisConfig:
Expand All @@ -43,6 +55,12 @@ def __init__(
self._min_realization = min_realization

options: Dict[str, Dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}}
observation_settings: Dict[str, Any] = {
"alpha": alpha,
"std_cutoff": std_cutoff,
"auto_scale_observations": [],
"min_required_realizations": min_realization,
}
analysis_set_var = [] if analysis_set_var is None else analysis_set_var
inversion_str_map: Final = {
"STD_ENKF": {
Expand All @@ -64,6 +82,19 @@ def __init__(
all_errors = []

for module_name, var_name, value in analysis_set_var:
if module_name == "OBSERVATIONS":
if var_name == "AUTO_SCALE":
observation_settings["auto_scale_observations"].append(
value.split(",")
)
else:
all_errors.append(
ConfigValidationError(
f"Unknown variable: {var_name} for: ANALYSIS_SET_VAR OBSERVATIONS {var_name}"
"Valid options: AUTO_SCALE"
)
)
continue
if var_name in deprecated_keys:
errors.append(var_name)
continue
Expand Down Expand Up @@ -95,7 +126,14 @@ def __init__(

var_name = "inversion"
key = var_name.lower()
options[module_name][key] = value
try:
options[module_name][key] = value
except KeyError:
all_errors.append(
ConfigValidationError(
f"Invalid configuration: ANALYSIS_SET_VAR {module_name} {var_name}"
)
)

if errors:
all_errors.append(
Expand All @@ -110,6 +148,7 @@ def __init__(
try:
self.es_module = ESSettings(**options["STD_ENKF"])
self.ies_module = IESSettings(**options["IES_ENKF"])
self.observation_settings = UpdateSettings(**observation_settings)
except ValidationError as err:
for error in err.errors():
error["loc"] = tuple(
Expand Down Expand Up @@ -226,7 +265,6 @@ def __repr__(self) -> str:
return (
"AnalysisConfig("
f"alpha={self._alpha}, "
f"std_cutoff={self._std_cutoff}, "
f"stop_long_running={self._stop_long_running}, "
f"max_runtime={self._max_runtime}, "
f"min_realization={self._min_realization}, "
Expand All @@ -247,10 +285,7 @@ def __eq__(self, other: object) -> bool:
if self.stop_long_running != other.stop_long_running:
return False

if self.std_cutoff != other.std_cutoff:
return False

if self.enkf_alpha != other.enkf_alpha:
if self.observation_settings != other.observation_settings:
sondreso marked this conversation as resolved.
Show resolved Hide resolved
return False

if self.ies_module != other.ies_module:
Expand Down
8 changes: 1 addition & 7 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from qtpy.QtWidgets import QApplication, QMessageBox

from ert.analysis import ErtAnalysisError, smoother_update
from ert.analysis._es_update import UpdateSettings
from ert.analysis.event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
from ert.enkf_main import EnKFMain, _seed_sequence
from ert.gui.ertnotifier import ErtNotifier
Expand Down Expand Up @@ -48,12 +47,7 @@ def run(self):
error: Optional[str] = None
config = self._ert.ert_config
rng = np.random.default_rng(_seed_sequence(config.random_seed))
update_settings = UpdateSettings(
std_cutoff=config.analysis_config.std_cutoff,
alpha=config.analysis_config.enkf_alpha,
misfit_preprocess=False,
min_required_realizations=config.analysis_config.minimum_required_realizations,
)
update_settings = config.analysis_config.observation_settings
try:
smoother_update(
self._source_ensemble,
Expand Down
9 changes: 1 addition & 8 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ert.shared.version import __version__
from ert.storage import Ensemble

from .analysis._es_update import UpdateSettings
from .enkf_main import EnKFMain, ensemble_context

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,19 +91,13 @@ def smoother_update(
) -> SmootherSnapshot:
if rng is None:
rng = np.random.default_rng()
analysis_config = UpdateSettings(
std_cutoff=self.config.analysis_config.std_cutoff,
alpha=self.config.analysis_config.enkf_alpha,
misfit_preprocess=misfit_process,
min_required_realizations=self.config.analysis_config.minimum_required_realizations,
)
update_snapshot = smoother_update(
prior_storage,
posterior_storage,
run_id,
observations,
parameters,
analysis_config,
self.config.analysis_config.observation_settings,
self.config.analysis_config.es_module,
rng,
progress_callback,
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESRunArguments
from ert.storage import Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.run_models.run_arguments import SIESRunArguments
from ert.storage import Ensemble, Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import IESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.run_models.run_arguments import ESMDARunArguments
from ert.storage import Ensemble, Storage

from ..analysis._es_update import UpdateSettings
from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
from .base_run_model import BaseRunModel, ErtRunError
from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent
Expand Down
Loading
Loading