Skip to content

Commit

Permalink
Implement adaptive localization
Browse files Browse the repository at this point in the history
Add option of running adaptive localization that can simply
be turned on and does not need any user input.
Only parameters that are significantly correlated to responses
will be updated.
Default value of what constitutes significant correlation is calculated
based on theory, but can be set by the user.
  • Loading branch information
dafeda committed Sep 21, 2023
1 parent b0b238e commit b6f14fe
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 48 deletions.
140 changes: 116 additions & 24 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from iterative_ensemble_smoother.experimental import (
ensemble_smoother_update_step_row_scaling,
)
from tqdm import tqdm

from ert.config import Field, GenKwConfig, SurfaceConfig
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -381,6 +382,12 @@ def _load_observations_and_responses(
)


def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)


def analysis_ES(
updatestep: "UpdateConfiguration",
obs: EnkfObs,
Expand All @@ -399,9 +406,13 @@ def analysis_ES(
iens_active_index = [i for i in range(len(ens_mask)) if ens_mask[i]]

progress_callback(Progress(Task("Loading data", 1, 3), None))
start = time.time()
temp_storage = _create_temporary_parameter_storage(
source_fs, ensemble_config, iens_active_index
)
end = time.time()
elapsed = end - start
print(f"Time to run _create_temporary_parameter_storage: {elapsed}")

ensemble_size = sum(ens_mask)
param_ensemble = _param_ensemble_for_projection(
Expand All @@ -410,6 +421,8 @@ def analysis_ES(

progress_callback(Progress(Task("Updating data", 2, 3), None))
for update_step in updatestep:
print("Loading responses and observations...")
start = time.time()
try:
S, (
observation_values,
Expand All @@ -426,35 +439,108 @@ def analysis_ES(
)
except IndexError as e:
raise ErtAnalysisError(e) from e
end = time.time()
elapsed = end - start
print(f"Time to run _load_observations_and_responses: {elapsed}")

# pylint: disable=unsupported-assignment-operation
smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:

num_obs = len(observation_values)
if num_obs == 0:
raise ErtAnalysisError(
f"No active observations for update step: {update_step.name}."
)

noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))

smoother = ies.ES()
for parameter in update_step.parameters:
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=module.get_truncation(),
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
if active_indices := parameter.index_list:
temp_storage[parameter.name][active_indices, :] = smoother.update(
temp_storage[parameter.name][active_indices, :]
truncation = module.get_truncation()
noise = rng.standard_normal(size=(num_obs, ensemble_size))

if module.localization():
Y_prime = S - S.mean(axis=1, keepdims=True)
C_YY = Y_prime @ Y_prime.T / (ensemble_size - 1)
Sigma_Y = np.diag(np.sqrt(np.diag(C_YY)))
batch_size: int = 1000
for parameter in update_step.parameters:
num_params = temp_storage[parameter.name].shape[0]
correlation_threshold = module.localization_correlation_threshold(
ensemble_size
)
else:
temp_storage[parameter.name] = smoother.update(
temp_storage[parameter.name]

print(
(
f"Running localization on {num_params} parameters,",
f"{num_obs} responses and {ensemble_size} realizations...",
)
)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)
for param_batch_idx in tqdm(batches):
X_local = temp_storage[parameter.name][param_batch_idx, :]
A = X_local - X_local.mean(axis=1, keepdims=True)
C_AA = A @ A.T / (ensemble_size - 1)

# State-measurement covariance matrix
C_AY = A @ Y_prime.T / (ensemble_size - 1)
Sigma_A = np.diag(np.sqrt(np.diag(C_AA)))

# State-measurement correlation matrix
c_AY = np.abs(
np.linalg.inv(Sigma_A) @ C_AY @ np.linalg.inv(Sigma_Y)
)
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated
# to the exact same responses,
# making up what we call a `parameter group``.
# We want to call the update only once per such parameter group
# to speed up computation.
param_groups = np.unique(c_bool, axis=0)

# Drop the parameter group that does not correlate to any responses.
row_with_all_false = np.all(~param_groups, axis=1)
param_groups = param_groups[~row_with_all_false]

for grp in param_groups:
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == grp, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[parameter.name][param_batch_idx, :][
row_indices, :
]
S_chunk = S[grp, :]
observation_errors_loc = observation_errors[grp]
observation_values_loc = observation_values[grp]
smoother.fit(
S_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[grp],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[parameter.name][
param_batch_idx[row_indices], :
] = smoother.update(X_chunk)
else:
for parameter in update_step.parameters:
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
if active_indices := parameter.index_list:
temp_storage[parameter.name][active_indices, :] = smoother.update(
temp_storage[parameter.name][active_indices, :]
)
else:
temp_storage[parameter.name] = smoother.update(
temp_storage[parameter.name]
)

if params_with_row_scaling := _get_params_with_row_scaling(
temp_storage, update_step.row_scaling_parameters
Expand All @@ -477,6 +563,9 @@ def analysis_ES(
_save_temp_storage_to_disk(
target_fs, ensemble_config, temp_storage, iens_active_index
)
end = time.time()
elapsed = end - start
print(f"Time to run _save_temporary_storage_to_disk: {elapsed}")


def analysis_IES(
Expand Down Expand Up @@ -511,9 +600,10 @@ def analysis_IES(
# Looping over local analysis update_step
for update_step in updatestep:
try:
S, (
Y, (
observation_values,
observation_errors,
_,
update_snapshot,
) = _load_observations_and_responses(
source_fs,
Expand All @@ -525,18 +615,20 @@ def analysis_IES(
update_step.observation_config(),
)
except IndexError as e:
raise ErtAnalysisError(e)
raise ErtAnalysisError(e) from e
# pylint: disable=unsupported-assignment-operation
smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:

num_obs = len(observation_values)
if num_obs == 0:
raise ErtAnalysisError(
f"No active observations for update step: {update_step.name}."
)

noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))
noise = rng.standard_normal(size=(len(observation_values), Y.shape[1]))
for parameter in update_step.parameters:
iterative_ensemble_smoother.fit(
S,
Y,
observation_errors,
observation_values,
noise=noise,
Expand Down
122 changes: 99 additions & 23 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import math
from typing import TYPE_CHECKING, Dict, List, Type, TypedDict, Union

from .parsing import ConfigValidationError
Expand Down Expand Up @@ -33,6 +34,23 @@ class VariableInfo(TypedDict):
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
# Default threshold is a function of ensemble size which is not available here.
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = -1


def correlation_threshold(ensemble_size: int, user_defined_threshold: float) -> float:
"""Decides whether or not to use user-defined or default threshold.
Default threshold taken from luo2022,
Continuous Hyper-parameter OPtimization (CHOP) in an ensemble Kalman filter
Section 2.3 - Localization in the CHOP problem
"""
default_threshold = 3 / math.sqrt(ensemble_size)
if user_defined_threshold == -1:
return default_threshold

return user_defined_threshold


class AnalysisMode(StrEnum):
Expand All @@ -58,6 +76,22 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]:
"step": 0.01,
"labelname": "Singular value truncation",
},
"LOCALIZATION": {
"type": bool,
"min": 0.0,
"value": DEFAULT_LOCALIZATION,
"max": 1.0,
"step": 1.0,
"labelname": "Adaptive localization",
},
"LOCALIZATION_CORRELATION_THRESHOLD": {
"type": float,
"min": 0.0,
"value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
"max": 1.0,
"step": 0.1,
"labelname": "Adaptive localization correlation threshold",
},
}
ies_variables: Dict[str, "VariableInfo"] = {
"IES_MAX_STEPLENGTH": {
Expand Down Expand Up @@ -169,31 +203,47 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]) -> None:
self.handle_special_key_set(var_name, value)
elif var_name in self._variables:
var = self._variables[var_name]
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"

if var["type"] is not bool:
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
)
else:
var["value"] = new_value

except ValueError:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has "
f"incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
else:
if not isinstance(var["value"], bool):
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
else:
var["value"] = new_value

except ValueError:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
)
# When config is first read, `value` is a string
# that's either "False" or "True",
# but since bool("False") is True we need to convert it to bool.
if not isinstance(value, bool):
value = str(value).lower() != "false"

var["value"] = var["type"](value)
else:
raise ConfigValidationError(
f"Variable {var_name!r} not found in {self.name!r} analysis module"
Expand All @@ -210,6 +260,32 @@ def inversion(self, value: int) -> None:
def get_truncation(self) -> float:
return self.get_variable_value("ENKF_TRUNCATION")

def localization(self) -> bool:
return bool(self.get_variable_value("LOCALIZATION"))

def localization_correlation_threshold(self, ensemble_size: int) -> float:
return correlation_threshold(
ensemble_size, self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
)

def get_steplength(self, iteration_nr: int) -> float:
"""
This is an implementation of Eq. (49), which calculates a suitable
step length for the update step, from the book:
Geir Evensen, Formulating the history matching problem with
consistent error statistics, Computational Geosciences (2021) 25:945 –970
Function not really used moved from C to keep the class interface consistent
should be investigated for possible removal.
"""
min_step_length = self.get_variable_value("IES_MIN_STEPLENGTH")
max_step_length = self.get_variable_value("IES_MAX_STEPLENGTH")
dec_step_length = self.get_variable_value("IES_DEC_STEPLENGTH")
step_length = min_step_length + (max_step_length - min_step_length) * pow(
2, -(iteration_nr - 1) / (dec_step_length - 1)
)
return step_length

def __repr__(self) -> str:
return f"AnalysisModule(name = {self.name})"

Expand Down
Loading

0 comments on commit b6f14fe

Please sign in to comment.