Skip to content

Commit

Permalink
Rewrite gui tests to have a clean app/storage for each test
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Mar 15, 2024
1 parent a195daa commit 77e3139
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 169 deletions.
290 changes: 172 additions & 118 deletions tests/unit_tests/gui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import shutil
import stat
import time
from contextlib import contextmanager
from datetime import datetime as dt
from textwrap import dedent
from typing import List, Type, TypeVar
from typing import Generator, List, Tuple, Type, TypeVar
from unittest.mock import MagicMock, Mock

import pytest
Expand Down Expand Up @@ -36,7 +37,6 @@
from ert.gui.ertwidgets.ensembleselector import EnsembleSelector
from ert.gui.ertwidgets.storage_widget import StorageWidget
from ert.gui.main import ErtMainWindow, GUILogHandler, _setup_main_window
from ert.gui.simulation.ensemble_experiment_panel import EnsembleExperimentPanel
from ert.gui.simulation.run_dialog import RunDialog
from ert.gui.simulation.simulation_panel import SimulationPanel
from ert.gui.simulation.view import RealizationWidget
Expand All @@ -46,7 +46,7 @@
)
from ert.run_models import EnsembleExperiment, MultipleDataAssimilation
from ert.services import StorageService
from ert.storage import open_storage
from ert.storage import Storage, open_storage
from tests.unit_tests.gui.simulation.test_run_path_dialog import handle_run_path_dialog


Expand All @@ -61,85 +61,174 @@ def handle_manage_dialog():
manage_tool.trigger()


@pytest.fixture(name="opened_main_window", scope="module")
def opened_main_window_fixture(source_root, tmpdir_factory) -> ErtMainWindow:
with pytest.MonkeyPatch.context() as mp:
tmp_path = tmpdir_factory.mktemp("test-data")
shutil.copytree(
os.path.join(source_root, "test-data", "poly_example"),
tmp_path / "test_data",
@pytest.fixture
def opened_main_window(
source_root, tmp_path, monkeypatch
) -> Generator[ErtMainWindow, None, None]:
monkeypatch.chdir(tmp_path)
_new_poly_example(source_root, tmp_path)
with _open_main_window(tmp_path) as (
gui,
storage,
config,
), StorageService.init_service(
project=os.path.abspath(config.ens_path),
):
_add_default_ensemble(storage, gui, config)
yield gui


def _new_poly_example(source_root, destination):
shutil.copytree(
os.path.join(source_root, "test-data", "poly_example"),
destination,
dirs_exist_ok=True,
)

with fileinput.input(destination / "poly.ert", inplace=True) as fin:
for line in fin:
if "NUM_REALIZATIONS" in line:
# Decrease the number of realizations to speed up the test,
# if there is flakyness, this can be increased.
print("NUM_REALIZATIONS 20", end="\n")
else:
print(line, end="")


def _add_default_ensemble(storage: Storage, gui: ErtMainWindow, config: ErtConfig):
gui.notifier.set_current_ensemble(
storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
observations=config.observations,
).create_ensemble(
name="default",
ensemble_size=config.model_config.num_realizations,
)
mp.chdir(tmp_path / "test_data")
with fileinput.input("poly.ert", inplace=True) as fin:
for line in fin:
if "NUM_REALIZATIONS" in line:
# Decrease the number of realizations to speed up the test,
# if there is flakyness, this can be increased.
print("NUM_REALIZATIONS 20", end="\n")
else:
print(line, end="")
config = ErtConfig.from_file("poly.ert")
poly_case = EnKFMain(config)
args_mock = Mock()
args_mock.config = "poly.ert"

with StorageService.init_service(
project=os.path.abspath(config.ens_path),
), open_storage(config.ens_path, mode="w") as storage:
gui = _setup_main_window(poly_case, args_mock, GUILogHandler(), storage)
gui.notifier.set_current_ensemble(
storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
observations=config.observations,
).create_ensemble(
name="default",
ensemble_size=config.model_config.num_realizations,
)
)
yield gui
gui.close()
)


@contextmanager
def _open_main_window(
path,
) -> Generator[Tuple[ErtMainWindow, Storage, ErtConfig], None, None]:
config = ErtConfig.from_file(path / "poly.ert")
poly_case = EnKFMain(config)

args_mock = Mock()
args_mock.config = "poly.ert"
with open_storage(config.ens_path, mode="w") as storage:
gui = _setup_main_window(poly_case, args_mock, GUILogHandler(), storage)
yield gui, storage, config
gui.close()


@pytest.fixture
def opened_main_window_clean(source_root, tmpdir):
with pytest.MonkeyPatch.context() as mp:
shutil.copytree(
os.path.join(source_root, "test-data", "poly_example"),
tmpdir / "test_data",
)
mp.chdir(tmpdir / "test_data")
def opened_main_window_clean(source_root, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
_new_poly_example(source_root, tmp_path)
with _open_main_window(tmp_path) as (gui, _, config), StorageService.init_service(
project=os.path.abspath(config.ens_path),
):
yield gui

with fileinput.input("poly.ert", inplace=True) as fin:
for line in fin:
if "NUM_REALIZATIONS" in line:
# Decrease the number of realizations to speed up the test,
# if there is flakyness, this can be increased.
print("NUM_REALIZATIONS 20", end="\n")
else:
print(line, end="")

poly_case = EnKFMain(ErtConfig.from_file("poly.ert"))
args_mock = Mock()
args_mock.config = "poly.ert"
@pytest.fixture(scope="module")
def _esmda_run(run_experiment, source_root, tmp_path_factory):
path = tmp_path_factory.mktemp("test-data")
_new_poly_example(source_root, path)
with pytest.MonkeyPatch.context() as mp, _open_main_window(path) as (
gui,
storage,
config,
):
mp.chdir(path)
_add_default_ensemble(storage, gui, config)
run_experiment(MultipleDataAssimilation, gui)

with StorageService.init_service(
project=os.path.abspath(poly_case.ert_config.ens_path),
), open_storage(poly_case.ert_config.ens_path, mode="w") as storage:
gui = _setup_main_window(poly_case, args_mock, GUILogHandler(), storage)
yield gui
return path


@pytest.fixture(scope="module")
def esmda_has_run(run_experiment):
# Runs a default ES-MDA run
run_experiment(MultipleDataAssimilation)
def _ensemble_experiment_run(run_experiment, source_root, tmp_path_factory):
path = tmp_path_factory.mktemp("test-data")
_new_poly_example(source_root, path)
with pytest.MonkeyPatch.context() as mp, _open_main_window(path) as (
gui,
storage,
config,
):
mp.chdir(path)
with open("poly_eval.py", "w", encoding="utf-8") as f:
f.write(
dedent(
"""\
#!/usr/bin/env python3
import numpy as np
import sys
import json
def _load_coeffs(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)["COEFFS"]
def _evaluate(coeffs, x):
return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"]
if __name__ == "__main__":
if np.random.random(1) > 0.5:
sys.exit(1)
coeffs = _load_coeffs("parameters.json")
output = [_evaluate(coeffs, x) for x in range(10)]
with open("poly.out", "w", encoding="utf-8") as f:
f.write("\\n".join(map(str, output)))
"""
)
)
os.chmod(
"poly_eval.py",
os.stat("poly_eval.py").st_mode
| stat.S_IXUSR
| stat.S_IXGRP
| stat.S_IXOTH,
)
_add_default_ensemble(storage, gui, config)
run_experiment(EnsembleExperiment, gui)

return path


@pytest.fixture
def esmda_has_run(_esmda_run, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
shutil.copytree(_esmda_run, tmp_path, dirs_exist_ok=True)
with _open_main_window(tmp_path) as (
gui,
_,
config,
), StorageService.init_service(
project=os.path.abspath(config.ens_path),
):
yield gui


@pytest.fixture
def ensemble_experiment_has_run(_ensemble_experiment_run, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
shutil.copytree(_ensemble_experiment_run, tmp_path, dirs_exist_ok=True)
with _open_main_window(tmp_path) as (
gui,
_,
config,
), StorageService.init_service(
project=os.path.abspath(config.ens_path),
):
yield gui


@pytest.fixture(name="run_experiment", scope="module")
def run_experiment_fixture(request, opened_main_window):
def func(experiment_mode):
def run_experiment_fixture(request):
def func(experiment_mode, gui):
qtbot = QtBot(request)
gui = opened_main_window
with contextlib.suppress(FileNotFoundError):
shutil.rmtree("poly_out")
# Select correct experiment in the simulation panel
Expand All @@ -148,6 +237,11 @@ def func(experiment_mode):
simulation_mode_combo = simulation_panel.findChild(QComboBox)
assert isinstance(simulation_mode_combo, QComboBox)
simulation_mode_combo.setCurrentText(experiment_mode.name())
simulation_settings = simulation_panel._simulation_widgets[
simulation_panel.getCurrentSimulationModel()
]
if hasattr(simulation_settings, "_ensemble_name_field"):
simulation_settings._ensemble_name_field.setText("iter-0")

# Click start simulation and agree to the message
start_simulation = simulation_panel.findChild(QWidget, name="start_simulation")
Expand All @@ -158,10 +252,14 @@ def handle_dialog():
)

QTimer.singleShot(
500, lambda: handle_run_path_dialog(gui, qtbot, delete_run_path=False)
500,
lambda: handle_run_path_dialog(gui, qtbot, delete_run_path=False),
)

if not experiment_mode.name() in ("Ensemble experiment", "Evaluate ensemble"):
if not experiment_mode.name() in (
"Ensemble experiment",
"Evaluate ensemble",
):
QTimer.singleShot(500, handle_dialog)
qtbot.mouseClick(start_simulation, Qt.LeftButton)

Expand Down Expand Up @@ -190,50 +288,6 @@ def handle_dialog():
return func


@pytest.fixture(scope="module")
def ensemble_experiment_has_run(opened_main_window, run_experiment, request):
gui = opened_main_window

simulation_panel = get_child(gui, SimulationPanel)
simulation_mode_combo = get_child(simulation_panel, QComboBox)
simulation_settings = get_child(simulation_panel, EnsembleExperimentPanel)
simulation_mode_combo.setCurrentText(EnsembleExperiment.name())

simulation_settings._ensemble_name_field.setText("iter-0")

with open("poly_eval.py", "w", encoding="utf-8") as f:
f.write(
dedent(
"""\
#!/usr/bin/env python
import numpy as np
import sys
import json
def _load_coeffs(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)["COEFFS"]
def _evaluate(coeffs, x):
return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"]
if __name__ == "__main__":
if np.random.random(1) > 0.5:
sys.exit(1)
coeffs = _load_coeffs("parameters.json")
output = [_evaluate(coeffs, x) for x in range(10)]
with open("poly.out", "w", encoding="utf-8") as f:
f.write("\\n".join(map(str, output)))
"""
)
)
os.chmod(
"poly_eval.py",
os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH,
)
run_experiment(EnsembleExperiment)


@pytest.fixture()
def full_snapshot() -> Snapshot:
real = RealizationSnapshot(
Expand Down Expand Up @@ -301,8 +355,8 @@ def large_snapshot() -> Snapshot:
status=FORWARD_MODEL_STATE_START,
stdout=f"job_{i}.stdout",
stderr=f"job_{i}.stderr",
start_time=dt(1999, 1, 1).isoformat(),
end_time=dt(2019, 1, 1).isoformat(),
start_time=dt(1999, 1, 1),
end_time=dt(2019, 1, 1),
)
real_ids = [str(i) for i in range(0, 150)]
return builder.build(real_ids, REALIZATION_STATE_UNKNOWN)
Expand All @@ -321,8 +375,8 @@ def small_snapshot() -> Snapshot:
status=FORWARD_MODEL_STATE_START,
stdout=f"job_{i}.stdout",
stderr=f"job_{i}.stderr",
start_time=dt(1999, 1, 1).isoformat(),
end_time=dt(2019, 1, 1).isoformat(),
start_time=dt(1999, 1, 1),
end_time=dt(2019, 1, 1),
)
real_ids = [str(i) for i in range(0, 5)]
return builder.build(real_ids, REALIZATION_STATE_UNKNOWN)
Expand Down
Loading

0 comments on commit 77e3139

Please sign in to comment.