diff --git a/src/ert/cli/workflow.py b/src/ert/cli/workflow.py index c41806dda56..c4dce0acdaf 100644 --- a/src/ert/cli/workflow.py +++ b/src/ert/cli/workflow.py @@ -1,8 +1,10 @@ from __future__ import annotations import logging +from pathlib import Path from typing import TYPE_CHECKING +from ert.runpaths import Runpaths from ert.workflow_runner import WorkflowRunner if TYPE_CHECKING: @@ -20,7 +22,26 @@ def execute_workflow( msg = "Workflow {} is not in the list of available workflows" logger.error(msg.format(workflow_name)) return - runner = WorkflowRunner(workflow=workflow, storage=storage, ert_config=ert_config) + + runner = WorkflowRunner( + workflow=workflow, + fixtures={ + "storage": storage, + "random_seed": ert_config.random_seed, + "reports_dir": str( + storage.path.parent / "reports" / Path(ert_config.user_config_file).stem + ), + "observation_settings": ert_config.analysis_config.observation_settings, + "es_settings": ert_config.analysis_config.es_module, + "run_paths": Runpaths( + jobname_format=ert_config.model_config.jobname_format_string, + runpath_format=ert_config.model_config.runpath_format_string, + filename=str(ert_config.runpath_file), + substitutions=ert_config.substitutions, + eclbase=ert_config.model_config.eclbase_format_string, + ), + }, + ) runner.run_blocking() if not all(v["completed"] for v in runner.workflowReport().values()): logger.error(f"Workflow {workflow_name} failed!") diff --git a/src/ert/config/ert_script.py b/src/ert/config/ert_script.py index 2248424a7d3..2ab8d019239 100644 --- a/src/ert/config/ert_script.py +++ b/src/ert/config/ert_script.py @@ -5,13 +5,12 @@ import logging import sys import traceback -import warnings from abc import abstractmethod from collections.abc import Callable -from types import MappingProxyType, ModuleType +from types import ModuleType from typing import TYPE_CHECKING, Any, TypeAlias -from typing_extensions import deprecated +from ert.config.workflow_fixtures import WorkflowFixtures if TYPE_CHECKING: from ert.config import ErtConfig @@ -39,11 +38,6 @@ def __init__( self._stdoutdata = "" self._stderrdata = "" - # Deprecated: - self._ert = None - self._ensemble = None - self._storage = None - @abstractmethod def run(self, *arg: Any, **kwarg: Any) -> Any: """ @@ -71,31 +65,6 @@ def stderrdata(self) -> str: self._stderrdata = self._stderrdata.decode() return self._stderrdata - @deprecated("Use fixtures to the run function instead") - def ert(self) -> ErtConfig | None: - logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}") - return self._ert - - @property - def ensemble(self) -> Ensemble | None: - warnings.warn( - "The ensemble property is deprecated, use the fixture to the run function instead", - DeprecationWarning, - stacklevel=1, - ) - logger.info(f"Accessing ensemble from workflow: {self.__class__.__name__}") - return self._ensemble - - @property - def storage(self) -> Storage | None: - warnings.warn( - "The storage property is deprecated, use the fixture to the run function instead", - DeprecationWarning, - stacklevel=1, - ) - logger.info(f"Accessing storage from workflow: {self.__class__.__name__}") - return self._storage - def isCancelled(self) -> bool: return self.__is_cancelled @@ -114,36 +83,70 @@ def initializeAndRun( self, argument_types: list[type[Any]], argument_values: list[str], - fixtures: dict[str, Any] | None = None, + fixtures: WorkflowFixtures | None = None, + **kwargs: dict[str, Any], ) -> Any: fixtures = {} if fixtures is None else fixtures - arguments = [] + workflow_args = [] for index, arg_value in enumerate(argument_values): arg_type = argument_types[index] if index < len(argument_types) else str if arg_value is not None: - arguments.append(arg_type(arg_value)) + workflow_args.append(arg_type(arg_value)) else: - arguments.append(None) - fixtures["workflow_args"] = arguments + workflow_args.append(None) + + fixtures["workflow_args"] = workflow_args + + fixture_args = [] + all_func_args = inspect.signature(self.run).parameters + is_using_wf_args_fixture = "workflow_args" in all_func_args + try: - func_args = inspect.signature(self.run).parameters + if not is_using_wf_args_fixture: + fixture_or_kw_arguments = list(all_func_args)[len(workflow_args) :] + else: + fixture_or_kw_arguments = list(all_func_args) + + func_args = {k: all_func_args[k] for k in fixture_or_kw_arguments} + + kwargs_defaults = { + k: v.default + for k, v in func_args.items() + if k not in fixtures + and v.kind != v.VAR_POSITIONAL + and not str(v).startswith("*") + and v.default != v.empty + } + use_kwargs = { + k: (kwargs or {}).get(k, default_value) + for k, default_value in ({**kwargs_defaults, **kwargs}).items() + } # If the user has specified *args, we skip injecting fixtures, and just # pass the user configured arguments if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()): try: - arguments = self.insert_fixtures(func_args, fixtures) + fixture_args = self.insert_fixtures(func_args, fixtures, use_kwargs) except ValueError as e: # This is here for backwards compatibility, the user does not have *argv # but positional arguments. Can not be mixed with using fixtures. logger.warning( f"Mixture of fixtures and positional arguments, err: {e}" ) - # Part of deprecation - self._ert = fixtures.get("ert_config") - self._ensemble = fixtures.get("ensemble") - self._storage = fixtures.get("storage") - return self.run(*arguments) + + positional_args = ( + fixture_args + if is_using_wf_args_fixture + else [*workflow_args, *fixture_args] + ) + if not positional_args and not use_kwargs: + return self.run() + elif positional_args and not use_kwargs: + return self.run(*positional_args) + elif not positional_args and use_kwargs: + return self.run(**use_kwargs) + else: + return self.run(*positional_args, **use_kwargs) except AttributeError as e: error_msg = str(e) if not hasattr(self, "run"): @@ -169,20 +172,22 @@ def initializeAndRun( def insert_fixtures( self, - func_args: MappingProxyType[str, inspect.Parameter], - fixtures: dict[str, Fixtures], + func_args: dict[str, inspect.Parameter], + fixtures: WorkflowFixtures, + kwargs: dict[str, Any], ) -> list[Any]: arguments = [] errors = [] for val in func_args: if val in fixtures: - arguments.append(fixtures[val]) - else: + arguments.append(fixtures.get(val)) + elif val not in kwargs: errors.append(val) if errors: + kwargs_str = ",".join(f"{k}='{v}'" for k, v in kwargs.items()) raise ValueError( f"Plugin: {self.__class__.__name__} misconfigured, arguments: {errors} " - f"not found in fixtures: {list(fixtures)}" + f"not found in fixtures: {list(fixtures)} or kwargs {kwargs_str}" ) return arguments diff --git a/src/ert/config/workflow_fixtures.py b/src/ert/config/workflow_fixtures.py new file mode 100644 index 00000000000..af708b5e251 --- /dev/null +++ b/src/ert/config/workflow_fixtures.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from PyQt6.QtWidgets import QWidget +from typing_extensions import TypedDict + +if TYPE_CHECKING: + from ert.config import ESSettings, UpdateSettings + from ert.runpaths import Runpaths + from ert.storage import Ensemble, Storage + + +class WorkflowFixtures(TypedDict, total=False): + ensemble: Ensemble + storage: Storage + random_seed: int | None + reports_dir: str + observation_settings: UpdateSettings + es_settings: ESSettings + run_paths: Runpaths + workflow_args: list[Any] + parent: QWidget diff --git a/src/ert/gui/tools/export/exporter.py b/src/ert/gui/tools/export/exporter.py index 1c405847bf7..95480601663 100644 --- a/src/ert/gui/tools/export/exporter.py +++ b/src/ert/gui/tools/export/exporter.py @@ -29,7 +29,7 @@ def run_export(self, parameters: list[Any]) -> None: export_job_runner = WorkflowJobRunner(self.export_job) user_warn = export_job_runner.run( - fixtures={"storage": self._notifier.storage, "ert_config": self.config}, + fixtures={"storage": self._notifier.storage}, arguments=parameters, ) if export_job_runner.hasFailed(): diff --git a/src/ert/gui/tools/plugins/plugin.py b/src/ert/gui/tools/plugins/plugin.py index 3795cb2b77b..cc8ec50cbfa 100644 --- a/src/ert/gui/tools/plugins/plugin.py +++ b/src/ert/gui/tools/plugins/plugin.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any from ert import ErtScript +from ert.config.workflow_fixtures import WorkflowFixtures if TYPE_CHECKING: from PyQt6.QtWidgets import QWidget @@ -34,7 +35,7 @@ def getName(self) -> str: def getDescription(self) -> str: return self.__description - def getArguments(self, fixtures: dict[str, Any]) -> list[Any]: + def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]: """ Returns a list of arguments. Either from GUI or from arbitrary code. If the user for example cancels in the GUI a CancelPluginException is raised. @@ -42,12 +43,8 @@ def getArguments(self, fixtures: dict[str, Any]) -> list[Any]: script = self.__loadPlugin() fixtures["parent"] = self.__parent_window func_args = inspect.signature(script.getArguments).parameters - arguments = script.insert_fixtures(func_args, fixtures) + arguments = script.insert_fixtures(dict(func_args), fixtures, {}) - # Part of deprecation - script._ert = fixtures.get("ert_config") - script._ensemble = fixtures.get("ensemble") - script._storage = fixtures.get("storage") return script.getArguments(*arguments) def setParentWindow(self, parent_window: QWidget | None) -> None: diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 1fea02f6edc..c4415594830 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -2,16 +2,18 @@ import time from collections.abc import Callable +from pathlib import Path from typing import TYPE_CHECKING, Any from _ert.threading import ErtThread -from ert.config import CancelPluginException +from ert.config import CancelPluginException, ErtConfig +from ert.config.workflow_fixtures import WorkflowFixtures +from ert.runpaths import Runpaths from ert.workflow_runner import WorkflowJobRunner from .process_job_dialog import ProcessJobDialog if TYPE_CHECKING: - from ert.config import ErtConfig from ert.storage import LocalStorage from .plugin import Plugin @@ -22,10 +24,10 @@ def __init__( self, plugin: Plugin, ert_config: ErtConfig, storage: LocalStorage ) -> None: super().__init__() + self.ert_config = ert_config self.storage = storage self.__plugin = plugin - self.__plugin_finished_callback: Callable[[], None] = lambda: None self.__result = None @@ -33,11 +35,28 @@ def __init__( self.poll_thread: ErtThread | None = None def run(self) -> None: + ert_config = self.ert_config try: plugin = self.__plugin - arguments = plugin.getArguments( - fixtures={"storage": self.storage, "ert_config": self.ert_config} + fixtures={ + "storage": self.storage, + "random_seed": ert_config.random_seed, + "reports_dir": str( + self.storage.path.parent + / "reports" + / Path(ert_config.user_config_file).stem + ), + "observation_settings": ert_config.analysis_config.observation_settings, + "es_settings": ert_config.analysis_config.es_module, + "run_paths": Runpaths( + jobname_format=ert_config.model_config.jobname_format_string, + runpath_format=ert_config.model_config.runpath_format_string, + filename=str(ert_config.runpath_file), + substitutions=ert_config.substitutions, + eclbase=ert_config.model_config.eclbase_format_string, + ), + } ) dialog = ProcessJobDialog(plugin.getName(), plugin.getParentWindow()) dialog.setObjectName("process_job_dialog") @@ -45,7 +64,7 @@ def run(self) -> None: dialog.cancelConfirmed.connect(self.cancel) fixtures = { k: getattr(self, k) - for k in ["storage", "ert_config"] + for k in ["storage", "run_paths"] if getattr(self, k) } workflow_job_thread = ErtThread( @@ -71,7 +90,7 @@ def run(self) -> None: print("Plugin cancelled before execution!") def __runWorkflowJob( - self, arguments: list[Any] | None, fixtures: dict[str, Any] + self, arguments: list[Any] | None, fixtures: WorkflowFixtures ) -> None: self.__result = self._runner.run(arguments, fixtures=fixtures) diff --git a/src/ert/gui/tools/workflows/run_workflow_widget.py b/src/ert/gui/tools/workflows/run_workflow_widget.py index 2fe3157f2c7..795da59faa8 100644 --- a/src/ert/gui/tools/workflows/run_workflow_widget.py +++ b/src/ert/gui/tools/workflows/run_workflow_widget.py @@ -2,6 +2,7 @@ import time from collections.abc import Iterable +from pathlib import Path from typing import TYPE_CHECKING from PyQt6.QtCore import QSize, Qt @@ -20,6 +21,7 @@ from _ert.threading import ErtThread from ert.gui.ertwidgets import EnsembleSelector from ert.gui.tools.workflows.workflow_dialog import WorkflowDialog +from ert.runpaths import Runpaths from ert.workflow_runner import WorkflowRunner if TYPE_CHECKING: @@ -125,9 +127,26 @@ def startWorkflow(self) -> None: workflow = self.config.workflows[self.getCurrentWorkflowName()] self._workflow_runner = WorkflowRunner( - workflow, - storage=self.storage, - ensemble=self.source_ensemble_selector.currentData(), + workflow=workflow, + fixtures={ + "ensemble": self.source_ensemble_selector.currentData(), + "storage": self.storage, + "random_seed": self.config.random_seed, + "reports_dir": str( + self.storage.path.parent + / "reports" + / Path(self.config.user_config_file).stem + ), + "observation_settings": self.config.analysis_config.observation_settings, + "es_settings": self.config.analysis_config.es_module, + "run_paths": Runpaths( + jobname_format=self.config.model_config.jobname_format_string, + runpath_format=self.config.model_config.runpath_format_string, + filename=str(self.config.runpath_file), + substitutions=self.config.substitutions, + eclbase=self.config.model_config.eclbase_format_string, + ), + }, ) self._workflow_runner.run() diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index 37b59792ee1..712f0f3c38b 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -98,9 +98,6 @@ def get_field_parameters(self) -> list[str]: if isinstance(val, Field) ] - def get_gen_kw(self) -> list[str]: - return self.config.ensemble_config.get_keylist_gen_kw() - def get_ensemble_size(self) -> int: return self.config.model_config.num_realizations @@ -210,6 +207,7 @@ def run_ertscript( # type: ignore storage: Storage, ensemble: Ensemble, *args: Any, + **kwargs: dict[str, Any], ) -> Any: warnings.warn( "run_ertscript is deprecated, use the workflow runner", @@ -220,10 +218,19 @@ def run_ertscript( # type: ignore [], argument_values=args, fixtures={ - "ert_config": self.config, - "ensemble": ensemble, "storage": storage, + "ensemble": ensemble, + "reports_dir": ( + storage.path.parent + / "reports" + / Path(str(self.user_config_file)).stem + / ensemble.name + ), + "observation_settings": self.config.analysis_config.observation_settings, + "es_settings": self.config.analysis_config.es_module, + "random_seed": self.config.random_seed, }, + **kwargs, ) @classmethod diff --git a/src/ert/plugins/hook_implementations/workflows/csv_export.py b/src/ert/plugins/hook_implementations/workflows/csv_export.py index e271c27af89..eb8401ae4a5 100644 --- a/src/ert/plugins/hook_implementations/workflows/csv_export.py +++ b/src/ert/plugins/hook_implementations/workflows/csv_export.py @@ -1,13 +1,17 @@ import json import os from collections.abc import Sequence +from typing import TYPE_CHECKING import pandas as pd +import polars as pl from ert import ErtScript, LibresFacade -from ert.config import ErtConfig from ert.storage import Storage +if TYPE_CHECKING: + from ert.storage import Ensemble + def loadDesignMatrix(filename: str) -> pd.DataFrame: dm = pd.read_csv(filename, delim_whitespace=True) @@ -52,26 +56,23 @@ def getDescription() -> str: def run( self, - ert_config: ErtConfig, storage: Storage, workflow_args: Sequence[str], ) -> str: output_file = workflow_args[0] ensemble_data_as_json = None if len(workflow_args) < 2 else workflow_args[1] design_matrix_path = None if len(workflow_args) < 3 else workflow_args[2] - _ = True if len(workflow_args) < 4 else workflow_args[3] drop_const_cols = False if len(workflow_args) < 5 else workflow_args[4] - facade = LibresFacade(ert_config) ensemble_data_as_dict = ( json.loads(ensemble_data_as_json) if ensemble_data_as_json else {} ) # Use the keys (UUIDs as strings) to get ensembles - ensembles = [] + ensembles: list[Ensemble] = [] for ensemble_id in ensemble_data_as_dict: - assert self.storage is not None - ensemble = self.storage.get_ensemble(ensemble_id) + assert storage is not None + ensemble = storage.get_ensemble(ensemble_id) ensembles.append(ensemble) if design_matrix_path is not None: @@ -96,13 +97,20 @@ def run( if not design_matrix_data.empty: ensemble_data = ensemble_data.join(design_matrix_data, how="outer") - misfit_data = facade.load_all_misfit_data(ensemble) + misfit_data = LibresFacade.load_all_misfit_data(ensemble) if not misfit_data.empty: ensemble_data = ensemble_data.join(misfit_data, how="outer") + realizations = ensemble.get_realization_list_with_responses() + + try: + summary_data = ensemble.load_responses("summary", tuple(realizations)) + except (KeyError, ValueError): + summary_data = pl.DataFrame({}) - summary_data = ensemble.load_all_summary_data() - if not summary_data.empty: - ensemble_data = ensemble_data.join(summary_data, how="outer") + if not summary_data.is_empty(): + ensemble_data = ensemble_data.join( + summary_data.to_pandas(), how="outer" + ) else: ensemble_data["Date"] = None ensemble_data.set_index(["Date"], append=True, inplace=True) @@ -114,8 +122,8 @@ def run( ) data = pd.concat([data, ensemble_data]) - - data = data.reorder_levels(["Realization", "Iteration", "Date", "Ensemble"]) + if not data.empty: + data = data.reorder_levels(["Realization", "Iteration", "Date", "Ensemble"]) if drop_const_cols: data = data.loc[:, (data != data.iloc[0]).any()] diff --git a/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py b/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py index f7e44887a0c..77fc662b6d0 100644 --- a/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py +++ b/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py @@ -8,7 +8,6 @@ from ert.exceptions import StorageError if TYPE_CHECKING: - from ert.config import ErtConfig from ert.storage import Ensemble @@ -23,17 +22,14 @@ class ExportMisfitDataJob(ErtScript): ((response_value - observation_data) / observation_std)**2 """ - def run( - self, ert_config: ErtConfig, ensemble: Ensemble, workflow_args: list[Any] - ) -> None: + def run(self, ensemble: Ensemble, workflow_args: list[Any]) -> None: target_file = "misfit.hdf" if not workflow_args else workflow_args[0] realizations = ensemble.get_realization_list_with_responses() from ert import LibresFacade # noqa: PLC0415 (circular import) - facade = LibresFacade(ert_config) - misfit = facade.load_all_misfit_data(ensemble) + misfit = LibresFacade.load_all_misfit_data(ensemble) if len(realizations) == 0 or misfit.empty: raise StorageError("No responses loaded") misfit.columns = pd.Index([val.split(":")[1] for val in misfit.columns]) diff --git a/src/ert/plugins/hook_implementations/workflows/export_runpath.py b/src/ert/plugins/hook_implementations/workflows/export_runpath.py index 18e78be5f5e..b951f9a61d1 100644 --- a/src/ert/plugins/hook_implementations/workflows/export_runpath.py +++ b/src/ert/plugins/hook_implementations/workflows/export_runpath.py @@ -7,7 +7,7 @@ from ert.validation import rangestring_to_list if TYPE_CHECKING: - from ert.config import ErtConfig + from ert.storage import Ensemble class ExportRunpathJob(ErtScript): @@ -33,20 +33,18 @@ class ExportRunpathJob(ErtScript): file. """ - def run(self, ert_config: ErtConfig, workflow_args: list[Any]) -> None: + def run( + self, run_paths: Runpaths, ensemble: Ensemble, workflow_args: list[Any] + ) -> None: args = " ".join(workflow_args).split() # Make sure args is a list of words - run_paths = Runpaths( - jobname_format=ert_config.model_config.jobname_format_string, - runpath_format=ert_config.model_config.runpath_format_string, - filename=str(ert_config.runpath_file), - substitutions=ert_config.substitutions, - eclbase=ert_config.model_config.eclbase_format_string, - ) + assert ensemble + iter = ensemble.iteration + reals = ensemble.ensemble_size run_paths.write_runpath_list( *self.get_ranges( args, - ert_config.analysis_config.num_iterations, - ert_config.model_config.num_realizations, + iter, + reals, ) ) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 6389e091912..d5263ce7951 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -58,6 +58,7 @@ from ert.trace import tracer from ert.workflow_runner import WorkflowRunner +from ..config.workflow_fixtures import WorkflowFixtures from ..run_arg import RunArg from .event import ( AnalysisStatusEvent, @@ -181,13 +182,20 @@ def __init__( self.start_iteration = start_iteration self.restart = False + def reports_dir(self, ensemble_name: str) -> str: + return str( + self._storage.path.parent + / "reports" + / Path(str(self._user_config_file)).stem + / ensemble_name + ) + def log_at_startup(self) -> None: keys_to_drop = [ "_end_queue", "_queue_config", "_status_queue", "_storage", - "ert_config", "rng", "run_paths", "substitutions", @@ -677,11 +685,10 @@ def validate_successful_realizations_count(self) -> None: def run_workflows( self, runtime: HookRuntime, - storage: Storage | None = None, - ensemble: Ensemble | None = None, + fixtures: WorkflowFixtures, ) -> None: for workflow in self._hooked_workflows[runtime]: - WorkflowRunner(workflow, storage, ensemble).run_blocking() + WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking() def _evaluate_and_postprocess( self, @@ -703,7 +710,16 @@ def _evaluate_and_postprocess( context_env=self._context_env, ) - self.run_workflows(HookRuntime.PRE_SIMULATION, self._storage, ensemble) + self.run_workflows( + HookRuntime.PRE_SIMULATION, + fixtures={ + "storage": self._storage, + "ensemble": ensemble, + "reports_dir": self.reports_dir(ensemble_name=ensemble.name), + "random_seed": self.random_seed, + "run_paths": self.run_paths, + }, + ) successful_realizations = self.run_ensemble_evaluator( run_args, ensemble, @@ -729,7 +745,16 @@ def _evaluate_and_postprocess( f"{self.ensemble_size - num_successful_realizations}" ) logger.info(f"Experiment run finished in: {self.get_runtime()}s") - self.run_workflows(HookRuntime.POST_SIMULATION, self._storage, ensemble) + self.run_workflows( + HookRuntime.POST_SIMULATION, + fixtures={ + "storage": self._storage, + "ensemble": ensemble, + "reports_dir": self.reports_dir(ensemble_name=ensemble.name), + "random_seed": self.random_seed, + "run_paths": self.run_paths, + }, + ) return num_successful_realizations @@ -794,6 +819,17 @@ def update( msg="Creating posterior ensemble..", ) ) + + workflow_fixtures: WorkflowFixtures = { + "storage": self._storage, + "ensemble": prior, + "observation_settings": self._update_settings, + "es_settings": self._analysis_settings, + "random_seed": self.random_seed, + "reports_dir": self.reports_dir(ensemble_name=prior.name), + "run_paths": self.run_paths, + } + posterior = self._storage.create_ensemble( prior.experiment, ensemble_size=prior.ensemble_size, @@ -802,8 +838,14 @@ def update( prior_ensemble=prior, ) if prior.iteration == 0: - self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior) - self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior) + self.run_workflows( + HookRuntime.PRE_FIRST_UPDATE, + fixtures=workflow_fixtures, + ) + self.run_workflows( + HookRuntime.PRE_UPDATE, + fixtures=workflow_fixtures, + ) try: smoother_update( prior, @@ -825,5 +867,8 @@ def update( "Update algorithm failed for iteration:" f"{posterior.iteration}. The following error occurred: {e}" ) from e - self.run_workflows(HookRuntime.POST_UPDATE, self._storage, prior) + self.run_workflows( + HookRuntime.POST_UPDATE, + fixtures=workflow_fixtures, + ) return posterior diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index a7aac4d0563..92150db13f3 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -93,7 +93,10 @@ def run_experiment( raise ErtRunError(str(exc)) from exc if not restart: - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={"random_seed": self.random_seed}, + ) self.experiment = self._storage.create_experiment( name=self.experiment_name, parameters=( @@ -143,7 +146,14 @@ def run_experiment( self.ensemble, evaluator_server_config, ) - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={ + "random_seed": self.random_seed, + "storage": self._storage, + "ensemble": self.ensemble, + }, + ) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index 1a6f8cf3382..f30a8049896 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -74,7 +74,10 @@ def run_experiment( ) -> None: self.log_at_startup() self.restart = restart - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={"random_seed": self.random_seed}, + ) ensemble_format = self.target_ensemble_format experiment = self._storage.create_experiment( parameters=self._parameter_configuration, @@ -120,7 +123,14 @@ def run_experiment( posterior, evaluator_server_config, ) - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={ + "random_seed": self.random_seed, + "storage": self._storage, + "ensemble": posterior, + }, + ) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index 8b799a44af8..419c803e909 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -116,7 +116,10 @@ def run_experiment( f"Prior ensemble with ID: {id} does not exists" ) from err else: - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={"random_seed": self.random_seed}, + ) sim_args = {"weights": self._relative_weights} experiment = self._storage.create_experiment( parameters=self._parameter_configuration, @@ -171,7 +174,14 @@ def run_experiment( ) prior = posterior - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={ + "random_seed": self.random_seed, + "storage": self._storage, + "ensemble": prior, + }, + ) @staticmethod def parse_weights(weights: str) -> list[float]: diff --git a/src/ert/workflow_runner.py b/src/ert/workflow_runner.py index 232ed13a8f9..548c4e5438d 100644 --- a/src/ert/workflow_runner.py +++ b/src/ert/workflow_runner.py @@ -3,12 +3,10 @@ import logging from concurrent import futures from concurrent.futures import Future -from typing import TYPE_CHECKING, Any, Self +from typing import Any, Self -from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob - -if TYPE_CHECKING: - from ert.storage import Ensemble, Storage +from ert.config import ErtScript, ExternalErtScript, Workflow, WorkflowJob +from ert.config.workflow_fixtures import WorkflowFixtures class WorkflowJobRunner: @@ -21,7 +19,7 @@ def __init__(self, workflow_job: WorkflowJob): def run( self, arguments: list[Any] | None = None, - fixtures: dict[str, Any] | None = None, + fixtures: WorkflowFixtures | None = None, ) -> Any: if arguments is None: arguments = [] @@ -57,7 +55,7 @@ def run( else: raise UserWarning("Unknown script type!") result = self.__script.initializeAndRun( # type: ignore - self.job.argument_types(), arguments, fixtures=fixtures + self.job.argument_types(), arguments, fixtures ) self.__running = False @@ -107,14 +105,10 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - storage: Storage | None = None, - ensemble: Ensemble | None = None, - ert_config: ErtConfig | None = None, + fixtures: WorkflowFixtures, ) -> None: self.__workflow = workflow - self.storage = storage - self.ensemble = ensemble - self.ert_config = ert_config + self.fixtures = fixtures self.__workflow_result: bool | None = None self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1) @@ -150,18 +144,13 @@ def run_blocking(self) -> None: # Reset status self.__status = {} self.__running = True - fixtures = { - k: getattr(self, k) - for k in ["storage", "ensemble", "ert_config"] - if getattr(self, k) - } for job, args in self.__workflow: jobrunner = WorkflowJobRunner(job) self.__current_job = jobrunner if not self.__cancelled: logger.info(f"Workflow job {jobrunner.name} starting") - jobrunner.run(args, fixtures=fixtures) + jobrunner.run(args, fixtures=self.fixtures) self.__status[jobrunner.name] = { "stdout": jobrunner.stdoutdata(), "stderr": jobrunner.stderrdata(), diff --git a/tests/ert/unit_tests/cli/test_cli_workflow.py b/tests/ert/unit_tests/cli/test_cli_workflow.py index bde6a4e49e7..9c7dfddb938 100644 --- a/tests/ert/unit_tests/cli/test_cli_workflow.py +++ b/tests/ert/unit_tests/cli/test_cli_workflow.py @@ -12,7 +12,7 @@ def test_executing_workflow(storage): with ErtPluginContext(): with open("test_wf", "w", encoding="utf-8") as wf_file: - wf_file.write("EXPORT_RUNPATH") + wf_file.write("CSV_EXPORT test_workflow_output.csv") config_file = "poly.ert" with open(config_file, "a", encoding="utf-8") as file_handle: @@ -21,4 +21,4 @@ def test_executing_workflow(storage): rc = ErtConfig.from_file(config_file) args = Namespace(name="test_wf") execute_workflow(rc, storage, args.name) - assert os.path.isfile(".ert_runpath_list") + assert os.path.isfile("test_workflow_output.csv") diff --git a/tests/ert/unit_tests/cli/test_model_hook_order.py b/tests/ert/unit_tests/cli/test_model_hook_order.py index ea2b3872b7a..30ae271a8fa 100644 --- a/tests/ert/unit_tests/cli/test_model_hook_order.py +++ b/tests/ert/unit_tests/cli/test_model_hook_order.py @@ -14,15 +14,91 @@ ) EXPECTED_CALL_ORDER = [ - call(HookRuntime.PRE_EXPERIMENT), - call(HookRuntime.PRE_SIMULATION, ANY, ANY), - call(HookRuntime.POST_SIMULATION, ANY, ANY), - call(HookRuntime.PRE_FIRST_UPDATE, ANY, ANY), - call(HookRuntime.PRE_UPDATE, ANY, ANY), - call(HookRuntime.POST_UPDATE, ANY, ANY), - call(HookRuntime.PRE_SIMULATION, ANY, ANY), - call(HookRuntime.POST_SIMULATION, ANY, ANY), - call(HookRuntime.POST_EXPERIMENT), + call(HookRuntime.PRE_EXPERIMENT, fixtures={"random_seed": ANY}), + call( + HookRuntime.PRE_SIMULATION, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.POST_SIMULATION, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.PRE_FIRST_UPDATE, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "es_settings": ANY, + "observation_settings": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.PRE_UPDATE, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "es_settings": ANY, + "observation_settings": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.POST_UPDATE, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "es_settings": ANY, + "observation_settings": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.PRE_SIMULATION, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.POST_SIMULATION, + fixtures={ + "storage": ANY, + "ensemble": ANY, + "reports_dir": ANY, + "random_seed": ANY, + "run_paths": ANY, + }, + ), + call( + HookRuntime.POST_EXPERIMENT, + fixtures={ + "random_seed": ANY, + "storage": ANY, + "ensemble": ANY, + }, + ), ] diff --git a/tests/ert/unit_tests/config/test_ert_plugin.py b/tests/ert/unit_tests/config/test_ert_plugin.py index 134535685ff..655e431b6ac 100644 --- a/tests/ert/unit_tests/config/test_ert_plugin.py +++ b/tests/ert/unit_tests/config/test_ert_plugin.py @@ -71,17 +71,17 @@ def test_cancel_plugin(): def test_plugin_with_fixtures(): class FixturePlugin(ErtPlugin): - def run(self, ert_script): - return ert_script + def run(self, ensemble): + return ensemble plugin = FixturePlugin() fixture_mock = MagicMock() - assert plugin.initializeAndRun([], [], {"ert_script": fixture_mock}) == fixture_mock + assert plugin.initializeAndRun([], [], {"ensemble": fixture_mock}) == fixture_mock def test_plugin_with_missing_arguments(caplog): class FixturePlugin(ErtPlugin): - def run(self, arg_1, ert_script, fixture_2, arg_2="something"): + def run(self, arg_1, ensemble, run_paths, arg_2="something"): pass plugin = FixturePlugin() @@ -89,41 +89,78 @@ def run(self, arg_1, ert_script, fixture_2, arg_2="something"): fixture2_mock = MagicMock() with caplog.at_level(logging.WARNING): plugin.initializeAndRun( - [], [1, 2], {"ert_script": fixture_mock, "fixture_2": fixture2_mock} + [], + [], + {"ensemble": fixture_mock, "run_paths": fixture2_mock}, ) assert plugin.hasFailed() log = "\n".join(caplog.messages) assert "FixturePlugin misconfigured" in log - assert "['arg_1', 'arg_2'] not found in fixtures" in log + assert ("arguments: ['arg_1'] not found in fixtures") in log + + +def test_plugin_with_mixed_arguments(caplog): + fixture_mock = MagicMock() + fixture2_mock = MagicMock() + + class FixturePlugin(ErtPlugin): + def run(self, arg_0, arg_1, ensemble, fixture_2, arg_2="something"): + nonlocal fixture_mock + nonlocal fixture2_mock + + assert arg_0 == "1" + assert arg_1 == "2" + assert ensemble == fixture_mock + assert fixture_2 == fixture2_mock + assert arg_2 == "something else" + + plugin = FixturePlugin() + + plugin.initializeAndRun( + [], + [1, 2], + {"ensemble": fixture_mock, "fixture_2": fixture2_mock}, + arg_2="something_else", + ) def test_plugin_with_fixtures_and_enough_arguments(): class FixturePlugin(ErtPlugin): - def run(self, workflow_args, ert_script): - return workflow_args, ert_script + def run(self, workflow_args, ensemble): + return workflow_args, ensemble plugin = FixturePlugin() fixture_mock = MagicMock() - assert plugin.initializeAndRun([], [1, 2, 3], {"ert_script": fixture_mock}) == ( + assert plugin.initializeAndRun([], [1, 2, 3], {"ensemble": fixture_mock}) == ( ["1", "2", "3"], fixture_mock, ) -def test_plugin_with_default_arguments(capsys): +def test_plugin_with_fixtures_and_enough_positional_arguments(): class FixturePlugin(ErtPlugin): - def run(self, ert_script=None): - return ert_script + def run(self, a, b, c, ensemble): + return ([a, b, c], ensemble) plugin = FixturePlugin() fixture_mock = MagicMock() - assert ( - plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) - == fixture_mock + assert plugin.initializeAndRun([], [1, 2, 3], {"ensemble": fixture_mock}) == ( + ["1", "2", "3"], + fixture_mock, ) +def test_plugin_with_default_arguments(): + class FixturePlugin(ErtPlugin): + def run(self, ensemble=None): + return ensemble + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert plugin.initializeAndRun([], [], {"ensemble": fixture_mock}) == fixture_mock + + def test_plugin_with_args(): class FixturePlugin(ErtPlugin): def run(self, *args): @@ -131,7 +168,7 @@ def run(self, *args): plugin = FixturePlugin() fixture_mock = MagicMock() - assert plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) == ( + assert plugin.initializeAndRun([], [1, 2], {"ensemble": fixture_mock}) == ( "1", "2", ) @@ -144,20 +181,7 @@ def run(self, *args, **kwargs): plugin = FixturePlugin() fixture_mock = MagicMock() - assert plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) == ( + assert plugin.initializeAndRun([], [1, 2], {"ensemble": fixture_mock}) == ( "1", "2", ) - - -def test_deprecated_properties(): - class FixturePlugin(ErtPlugin): - def run(self): - pass - - plugin = FixturePlugin() - ert_mock = MagicMock() - ensemble_mock = MagicMock() - plugin.initializeAndRun([], [], {"ert_config": ert_mock, "ensemble": ensemble_mock}) - with pytest.deprecated_call(): - assert (plugin.ert(), plugin.ensemble) == (ert_mock, ensemble_mock) diff --git a/tests/ert/unit_tests/plugins/test_export_misfit.py b/tests/ert/unit_tests/plugins/test_export_misfit.py index ae7b14cbca1..9e7a8e4166a 100644 --- a/tests/ert/unit_tests/plugins/test_export_misfit.py +++ b/tests/ert/unit_tests/plugins/test_export_misfit.py @@ -14,8 +14,8 @@ sys.platform.startswith("darwin"), reason="https://github.com/equinor/ert/issues/7533", ) -def test_export_misfit(snake_oil_case_storage, snake_oil_default_storage, snapshot): - ExportMisfitDataJob().run(snake_oil_case_storage, snake_oil_default_storage, []) +def test_export_misfit(snake_oil_default_storage, snapshot): + ExportMisfitDataJob().run(snake_oil_default_storage, []) result = pd.read_hdf("misfit.hdf") snapshot.assert_match( result.to_csv(), @@ -25,7 +25,7 @@ def test_export_misfit(snake_oil_case_storage, snake_oil_default_storage, snapsh def test_export_misfit_no_responses_in_storage(poly_case, new_ensemble): with pytest.raises(StorageError, match="No responses loaded"): - ExportMisfitDataJob().run(poly_case, new_ensemble, []) + ExportMisfitDataJob().run(new_ensemble, []) def test_export_misfit_data_job_is_loaded(): diff --git a/tests/ert/unit_tests/plugins/test_export_runpath.py b/tests/ert/unit_tests/plugins/test_export_runpath.py index d63693b1a51..e1feb51e454 100644 --- a/tests/ert/unit_tests/plugins/test_export_runpath.py +++ b/tests/ert/unit_tests/plugins/test_export_runpath.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest @@ -25,22 +25,41 @@ class WritingSetup: def writing_setup(setup_case): with patch.object(Runpaths, "write_runpath_list") as write_mock: config = setup_case("snake_oil", "snake_oil.ert") - yield WritingSetup(write_mock, ExportRunpathJob()), config + yield ( + WritingSetup( + write_mock, + ExportRunpathJob(), + ), + Runpaths( + jobname_format=config.model_config.jobname_format_string, + runpath_format=config.model_config.runpath_format_string, + filename=str(config.runpath_file), + substitutions=config.substitutions, + eclbase=config.model_config.eclbase_format_string, + ), + ) def test_export_runpath_empty_range(writing_setup): - writing_setup, config = writing_setup - writing_setup.export_job.run(config, []) + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 0 + mock_ensemble.ensemble_size = 25 + writing_setup.export_job.run(run_paths, mock_ensemble, []) writing_setup.write_mock.assert_called_with( - [0], - list(range(25)), + [0], list(range(mock_ensemble.ensemble_size)) ) def test_export_runpath_star_parameter(writing_setup): - writing_setup, config = writing_setup - writing_setup.export_job.run(config, ["* | *"]) + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 1 + mock_ensemble.ensemble_size = 25 + writing_setup.export_job.run(run_paths, mock_ensemble, ["* | *"]) writing_setup.write_mock.assert_called_with( list(range(1)), @@ -49,8 +68,12 @@ def test_export_runpath_star_parameter(writing_setup): def test_export_runpath_range_parameter(writing_setup): - writing_setup, config = writing_setup - writing_setup.export_job.run(config, ["* | 1-2"]) + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 0 + mock_ensemble.ensemble_size = 25 + writing_setup.export_job.run(run_paths, mock_ensemble, ["* | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], @@ -59,8 +82,12 @@ def test_export_runpath_range_parameter(writing_setup): def test_export_runpath_comma_parameter(writing_setup): - writing_setup, config = writing_setup - writing_setup.export_job.run(config, ["3,4 | 1-2"]) + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 0 + mock_ensemble.ensemble_size = 25 + writing_setup.export_job.run(run_paths, mock_ensemble, ["3,4 | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], @@ -69,8 +96,12 @@ def test_export_runpath_comma_parameter(writing_setup): def test_export_runpath_combination_parameter(writing_setup): - writing_setup, config = writing_setup - writing_setup.export_job.run(config, ["1,2-3 | 1-2"]) + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 0 + mock_ensemble.ensemble_size = 25 + writing_setup.export_job.run(run_paths, mock_ensemble, ["1,2-3 | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], @@ -79,9 +110,13 @@ def test_export_runpath_combination_parameter(writing_setup): def test_export_runpath_bad_arguments(writing_setup): - writing_setup, config = writing_setup + writing_setup, run_paths = writing_setup + + mock_ensemble = MagicMock() + mock_ensemble.iteration = 0 + mock_ensemble.ensemble_size = 25 with pytest.raises(ValueError, match="Expected \\|"): - writing_setup.export_job.run(config, ["wat"]) + writing_setup.export_job.run(run_paths, mock_ensemble, ["wat"]) def test_export_runpath_job_is_loaded(): diff --git a/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py b/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py index 95e68c73936..8530d301f62 100644 --- a/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py +++ b/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py @@ -149,7 +149,7 @@ def test_workflow_run(): job, args = workflow[1] assert job.name == "DUMP" - WorkflowRunner(workflow).run_blocking() + WorkflowRunner(workflow, fixtures={}).run_blocking() with open("dump1", encoding="utf-8") as f: assert f.read() == "dump_text_1" @@ -169,7 +169,7 @@ def test_workflow_thread_cancel_ert_script(): assert len(workflow) == 3 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() @@ -208,7 +208,7 @@ def test_workflow_thread_cancel_external(): assert len(workflow) == 3 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() @@ -237,7 +237,7 @@ def test_workflow_failed_job(): workflow = Workflow.from_file("dump_workflow", Substitutions(), {"DUMP": dump_job}) assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() with ( @@ -272,7 +272,7 @@ def test_workflow_success(): assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() with workflow_runner: @@ -306,7 +306,7 @@ def test_workflow_stops_with_stopping_job(): job_dict={"DUMP": job_failing_dump}, ) - runner = WorkflowRunner(workflow) + runner = WorkflowRunner(workflow, fixtures={}) with pytest.raises(RuntimeError, match="Workflow job dump_failing_job failed"): runner.run_blocking() @@ -322,4 +322,4 @@ def test_workflow_stops_with_stopping_job(): ) # Expect no error raised - WorkflowRunner(workflow).run_blocking() + WorkflowRunner(workflow, fixtures={}).run_blocking()