Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Mar 22, 2024
1 parent 4b0b1ec commit 9489523
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/ert/cli/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def execute_workflow(
msg = "Workflow {} is not in the list of available workflows"
logger.error(msg.format(workflow_name))
return
runner = WorkflowRunner(workflow, storage)
runner = WorkflowRunner(workflow, storage, ert_config=ert_config)
runner.run_blocking()
if not all(v["completed"] for v in runner.workflowReport().values()):
logger.error(f"Workflow {workflow_name} failed!")
2 changes: 1 addition & 1 deletion src/ert/config/ert_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CancelPluginException(Exception):


class ErtPlugin(ErtScript, ABC):
def getArguments(self, parent: Any = None) -> List[Any]:
def getArguments(self, parent: Any = None, storage=None) -> List[Any]:

Check failure on line 12 in src/ert/config/ert_plugin.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation for one or more arguments
return []

def getName(self) -> str:
Expand Down
26 changes: 3 additions & 23 deletions src/ert/config/ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import traceback
from abc import abstractmethod
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type

if TYPE_CHECKING:
from ert.storage import Ensemble, Storage
from typing import Any, Callable, Dict, List, Optional, Type

logger = logging.getLogger(__name__)

Expand All @@ -20,12 +17,7 @@ class ErtScript:

def __init__(
self,
storage: Storage,
ensemble: Optional[Ensemble] = None,
) -> None:
self.__storage = storage
self.__ensemble = ensemble

self.__is_cancelled = False
self.__failed = False
self._stdoutdata = ""
Expand All @@ -51,18 +43,6 @@ def ert(self) -> None:
logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}")
raise NotImplementedError("The ert() function has been removed")

@property
def storage(self) -> Storage:
return self.__storage

@property
def ensemble(self) -> Optional[Ensemble]:
return self.__ensemble

@ensemble.setter
def ensemble(self, ensemble: Ensemble) -> None:
self.__ensemble = ensemble

def isCancelled(self) -> bool:
return self.__is_cancelled

Expand Down Expand Up @@ -134,7 +114,7 @@ def output_stack_trace(self, error: str = "") -> None:
@staticmethod
def loadScriptFromFile(
path: str,
) -> Callable[["Storage"], "ErtScript"]:
) -> Callable[[], "ErtScript"]:
module_name = f"ErtScriptModule_{ErtScript.__module_count}"
ErtScript.__module_count += 1

Expand All @@ -155,7 +135,7 @@ def loadScriptFromFile(
@staticmethod
def __findErtScriptImplementations(
module: ModuleType,
) -> Callable[["Storage"], "ErtScript"]:
) -> Callable[[], "ErtScript"]:
result = []
for _, member in inspect.getmembers(
module,
Expand Down
9 changes: 3 additions & 6 deletions src/ert/config/external_ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
import codecs
import sys
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional

from .ert_script import ErtScript

if TYPE_CHECKING:
from ert.storage import Storage


class ExternalErtScript(ErtScript):
def __init__(self, storage: Storage, executable: str):
super().__init__(storage, None)
def __init__(self, executable: str):
super().__init__()

self.__executable = executable
self.__job: Optional[Popen[bytes]] = None
Expand Down
7 changes: 2 additions & 5 deletions src/ert/gui/tools/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ def __init__(self, notifier: "ErtNotifier", workflow_job: "WorkflowJob"):

def __loadPlugin(self) -> "ErtPlugin":
script_obj = ErtScript.loadScriptFromFile(self.__workflow_job.script)
script = script_obj(
self.__notifier._storage,
ensemble=self.__notifier.current_ensemble,
)
script = script_obj()
return script

def getName(self) -> str:
Expand All @@ -39,7 +36,7 @@ def getArguments(self) -> List[Any]:
If the user for example cancels in the GUI a CancelPluginException is raised.
"""
script = self.__loadPlugin()
return script.getArguments(self.__parent_window)
return script.getArguments(self.__parent_window, self.__notifier.storage)

def setParentWindow(self, parent_window):
self.__parent_window = parent_window
Expand Down
11 changes: 8 additions & 3 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


class PluginRunner:
def __init__(self, plugin: "Plugin", ert_config: ErtConfig):
def __init__(self, plugin: "Plugin", ert_config: ErtConfig, storage):
super().__init__()
self.ert_config = ert_config
self.storage = storage
self.__plugin = plugin

self.__plugin_finished_callback = lambda: None
Expand All @@ -36,11 +37,15 @@ def run(self):
dialog.setObjectName("process_job_dialog")

dialog.cancelConfirmed.connect(self.cancel)

fixtures = {
k: getattr(self, k)
for k in ["storage", "ert_config"]
if getattr(self, k)
}
workflow_job_thread = ErtThread(
name="ert_gui_workflow_job_thread",
target=self.__runWorkflowJob,
args=(plugin, arguments, {"ert_config": self.ert_config}),
args=(plugin, arguments, fixtures),
daemon=True,
should_raise=False,
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/gui/tools/plugins/plugins_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, plugin_handler: "PluginHandler", notifier, ert_config):

menu = QMenu()
for plugin in plugin_handler:
plugin_runner = PluginRunner(plugin, ert_config)
plugin_runner = PluginRunner(plugin, ert_config, notifier.storage)
plugin_runner.setPluginFinishedCallback(self.trigger)

self.__plugins[plugin] = plugin_runner
Expand Down
1 change: 1 addition & 0 deletions src/ert/gui/tools/workflows/run_workflow_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def startWorkflow(self):
workflow,
storage=self.storage,
ensemble=self.source_ensemble_selector.currentData(),
ert_config=self.config,
)
self._workflow_runner.run()

Expand Down
20 changes: 12 additions & 8 deletions src/ert/job_queue/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing_extensions import Self

from ert.config import ErtScript, ExternalErtScript, Workflow, WorkflowJob
from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob

if TYPE_CHECKING:
from ert.storage import Ensemble, Storage
Expand All @@ -22,8 +22,6 @@ def __init__(self, workflow_job: WorkflowJob):

def run(
self,
storage: Optional[Storage] = None,
ensemble: Optional[Ensemble] = None,
arguments: Optional[List[Any]] = None,
fixtures: Optional[Dict[str, Any]] = None,
) -> Any:
Expand All @@ -44,15 +42,14 @@ def run(
)

if self.job.ert_script is not None:
self.__script = self.job.ert_script(storage, ensemble)
self.__script = self.job.ert_script()
if self.job.stop_on_fail is not None:
self.stop_on_fail = self.job.stop_on_fail
elif self.__script is not None:
self.stop_on_fail = self.__script.stop_on_fail or False

elif not self.job.internal:
self.__script = ExternalErtScript(
storage, # type: ignore
self.job.executable, # type: ignore
)

Expand Down Expand Up @@ -114,10 +111,12 @@ def __init__(
workflow: Workflow,
storage: Optional[Storage] = None,
ensemble: Optional[Ensemble] = None,
ert_config: Optional[ErtConfig] = None,
) -> None:
self.__workflow = workflow
self._storage = storage
self._ensemble = ensemble
self.storage = storage
self.ensemble = ensemble
self.ert_config = ert_config

self.__workflow_result: Optional[bool] = None
self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1)
Expand Down Expand Up @@ -153,13 +152,18 @@ def run_blocking(self) -> None:
# Reset status
self.__status = {}
self.__running = True
fixtures = {
k: getattr(self, k)
for k in ["storage", "ensemble", "ert_config"]
if getattr(self, k)
}

for job, args in self.__workflow:
jobrunner = WorkflowJobRunner(job)
self.__current_job = jobrunner
if not self.__cancelled:
logger.info(f"Workflow job {jobrunner.name} starting")
jobrunner.run(self._storage, self._ensemble, args)
jobrunner.run(args, fixtures=fixtures)
self.__status[jobrunner.name] = {
"stdout": jobrunner.stdoutdata(),
"stderr": jobrunner.stderrdata(),
Expand Down
4 changes: 3 additions & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ def validate(self) -> None:

def run_workflows(self, runtime, storage, ensemble):

Check failure on line 497 in src/ert/run_models/base_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation
for workflow in self.ert_config.hooked_workflows[runtime]:
WorkflowRunner(workflow, storage, ensemble).run_blocking()
WorkflowRunner(
workflow, storage, ensemble, ert_config=self.ert_config
).run_blocking()

def _evaluate_and_postprocess(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def inferIterationNumber(self, ensemble_name):
def run(
self,
output_file,
ert_config,
ensemble_list=None,
design_matrix_path=None,
infer_iteration=True,
drop_const_cols=False,
):
ensembles = []
facade = LibresFacade(self.ert())
facade = LibresFacade(ert_config)

if ensemble_list is not None:
if ensemble_list.strip() == "*":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def run(
self,
output_file,
trajectory_path,
storage,
ensemble_list=None,
infer_iteration=True,
drop_const_cols=False,
Expand Down Expand Up @@ -135,7 +136,7 @@ def run(
ensemble_data = []

try:
ensemble = self.storage.get_ensemble_by_name(ensemble_name)
ensemble = storage.get_ensemble_by_name(ensemble_name)
except KeyError as exc:
raise UserWarning(
f"The ensemble '{ensemble_name}' does not exist!"
Expand Down Expand Up @@ -228,7 +229,7 @@ def run(
)
return export_info

def getArguments(self, parent=None):
def getArguments(self, parent, storage):
description = (
"The GEN_DATA RFT CSV export requires some information before it starts:"
)
Expand All @@ -243,7 +244,7 @@ def getArguments(self, parent=None):
trajectory_chooser = PathChooser(trajectory_model)
trajectory_chooser.setObjectName("trajectory_chooser")

all_ensemble_list = [ensemble.name for ensemble in self.storage.ensembles]
all_ensemble_list = [ensemble.name for ensemble in storage.ensembles]
list_edit = ListEditBox(all_ensemble_list)
list_edit.setObjectName("list_of_ensembles")

Expand Down

0 comments on commit 9489523

Please sign in to comment.