Skip to content

Commit

Permalink
Escape slashes in plotapi
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Oct 10, 2024
1 parent bc8b1fb commit 0b3f534
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def data_for_key(
"summary", tuple(ensemble.get_realization_list_with_responses("summary"))
)
summary_keys = summary_data["response_key"].unique().to_list()
except (ValueError, KeyError):
except (ValueError, KeyError, polars.exceptions.ColumnNotFoundError):
summary_data = polars.DataFrame()
summary_keys = []

Expand Down
4 changes: 4 additions & 0 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from typing import Any, Dict, List, Mapping, Union
from urllib.parse import unquote
from uuid import UUID, uuid4

import numpy as np
Expand Down Expand Up @@ -34,6 +35,7 @@ async def get_record_observations(
ensemble_id: UUID,
response_name: str,
) -> List[js.ObservationOut]:
response_name = unquote(response_name)
ensemble = storage.get_ensemble(ensemble_id)
obs_keys = get_observation_keys_for_response(ensemble, response_name)
obss = get_observations_for_obs_keys(ensemble, obs_keys)
Expand Down Expand Up @@ -74,6 +76,7 @@ async def get_ensemble_record(
ensemble_id: UUID,
accept: Annotated[Union[str, None], Header()] = None,
) -> Any:
name = unquote(name)
dataframe = data_for_key(storage.get_ensemble(ensemble_id), name)
media_type = accept if accept is not None else "text/csv"
if media_type == "application/x-parquet":
Expand Down Expand Up @@ -153,6 +156,7 @@ def get_ensemble_responses(
def get_std_dev(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int
) -> Response:
key = unquote(key)
ensemble = storage.get_ensemble(ensemble_id)
try:
da = ensemble.calculate_std_dev_for_parameter(key)["values"]
Expand Down
14 changes: 11 additions & 3 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations as combi
from json.decoder import JSONDecodeError
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import quote

import httpx
import numpy as np
Expand Down Expand Up @@ -42,6 +43,10 @@ def __init__(self) -> None:
self._all_ensembles: Optional[List[EnsembleObject]] = None
self._timeout = 120

@staticmethod
def escape(s: str) -> str:
return quote(quote(s, safe=""))

def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]:
for ensemble in self.get_all_ensembles():
if ensemble.id == id:
Expand Down Expand Up @@ -162,8 +167,9 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:
return pd.DataFrame()

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}",
headers={"accept": "application/x-parquet"},
timeout=self._timeout,
)
Expand Down Expand Up @@ -195,8 +201,9 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram
continue

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/observations",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/observations",
timeout=self._timeout,
)
self._check_response(response)
Expand Down Expand Up @@ -260,8 +267,9 @@ def std_dev_for_parameter(
return np.array([])

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/std_dev",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/std_dev",
params={"z": z},
timeout=self._timeout,
)
Expand Down
30 changes: 19 additions & 11 deletions tests/ert/unit_tests/dark_storage/test_dark_storage_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
import os
from urllib.parse import quote
from uuid import UUID

import hypothesis.strategies as st
import pandas as pd
import pytest
from hypothesis import assume, settings
from hypothesis import assume
from hypothesis.stateful import rule
from starlette.testclient import TestClient

Expand All @@ -14,7 +15,10 @@
from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest


@settings(max_examples=1000)
def escape(s):
return quote(quote(quote(s, safe="")))


class DarkStorageStateTest(StatefulStorageTest):
def __init__(self):
super().__init__()
Expand All @@ -40,9 +44,11 @@ def get_experiments_through_client(self):
@rule(model_experiment=StatefulStorageTest.experiments)
def get_observations_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/observations")
assert {r["name"] for r in response.json()} == set(
model_experiment.observations.keys()
)
assert {r["name"] for r in response.json()} == {
key
for _, ds in model_experiment.observations.items()
for key in ds["observation_key"]
}

@rule(model_experiment=StatefulStorageTest.experiments)
def get_ensembles_through_client(self, model_experiment):
Expand All @@ -55,14 +61,15 @@ def get_ensembles_through_client(self, model_experiment):
def get_responses_through_client(self, model_ensemble):
response = self.client.get(f"/ensembles/{model_ensemble.uuid}/responses")
response_names = {
k for r in model_ensemble.response_values.values() for k in r["name"].values
k
for r in model_ensemble.response_values.values()
for k in r["response_key"]
}
assert set(response.json().keys()) == response_names

@rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data())
def get_response_csv_through_client(self, model_ensemble, data):
assume(model_ensemble.response_values)
print("Hit it!")
response_key, response_name = data.draw(
st.sampled_from(
[
Expand All @@ -75,16 +82,17 @@ def get_response_csv_through_client(self, model_ensemble, data):
df = pd.read_parquet(
io.BytesIO(
self.client.get(
f"/ensembles/{model_ensemble.uuid}/records/{response_name}",
f"/ensembles/{model_ensemble.uuid}/records/{escape(response_name)}",
headers={"accept": "application/x-parquet"},
).content
)
)
assert set(df.columns) == set(
model_ensemble.response_values[response_key]
assert {dt[:10] for dt in df.columns} == {
str(dt)[:10]
for dt in model_ensemble.response_values[response_key]
.sel(name=response_name)["time"]
.values
)
}

def teardown(self):
super().teardown()
Expand Down
6 changes: 3 additions & 3 deletions tests/ert/unit_tests/gui/tools/plot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs):

records = {
"/ensembles/ens_id_3/records/FOPR": summary_parquet_data,
"/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data,
"/ensembles/ens_id_3/records/FOPRH": history_parquet_data,
}
Expand Down
91 changes: 89 additions & 2 deletions tests/ert/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
from datetime import datetime
from textwrap import dedent
from urllib.parse import quote

import httpx
import pandas as pd
import polars
import pytest
from pandas.testing import assert_frame_equal

from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition
from starlette.testclient import TestClient

from ert.config import SummaryConfig
from ert.dark_storage.app import app
from ert.dark_storage.enkf import update_storage
from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition
from ert.services import StorageService
from ert.storage import open_storage
from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse


@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)

def test_escape(s: str) -> str:
"""
Workaround for issue with TestClient:
https://github.com/encode/starlette/issues/1060
"""
print("TESTESCAPING")
return quote(quote(quote(s, safe="")))

PlotApi.escape = test_escape


def test_key_def_structure(api):
key_defs = api.all_data_type_keys()
fopr = next(x for x in key_defs if x.key == "FOPR")
Expand Down Expand Up @@ -146,3 +173,63 @@ def test_plot_api_request_errors(api):

with pytest.raises(httpx.RequestError):
api.data_for_key(ensemble.id, "should_not_be_there")


def test_plot_api_handles_urlescape(tmp_path, monkeypatch):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
update_storage()
api = PlotApi()
key = "WBHP:46/3-7S"
date = datetime(year=2024, month=10, day=4)
experiment = storage.create_experiment(
parameters=[],
responses=[
SummaryConfig(
name="summary",
input_files=["CASE.UNSMRY", "CASE.SMSPEC"],
keys=[key],
)
],
observations={
"summary": polars.DataFrame(
{
"response_key": key,
"observation_key": "sumobs",
"time": polars.Series([date]).dt.cast_time_unit("ms"),
"observations": polars.Series([1.0], dtype=polars.Float32),
"std": polars.Series([1.0], dtype=polars.Float32),
}
)
},
)
ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble")
assert api.data_for_key(str(ensemble.id), key).empty
df = polars.DataFrame(
{
"response_key": [key],
"time": [polars.Series([date]).dt.cast_time_unit("ms")],
"values": [polars.Series([1.0], dtype=polars.Float32)],
}
)
df = df.explode("values", "time")
ensemble.save_response(
"summary",
df,
0,
)
assert api.data_for_key(str(ensemble.id), key).to_csv() == dedent(
"""\
Realization,2024-10-04
0,1.0
"""
)
assert api.observations_for_key([str(ensemble.id)], key).to_csv() == dedent(
"""\
,0
STD,1.0
OBS,1.0
key_index,2024-10-04 00:00:00
"""
)
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ class Experiment:
ensembles: Dict[UUID, Ensemble] = field(default_factory=dict)
parameters: List[ParameterConfig] = field(default_factory=list)
responses: List[ResponseConfig] = field(default_factory=list)
observations: Dict[str, xr.Dataset] = field(default_factory=dict)
observations: Dict[str, polars.DataFrame] = field(default_factory=dict)


class StatefulStorageTest(RuleBasedStateMachine):
Expand Down

0 comments on commit 0b3f534

Please sign in to comment.