From aa953ab04e949bcf751440a81ffbc82a9a5bd31a Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 23 Jan 2024 11:36:45 -0500 Subject: [PATCH] Setup a SLURM cluster in the GitHub CI for integration tests [MT-34] (#84) * Test out GH Action to setup a fake SLURM cluster Signed-off-by: Fabrice Normandin * Change the scope to also run on PRs Signed-off-by: Fabrice Normandin * Change the command to use `srun (...) hostname` Signed-off-by: Fabrice Normandin * Test out running tests that call srun over ssh Signed-off-by: Fabrice Normandin * Use `poetry run pytest` instead of `pytest` Signed-off-by: Fabrice Normandin * Try to test the `ensure_allocation` method Signed-off-by: Fabrice Normandin * Simplify to avoid hanging on test setup Signed-off-by: Fabrice Normandin * Skip making a Connection (hopefully fixes hang) Signed-off-by: Fabrice Normandin * Try using a custom version of setup-slurm action Signed-off-by: Fabrice Normandin * Rename custom action file Signed-off-by: Fabrice Normandin * Try to fix the path to the custom action file Signed-off-by: Fabrice Normandin * Fix role number in custom action file Signed-off-by: Fabrice Normandin * Only mark one partition with Default: YES Signed-off-by: Fabrice Normandin * Only have `localhost` as a node Signed-off-by: Fabrice Normandin * Re-simplify test to check that slurm works Signed-off-by: Fabrice Normandin * Put the slurm playbook in a file Signed-off-by: Fabrice Normandin * Add main and unkillable partitions Signed-off-by: Fabrice Normandin * Trying to add tests using the local SLURM cluster Signed-off-by: Fabrice Normandin * Add `in_stream=False` to `run` and `simple_run` Signed-off-by: Fabrice Normandin * Simplify tests: greatly reduce need for -s flag Signed-off-by: Fabrice Normandin * `SlurmRemote.ensure_allocation` test works on Mila Signed-off-by: Fabrice Normandin * Try to make tests timeout instead of hang in CI Signed-off-by: Fabrice Normandin * Make slurm tests the integration tests in build Signed-off-by: Fabrice Normandin * Skip some tests for now to debug the CI issues Signed-off-by: Fabrice Normandin * Only run integration tests with slurm on linux :( Signed-off-by: Fabrice Normandin * Debugging hanging integration test Signed-off-by: Fabrice Normandin * Test if hanging test is due to nested sallocs Signed-off-by: Fabrice Normandin * Skip tests that use salloc/sbatch in GitHub CI :( Signed-off-by: Fabrice Normandin * Minor tying/docstring improvements to Remote class Signed-off-by: Fabrice Normandin * Add some tests for SlurmRemote.run and such Signed-off-by: Fabrice Normandin * Don't actually extract jobid from salloc for now Signed-off-by: Fabrice Normandin * Add sleeps so sacct can update to show recent jobs Signed-off-by: Fabrice Normandin * Mark tests that cause a hang in GitHub CI Signed-off-by: Fabrice Normandin * Add timeout of 3 minutes to integration tests step Signed-off-by: Fabrice Normandin * Remove check that fails in GitHub CI Signed-off-by: Fabrice Normandin * Update tests/cli/test_slurm_remote.py Co-authored-by: satyaog --------- Signed-off-by: Fabrice Normandin Signed-off-by: Fabrice Normandin Co-authored-by: satyaog --- .github/custom_setup_slurm_action/action.yml | 54 +++ .../slurm-playbook.yml | 74 ++++ .github/workflows/build.yml | 82 +++- milatools/cli/remote.py | 81 ++-- tests/cli/common.py | 2 - tests/cli/conftest.py | 8 +- tests/cli/test_remote.py | 17 +- tests/cli/test_slurm_remote.py | 394 ++++++++++++++++++ 8 files changed, 674 insertions(+), 38 deletions(-) create mode 100644 .github/custom_setup_slurm_action/action.yml create mode 100644 .github/custom_setup_slurm_action/slurm-playbook.yml create mode 100644 tests/cli/test_slurm_remote.py diff --git a/.github/custom_setup_slurm_action/action.yml b/.github/custom_setup_slurm_action/action.yml new file mode 100644 index 00000000..f16e4665 --- /dev/null +++ b/.github/custom_setup_slurm_action/action.yml @@ -0,0 +1,54 @@ +name: "setup-slurm-action" +description: "Setup slurm cluster on GitHub Actions using https://github.com/galaxyproject/ansible-slurm" +branding: + icon: arrow-down-circle + color: blue +runs: + using: "composite" + steps: + # prior to slurm-setup we need the podmand-correct command + # see https://github.com/containers/podman/issues/13338 + - name: Download slurm ansible roles + shell: bash -e {0} + # ansible-galaxy role install https://github.com/galaxyproject/ansible-slurm/archive/1.0.1.tar.gz + run: | + ansible-galaxy role install https://github.com/mila-iqia/ansible-slurm/archive/1.1.2.tar.gz + + - name: Apt prerequisites + shell: bash -e {0} + run: | + sudo apt-get update + sudo apt-get install retry + + - name: Set XDG_RUNTIME_DIR + shell: bash -e {0} + run: | + mkdir -p /tmp/1002-runtime # work around podman issue (https://github.com/containers/podman/issues/13338) + echo XDG_RUNTIME_DIR=/tmp/1002-runtime >> $GITHUB_ENV + + - name: Setup slurm + shell: bash -e {0} + run: | + ansible-playbook ./.github/custom_setup_slurm_action/slurm-playbook.yml || (journalctl -xe && exit 1) + + - name: Add Slurm Account + shell: bash -e {0} + run: | + sudo retry --until=success -- sacctmgr -i create account "Name=runner" + sudo sacctmgr -i create user "Name=runner" "Account=runner" + + - name: Test srun submission + shell: bash -e {0} + run: | + srun -vvvv echo "hello world" + sudo cat /var/log/slurm/slurmd.log + + - name: Show partition info + shell: bash -e {0} + run: | + scontrol show partition + + - name: Test sbatch submission + shell: bash -e {0} + run: | + sbatch -vvvv -N 1 --mem 5 --wrap "echo 'hello world'" diff --git a/.github/custom_setup_slurm_action/slurm-playbook.yml b/.github/custom_setup_slurm_action/slurm-playbook.yml new file mode 100644 index 00000000..3b87a135 --- /dev/null +++ b/.github/custom_setup_slurm_action/slurm-playbook.yml @@ -0,0 +1,74 @@ +- name: Slurm all in One + hosts: localhost + roles: + - role: 1.1.2 + become: true + vars: + slurm_upgrade: true + slurm_roles: ["controller", "exec", "dbd"] + slurm_config_dir: /etc/slurm + slurm_config: + ClusterName: cluster + SlurmctldLogFile: /var/log/slurm/slurmctld.log + SlurmctldPidFile: /run/slurmctld.pid + SlurmdLogFile: /var/log/slurm/slurmd.log + SlurmdPidFile: /run/slurmd.pid + SlurmdSpoolDir: /tmp/slurmd # the default /var/lib/slurm/slurmd does not work because of noexec mounting in github actions + StateSaveLocation: /var/lib/slurm/slurmctld + AccountingStorageType: accounting_storage/slurmdbd + SelectType: select/cons_res + slurmdbd_config: + StorageType: accounting_storage/mysql + PidFile: /run/slurmdbd.pid + LogFile: /var/log/slurm/slurmdbd.log + StoragePass: root + StorageUser: root + StorageHost: 127.0.0.1 # see https://stackoverflow.com/questions/58222386/github-actions-using-mysql-service-throws-access-denied-for-user-rootlocalh + StoragePort: 8888 + DbdHost: localhost + slurm_create_user: yes + #slurm_munge_key: "../../../munge.key" + slurm_nodes: + - name: localhost + State: UNKNOWN + Sockets: 1 + CoresPerSocket: 2 + RealMemory: 2000 + # - name: cn-a[001-011] + # NodeAddr: localhost + # Gres: gpu:rtx8000:8 + # CPUs: 40 + # Boards: 1 + # SocketsPerBoard: 2 + # CoresPerSocket: 20 + # ThreadsPerCore: 1 + # RealMemory: 386618 + # TmpDisk: 3600000 + # State: UNKNOWN + # Feature: x86_64,turing,48gb + # - name: "cn-c[001-010]" + # CoresPerSocket: 18 + # Gres: "gpu:rtx8000:8" + # Sockets: 2 + # ThreadsPerCore: 2 + slurm_partitions: + - name: long + Default: YES + MaxTime: UNLIMITED + Nodes: "localhost" + - name: main + Default: NO + MaxTime: UNLIMITED + Nodes: "localhost" + - name: unkillable + Default: NO + MaxTime: UNLIMITED + Nodes: "localhost" + slurm_user: + comment: "Slurm Workload Manager" + gid: 1002 + group: slurm + home: "/var/lib/slurm" + name: slurm + shell: "/bin/bash" + uid: 1002 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1f05ceb0..f1aaf146 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,8 +3,8 @@ name: Python package on: [push, pull_request] jobs: - pre-commit: - name: Run pre-commit checks + linting: + name: Run linting/pre-commit checks runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -16,8 +16,8 @@ jobs: - run: pre-commit install - run: pre-commit run --all-files - test: - needs: [pre-commit] + unit-tests: + needs: [linting] runs-on: ${{ matrix.platform }} strategy: max-parallel: 4 @@ -70,3 +70,77 @@ jobs: env_vars: PLATFORM,PYTHON name: codecov-umbrella fail_ci_if_error: false + + integration-tests: + name: integration tests + needs: [unit-tests] + runs-on: ${{ matrix.platform }} + + strategy: + max-parallel: 5 + matrix: + # TODO: We should ideally also run this with Windows/Mac clients and a Linux + # server. Unsure how to set that up with GitHub Actions though. + platform: [ubuntu-latest] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + + # For the action to work, you have to supply a mysql + # service as defined below. + services: + mysql: + image: mysql:8.0 + env: + MYSQL_ROOT_PASSWORD: root + ports: + - "8888:3306" + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + + steps: + - uses: actions/checkout@v3 + + # NOTE: Replacing this with our customized version of + # - uses: koesterlab/setup-slurm-action@v1 + - uses: ./.github/custom_setup_slurm_action + + - name: Test if the slurm cluster is setup correctly + run: srun --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=00:01:00 hostname + + - name: Setup passwordless SSH access to localhost for tests + # Adapted from https://stackoverflow.com/a/60367309/6388696 + run: | + ssh-keygen -t ed25519 -f ~/.ssh/testkey -N '' + cat > ~/.ssh/config < ~/.ssh/authorized_keys + chmod og-rw ~ + ssh -o 'StrictHostKeyChecking no' localhost id + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install poetry + poetry install --with=dev + + - name: Launch integration tests + run: poetry run pytest tests/cli/test_slurm_remote.py --cov=milatools --cov-report=xml --cov-append -s -vvv --log-level=DEBUG + timeout-minutes: 3 + env: + SLURM_CLUSTER: localhost + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: integrationtests + env_vars: PLATFORM,PYTHON + name: codecov-umbrella + fail_ci_if_error: false diff --git a/milatools/cli/remote.py b/milatools/cli/remote.py index c17419ab..1b546896 100644 --- a/milatools/cli/remote.py +++ b/milatools/cli/remote.py @@ -146,6 +146,7 @@ def display(self, cmd: str) -> None: def _run( self, cmd: str, + *, hide: Literal[True, False, "out", "stdout", "err", "stderr"] = False, warn: bool = False, asynchronous: bool = False, @@ -170,43 +171,49 @@ def _run( ) def simple_run(self, cmd: str): - return self._run(cmd, hide=True) + return self._run(cmd, hide=True, in_stream=False) @overload def run( self, cmd: str, + *, display: bool | None = None, hide: bool = False, warn: bool = False, - asynchronous: Literal[True] = True, + asynchronous: Literal[False] = False, out_stream: TextIO | None = None, + in_stream: TextIO | bool = False, **kwargs, - ) -> invoke.runners.Promise: + ) -> invoke.runners.Result: ... @overload def run( self, cmd: str, + *, display: bool | None = None, hide: bool = False, warn: bool = False, - asynchronous: Literal[False] = False, + asynchronous: Literal[True] = True, out_stream: TextIO | None = None, + in_stream: TextIO | bool = False, **kwargs, - ) -> invoke.runners.Result: + ) -> invoke.runners.Promise: ... @overload def run( self, cmd: str, + *, display: bool | None = None, hide: bool = False, warn: bool = False, - asynchronous: bool = False, + asynchronous: bool = ..., out_stream: TextIO | None = None, + in_stream: TextIO | bool = False, **kwargs, ) -> invoke.runners.Result | invoke.runners.Promise: ... @@ -219,6 +226,7 @@ def run( warn: bool = False, asynchronous: bool = False, out_stream: TextIO | None = None, + in_stream: TextIO | bool = False, **kwargs, ) -> invoke.runners.Promise | invoke.runners.Result: """Run a command on the remote host, returning the `invoke.Result`. @@ -233,7 +241,7 @@ def run( Parameters ---------- cmd: The command to run - display: TODO: add a description of what this argument does. + display: Displays the host name and the command in colour before running it. hide: ``'out'`` (or ``'stdout'``) to hide only the stdout stream, \ ``hide='err'`` (or ``'stderr'``) to hide only stderr, or ``hide='both'`` \ (or ``True``) to hide both streams. @@ -262,6 +270,7 @@ def run( warn=warn, asynchronous=asynchronous, out_stream=out_stream, + in_stream=in_stream, **kwargs, ) @@ -277,6 +286,7 @@ def get_output( display=display, hide=hide, warn=warn, + in_stream=False, ).stdout.strip() @deprecated( @@ -305,7 +315,7 @@ def extract( pty: bool = True, hide: bool = False, **kwargs, - ) -> tuple[fabric.runners.Runner, dict[str, str]]: + ) -> tuple[fabric.runners.Remote, dict[str, str]]: # TODO: We pass this `QueueIO` class to `connection.run`, which expects a # file-like object and defaults to sys.stdout (a TextIO). However they only use # the `write` and `flush` methods, which means that this QueueIO is actually @@ -314,7 +324,7 @@ def extract( # reading from it, and pass that `io.StringIO` buffer to `self.run`. qio: TextIO = QueueIO() - proc = self.run( + promise = self.run( cmd, hide=hide, asynchronous=True, @@ -322,9 +332,12 @@ def extract( pty=pty, **kwargs, ) + runner = promise.runner + assert isinstance(runner, fabric.runners.Remote) + results: dict[str, str] = {} try: - for line in qio.readlines(lambda: proc.runner.process_is_finished): + for line in qio.readlines(lambda: runner.process_is_finished): print(line, end="") for name, patt in list(patterns.items()): m = re.search(patt, line) @@ -332,21 +345,21 @@ def extract( results[name] = m.groups()[0] patterns.pop(name) if not patterns and not wait: - return proc.runner, results + return runner, results # Check what the job id is when we sbatch m = re.search("^Submitted batch job ([0-9]+)", line) if m: results["batch_id"] = m.groups()[0] except KeyboardInterrupt: - proc.runner.kill() + runner.kill() if "batch_id" in results: # We need to preemptively cancel the job so that it doesn't # clutter the user's squeue when they Ctrl+C self.simple_run(f"scancel {results['batch_id']}") raise - proc.join() - return proc.runner, results + promise.join() + return runner, results def get(self, src: str, dest: str | None) -> fabric.transfer.Result: return self.connection.get(src, dest) @@ -457,14 +470,23 @@ def persist(self): def ensure_allocation( self, - ) -> tuple[NodeNameDict | NodeNameAndJobidDict, invoke.runners.Runner]: - """Requests a compute node from the cluster if not already allocated. - - Returns a dictionary with the `node_name`, and additionally the `jobid` if this - Remote is already connected to a compute node. + ) -> tuple[NodeNameDict | NodeNameAndJobidDict, fabric.runners.Remote]: + """Makes the `salloc` or `sbatch` call and waits until the job starts. + + When `persist=True`: + - uses `sbatch` + - returns a tuple with: + - a dict with the compute node name and the jobid + - a `fabric.runners.Remote` object connected to the *login* node. + + When `persist=False`: + - uses `salloc` + - returns a tuple with: + - a dict with the compute node name (without the jobid) + - a `fabric.runners.Remote` object connected to the *login* node. """ if self._persist: - proc, results = self.extract( + login_node_runner, results = self.extract( "echo @@@ $(hostname) @@@ && sleep 1000d", patterns={ "node_name": "@@@ ([^ ]+) @@@", @@ -473,15 +495,26 @@ def ensure_allocation( hide=True, ) node_name = get_first_node_name(results["node_name"]) - return {"node_name": node_name, "jobid": results["jobid"]}, proc + return { + "node_name": node_name, + "jobid": results["jobid"], + }, login_node_runner else: remote = Remote(hostname="->", connection=self.connection).with_bash() - proc, results = remote.extract( + login_node_runner, results = remote.extract( shjoin(["salloc", *self.alloc]), - patterns={"node_name": "salloc: Nodes ([^ ]+) are ready for job"}, + patterns={ + "node_name": "salloc: Nodes ([^ ]+) are ready for job", + # TODO: This would also work! + # "jobid": "salloc: Granted job allocation ([0-9]+)", + }, ) # The node name can look like 'cn-c001', or 'cn-c[001-003]', or # 'cn-c[001,008]', or 'cn-c001,rtx8', etc. We will only connect to # a single one, though, so we will simply pick the first one. node_name = get_first_node_name(results["node_name"]) - return {"node_name": node_name}, proc + return { + "node_name": node_name, + # TODO: This would also work! + # "jobid": results["jobid"], + }, login_node_runner diff --git a/tests/cli/common.py b/tests/cli/common.py index de8eb31d..479b3c06 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -54,8 +54,6 @@ "-s" not in sys.argv, reason=REQUIRES_S_FLAG_REASON, ) - - requires_no_s_flag = pytest.mark.skipif( "-s" in sys.argv, reason="Passing pytest's -s flag makes this test fail.", diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 1914600b..1bfb1bee 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,12 +1,16 @@ from __future__ import annotations +import os + import pytest from .common import REQUIRES_S_FLAG_REASON +in_github_ci = "PLATFORM" in os.environ + -@pytest.fixture(autouse=True) -def skip_if_s_flag_passed_and_test_doesnt_require_it( +@pytest.fixture(autouse=in_github_ci) +def skip_if_s_flag_passed_during_ci_run_and_test_doesnt_require_it( request: pytest.FixtureRequest, pytestconfig: pytest.Config ): capture_value = pytestconfig.getoption("-s") diff --git a/tests/cli/test_remote.py b/tests/cli/test_remote.py index 3d4884fa..a71d48aa 100644 --- a/tests/cli/test_remote.py +++ b/tests/cli/test_remote.py @@ -143,7 +143,6 @@ def test_init_with_connection( # Note: We could actually run this for real also! -@requires_s_flag @pytest.mark.parametrize("command_to_run", ["echo OK"]) @pytest.mark.parametrize("initial_transforms", [[]]) @pytest.mark.parametrize( @@ -278,7 +277,6 @@ def hide( assert output.err == "" -@requires_s_flag @pytest.mark.parametrize(("command", "expected_output"), [("echo OK", "OK")]) @pytest.mark.parametrize("asynchronous", [True, False]) @pytest.mark.parametrize("warn", [True, False]) @@ -320,6 +318,7 @@ def test_run( hide=hide, warn=warn, out_stream=None, + in_stream=False, ) remote.connection.local.assert_not_called() @@ -366,6 +365,7 @@ def test_get_output( hide=hide, warn=warn, out_stream=None, + in_stream=False, ) mock_result.stdout.strip.assert_called_once_with() @@ -471,7 +471,6 @@ def test_put(remote: Remote, tmp_path: Path): assert dest.read_text() == source_content -@requires_s_flag def test_puttext(remote: Remote, tmp_path: Path): _xfail_if_not_on_localhost(remote.hostname) dest_dir = tmp_path / "bar/baz" @@ -484,17 +483,22 @@ def test_puttext(remote: Remote, tmp_path: Path): out_stream=None, hide=True, warn=False, + in_stream=False, ) # The first argument of `put` will be the name of a temporary file. remote.connection.put.assert_called_once_with(unittest.mock.ANY, str(dest)) assert dest.read_text() == some_text -@requires_s_flag def test_home(remote: Remote): home_dir = remote.home() remote.connection.run.assert_called_once_with( - "echo $HOME", asynchronous=False, out_stream=None, warn=False, hide=True + "echo $HOME", + asynchronous=False, + out_stream=None, + warn=False, + hide=True, + in_stream=False, ) remote.connection.local.assert_not_called() if remote.hostname == "mila": @@ -550,7 +554,6 @@ def test_srun_transform(self, mock_connection: Connection): command = "bob" assert remote.srun_transform(command) == f"srun {alloc[0]} bash -c {command}" - @requires_s_flag def test_srun_transform_persist( self, mock_connection: Connection, @@ -724,6 +727,7 @@ def write_stuff( pty: bool, out_stream: QueueIO, warn: bool, + in_stream: bool, ): assert command == expected_command out_stream.write(f"salloc: Nodes {node} are ready for job") @@ -739,6 +743,7 @@ def write_stuff( asynchronous=True, out_stream=unittest.mock.ANY, pty=True, + in_stream=False, ) assert results == {"node_name": node} diff --git a/tests/cli/test_slurm_remote.py b/tests/cli/test_slurm_remote.py new file mode 100644 index 00000000..30e35b78 --- /dev/null +++ b/tests/cli/test_slurm_remote.py @@ -0,0 +1,394 @@ +"""Tests that use an actual SLURM cluster. + +The cluster to use can be specified by setting the SLURM_CLUSTER environment variable. +During the CI on GitHub, a small local slurm cluster is setup with a GitHub Action, and +SLURM_CLUSTER is set to "localhost". +""" +from __future__ import annotations + +import datetime +import functools +import os +import time +from logging import getLogger as get_logger + +import fabric.runners +import pytest + +from milatools.cli.remote import Remote, SlurmRemote + +logger = get_logger(__name__) + +SLURM_CLUSTER = os.environ.get("SLURM_CLUSTER") +JOB_NAME = "milatools_test" +WCKEY = "milatools_test" +MAX_JOB_DURATION = datetime.timedelta(seconds=10) +# BUG: pytest-timeout seems to cause issues with paramiko threads.. +# pytestmark = pytest.mark.timeout(60) + + +_SACCT_UPDATE_DELAY = datetime.timedelta(seconds=10) +"""How long after salloc/sbatch before we expect to see the job show up in sacct.""" + +requires_access_to_slurm_cluster = pytest.mark.skipif( + not SLURM_CLUSTER, + reason="Requires ssh access to a SLURM cluster.", +) + +# TODO: Import the value from `milatools.cli.utils` once the other PR adds it. +CLUSTERS = ["mila", "narval", "cedar", "beluga", "graham"] + + +@pytest.fixture(scope="session", params=[SLURM_CLUSTER]) +def cluster(request: pytest.FixtureRequest) -> str: + """Fixture that gives the hostname of the slurm cluster to use for tests. + + NOTE: The `cluster` can also be parametrized indirectly by tests, for example: + + ```python + @pytest.mark.parametrize("cluster", ["mila", "some_cluster"], indirect=True) + def test_something(remote: Remote): + ... # here the remote is connected to the cluster specified above! + ``` + """ + slurm_cluster_hostname = request.param + + if not slurm_cluster_hostname: + pytest.skip("Requires ssh access to a SLURM cluster.") + return slurm_cluster_hostname + + +def can_run_on_all_clusters(): + """Makes a given test run on all the clusters in `CLUSTERS`, *for real*! + + NOTE: (@lebrice): Unused here at the moment in the GitHub CI, but locally I'm + enabling it sometimes to test stuff on DRAC clusters. + """ + return pytest.mark.parametrize("cluster", CLUSTERS, indirect=True) + + +@pytest.fixture() +def login_node(cluster: str) -> Remote: + """Fixture that gives a Remote connected to the login node of the slurm cluster. + + NOTE: Making this a function-scoped fixture because the Connection object is of the + Remote is used when creating the SlurmRemotes. + """ + return Remote(cluster) + + +@pytest.fixture(scope="module", autouse=True) +def cancel_all_milatools_jobs_before_and_after_tests(cluster: str): + # Note: need to recreate this because login_node is a function-scoped fixture. + login_node = Remote(cluster) + login_node.run(f"scancel -u $USER --wckey={WCKEY}") + time.sleep(1) + yield + login_node.run(f"scancel -u $USER --wckey={WCKEY}") + time.sleep(1) + # Display the output of squeue just to be sure that the jobs were cancelled. + login_node._run("squeue --me", echo=True, in_stream=False) + + +@functools.lru_cache() +def get_slurm_account(cluster: str) -> str: + """Gets the SLURM account of the user using sacctmgr on the slurm cluster. + + When there are multiple accounts, this selects the first account, alphabetically. + + On DRAC cluster, this uses the `def` allocations instead of `rrg`, and when + the rest of the accounts are the same up to a '_cpu' or '_gpu' suffix, it uses + '_cpu'. + + For example: + + ```text + def-someprofessor_cpu <-- this one is used. + def-someprofessor_gpu + rrg-someprofessor_cpu + rrg-someprofessor_gpu + ``` + """ + # note: recreating the Connection here because this will be called for every test + # and we use functools.cache to cache the result, so the input has to be a simpler + # value like a string. + result = fabric.Connection(cluster).run( + "sacctmgr --noheader show associations where user=$USER format=Account%50", + echo=True, + in_stream=False, + ) + assert isinstance(result, fabric.runners.Result) + accounts: list[str] = [line.strip() for line in result.stdout.splitlines()] + assert accounts + logger.info(f"Accounts on the slurm cluster {cluster}: {accounts}") + account = sorted(accounts)[0] + logger.info(f"Using account {account} to launch jobs in tests.") + return account + + +def get_recent_jobs_info( + login_node: Remote, + since=datetime.timedelta(minutes=5), + fields=("JobID", "JobName", "Node", "State"), +) -> list[tuple[str, ...]]: + """Returns a list of fields for jobs that started recently.""" + # otherwise this would launch a job! + assert not isinstance(login_node, SlurmRemote) + lines = login_node.run( + f"sacct --noheader --allocations " + f"--starttime=now-{int(since.total_seconds())}seconds " + "--format=" + ",".join(f"{field}%40" for field in fields), + echo=True, + in_stream=False, + ).stdout.splitlines() + # note: using maxsplit because the State field can contain spaces: "canceled by ..." + return [tuple(line.strip().split(maxsplit=len(fields))) for line in lines] + + +def sleep_so_sacct_can_update(): + print("Sleeping so sacct can update...") + time.sleep(_SACCT_UPDATE_DELAY.total_seconds()) + + +@pytest.fixture() +def allocation_flags(cluster: str, request: pytest.FixtureRequest): + # note: thanks to lru_cache, this is only making one ssh connection per cluster. + account = get_slurm_account(cluster) + allocation_options = { + "job-name": JOB_NAME, + "wckey": WCKEY, + "account": account, + "nodes": 1, + "ntasks": 1, + "cpus-per-task": 1, + "mem": "1G", + "time": MAX_JOB_DURATION, + "oversubscribe": None, # allow multiple such jobs to share resources. + } + overrides = getattr(request, "param", {}) + assert isinstance(overrides, dict) + if overrides: + print(f"Overriding allocation options with {overrides}") + allocation_options.update(overrides) + return " ".join( + [ + f"--{key}={value}" if value is not None else f"--{key}" + for key, value in allocation_options.items() + ] + ) + + +@requires_access_to_slurm_cluster +def test_cluster_setup(login_node: Remote, allocation_flags: str): + """Sanity Checks for the SLURM cluster of the CI: checks that `srun` works. + + NOTE: This is more-so a test to check that the slurm cluster used in the GitHub CI + is setup correctly, rather than to check that the Remote/SlurmRemote work correctly. + """ + + job_id, compute_node = ( + login_node.get_output( + f"srun {allocation_flags} bash -c 'echo $SLURM_JOB_ID $SLURMD_NODENAME'" + ) + .strip() + .split() + ) + assert compute_node + assert job_id.isdigit() + + sleep_so_sacct_can_update() + + # NOTE: the job should be done by now, since `.run` of the Remote is called with + # asynchronous=False. + sacct_output = get_recent_jobs_info(login_node, fields=("JobID", "JobName", "Node")) + assert (job_id, JOB_NAME, compute_node) in sacct_output + + +@pytest.fixture +def salloc_slurm_remote(login_node: Remote, allocation_flags: str): + """Fixture that creates a `SlurmRemote` that uses `salloc` (persist=False). + + The SlurmRemote is essentially just a Remote with an added `ensure_allocation` + method as well as a transform that does `salloc` or `sbatch` with some allocation + flags before a command is run. + """ + return SlurmRemote( + connection=login_node.connection, + alloc=allocation_flags.split(), + ) + + +@pytest.fixture +def sbatch_slurm_remote(login_node: Remote, allocation_flags: str): + """Fixture that creates a `SlurmRemote` that uses `sbatch` (persist=True).""" + return SlurmRemote( + connection=login_node.connection, + alloc=allocation_flags.split(), + persist=True, + ) + + +## Tests for the SlurmRemote class: + + +@requires_access_to_slurm_cluster +def test_run( + login_node: Remote, + salloc_slurm_remote: SlurmRemote, +): + """Test for `SlurmRemote.run` with persist=False without an initial call to + `ensure_allocation`. + + This should use `srun` from the login node to run the command, and return a result. + """ + result = salloc_slurm_remote.run("echo $SLURM_JOB_ID $SLURMD_NODENAME", echo=True) + assert isinstance(result, fabric.runners.Result) + output_lines = result.stdout.strip().splitlines() + assert len(output_lines) == 1 + job_id, compute_node = output_lines[0].split() + + assert compute_node + assert job_id.isdigit() + + # This stuff gets printed out when you do an `srun` on the login node (at least it + # does on the Mila cluster. Doesn't seem to be the case when using the slurm cluster + # on localhost in the CI.) + if login_node.hostname != "localhost": + assert "srun: ----" in result.stderr + + # NOTE: the job should be done by now, since `.run` of the Remote is called with + # asynchronous=False. + + # This check is to make sure that even though the Remote and SlurmRemote share the + # same fabric.Connection object, this is actually running on the login node, and not + # on the compute node. + # TODO: Move this check to a test for salloc, since here it's working because the + # job is completed, not because it's being executed on the login node. + login_node_hostname = login_node.get_output("hostname") + assert login_node_hostname != compute_node + + sleep_so_sacct_can_update() + + sacct_output = get_recent_jobs_info(login_node, fields=("JobID", "JobName", "Node")) + assert (job_id, JOB_NAME, compute_node) in sacct_output + + +hangs_in_github_CI = pytest.mark.skipif( + SLURM_CLUSTER == "localhost", reason="BUG: Hangs in the GitHub CI.." +) + + +@hangs_in_github_CI +@requires_access_to_slurm_cluster +def test_ensure_allocation( + login_node: Remote, + salloc_slurm_remote: SlurmRemote, + capsys: pytest.CaptureFixture[str], +): + """Test that `ensure_allocation` calls salloc for a SlurmRemote with persist=False. + + Calling `ensure_allocation` on a SlurmRemote with persist=False should: + 1. Call `salloc` on the login node and retrieve the allocated node name, + 2. return a dict with the node name and a remote runner that is: + TODO: What should it be? + - connected to the login node (what's currently happening) + - connected to the compute node through the interactive terminal of salloc on + the login node. + + FIXME: It should be made impossible / a critical error to make more than a single + call to `run` or `ensure_allocation` on a SlurmRemote, because every call to `run` + creates a new job! + (This is because .run applies the transforms to the command, and the last + transform adds either `salloc` or `sbatch`, hence every call to `run` launches + either an interactive or batch job!) + """ + data, remote_runner = salloc_slurm_remote.ensure_allocation() + assert isinstance(remote_runner, fabric.runners.Remote) + + assert "node_name" in data + # NOTE: it very well could be if we also extracted it from the salloc output. + assert "jobid" not in data + compute_node_from_salloc_output = data["node_name"] + + # Check that salloc was called. This would be printed to stderr by fabric if we were + # using `run`, but `ensure_allocation` uses `extract` which somehow doesn't print it + salloc_stdout, salloc_stderr = capsys.readouterr() + # assert not stdout + # assert not stderr + assert "salloc: Granted job allocation" in salloc_stdout + assert ( + f"salloc: Nodes {compute_node_from_salloc_output} are ready for job" + in salloc_stdout + ) + assert not salloc_stderr + + # Check that the returned remote runner is indeed connected to a *login* node (?!) + # NOTE: This is a fabric.runners.Remote, not a Remote or SlurmRemote of milatools, + # so calling `.run` doesn't launch a job. + result = remote_runner.run("hostname", echo=True, in_stream=False) + assert result + hostname_from_remote_runner = result.stdout.strip() + result2 = login_node.run("hostname", echo=True, in_stream=False) + assert result2 + hostname_from_login_node_runner = result2.stdout.strip() + assert hostname_from_remote_runner == hostname_from_login_node_runner + + # TODO: IF the remote runner was to be connected to the compute node through the + # same interactive terminal, then we'd use this: + # result = remote_runner.run( + # "echo $SLURM_JOB_ID $SLURMD_NODENAME", + # echo=True, + # echo_format=T.bold_cyan( + # f"({compute_node_from_salloc_output})" + " $ {command}" + # ), + # in_stream=False, + # ) + # assert result + # assert not result.stderr + # assert result.stdout.strip() + # job_id, compute_node = result.stdout.strip().split() + # # cn-a001 vs cn-a001.server.mila.quebec for example. + # assert compute_node.startswith(compute_node_from_salloc_output) + # assert compute_node != login_node.hostname # hopefully also works in CI... + + # NOTE: too brittle. + # if datetime.datetime.now() - start_time < MAX_JOB_DURATION: + # # Check that the job shows up as still running in the output of `sacct`, since + # # we should not have reached the end time yet. + # sacct_output = get_recent_jobs_info( + # login_node, fields=("JobName", "Node", "State") + # ) + # assert [JOB_NAME, compute_node_from_salloc_output, "RUNNING"] in sacct_output + + print(f"Sleeping for {MAX_JOB_DURATION.total_seconds()}s until job finishes...") + time.sleep(MAX_JOB_DURATION.total_seconds()) + + sacct_output = get_recent_jobs_info(login_node, fields=("JobName", "Node", "State")) + assert (JOB_NAME, compute_node_from_salloc_output, "COMPLETED") in sacct_output + + +@hangs_in_github_CI +@requires_access_to_slurm_cluster +def test_ensure_allocation_sbatch(login_node: Remote, sbatch_slurm_remote: SlurmRemote): + job_data, login_node_remote_runner = sbatch_slurm_remote.ensure_allocation() + print(job_data, login_node_remote_runner) + assert isinstance(login_node_remote_runner, fabric.runners.Remote) + + node_hostname = job_data["node_name"] + assert "jobid" in job_data + job_id_from_sbatch_extract = job_data["jobid"] + + sleep_so_sacct_can_update() + + job_infos = get_recent_jobs_info(login_node, fields=("JobId", "JobName", "Node")) + # NOTE: `.extract`, used by `ensure_allocation`, actually returns the full node + # hostname as an output (e.g. cn-a001.server.mila.quebec), but sacct only shows the + # node name. + assert any( + ( + job_id_from_sbatch_extract == job_id + and JOB_NAME == job_name + and node_hostname.startswith(node_name) + for job_id, job_name, node_name in job_infos + ) + )