Skip to content

Commit

Permalink
Fix --persist mila code bug and intermittent connection errors [MT-…
Browse files Browse the repository at this point in the history
…78] (#101)

* Fix --persist bug and unneeded cd $SCRATCH

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix broken tests, add integration test dir

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix check_passwordless bug, set banner_timeout=60

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove unneeded mark and test run

Signed-off-by: Fabrice Normandin <[email protected]>

* Centralize references to connect_kwargs

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove outdated todo

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug in check_disk_quota and add test

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove duplicate (moved) test_slurm_remote.py file

Signed-off-by: Fabrice Normandin <[email protected]>

* Add/improve integration test for `mila code`

Signed-off-by: Fabrice Normandin <[email protected]>

* Hide the 'which lfs' command

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix typing error in test_commands.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug and misleading type for `alloc` argument

`alloc` needs to be a list of strings, but it was typed as
`Sequence[str]`, which allows `str` to be passed (since `str`s are
sequences of `str`s).

This changes it to `list[str]` which is stricter and correct.

Incidentally, there was an undetected bug in the regression test at
`tests/integration/test_code_command.py::test_code` because I was
passing the allocation flags (a string) as salloc.

Signed-off-by: Fabrice Normandin <[email protected]>

* Change test fixture, adjust tests

Signed-off-by: Fabrice Normandin <[email protected]>

* Adjust the way we fetch SLURM accounts in tests

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix unused imports in conftest.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix typing error in python 3.8

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix other type error in python 3.8

Signed-off-by: Fabrice Normandin <[email protected]>

* Apply suggestions from code review

Co-authored-by: satyaog <[email protected]>

* Add a `currently_in_a_test` function

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix same issue `mila serve` commands

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
Co-authored-by: satyaog <[email protected]>
  • Loading branch information
lebrice and satyaog authored Feb 19, 2024
1 parent b823319 commit 7c733d1
Show file tree
Hide file tree
Showing 18 changed files with 561 additions and 279 deletions.
11 changes: 4 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python -m pip install --upgrade pip
pip install poetry
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: poetry
Expand All @@ -59,9 +59,6 @@ jobs:
- name: Test with pytest
run: poetry run pytest --cov=milatools --cov-report=xml --cov-append

- name: Test with pytest (with -s flag)
run: poetry run pytest --cov=milatools --cov-report=xml --cov-append -s

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down Expand Up @@ -96,7 +93,7 @@ jobs:
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

# NOTE: Replacing this with our customized version of
# - uses: koesterlab/setup-slurm-action@v1
Expand All @@ -120,7 +117,7 @@ jobs:
ssh -o 'StrictHostKeyChecking no' localhost id
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -131,7 +128,7 @@ jobs:
poetry install --with=dev
- name: Launch integration tests
run: poetry run pytest tests/cli/test_slurm_remote.py --cov=milatools --cov-report=xml --cov-append -s -vvv --log-level=DEBUG
run: poetry run pytest tests/integration --cov=milatools --cov-report=xml --cov-append -s -vvv --log-level=DEBUG
timeout-minutes: 3
env:
SLURM_CLUSTER: localhost
Expand Down
132 changes: 80 additions & 52 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
MilatoolsUserError,
SSHConnectionError,
T,
cluster_to_connect_kwargs,
currently_in_a_test,
get_fully_qualified_hostname_of_compute_node,
get_fully_qualified_name,
make_process,
Expand Down Expand Up @@ -492,7 +494,7 @@ def code(
persist: bool,
job: str | None,
node: str | None,
alloc: Sequence[str],
alloc: list[str],
cluster: Cluster = "mila",
):
"""Open a remote VSCode session on a compute node.
Expand Down Expand Up @@ -525,13 +527,12 @@ def code(
if command is None:
command = os.environ.get("MILATOOLS_CODE_COMMAND", "code")

if remote.hostname != "graham": # graham doesn't use lustre for $HOME
try:
check_disk_quota(remote)
except MilatoolsUserError:
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")
try:
check_disk_quota(remote)
except MilatoolsUserError:
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")

vscode_extensions_folder = Path.home() / ".vscode/extensions"
if vscode_extensions_folder.exists() and no_internet_on_compute_nodes(cluster):
Expand Down Expand Up @@ -618,6 +619,8 @@ def code(
"The editor was closed. Reopen it with <Enter>"
" or terminate the process with <Ctrl+C>"
)
if currently_in_a_test():
break
input()

except KeyboardInterrupt:
Expand Down Expand Up @@ -714,7 +717,7 @@ def serve_list(purge: bool):


class StandardServerArgs(TypedDict):
alloc: Sequence[str]
alloc: list[str]
"""Extra options to pass to slurm."""

job: str | None
Expand Down Expand Up @@ -922,7 +925,7 @@ def _standard_server(
name: str | None,
node: str | None,
job: str | None,
alloc: Sequence[str],
alloc: list[str],
port_pattern=None,
token_pattern=None,
):
Expand Down Expand Up @@ -1074,54 +1077,83 @@ def _standard_server(
proc.kill()


def _get_disk_quota_usage(
remote: Remote, print_command_output: bool = True
def _parse_lfs_quota_output(
lfs_quota_output: str,
) -> tuple[tuple[float, float], tuple[int, int]]:
"""Checks the disk quota on the $HOME filesystem on the mila cluster.
Returns whether the quota is exceeded, in terms of storage space or number of files.
"""
"""Parses space and # of files (usage, limit) from the output of `lfs quota`."""
lines = lfs_quota_output.splitlines()

header_line: str | None = None
header_line_index: int | None = None
for index, line in enumerate(lines):
if (
len(line_parts := line.strip().split()) == 9
and line_parts[0].lower() == "filesystem"
):
header_line = line
header_line_index = index
break
assert header_line
assert header_line_index is not None

values_line_parts: list[str] = []
# The next line may overflow to two (or maybe even more?) lines if the name of the
# $HOME dir is too long.
for content_line in lines[header_line_index + 1 :]:
additional_values = content_line.strip().split()
assert len(values_line_parts) < 9
values_line_parts.extend(additional_values)
if len(values_line_parts) == 9:
break

# NOTE: This is what the output of the command looks like on the Mila cluster:
#
# $ lfs quota -u $USER /home/mila
# Disk quotas for usr normandf (uid 1471600598):
# Filesystem kbytes quota limit grace files quota limit grace
# /home/mila 101440844 0 104857600 - 936140 0 1048576 -
# uid 1471600598 is using default block quota setting
# uid 1471600598 is using default file quota setting
#
home_disk_quota_output = remote.get_output(
"lfs quota -u $USER $HOME", hide=not print_command_output
)
lines = home_disk_quota_output.splitlines()
assert len(values_line_parts) == 9, values_line_parts
(
_filesystem,
used_kbytes,
_quota1,
_quota_kbytes,
limit_kbytes,
_grace1,
_grace_kbytes,
files,
_quota2,
_quota_files,
limit_files,
_grace2,
) = lines[2].strip().split()
_grace_files,
) = values_line_parts

used_gb = float(int(used_kbytes.strip()) / (1024) ** 2)
max_gb = float(int(limit_kbytes.strip()) / (1024) ** 2)
used_gb = int(used_kbytes.strip()) / (1024**2)
max_gb = int(limit_kbytes.strip()) / (1024**2)
used_files = int(files.strip())
max_files = int(limit_files.strip())
return (used_gb, max_gb), (used_files, max_files)


def check_disk_quota(remote: Remote) -> None:
cluster = (
"mila" # todo: if we run this on CC, then we should use `diskusage_report`
)
# todo: Check the disk-quota of other filesystems if needed.
filesystem = "$HOME"
cluster = remote.hostname

# NOTE: This is what the output of the command looks like on the Mila cluster:
#
# Disk quotas for usr normandf (uid 1471600598):
# Filesystem kbytes quota limit grace files quota limit grace
# /home/mila/n/normandf
# 95747836 0 104857600 - 908722 0 1048576 -
# uid 1471600598 is using default block quota setting
# uid 1471600598 is using default file quota setting

# Need to assert this, otherwise .get_output calls .run which would spawn a job!
assert not isinstance(remote, SlurmRemote)
if not remote.get_output("which lfs", hide=True):
logger.debug("Cluster doesn't have the lfs command. Skipping check.")
return

logger.debug("Checking disk quota on $HOME...")
(used_gb, max_gb), (used_files, max_files) = _get_disk_quota_usage(remote)

home_disk_quota_output = remote.get_output("lfs quota -u $USER $HOME", hide=True)
if "not on a mounted Lustre filesystem" in home_disk_quota_output:
logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.")
return

(used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output(
home_disk_quota_output
)
logger.debug(
f"Disk usage: {used_gb:.1f} / {max_gb} GiB and {used_files} / {max_files} files"
)
Expand All @@ -1144,7 +1176,7 @@ def check_disk_quota(remote: Remote) -> None:
if used_gb >= max_gb or used_files >= max_files:
raise MilatoolsUserError(
T.red(
f"ERROR: Your disk quota on the {filesystem} filesystem is exceeded! "
f"ERROR: Your disk quota on the $HOME filesystem is exceeded! "
f"({reason}).\n"
f"To fix this, login to the cluster with `ssh {cluster}` and free up "
f"some space, either by deleting files, or by moving them to a "
Expand All @@ -1153,24 +1185,20 @@ def check_disk_quota(remote: Remote) -> None:
)
if max(size_ratio, files_ratio) > 0.9:
warning_message = (
f"WARNING: You are getting pretty close to your disk quota on the $HOME "
f"You are getting pretty close to your disk quota on the $HOME "
f"filesystem: ({reason})\n"
"Please consider freeing up some space in your $HOME folder, either by "
"deleting files, or by moving them to a more suitable filesystem.\n"
+ freeing_up_space_instructions
)
# TODO: Perhaps we could use the logger or the warnings package instead of just
# printing?
# logger.warning(UserWarning(warning_message))
# warnings.warn(UserWarning(T.yellow(warning_message)))
print(UserWarning(T.yellow(warning_message)))
logger.warning(UserWarning(warning_message))


def _find_allocation(
remote: Remote,
node: str | None,
job: str | None,
alloc: Sequence[str],
alloc: list[str],
cluster: Cluster = "mila",
job_name: str = "mila-tools",
):
Expand All @@ -1179,11 +1207,11 @@ def _find_allocation(

if node is not None:
node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster)
return Remote(node_name)
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

elif job is not None:
node_name = remote.get_output(f"squeue --jobs {job} -ho %N")
return Remote(node_name)
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

else:
alloc = ["-J", job_name, *alloc]
Expand Down
10 changes: 8 additions & 2 deletions milatools/cli/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import paramiko.ssh_exception
from typing_extensions import deprecated

from .utils import CommandNotFoundError, T, shjoin
from .utils import CommandNotFoundError, T, cluster_to_connect_kwargs, shjoin

logger = get_logger(__name__)

Expand Down Expand Up @@ -76,7 +76,13 @@ def display(split_command: list[str] | tuple[str, ...] | str) -> None:

def check_passwordless(host: str) -> bool:
try:
with fabric.Connection(host) as connection:
connect_kwargs_for_host = {"allow_agent": False}
if host in cluster_to_connect_kwargs:
connect_kwargs_for_host.update(cluster_to_connect_kwargs[host])
with fabric.Connection(
host,
connect_kwargs=connect_kwargs_for_host,
) as connection:
results: fabric.runners.Result = connection.run(
"echo OK",
in_stream=False,
Expand Down
Loading

0 comments on commit 7c733d1

Please sign in to comment.