Skip to content

Commit

Permalink
Add stateful dark storage test
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Oct 10, 2024
1 parent 06c6dc6 commit bc8b1fb
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 8 deletions.
6 changes: 5 additions & 1 deletion src/ert/dark_storage/endpoints/updates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends

from ert.dark_storage.enkf import get_storage

DEFAULT_STORAGE = Depends(get_storage)

router = APIRouter(tags=["ensemble"])

Expand Down
15 changes: 10 additions & 5 deletions src/ert/dark_storage/enkf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@


def get_storage() -> Storage:
global _storage
global _storage # noqa
if _storage is None:
try:
return (_storage := open_storage(os.environ["ERT_STORAGE_ENS_PATH"]))
except RuntimeError as err:
raise InternalServerError(f"{err!s}") from err
_storage = update_storage()
_storage.refresh()
return _storage


def update_storage() -> Storage:
global _storage
try:
return (_storage := open_storage(os.environ["ERT_STORAGE_ENS_PATH"]))
except RuntimeError as err:
raise InternalServerError(f"{err!s}") from err
101 changes: 101 additions & 0 deletions tests/ert/unit_tests/dark_storage/test_dark_storage_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import io
import os
from uuid import UUID

import hypothesis.strategies as st
import pandas as pd
import pytest
from hypothesis import assume, settings
from hypothesis.stateful import rule
from starlette.testclient import TestClient

from ert.dark_storage.app import app
from ert.dark_storage.enkf import update_storage
from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest


@settings(max_examples=1000)
class DarkStorageStateTest(StatefulStorageTest):
def __init__(self):
super().__init__()
self.prev_no_token = os.environ.get("ERT_STORAGE_NO_TOKEN")
self.prev_ens_path = os.environ.get("ERT_STORAGE_ENS_PATH")
os.environ["ERT_STORAGE_NO_TOKEN"] = "yup"
os.environ["ERT_STORAGE_ENS_PATH"] = str(self.storage.path)
update_storage()
self.client = TestClient(app)

@rule()
def get_experiments_through_client(self):
self.client.get("/updates/storage")
response = self.client.get("/experiments")
experiment_records = response.json()
assert len(experiment_records) == len(list(self.storage.experiments))
for record in experiment_records:
storage_experiment = self.storage.get_experiment(UUID(record["id"]))
assert {UUID(i) for i in record["ensemble_ids"]} == {
ens.id for ens in storage_experiment.ensembles
}

@rule(model_experiment=StatefulStorageTest.experiments)
def get_observations_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/observations")
assert {r["name"] for r in response.json()} == set(
model_experiment.observations.keys()
)

@rule(model_experiment=StatefulStorageTest.experiments)
def get_ensembles_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/ensembles")
assert {r["id"] for r in response.json()} == {
str(uuid) for uuid in model_experiment.ensembles
}

@rule(model_ensemble=StatefulStorageTest.ensembles)
def get_responses_through_client(self, model_ensemble):
response = self.client.get(f"/ensembles/{model_ensemble.uuid}/responses")
response_names = {
k for r in model_ensemble.response_values.values() for k in r["name"].values
}
assert set(response.json().keys()) == response_names

@rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data())
def get_response_csv_through_client(self, model_ensemble, data):
assume(model_ensemble.response_values)
print("Hit it!")
response_key, response_name = data.draw(
st.sampled_from(
[
(key, n)
for key, r in model_ensemble.response_values.items()
for n in r["name"].values
]
)
)
df = pd.read_parquet(
io.BytesIO(
self.client.get(
f"/ensembles/{model_ensemble.uuid}/records/{response_name}",
headers={"accept": "application/x-parquet"},
).content
)
)
assert set(df.columns) == set(
model_ensemble.response_values[response_key]
.sel(name=response_name)["time"]
.values
)

def teardown(self):
super().teardown()
if self.prev_no_token is not None:
os.environ["ERT_STORAGE_NO_TOKEN"] = self.prev_no_token
else:
del os.environ["ERT_STORAGE_NO_TOKEN"]
if self.prev_ens_path is not None:
os.environ["ERT_STORAGE_ENS_PATH"] = self.prev_ens_path
else:
del os.environ["ERT_STORAGE_ENS_PATH"]


TestDarkStorage = pytest.mark.integration_test(DarkStorageStateTest.TestCase)
7 changes: 5 additions & 2 deletions tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import polars
import pytest
import xarray as xr
from hypothesis import assume, given
from hypothesis import assume, given, note
from hypothesis.extra.numpy import arrays
from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule

Expand Down Expand Up @@ -579,6 +579,7 @@ def __init__(self):
super().__init__()
self.tmpdir = tempfile.mkdtemp(prefix="StatefulStorageTest")
self.storage = open_storage(self.tmpdir + "/storage/", "w")
note(f"storage path is: {self.storage.path}")
self.model: Dict[UUID, Experiment] = {}
assert list(self.storage.ensembles) == []

Expand Down Expand Up @@ -816,7 +817,9 @@ def create_ensemble_from_prior(self, prior: Ensemble):
if (
list(prior.response_values.keys())
== [r.name for r in model_experiment.responses]
and not iens in prior.failure_messages
and iens not in prior.failure_messages
and prior_ensemble.get_ensemble_state()[iens]
!= RealizationStorageState.PARENT_FAILURE
):
state[iens] = RealizationStorageState.UNDEFINED
assert ensemble.get_ensemble_state() == state
Expand Down

0 comments on commit bc8b1fb

Please sign in to comment.