Skip to content

Commit

Permalink
Group observations by response (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Mar 8, 2024
1 parent 79ab0a3 commit d2eedf6
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 54 deletions.
31 changes: 17 additions & 14 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _get_obs_and_measure_data(
observation_values = []
observation_errors = []
observations = ensemble.experiment.observations

for group in ensemble.experiment.response_info:
if group not in observations:
continue
Expand All @@ -224,24 +225,26 @@ def _get_obs_and_measure_data(
f"Observation: {observation} attached to response: {group}"
) from e

df = filtered_response.to_dataframe().reset_index()
grouped = filtered_response.groupby(
"obs_name", squeeze=True, restore_coord_dims=False
)

observation_keys.append(df["name"].to_list())
observation_values.append(df["observations"].to_list())
observation_errors.append(df["std"].to_list())
for obs_name, group_ds in grouped:
# df = observation.to_dataframe().reset_index()

observation_keys.append([obs_name] * group_ds["observations"].size)
observation_values.append(group_ds["observations"].data.ravel())
observation_errors.append(group_ds["std"].data.ravel())
measured_data.append(
group_ds["values"]
.transpose(..., "realization")
.values.reshape((-1, len(group_ds.realization)))
)

measured_data.append(
filtered_response["values"]
.transpose(..., "realization")
.values.reshape((-1, len(filtered_response.realization)))
)
ensemble.load_responses.cache_clear()
source_fs.load_responses.cache_clear()

# Measured_data, an array of 3 dimensions
# Outer dimension: One array per observation
# Mid dimension is ??? Sometimes length 1, sometimes nreals?
# Inner dimension: value is "values", index is realization
# Measured_data, an array of nd arrays with shape (1, nreals)
# Each inner array has 1 dimension containing obs key, and nreals "values"
return (
np.concatenate(measured_data, axis=0),
np.concatenate(observation_values),
Expand Down
3 changes: 3 additions & 0 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset_info(self, active_list: List[int]) -> List[any]:

Check failure on line 30 in src/ert/config/observation_vector.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Missing return statement

Check failure on line 30 in src/ert/config/observation_vector.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function "builtins.any" is not valid as a type
pass

def to_dataset(self, active_list: List[int]) -> xr.Dataset:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
datasets = []
Expand Down
226 changes: 208 additions & 18 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)

import numpy as np
import pandas as pd
import xarray as xr
from pydantic import BaseModel, Field

from ert.validation import rangestring_to_list

Expand Down Expand Up @@ -35,34 +45,214 @@ def history_key(key: str) -> str:
return ":".join([keyword + "H"] + rest)


class _SummaryObsDataset(BaseModel):
observations: List[float] = Field(default_factory=lambda: [])
stds: List[float] = Field(default_factory=lambda: [])
times: List[int] = Field(default_factory=lambda: [])
summary_keywords: List[str] = Field(default_factory=lambda: [])
obs_names: List[str] = Field(default_factory=lambda: [])

def to_xarray(self) -> xr.Dataset:
return (
pd.DataFrame(
data={
"obs_name": self.obs_names,
"name": self.summary_keywords,
"time": self.times,
"observations": self.observations,
"std": self.stds,
},
)
.set_index(["obs_name", "name", "time"])
.to_xarray()
)


class _GenObsDataset(BaseModel):
observations: List[float] = Field(default_factory=lambda: [])
stds: List[float] = Field(default_factory=lambda: [])
indexes: List[int] = Field(default_factory=lambda: [])
report_steps: List[int] = Field(default_factory=lambda: [])
obs_names: List[str] = Field(default_factory=lambda: [])

def to_xarray(self) -> xr.Dataset:
return (
pd.DataFrame(
data={
"obs_name": self.obs_names,
"report_step": self.report_steps,
"index": self.indexes,
"observations": self.observations,
"std": self.stds,
}
)
.set_index(["obs_name", "report_step", "index"])
.to_xarray()
)

# return xr.Dataset(
# data_vars={
# "observations": (
# ["obs_name", "report_step", "index"],
# [[self.observations]],
# ),
# "std": (["obs_name", "report_step", "index"], [[self.stds]]),
# },
# coords={
# "index": self.indexes,
# "report_step": self.report_steps,
# "obs_name": self.obs_names,
# },
# )


class _GenObsAccumulator:
def __init__(self):

Check failure on line 110 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self.obs: Dict[str, _GenObsDataset] = {}

def write(

Check failure on line 113 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation

Check failure on line 113 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation for one or more arguments
self,
response_key: str,
obs_name: str,
report_step: int,
observations,
stds,
indexes: List[int],
):
# We assume the input lists all have the same length
if response_key not in self.obs:
self.obs[response_key] = _GenObsDataset()

vecs = self.obs[response_key]

vecs.observations.extend(observations)
vecs.stds.extend(stds)
vecs.indexes.extend(indexes)

for _ in observations:
vecs.obs_names.append(obs_name)
vecs.report_steps.append(report_step)

def to_xarrays_grouped_by_response(self) -> Dict[str, xr.Dataset]:
return {response_key: ds.to_xarray() for response_key, ds in self.obs.items()}


class _SummaryObsAccumulator:
def __init__(self):

Check failure on line 141 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self.obs: Dict[str, _SummaryObsDataset] = {}

def write(

Check failure on line 144 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self,
response_key: str,
obs_names: List[str],
summary_keyword: str,
observations: List[float],
stds: List[float],
times: List[int],
):
# We assume the input lists all have the same length
if response_key not in self.obs:
self.obs[response_key] = _SummaryObsDataset()

vecs = self.obs[response_key]

vecs.obs_names.extend(obs_names)
vecs.observations.extend(observations)
vecs.stds.extend(stds)
vecs.times.extend(times)

for _ in observations:
vecs.summary_keywords.append(summary_keyword)

def to_xarrays_grouped_by_response(self) -> Dict[str, xr.Dataset]:
return {response_key: ds.to_xarray() for response_key, ds in self.obs.items()}


class EnkfObs:
def __init__(self, obs_vectors: Dict[str, ObsVector], obs_time: List[datetime]):
self.obs_vectors = obs_vectors
self.obs_time = obs_time

vecs: List[ObsVector] = [*self.obs_vectors.values()]
response_keys = set([x.data_key for x in vecs])
observations_by_response: Dict[str, List[xr.Dataset]] = {
k: [] for k in response_keys
}

for vec in vecs:
k = vec.data_key
ds = vec.to_dataset([])
assert k in observations_by_response

if "name" not in ds.dims:
ds = ds.expand_dims(name=[vec.observation_key])
gen_obs = _GenObsAccumulator()

Check failure on line 178 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_GenObsAccumulator" in typed context
sum_obs = _SummaryObsAccumulator()

Check failure on line 179 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_SummaryObsAccumulator" in typed context

observations_by_response[k].append(ds)
# Faster to not create a single xr.Dataset per
# observation and then merge/concat
# this just accumulates 1d vecs before making a dataset
for vec in vecs:
if vec.observation_type == EnkfObservationImplementationType.GEN_OBS:
for report_step, node in vec.observations.items():
gen_obs.write(
response_key=vec.data_key,
obs_name=vec.observation_key,
report_step=report_step,

Check failure on line 190 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument "report_step" to "write" of "_GenObsAccumulator" has incompatible type "int | datetime"; expected "int"
observations=node.values,
stds=node.stds,
indexes=node.indices,
)

merged_by_response: Dict[str, xr.Dataset] = {}
elif vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
stds = []
dates = []
obs_keys = []

for the_date, obs in vec.observations.items():
assert isinstance(obs, SummaryObservation)
observations.append(obs.value)
stds.append(obs.std)
dates.append(the_date)
obs_keys.append(obs.observation_key)

sum_obs.write(
response_key=vec.data_key,
obs_names=obs_keys,
summary_keyword=vec.observation_key,
observations=observations,
stds=stds,
times=dates,
)
else:
raise ValueError("Unknown observation type")

for k in observations_by_response:
datasets = observations_by_response[k]
merged_by_response[k] = xr.concat(datasets, dim="name")
self.datasets: Dict[str, xr.Dataset] = {
**gen_obs.to_xarrays_grouped_by_response(),
**sum_obs.to_xarrays_grouped_by_response(),
}

self.datasets: Dict[str, xr.Dataset] = merged_by_response
for response_key, ds in self.datasets.items():
ds.attrs["response"] = response_key

# Alternate approach: Merge xarray datasets
# seems to be a lot slower as it probably does some checks to see if
# merge is OK and whatnot, faster to create some 1d vecs then build the
# larger datasets
# vecs: List[ObsVector] = [*self.obs_vectors.values()]
# response_keys = set([x.data_key for x in vecs])
# observations_by_response: Dict[str, List[xr.Dataset]] = {
# k: [] for k in response_keys
# }

# for vec in vecs:
# k = vec.data_key
# ds = vec.to_dataset([])
# assert k in observations_by_response

#
# if "obs_name" not in ds.dims:
# ds = ds.expand_dims(obs_name=[vec.observation_key])
#
# observations_by_response[k].append(ds)
#
# merged_by_response: Dict[str, xr.Dataset] = {}
#
# for k in observations_by_response:
# datasets = observations_by_response[k]
# merged_by_response[k] = xr.combine_by_coords(datasets, join="inner")
#
# self.datasets: Dict[str, xr.Dataset] = merged_by_response

def __len__(self) -> int:
return len(self.obs_vectors)
Expand Down
34 changes: 19 additions & 15 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,20 @@ def data_for_key(

def get_all_observations(experiment: Experiment) -> List[Dict[str, Any]]:
observations = []
for key, dataset in experiment.observations.items():
observation = {
"name": key,
"values": list(dataset["observations"].values.flatten()),
"errors": list(dataset["std"].values.flatten()),
}
if "time" in dataset.coords:
observation["x_axis"] = _prepare_x_axis(dataset["time"].values.flatten()) # type: ignore
else:
observation["x_axis"] = _prepare_x_axis(dataset["index"].values.flatten()) # type: ignore
observations.append(observation)
for response_key, dataset in experiment.observations.items():

Check warning on line 151 in src/ert/dark_storage/common.py

View workflow job for this annotation

GitHub Actions / check-style (3.12)

Loop control variable `response_key` not used within loop body
x_coord_key = "time" if "time" in dataset.coords else "index"

for obs_name in dataset["name"].values.flatten():
ds_for_name = dataset.sel(name=obs_name)
df_for_name = ds_for_name.reset_index()
observations.append({
"name": obs_name,
"values": df_for_name["values"].to_list(),
"errors": df_for_name["std"].to_list(),
"x_axis": _prepare_x_axis(df_for_name[x_coord_key].to_list())
})

observations.extend(dataset)

observations.sort(key=lambda x: x["x_axis"]) # type: ignore
return observations
Expand All @@ -171,15 +174,16 @@ def get_observations_for_obs_keys(
experiment_observations = ensemble.experiment.observations
for key in observation_keys:
dataset = experiment_observations[key]
df = dataset.to_dataframe().reset_index()
observation = {
"name": key,
"values": list(dataset["observations"].values.flatten()),
"errors": list(dataset["std"].values.flatten()),
"values": list(df["observations"].to_list()),
"errors": list(df["std"].to_list()),
}
if "time" in dataset.coords:
observation["x_axis"] = _prepare_x_axis(dataset["time"].values.flatten()) # type: ignore
observation["x_axis"] = _prepare_x_axis(df["time"].to_list()) # type: ignore
else:
observation["x_axis"] = _prepare_x_axis(dataset["index"].values.flatten()) # type: ignore
observation["x_axis"] = _prepare_x_axis(df["index"].to_list()) # type: ignore
observations.append(observation)

observations.sort(key=lambda x: x["x_axis"]) # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def get_ensemble_responses(
response_names_with_observations = set()
for dataset in ensemble.experiment.observations.values():
if dataset.attrs["response"] == "summary" and "name" in dataset.coords:
response_name = dataset.name.values.flatten()[0]
response_names_with_observations.add(response_name)
summary_kw_names = dataset.name.values.flatten()
response_names_with_observations = response_names_with_observations.union(set(summary_kw_names))
else:
response_name = dataset.attrs["response"]
if "report_step" in dataset.coords:
report_step = dataset.report_step.values.flatten()[0]
report_step = dataset.report_step.values.flatten()
response_names_with_observations.add(response_name + "@" + str(report_step))

for name in ensemble.get_summary_keyset():
Expand Down
11 changes: 7 additions & 4 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def update_parameters(self) -> List[str]:
@cached_property
def observations(self) -> Dict[str, xr.Dataset]:
observations = list(self.mount_point.glob("observations/*"))
return {
observation.name: xr.open_dataset(observation, engine="scipy")
for observation in observations
}
obs_by_response_name = {}

for obs_file in observations:
ds = xr.open_dataset(obs_file, engine="scipy")
obs_by_response_name[ds.attrs["response"]] = ds

return obs_by_response_name
Loading

0 comments on commit d2eedf6

Please sign in to comment.