Skip to content

Commit

Permalink
Refactor load_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Mar 6, 2024
1 parent ae41eb2 commit dff37c3
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 85 deletions.
115 changes: 41 additions & 74 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
)
from typing_extensions import Self

from ert.config import FieldConfig, GenKwConfig, SurfaceConfig

from ..config.analysis_module import ESSettings, IESSettings
from . import misfit_preprocessor
from .event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent
Expand Down Expand Up @@ -148,7 +146,7 @@ def __setitem__(


def _all_parameters(
source_fs: LocalEnsemble,
ensemble: LocalEnsemble,
iens_active_index: npt.NDArray[np.int_],
param_groups: List[str],
) -> Optional[npt.NDArray[np.double]]:
Expand All @@ -157,7 +155,7 @@ def _all_parameters(
temp_storage = TempStorage()
for param_group in param_groups:
_temp_storage = _create_temporary_parameter_storage(
source_fs, iens_active_index, param_group
ensemble, iens_active_index, param_group
)
temp_storage[param_group] = _temp_storage[param_group]
matrices = [temp_storage[p] for p in param_groups]
Expand All @@ -176,51 +174,20 @@ def _save_temp_storage_to_disk(


def _create_temporary_parameter_storage(
source_fs: LocalEnsemble,
ensemble: LocalEnsemble,
iens_active_index: npt.NDArray[np.int_],
param_group: str,
) -> TempStorage:
temp_storage = TempStorage()
t_genkw = 0.0
t_surface = 0.0
t_field = 0.0
_logger.debug("_create_temporary_parameter_storage() - start")
config_node = source_fs.experiment.parameter_configuration[param_group]
matrix: Union[npt.NDArray[np.double], xr.DataArray]
if isinstance(config_node, GenKwConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(param_group, iens_active_index)[
"values"
].values.T
t_genkw += time.perf_counter() - t
elif isinstance(config_node, SurfaceConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(param_group, iens_active_index)["values"]
t_surface += time.perf_counter() - t
elif isinstance(config_node, FieldConfig):
t = time.perf_counter()
ds = source_fs.load_parameters(param_group, iens_active_index)
ensemble_size = len(ds.realizations)
da = xr.DataArray(
[
np.ma.MaskedArray(data=d, mask=config_node.mask).compressed() # type: ignore
for d in ds["values"].values.reshape(ensemble_size, -1)
]
)
matrix = da.T.to_numpy()
t_field += time.perf_counter() - t
else:
raise NotImplementedError(f"{type(config_node)} is not supported")
temp_storage[param_group] = matrix
_logger.debug(
f"_create_temporary_parameter_storage() time_used gen_kw={t_genkw:.4f}s, \
surface={t_surface:.4f}s, field={t_field:.4f}s"
config_node = ensemble.experiment.parameter_configuration[param_group]
temp_storage[param_group] = config_node.load_parameters(
ensemble, param_group, iens_active_index
)
return temp_storage


def _get_obs_and_measure_data(
source_fs: LocalEnsemble,
ensemble: LocalEnsemble,
selected_observations: Iterable[str],
iens_active_index: npt.NDArray[np.int_],
) -> Tuple[
Expand All @@ -233,11 +200,11 @@ def _get_obs_and_measure_data(
observation_keys = []
observation_values = []
observation_errors = []
observations = source_fs.experiment.observations
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
response = source_fs.load_responses(group, tuple(iens_active_index))
response = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
response = response.reindex(
time=observation.time, method="nearest", tolerance="1s" # type: ignore
Expand All @@ -258,7 +225,7 @@ def _get_obs_and_measure_data(
.transpose(..., "realization")
.values.reshape((-1, len(filtered_response.realization)))
)
source_fs.load_responses.cache_clear()
ensemble.load_responses.cache_clear()
return (
np.concatenate(measured_data, axis=0),
np.concatenate(observation_values),
Expand All @@ -268,7 +235,7 @@ def _get_obs_and_measure_data(


def _load_observations_and_responses(
source_fs: LocalEnsemble,
ensemble: LocalEnsemble,
alpha: float,
std_cutoff: float,
global_std_scaling: float,
Expand All @@ -284,7 +251,7 @@ def _load_observations_and_responses(
],
]:
S, observations, errors, obs_keys = _get_obs_and_measure_data(
source_fs,
ensemble,
selected_observations,
iens_ative_index,
)
Expand Down Expand Up @@ -418,8 +385,8 @@ def _copy_unupdated_parameters(
all_parameter_groups: Iterable[str],
updated_parameter_groups: Iterable[str],
iens_active_index: npt.NDArray[np.int_],
source_fs: LocalEnsemble,
target_fs: LocalEnsemble,
source_ensemble: LocalEnsemble,
target_ensemble: LocalEnsemble,
) -> None:
"""
Copies parameter groups that have not been updated from a source ensemble to a target ensemble.
Expand All @@ -432,8 +399,8 @@ def _copy_unupdated_parameters(
updated_parameter_groups (List[str]): A list of parameter groups that have already been updated.
iens_active_index (npt.NDArray[np.int_]): An array of indices for the active realizations in the
target ensemble.
source_fs (EnsembleReader): The file system of the source ensemble, from which parameters are copied.
target_fs (EnsembleAccessor): The file system of the target ensemble, to which parameters are saved.
source_ensemble (LocalEnsemble): The file system of the source ensemble, from which parameters are copied.
target_ensemble (LocalEnsemble): The file system of the target ensemble, to which parameters are saved.
Returns:
None: The function does not return any value but updates the target file system by copying over
Expand All @@ -447,8 +414,8 @@ def _copy_unupdated_parameters(
# Copy the non-updated parameter groups from source to target for each active realization
for parameter_group in not_updated_parameter_groups:
for realization in iens_active_index:
ds = source_fs.load_parameters(parameter_group, int(realization))
target_fs.save_parameters(parameter_group, realization, ds)
ds = source_ensemble.load_parameters(parameter_group, int(realization))
target_ensemble.save_parameters(parameter_group, realization, ds)


def analysis_ES(
Expand All @@ -461,8 +428,8 @@ def analysis_ES(
global_scaling: float,
smoother_snapshot: SmootherSnapshot,
ens_mask: npt.NDArray[np.bool_],
source_fs: LocalEnsemble,
target_fs: LocalEnsemble,
source_ensemble: LocalEnsemble,
target_ensemble: LocalEnsemble,
progress_callback: Callable[[AnalysisEvent], None],
misfit_process: bool,
) -> None:
Expand All @@ -484,7 +451,7 @@ def adaptive_localization_progress_callback(
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
source_ensemble,
alpha,
std_cutoff,
global_scaling,
Expand Down Expand Up @@ -529,8 +496,8 @@ def adaptive_localization_progress_callback(
np.fill_diagonal(T, T.diagonal() + 1)

for param_group in parameters:
source: Union[EnsembleReader, EnsembleAccessor]
source = source_fs
source: LocalEnsemble
source = source_ensemble
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group
)
Expand Down Expand Up @@ -572,17 +539,17 @@ def adaptive_localization_progress_callback(
_logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))
start = time.time()
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
_save_temp_storage_to_disk(target_ensemble, temp_storage, iens_active_index)
_logger.info(
f"Storing data for {param_group} completed in {(time.time() - start) / 60} minutes"
)

_copy_unupdated_parameters(
list(source_fs.experiment.parameter_configuration.keys()),
list(source_ensemble.experiment.parameter_configuration.keys()),
parameters,
iens_active_index,
source_fs,
target_fs,
source_ensemble,
target_ensemble,
)


Expand All @@ -595,8 +562,8 @@ def analysis_IES(
std_cutoff: float,
smoother_snapshot: SmootherSnapshot,
ens_mask: npt.NDArray[np.bool_],
source_fs: LocalEnsemble,
target_fs: LocalEnsemble,
source_ensemble: LocalEnsemble,
target_ensemble: LocalEnsemble,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
misfit_preprocessor: bool,
Expand All @@ -620,7 +587,7 @@ def analysis_IES(
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
source_ensemble,
alpha,
std_cutoff,
1.0,
Expand All @@ -637,9 +604,9 @@ def analysis_IES(
if sies_smoother is None:
# The sies smoother must be initialized with the full parameter ensemble
# Get relevant active realizations
param_groups = list(source_fs.experiment.parameter_configuration.keys())
param_groups = list(source_ensemble.experiment.parameter_configuration.keys())
parameter_ensemble_active = _all_parameters(
source_fs, iens_active_index, param_groups
source_ensemble, iens_active_index, param_groups
)
sies_smoother = ies.SIES(
parameters=parameter_ensemble_active,
Expand All @@ -665,11 +632,11 @@ def analysis_IES(
sies_smoother.W[:, masking_of_initial_parameters] = proposed_W

for param_group in parameters:
source: Union[EnsembleReader, EnsembleAccessor] = target_fs
source: LocalEnsemble = target_ensemble
try:
target_fs.load_parameters(group=param_group, realizations=0)["values"]
target_ensemble.load_parameters(group=param_group, realizations=0)["values"]
except Exception:
source = source_fs
source = source_ensemble
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group
)
Expand All @@ -679,14 +646,14 @@ def analysis_IES(
)

progress_callback(AnalysisStatusEvent(msg=f"Storing data for {param_group}.."))
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
_save_temp_storage_to_disk(target_ensemble, temp_storage, iens_active_index)

_copy_unupdated_parameters(
list(source_fs.experiment.parameter_configuration.keys()),
list(source_ensemble.experiment.parameter_configuration.keys()),
parameters,
iens_active_index,
source_fs,
target_fs,
source_ensemble,
target_ensemble,
)

assert sies_smoother is not None, "sies_smoother should be initialized"
Expand Down Expand Up @@ -857,8 +824,8 @@ def iterative_smoother_update(
std_cutoff=update_settings.std_cutoff,
smoother_snapshot=smoother_snapshot,
ens_mask=ens_mask,
source_fs=prior_storage,
target_fs=posterior_storage,
source_ensemble=prior_storage,
target_ensemble=posterior_storage,
sies_smoother=sies_smoother,
progress_callback=progress_callback,
misfit_preprocessor=update_settings.misfit_preprocess,
Expand Down
13 changes: 12 additions & 1 deletion src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .parameter_config import ParameterConfig

if TYPE_CHECKING:
import numpy.typing as npt

from ert.storage import LocalEnsemble

Number = Union[int, float]
Expand Down Expand Up @@ -91,10 +93,19 @@ def write_to_runpath(
json.dump(data, f)

def save_parameters(
self, ensemble: LocalEnsemble, group: str, realization: int, data: np.ndarray
self,
ensemble: LocalEnsemble,
group: str,
realization: int,
data: npt.NDArray[np.float_],
) -> None:
raise NotImplementedError()

def load_parameters(
self, ensemble: LocalEnsemble, group: str, realizations: npt.NDArray[np.int_]
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
raise NotImplementedError()

@staticmethod
def to_dataset(data: DataType) -> xr.Dataset:
"""Flattens data to fit inside a dataset"""
Expand Down
19 changes: 18 additions & 1 deletion src/ert/config/field_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def write_to_runpath(
_logger.debug(f"save() time_used {(time.perf_counter() - t):.4f}s")

def save_parameters(
self, ensemble: LocalEnsemble, group: str, realization: int, data: np.ndarray
self,
ensemble: LocalEnsemble,
group: str,
realization: int,
data: npt.NDArray[np.float_],
) -> None:
ma = np.ma.MaskedArray( # type: ignore
data=np.zeros(self.mask.size),
Expand All @@ -192,6 +196,19 @@ def save_parameters(
ds = xr.Dataset({"values": (["x", "y", "z"], ma.filled())}) # type: ignore
ensemble.save_parameters(group, realization, ds)

def load_parameters(
self, ensemble: LocalEnsemble, group: str, realizations: npt.NDArray[np.int_]
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
ds = ensemble.load_parameters(group, realizations)
ensemble_size = len(ds.realizations)
da = xr.DataArray(
[
np.ma.MaskedArray(data=d, mask=self.mask).compressed() # type: ignore
for d in ds["values"].values.reshape(ensemble_size, -1)
]
)
return da.T.to_numpy()

def _fetch_from_ensemble(
self, real_nr: int, ensemble: LocalEnsemble
) -> xr.DataArray:
Expand Down
22 changes: 20 additions & 2 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypedDict, overload
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
TypedDict,
Union,
overload,
)

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -270,7 +279,11 @@ def write_to_runpath(
return {self.name: data}

def save_parameters(
self, ensemble: LocalEnsemble, group: str, realization: int, data: np.ndarray
self,
ensemble: LocalEnsemble,
group: str,
realization: int,
data: npt.NDArray[np.float_],
) -> None:
ds = xr.Dataset(
{
Expand All @@ -284,6 +297,11 @@ def save_parameters(
)
ensemble.save_parameters(group, realization, ds)

def load_parameters(
self, ensemble: LocalEnsemble, group: str, realizations: npt.NDArray[np.int_]
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
return ensemble.load_parameters(group, realizations)["values"].values.T

def shouldUseLogScale(self, keyword: str) -> bool:
for tf in self.transfer_functions:
if tf.name == keyword:
Expand Down
Loading

0 comments on commit dff37c3

Please sign in to comment.