Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Localization #6349

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 116 additions & 18 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from iterative_ensemble_smoother.experimental import (
ensemble_smoother_update_step_row_scaling,
)
from tqdm import tqdm

from ert.config import Field, GenKwConfig, SurfaceConfig
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -370,6 +371,12 @@ def _load_observations_and_responses(
)


def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)


def analysis_ES(
updatestep: UpdateConfiguration,
obs: EnkfObs,
Expand Down Expand Up @@ -417,21 +424,17 @@ def analysis_ES(

# pylint: disable=unsupported-assignment-operation
smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:

num_obs = len(observation_values)
if num_obs == 0:
raise ErtAnalysisError(
f"No active observations for update step: {update_step.name}."
)
noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))

smoother = ies.ES()
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=module.get_truncation(),
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
truncation = module.get_truncation()
noise = rng.standard_normal(size=(num_obs, ensemble_size))

for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
Expand All @@ -441,15 +444,98 @@ def analysis_ES(
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
progress_callback(Progress(Task("Updating data", 2, 3), None))
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]

if module.localization():
Y_prime = S - S.mean(axis=1, keepdims=True)
C_YY = Y_prime @ Y_prime.T / (ensemble_size - 1)
Sigma_Y = np.diag(np.sqrt(np.diag(C_YY)))
batch_size: int = 1000
correlation_threshold = module.localization_correlation_threshold(
ensemble_size
)
# for parameter in update_step.parameters:
num_params = temp_storage[param_group.name].shape[0]

print(
(
f"Running localization on {num_params} parameters,",
f"{num_obs} responses and {ensemble_size} realizations...",
)
)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)
for param_batch_idx in tqdm(batches):
X_local = temp_storage[param_group.name][param_batch_idx, :]
A = X_local - X_local.mean(axis=1, keepdims=True)
C_AA = A @ A.T / (ensemble_size - 1)

# State-measurement covariance matrix
C_AY = A @ Y_prime.T / (ensemble_size - 1)
Sigma_A = np.diag(np.sqrt(np.diag(C_AA)))

# State-measurement correlation matrix
c_AY = np.abs(
np.linalg.inv(Sigma_A) @ C_AY @ np.linalg.inv(Sigma_Y)
)
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated
# to the exact same responses.
# We want to call the update only once per such parameter group
# to speed up computation.
# Here we create a collection of unique sets of parameter-to-observation
# correlations.
param_correlation_sets: npt.NDArray[np.bool_] = np.unique(
c_bool, axis=0
)
# Drop the correlation set that does not correlate to any responses.
row_with_all_false = np.all(~param_correlation_sets, axis=1)
param_correlation_sets = param_correlation_sets[~row_with_all_false]

for param_correlation_set in param_correlation_sets:
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == param_correlation_set, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[param_group.name][param_batch_idx, :][
row_indices, :
]
S_chunk = S[param_correlation_set, :]
observation_errors_loc = observation_errors[
param_correlation_set
]
observation_values_loc = observation_values[
param_correlation_set
]
smoother.fit(
S_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[param_correlation_set],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[param_group.name][
param_batch_idx[row_indices], :
] = smoother.update(X_chunk)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]
)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
)

if params_with_row_scaling := _get_params_with_row_scaling(
temp_storage, update_step.row_scaling_parameters
):
Expand All @@ -465,7 +551,19 @@ def analysis_ES(
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)
params_with_row_scaling = ensemble_smoother_update_step_row_scaling(
S,
params_with_row_scaling,
observation_errors,
observation_values,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)

progress_callback(Progress(Task("Storing data", 3, 3), None))
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
Expand Down
124 changes: 100 additions & 24 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import sys
from typing import TYPE_CHECKING, Dict, List, Type, TypedDict, Union

Expand Down Expand Up @@ -33,6 +34,23 @@ class VariableInfo(TypedDict):
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
# Default threshold is a function of ensemble size which is not available here.
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = -1


def correlation_threshold(ensemble_size: int, user_defined_threshold: float) -> float:
"""Decides whether or not to use user-defined or default threshold.

Default threshold taken from luo2022,
Continuous Hyper-parameter OPtimization (CHOP) in an ensemble Kalman filter
Section 2.3 - Localization in the CHOP problem
"""
default_threshold = 3 / math.sqrt(ensemble_size)
if user_defined_threshold == -1:
return default_threshold

return user_defined_threshold


class AnalysisMode(StrEnum):
Expand All @@ -58,6 +76,22 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]:
"step": 0.01,
"labelname": "Singular value truncation",
},
"LOCALIZATION": {
"type": bool,
"min": 0.0,
"value": DEFAULT_LOCALIZATION,
"max": 1.0,
"step": 1.0,
"labelname": "Adaptive localization",
},
"LOCALIZATION_CORRELATION_THRESHOLD": {
"type": float,
"min": 0.0,
"value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
"max": 1.0,
"step": 0.1,
"labelname": "Adaptive localization correlation threshold",
},
}
ies_variables: Dict[str, "VariableInfo"] = {
"IES_MAX_STEPLENGTH": {
Expand Down Expand Up @@ -169,31 +203,47 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]) -> None:
self.handle_special_key_set(var_name, value)
elif var_name in self._variables:
var = self._variables[var_name]
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"

if var["type"] is not bool:
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has "
f"incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
else:
if not isinstance(var["value"], bool):
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
# When config is first read, `value` is a string
# that's either "False" or "True",
# but since bool("False") is True we need to convert it to bool.
if not isinstance(value, bool):
value = str(value).lower() != "false"

var["value"] = var["type"](value)
else:
raise ConfigValidationError(
f"Variable {var_name!r} not found in {self.name!r} analysis module"
Expand All @@ -210,6 +260,32 @@ def inversion(self, value: int) -> None:
def get_truncation(self) -> float:
return self.get_variable_value("ENKF_TRUNCATION")

def localization(self) -> bool:
return bool(self.get_variable_value("LOCALIZATION"))

def localization_correlation_threshold(self, ensemble_size: int) -> float:
return correlation_threshold(
ensemble_size, self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
)

def get_steplength(self, iteration_nr: int) -> float:
"""
This is an implementation of Eq. (49), which calculates a suitable
step length for the update step, from the book:
Geir Evensen, Formulating the history matching problem with
consistent error statistics, Computational Geosciences (2021) 25:945 –970

Function not really used moved from C to keep the class interface consistent
should be investigated for possible removal.
"""
min_step_length = self.get_variable_value("IES_MIN_STEPLENGTH")
max_step_length = self.get_variable_value("IES_MAX_STEPLENGTH")
dec_step_length = self.get_variable_value("IES_DEC_STEPLENGTH")
step_length = min_step_length + (max_step_length - min_step_length) * pow(
2, -(iteration_nr - 1) / (dec_step_length - 1)
)
return step_length

def __repr__(self) -> str:
return f"AnalysisModule(name = {self.name})"

Expand Down
19 changes: 19 additions & 0 deletions src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QWidget,
)

from ert.config.analysis_module import correlation_threshold
from ert.gui.ertwidgets.models.analysismodulevariablesmodel import (
AnalysisModuleVariablesModel,
)
Expand Down Expand Up @@ -41,10 +42,16 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
variable_type = analysis_module_variables_model.getVariableType(
variable_name
)

variable_value = analysis_module_variables_model.getVariableValue(
self.facade, self._analysis_module_name, variable_name
)

if variable_name == "LOCALIZATION_CORRELATION_THRESHOLD":
variable_value = correlation_threshold(
self.facade.get_ensemble_size(), variable_value
)

label_name = analysis_module_variables_model.getVariableLabelName(
variable_name
)
Expand Down Expand Up @@ -123,6 +130,17 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
lambda value: self.update_truncation_spinners(value, truncation_spinner)
)

localization_checkbox = self.widget_from_layout(layout, "LOCALIZATION")
localization_correlation_spinner = self.widget_from_layout(
layout, "LOCALIZATION_CORRELATION_THRESHOLD"
)
localization_correlation_spinner.setEnabled(localization_checkbox.isChecked())
localization_checkbox.stateChanged.connect(
lambda localization_is_on: localization_correlation_spinner.setEnabled(True)
if localization_is_on
else localization_correlation_spinner.setEnabled(False)
)

self.setLayout(layout)
self.blockSignals(False)

Expand Down Expand Up @@ -172,6 +190,7 @@ def createSpinBox(
def createCheckBox(self, variable_name, variable_value, variable_type):
spinner = QCheckBox()
spinner.setChecked(variable_value)
spinner.setObjectName(variable_name)
spinner.clicked.connect(
partial(self.valueChanged, variable_name, variable_type, spinner)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
versions:
'2023.04-pre':
scalar:
executable: /Users/fedacuric/opm-simulators/build/bin/flow
Loading
Loading