Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some adaptations to workflows #9994

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/ert/cli/workflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from ert.runpaths import Runpaths
from ert.workflow_runner import WorkflowRunner

if TYPE_CHECKING:
Expand All @@ -20,7 +22,26 @@ def execute_workflow(
msg = "Workflow {} is not in the list of available workflows"
logger.error(msg.format(workflow_name))
return
runner = WorkflowRunner(workflow=workflow, storage=storage, ert_config=ert_config)

runner = WorkflowRunner(
workflow=workflow,
fixtures={
"storage": storage,
"random_seed": ert_config.random_seed,
"reports_dir": str(
storage.path.parent / "reports" / Path(ert_config.user_config_file).stem
),
"observation_settings": ert_config.analysis_config.observation_settings,
"es_settings": ert_config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=ert_config.substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
),
},
)
runner.run_blocking()
if not all(v["completed"] for v in runner.workflowReport().values()):
logger.error(f"Workflow {workflow_name} failed!")
105 changes: 55 additions & 50 deletions src/ert/config/ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import logging
import sys
import traceback
import warnings
from abc import abstractmethod
from collections.abc import Callable
from types import MappingProxyType, ModuleType
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypeAlias

from typing_extensions import deprecated
from ert.config.workflow_fixtures import WorkflowFixtures

if TYPE_CHECKING:
from ert.config import ErtConfig
Expand Down Expand Up @@ -39,11 +38,6 @@ def __init__(
self._stdoutdata = ""
self._stderrdata = ""

# Deprecated:
self._ert = None
self._ensemble = None
self._storage = None

@abstractmethod
def run(self, *arg: Any, **kwarg: Any) -> Any:
"""
Expand Down Expand Up @@ -71,31 +65,6 @@ def stderrdata(self) -> str:
self._stderrdata = self._stderrdata.decode()
return self._stderrdata

@deprecated("Use fixtures to the run function instead")
def ert(self) -> ErtConfig | None:
logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}")
return self._ert

@property
def ensemble(self) -> Ensemble | None:
warnings.warn(
"The ensemble property is deprecated, use the fixture to the run function instead",
DeprecationWarning,
stacklevel=1,
)
logger.info(f"Accessing ensemble from workflow: {self.__class__.__name__}")
return self._ensemble

@property
def storage(self) -> Storage | None:
warnings.warn(
"The storage property is deprecated, use the fixture to the run function instead",
DeprecationWarning,
stacklevel=1,
)
logger.info(f"Accessing storage from workflow: {self.__class__.__name__}")
return self._storage

def isCancelled(self) -> bool:
return self.__is_cancelled

Expand All @@ -114,36 +83,70 @@ def initializeAndRun(
self,
argument_types: list[type[Any]],
argument_values: list[str],
fixtures: dict[str, Any] | None = None,
fixtures: WorkflowFixtures | None = None,
**kwargs: dict[str, Any],
) -> Any:
fixtures = {} if fixtures is None else fixtures
arguments = []
workflow_args = []
for index, arg_value in enumerate(argument_values):
arg_type = argument_types[index] if index < len(argument_types) else str

if arg_value is not None:
arguments.append(arg_type(arg_value))
workflow_args.append(arg_type(arg_value))
else:
arguments.append(None)
fixtures["workflow_args"] = arguments
workflow_args.append(None)

fixtures["workflow_args"] = workflow_args

fixture_args = []
all_func_args = inspect.signature(self.run).parameters
is_using_wf_args_fixture = "workflow_args" in all_func_args

try:
func_args = inspect.signature(self.run).parameters
if not is_using_wf_args_fixture:
fixture_or_kw_arguments = list(all_func_args)[len(workflow_args) :]
else:
fixture_or_kw_arguments = list(all_func_args)

func_args = {k: all_func_args[k] for k in fixture_or_kw_arguments}

kwargs_defaults = {
k: v.default
for k, v in func_args.items()
if k not in fixtures
and v.kind != v.VAR_POSITIONAL
and not str(v).startswith("*")
and v.default != v.empty
}
use_kwargs = {
k: (kwargs or {}).get(k, default_value)
for k, default_value in ({**kwargs_defaults, **kwargs}).items()
}
# If the user has specified *args, we skip injecting fixtures, and just
# pass the user configured arguments
if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()):
try:
arguments = self.insert_fixtures(func_args, fixtures)
fixture_args = self.insert_fixtures(func_args, fixtures, use_kwargs)
except ValueError as e:
# This is here for backwards compatibility, the user does not have *argv
# but positional arguments. Can not be mixed with using fixtures.
logger.warning(
f"Mixture of fixtures and positional arguments, err: {e}"
)
# Part of deprecation
self._ert = fixtures.get("ert_config")
self._ensemble = fixtures.get("ensemble")
self._storage = fixtures.get("storage")
return self.run(*arguments)

positional_args = (
fixture_args
if is_using_wf_args_fixture
else [*workflow_args, *fixture_args]
)
if not positional_args and not use_kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This we could keep this as it used to be? I dont think kwargs were really supported before, they just ended up being workflow_args in a list?

return self.run()
elif positional_args and not use_kwargs:
return self.run(*positional_args)
elif not positional_args and use_kwargs:
return self.run(**use_kwargs)
else:
return self.run(*positional_args, **use_kwargs)
except AttributeError as e:
error_msg = str(e)
if not hasattr(self, "run"):
Expand All @@ -169,20 +172,22 @@ def initializeAndRun(

def insert_fixtures(
self,
func_args: MappingProxyType[str, inspect.Parameter],
fixtures: dict[str, Fixtures],
func_args: dict[str, inspect.Parameter],
fixtures: WorkflowFixtures,
kwargs: dict[str, Any],
) -> list[Any]:
arguments = []
errors = []
for val in func_args:
jonathan-eq marked this conversation as resolved.
Show resolved Hide resolved
if val in fixtures:
arguments.append(fixtures[val])
else:
arguments.append(fixtures.get(val))
elif val not in kwargs:
errors.append(val)
if errors:
kwargs_str = ",".join(f"{k}='{v}'" for k, v in kwargs.items())
raise ValueError(
f"Plugin: {self.__class__.__name__} misconfigured, arguments: {errors} "
f"not found in fixtures: {list(fixtures)}"
f"not found in fixtures: {list(fixtures)} or kwargs {kwargs_str}"
)
return arguments

Expand Down
21 changes: 21 additions & 0 deletions src/ert/config/workflow_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from typing_extensions import TypedDict

if TYPE_CHECKING:
from ert.config import ESSettings, UpdateSettings
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage


class WorkflowFixtures(TypedDict, total=False):
ensemble: Ensemble
storage: Storage
random_seed: int | None
reports_dir: str
observation_settings: UpdateSettings
es_settings: ESSettings
run_paths: Runpaths
workflow_args: list[Any]
2 changes: 1 addition & 1 deletion src/ert/gui/tools/export/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def run_export(self, parameters: list[Any]) -> None:

export_job_runner = WorkflowJobRunner(self.export_job)
user_warn = export_job_runner.run(
fixtures={"storage": self._notifier.storage, "ert_config": self.config},
fixtures={"storage": self._notifier.storage},
arguments=parameters,
)
if export_job_runner.hasFailed():
Expand Down
10 changes: 3 additions & 7 deletions src/ert/gui/tools/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any

from ert import ErtScript
from ert.config.workflow_fixtures import WorkflowFixtures

if TYPE_CHECKING:
from PyQt6.QtWidgets import QWidget
Expand Down Expand Up @@ -34,20 +35,15 @@ def getName(self) -> str:
def getDescription(self) -> str:
return self.__description

def getArguments(self, fixtures: dict[str, Any]) -> list[Any]:
def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]:
"""
Returns a list of arguments. Either from GUI or from arbitrary code.
If the user for example cancels in the GUI a CancelPluginException is raised.
"""
script = self.__loadPlugin()
fixtures["parent"] = self.__parent_window
func_args = inspect.signature(script.getArguments).parameters
arguments = script.insert_fixtures(func_args, fixtures)
arguments = script.insert_fixtures(dict(func_args), fixtures, {})

# Part of deprecation
script._ert = fixtures.get("ert_config")
script._ensemble = fixtures.get("ensemble")
script._storage = fixtures.get("storage")
return script.getArguments(*arguments)

def setParentWindow(self, parent_window: QWidget | None) -> None:
Expand Down
33 changes: 26 additions & 7 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any

from _ert.threading import ErtThread
from ert.config import CancelPluginException
from ert.config import CancelPluginException, ErtConfig
from ert.config.workflow_fixtures import WorkflowFixtures
from ert.runpaths import Runpaths
from ert.workflow_runner import WorkflowJobRunner

from .process_job_dialog import ProcessJobDialog

if TYPE_CHECKING:
from ert.config import ErtConfig
from ert.storage import LocalStorage

from .plugin import Plugin
Expand All @@ -22,30 +24,47 @@ def __init__(
self, plugin: Plugin, ert_config: ErtConfig, storage: LocalStorage
) -> None:
super().__init__()

self.ert_config = ert_config
self.storage = storage
self.__plugin = plugin

self.__plugin_finished_callback: Callable[[], None] = lambda: None

self.__result = None
self._runner = WorkflowJobRunner(plugin.getWorkflowJob())
self.poll_thread: ErtThread | None = None

def run(self) -> None:
ert_config = self.ert_config
try:
plugin = self.__plugin

arguments = plugin.getArguments(
fixtures={"storage": self.storage, "ert_config": self.ert_config}
fixtures={
"storage": self.storage,
"random_seed": ert_config.random_seed,
"reports_dir": str(
self.storage.path.parent
/ "reports"
/ Path(ert_config.user_config_file).stem
),
"observation_settings": ert_config.analysis_config.observation_settings,
"es_settings": ert_config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=ert_config.substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
),
}
)
dialog = ProcessJobDialog(plugin.getName(), plugin.getParentWindow())
dialog.setObjectName("process_job_dialog")

dialog.cancelConfirmed.connect(self.cancel)
fixtures = {
k: getattr(self, k)
for k in ["storage", "ert_config"]
for k in ["storage", "run_paths"]
if getattr(self, k)
}
workflow_job_thread = ErtThread(
Expand All @@ -71,7 +90,7 @@ def run(self) -> None:
print("Plugin cancelled before execution!")

def __runWorkflowJob(
self, arguments: list[Any] | None, fixtures: dict[str, Any]
self, arguments: list[Any] | None, fixtures: WorkflowFixtures
) -> None:
self.__result = self._runner.run(arguments, fixtures=fixtures)

Expand Down
25 changes: 22 additions & 3 deletions src/ert/gui/tools/workflows/run_workflow_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING

from PyQt6.QtCore import QSize, Qt
Expand All @@ -20,6 +21,7 @@
from _ert.threading import ErtThread
from ert.gui.ertwidgets import EnsembleSelector
from ert.gui.tools.workflows.workflow_dialog import WorkflowDialog
from ert.runpaths import Runpaths
from ert.workflow_runner import WorkflowRunner

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,9 +127,26 @@ def startWorkflow(self) -> None:

workflow = self.config.workflows[self.getCurrentWorkflowName()]
self._workflow_runner = WorkflowRunner(
workflow,
storage=self.storage,
ensemble=self.source_ensemble_selector.currentData(),
workflow=workflow,
fixtures={
"ensemble": self.source_ensemble_selector.currentData(),
"storage": self.storage,
"random_seed": self.config.random_seed,
"reports_dir": str(
self.storage.path.parent
/ "reports"
/ Path(self.config.user_config_file).stem
),
"observation_settings": self.config.analysis_config.observation_settings,
"es_settings": self.config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=self.config.model_config.jobname_format_string,
runpath_format=self.config.model_config.runpath_format_string,
filename=str(self.config.runpath_file),
substitutions=self.config.substitutions,
eclbase=self.config.model_config.eclbase_format_string,
),
},
)
self._workflow_runner.run()

Expand Down
Loading
Loading