Skip to content

Commit

Permalink
Add support for MEMORY_PER_JOB in OpenPBS
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Feb 20, 2024
1 parent d35aa94 commit 1e6fea5
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
16 changes: 9 additions & 7 deletions src/ert/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from ert.config.parsing.queue_system import QueueSystem
from ert.scheduler.driver import Driver
Expand All @@ -17,12 +17,14 @@ def create_driver(config: QueueConfig) -> Driver:
if config.queue_system == QueueSystem.LOCAL:
return LocalDriver()
elif config.queue_system == QueueSystem.TORQUE:
queue_name: Optional[str] = None
for key, val in config.queue_options.get(QueueSystem.TORQUE, []):
if key == "QUEUE":
queue_name = val

return OpenPBSDriver(queue_name=queue_name)
queue_config = {
key: value
for key, value in config.queue_options.get(QueueSystem.TORQUE, [])
}
return OpenPBSDriver(
queue_name=queue_config.get("QUEUE"),
memory_per_job=queue_config.get("MEMORY_PER_JOB"),
)
elif config.queue_system == QueueSystem.LSF:
queue_config = {
key: value for key, value in config.queue_options.get(QueueSystem.LSF, [])
Expand Down
15 changes: 14 additions & 1 deletion src/ert/scheduler/openpbs_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,23 @@ class _Stat(BaseModel):
class OpenPBSDriver(Driver):
"""Driver targetting OpenPBS (https://github.com/openpbs/openpbs) / PBS Pro"""

def __init__(self, *, queue_name: Optional[str] = None) -> None:
def __init__(
self, *, queue_name: Optional[str] = None, memory_per_job: Optional[str] = None
) -> None:
super().__init__()

self._queue_name = queue_name
self._memory_per_job = memory_per_job

self._jobs: MutableMapping[str, Tuple[int, JobState]] = {}
self._iens2jobid: MutableMapping[int, str] = {}

def _resource_string(self) -> str:
resource_specifiers: List[str] = []
if self._memory_per_job is not None:
resource_specifiers += ["mem=" + self._memory_per_job]
return ":".join(resource_specifiers)

async def submit(
self,
iens: int,
Expand All @@ -61,13 +71,16 @@ async def submit(
) -> None:

arg_queue_name = ["-q", self._queue_name] if self._queue_name else []
resource_string = self._resource_string()
arg_resource_string = ["-l", resource_string] if resource_string else []

qsub_with_args: List[str] = [
"qsub",
"-koe", # Discard stdout/stderr of job
"-rn", # Don't restart on failure
f"-N{name}", # Set name of job
*arg_queue_name,
*arg_resource_string,
"--",
executable,
*args,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/scheduler/bin/qsub
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

name="STDIN"

while getopts "N:r:k:" opt
while getopts "N:r:k:l:" opt
do
case "$opt" in
N)
Expand All @@ -13,6 +13,8 @@ do
;;
k)
;;
l)
;;
*)
echo "Unprocessed option ${opt}"
;;
Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/scheduler/test_openpbs_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import stat
from pathlib import Path

import pytest

from ert.scheduler import OpenPBSDriver


@pytest.fixture
def capturing_qsub(monkeypatch, tmp_path):
os.chdir(tmp_path)
bin_path = tmp_path / "bin"
bin_path.mkdir()
monkeypatch.setenv("PATH", f"{bin_path}:{os.environ['PATH']}")
qsub_path = bin_path / "qsub"
qsub_path.write_text(
"#!/bin/sh\necho $@ > captured_qsub_args; echo 'Job <1>'", encoding="utf-8"
)
qsub_path.chmod(qsub_path.stat().st_mode | stat.S_IEXEC)


@pytest.mark.usefixtures("capturing_qsub")
async def test_memory_per_job():
driver = OpenPBSDriver(memory_per_job="10gb")
await driver.submit(0, "sleep")
assert " -l mem=10gb " in Path("captured_qsub_args").read_text(encoding="utf-8")


@pytest.mark.usefixtures("capturing_qsub")
async def test_no_default_memory_per_job():
driver = OpenPBSDriver()
await driver.submit(0, "sleep")
assert " -l " not in Path("captured_qsub_args").read_text(encoding="utf-8")


@pytest.mark.usefixtures("capturing_qsub")
async def test_no_validation_of_memory_per_job():
# Validation will happen during config parsing
driver = OpenPBSDriver(memory_per_job="a_lot")
await driver.submit(0, "sleep")
assert " -l mem=a_lot " in Path("captured_qsub_args").read_text(encoding="utf-8")

0 comments on commit 1e6fea5

Please sign in to comment.