Skip to content

Commit

Permalink
Group observations by response (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Mar 8, 2024
1 parent 79ab0a3 commit 7b5300d
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 85 deletions.
31 changes: 17 additions & 14 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _get_obs_and_measure_data(
observation_values = []
observation_errors = []
observations = ensemble.experiment.observations

for group in ensemble.experiment.response_info:
if group not in observations:
continue
Expand All @@ -224,24 +225,26 @@ def _get_obs_and_measure_data(
f"Observation: {observation} attached to response: {group}"
) from e

df = filtered_response.to_dataframe().reset_index()
grouped = filtered_response.groupby(
"obs_name", squeeze=True, restore_coord_dims=False
)

observation_keys.append(df["name"].to_list())
observation_values.append(df["observations"].to_list())
observation_errors.append(df["std"].to_list())
for obs_name, group_ds in grouped:
# df = observation.to_dataframe().reset_index()

observation_keys.append([obs_name] * group_ds["observations"].size)
observation_values.append(group_ds["observations"].data.ravel())
observation_errors.append(group_ds["std"].data.ravel())
measured_data.append(
group_ds["values"]
.transpose(..., "realization")
.values.reshape((-1, len(group_ds.realization)))
)

measured_data.append(
filtered_response["values"]
.transpose(..., "realization")
.values.reshape((-1, len(filtered_response.realization)))
)
ensemble.load_responses.cache_clear()
source_fs.load_responses.cache_clear()

# Measured_data, an array of 3 dimensions
# Outer dimension: One array per observation
# Mid dimension is ??? Sometimes length 1, sometimes nreals?
# Inner dimension: value is "values", index is realization
# Measured_data, an array of nd arrays with shape (1, nreals)
# Each inner array has 1 dimension containing obs key, and nreals "values"
return (
np.concatenate(measured_data, axis=0),
np.concatenate(observation_values),
Expand Down
3 changes: 3 additions & 0 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset_info(self, active_list: List[int]) -> List[any]:

Check failure on line 30 in src/ert/config/observation_vector.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Missing return statement

Check failure on line 30 in src/ert/config/observation_vector.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function "builtins.any" is not valid as a type
pass

def to_dataset(self, active_list: List[int]) -> xr.Dataset:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
datasets = []
Expand Down
226 changes: 208 additions & 18 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)

import numpy as np
import pandas as pd
import xarray as xr
from pydantic import BaseModel, Field

from ert.validation import rangestring_to_list

Expand Down Expand Up @@ -35,34 +45,214 @@ def history_key(key: str) -> str:
return ":".join([keyword + "H"] + rest)


class _SummaryObsDataset(BaseModel):
observations: List[float] = Field(default_factory=lambda: [])
stds: List[float] = Field(default_factory=lambda: [])
times: List[int] = Field(default_factory=lambda: [])
summary_keywords: List[str] = Field(default_factory=lambda: [])
obs_names: List[str] = Field(default_factory=lambda: [])

def to_xarray(self) -> xr.Dataset:
return (
pd.DataFrame(
data={
"obs_name": self.obs_names,
"name": self.summary_keywords,
"time": self.times,
"observations": self.observations,
"std": self.stds,
},
)
.set_index(["obs_name", "name", "time"])
.to_xarray()
)


class _GenObsDataset(BaseModel):
observations: List[float] = Field(default_factory=lambda: [])
stds: List[float] = Field(default_factory=lambda: [])
indexes: List[int] = Field(default_factory=lambda: [])
report_steps: List[int] = Field(default_factory=lambda: [])
obs_names: List[str] = Field(default_factory=lambda: [])

def to_xarray(self) -> xr.Dataset:
return (
pd.DataFrame(
data={
"obs_name": self.obs_names,
"report_step": self.report_steps,
"index": self.indexes,
"observations": self.observations,
"std": self.stds,
}
)
.set_index(["obs_name", "report_step", "index"])
.to_xarray()
)

# return xr.Dataset(
# data_vars={
# "observations": (
# ["obs_name", "report_step", "index"],
# [[self.observations]],
# ),
# "std": (["obs_name", "report_step", "index"], [[self.stds]]),
# },
# coords={
# "index": self.indexes,
# "report_step": self.report_steps,
# "obs_name": self.obs_names,
# },
# )


class _GenObsAccumulator:
def __init__(self):

Check failure on line 110 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self.obs: Dict[str, _GenObsDataset] = {}

def write(

Check failure on line 113 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation

Check failure on line 113 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation for one or more arguments
self,
response_key: str,
obs_name: str,
report_step: int,
observations,
stds,
indexes: List[int],
):
# We assume the input lists all have the same length
if response_key not in self.obs:
self.obs[response_key] = _GenObsDataset()

vecs = self.obs[response_key]

vecs.observations.extend(observations)
vecs.stds.extend(stds)
vecs.indexes.extend(indexes)

for _ in observations:
vecs.obs_names.append(obs_name)
vecs.report_steps.append(report_step)

def to_xarrays_grouped_by_response(self) -> Dict[str, xr.Dataset]:
return {response_key: ds.to_xarray() for response_key, ds in self.obs.items()}


class _SummaryObsAccumulator:
def __init__(self):

Check failure on line 141 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self.obs: Dict[str, _SummaryObsDataset] = {}

def write(

Check failure on line 144 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self,
response_key: str,
obs_names: List[str],
summary_keyword: str,
observations: List[float],
stds: List[float],
times: List[int],
):
# We assume the input lists all have the same length
if response_key not in self.obs:
self.obs[response_key] = _SummaryObsDataset()

vecs = self.obs[response_key]

vecs.obs_names.extend(obs_names)
vecs.observations.extend(observations)
vecs.stds.extend(stds)
vecs.times.extend(times)

for _ in observations:
vecs.summary_keywords.append(summary_keyword)

def to_xarrays_grouped_by_response(self) -> Dict[str, xr.Dataset]:
return {response_key: ds.to_xarray() for response_key, ds in self.obs.items()}


class EnkfObs:
def __init__(self, obs_vectors: Dict[str, ObsVector], obs_time: List[datetime]):
self.obs_vectors = obs_vectors
self.obs_time = obs_time

vecs: List[ObsVector] = [*self.obs_vectors.values()]
response_keys = set([x.data_key for x in vecs])
observations_by_response: Dict[str, List[xr.Dataset]] = {
k: [] for k in response_keys
}

for vec in vecs:
k = vec.data_key
ds = vec.to_dataset([])
assert k in observations_by_response

if "name" not in ds.dims:
ds = ds.expand_dims(name=[vec.observation_key])
gen_obs = _GenObsAccumulator()

Check failure on line 178 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_GenObsAccumulator" in typed context
sum_obs = _SummaryObsAccumulator()

Check failure on line 179 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_SummaryObsAccumulator" in typed context

observations_by_response[k].append(ds)
# Faster to not create a single xr.Dataset per
# observation and then merge/concat
# this just accumulates 1d vecs before making a dataset
for vec in vecs:
if vec.observation_type == EnkfObservationImplementationType.GEN_OBS:
for report_step, node in vec.observations.items():
gen_obs.write(
response_key=vec.data_key,
obs_name=vec.observation_key,
report_step=report_step,

Check failure on line 190 in src/ert/config/observations.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument "report_step" to "write" of "_GenObsAccumulator" has incompatible type "int | datetime"; expected "int"
observations=node.values,
stds=node.stds,
indexes=node.indices,
)

merged_by_response: Dict[str, xr.Dataset] = {}
elif vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
stds = []
dates = []
obs_keys = []

for the_date, obs in vec.observations.items():
assert isinstance(obs, SummaryObservation)
observations.append(obs.value)
stds.append(obs.std)
dates.append(the_date)
obs_keys.append(obs.observation_key)

sum_obs.write(
response_key=vec.data_key,
obs_names=obs_keys,
summary_keyword=vec.observation_key,
observations=observations,
stds=stds,
times=dates,
)
else:
raise ValueError("Unknown observation type")

for k in observations_by_response:
datasets = observations_by_response[k]
merged_by_response[k] = xr.concat(datasets, dim="name")
self.datasets: Dict[str, xr.Dataset] = {
**gen_obs.to_xarrays_grouped_by_response(),
**sum_obs.to_xarrays_grouped_by_response(),
}

self.datasets: Dict[str, xr.Dataset] = merged_by_response
for response_key, ds in self.datasets.items():
ds.attrs["response"] = response_key

# Alternate approach: Merge xarray datasets
# seems to be a lot slower as it probably does some checks to see if
# merge is OK and whatnot, faster to create some 1d vecs then build the
# larger datasets
# vecs: List[ObsVector] = [*self.obs_vectors.values()]
# response_keys = set([x.data_key for x in vecs])
# observations_by_response: Dict[str, List[xr.Dataset]] = {
# k: [] for k in response_keys
# }

# for vec in vecs:
# k = vec.data_key
# ds = vec.to_dataset([])
# assert k in observations_by_response

#
# if "obs_name" not in ds.dims:
# ds = ds.expand_dims(obs_name=[vec.observation_key])
#
# observations_by_response[k].append(ds)
#
# merged_by_response: Dict[str, xr.Dataset] = {}
#
# for k in observations_by_response:
# datasets = observations_by_response[k]
# merged_by_response[k] = xr.combine_by_coords(datasets, join="inner")
#
# self.datasets: Dict[str, xr.Dataset] = merged_by_response

def __len__(self) -> int:
return len(self.obs_vectors)
Expand Down
Loading

0 comments on commit 7b5300d

Please sign in to comment.