diff --git a/CHANGES.md b/CHANGES.md index 2d35f52..81b6f75 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,13 @@ # Changelog +## [1.2.3] - 2023-03-16 + +- Correct download progress bar in logs [#40](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/40) + +- Correct SLURM job status checking [#43](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/43) + +### Fixes + ## [1.2.2] - 2023-03-02 ### Fixes @@ -74,6 +82,8 @@ - Remove unnecessary Optional - Enable ssh_config_file to be `None` +[1.2.3]: https://github.com/AntaresSimulatorTeam/antares-launcher/releases/tag/v1.2.3 + [1.2.2]: https://github.com/AntaresSimulatorTeam/antares-launcher/releases/tag/v1.2.2 [1.2.1]: https://github.com/AntaresSimulatorTeam/antares-launcher/releases/tag/v1.2.1 diff --git a/antareslauncher/__init__.py b/antareslauncher/__init__.py index 560a59c..c3aa744 100644 --- a/antareslauncher/__init__.py +++ b/antareslauncher/__init__.py @@ -9,9 +9,9 @@ # Standard project metadata -__version__ = "1.2.2" +__version__ = "1.2.3" __author__ = "RTE, Antares Web Team" -__date__ = "2023-03-02" +__date__ = "2023-03-16" # noinspection SpellCheckingInspection __credits__ = "(c) Réseau de Transport de l’Électricité (RTE)" diff --git a/antareslauncher/main.py b/antareslauncher/main.py index bd312a4..076c054 100644 --- a/antareslauncher/main.py +++ b/antareslauncher/main.py @@ -179,10 +179,12 @@ def run_with( def verify_connection(connection, display): + # fmt: off if connection.test_connection(): display.show_message(f"SSH connection to {connection.host} established", __name__) else: raise Exception(f"Could not establish SSH connection to {connection.host}") + # fmt: on def get_ssh_config_dict(file_manager, json_ssh_config, ssh_dict: dict): diff --git a/antareslauncher/main_option_parser.py b/antareslauncher/main_option_parser.py index 3aae204..7030e82 100644 --- a/antareslauncher/main_option_parser.py +++ b/antareslauncher/main_option_parser.py @@ -160,7 +160,7 @@ def add_basic_arguments(self) -> MainOptionParser: self.parser.add_argument( "--other-options", dest="other_options", - help='Other options to pass to the antares launcher script', + help="Other options to pass to the antares launcher script", ) self.parser.add_argument( diff --git a/antareslauncher/parameters_reader.py b/antareslauncher/parameters_reader.py index c3f1708..c9189d6 100644 --- a/antareslauncher/parameters_reader.py +++ b/antareslauncher/parameters_reader.py @@ -26,15 +26,15 @@ def __init__(self, json_ssh_conf: Path, yaml_filepath: Path): with open(Path(yaml_filepath)) as yaml_file: self.yaml_content = yaml.load(yaml_file, Loader=yaml.FullLoader) or {} + # fmt: off self._wait_time = self._get_compulsory_value("DEFAULT_WAIT_TIME") self.time_limit = self._get_compulsory_value("DEFAULT_TIME_LIMIT") self.n_cpu = self._get_compulsory_value("DEFAULT_N_CPU") self.studies_in_dir = os.path.expanduser(self._get_compulsory_value("STUDIES_IN_DIR")) self.log_dir = os.path.expanduser(self._get_compulsory_value("LOG_DIR")) self.finished_dir = os.path.expanduser(self._get_compulsory_value("FINISHED_DIR")) - self.ssh_conf_file_is_required = self._get_compulsory_value( - "SSH_CONFIG_FILE_IS_REQUIRED" - ) + self.ssh_conf_file_is_required = self._get_compulsory_value("SSH_CONFIG_FILE_IS_REQUIRED") + # fmt: on alt1, alt2 = self._get_ssh_conf_file_alts() self.ssh_conf_alt1, self.ssh_conf_alt2 = alt1, alt2 @@ -50,7 +50,6 @@ def __init__(self, json_ssh_conf: Path, yaml_filepath: Path): ) def get_parser_parameters(self): - options = ParserParameters( default_wait_time=self._wait_time, default_time_limit=self.time_limit, @@ -65,7 +64,6 @@ def get_parser_parameters(self): return options def get_main_parameters(self) -> MainParameters: - main_parameters = MainParameters( json_dir=self.json_dir, default_json_db_name=self.json_db_name, @@ -108,5 +106,7 @@ def _get_ssh_dict_from_json(self) -> Dict[str, Any]: with open(self.json_ssh_conf) as ssh_connection_json: ssh_dict = json.load(ssh_connection_json) if "private_key_file" in ssh_dict: - ssh_dict["private_key_file"] = os.path.expanduser(ssh_dict["private_key_file"]) + ssh_dict["private_key_file"] = os.path.expanduser( + ssh_dict["private_key_file"] + ) return ssh_dict diff --git a/antareslauncher/remote_environnement/iremote_environment.py b/antareslauncher/remote_environnement/iremote_environment.py deleted file mode 100644 index 9580b67..0000000 --- a/antareslauncher/remote_environnement/iremote_environment.py +++ /dev/null @@ -1,78 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional - -from antareslauncher.remote_environnement.ssh_connection import SshConnection -from antareslauncher.study_dto import StudyDTO - -NOT_SUBMITTED_STATE = "not_submitted" -SUBMITTED_STATE = "submitted" -STARTED_STATE = "started" -FINISHED_STATE = "finished" -FINISHED_WITH_ERROR_STATE = "finished_with_error" - - -class GetJobStateErrorException(Exception): - pass - - -class NoRemoteBaseDirException(Exception): - pass - - -class NoLaunchScriptFoundException(Exception): - def __init__(self, remote_path: str): - msg = f"Launch script not found in remote server: '{remote_path}." - super().__init__(msg) - - -class KillJobErrorException(Exception): - pass - - -class SubmitJobErrorException(Exception): - pass - - -class GetJobStateOutputException(Exception): - pass - - -class IRemoteEnvironment(ABC): - """Class that represents the remote environment""" - - def __init__(self, _connection: SshConnection): - self.connection = _connection - self.remote_base_path = None - - @abstractmethod - def get_queue_info(self): - raise NotImplementedError - - @abstractmethod - def kill_remote_job(self, job_id): - raise NotImplementedError - - @abstractmethod - def upload_file(self, src): - raise NotImplementedError - - @abstractmethod - def download_logs(self, study: StudyDTO) -> List[Path]: - raise NotImplementedError - - @abstractmethod - def download_final_zip(self, study: StudyDTO) -> Optional[Path]: - raise NotImplementedError - - @abstractmethod - def clean_remote_server(self, study: StudyDTO) -> bool: - raise NotImplementedError - - @abstractmethod - def submit_job(self, _study: StudyDTO): - raise NotImplementedError - - @abstractmethod - def get_job_state_flags(self, _study: StudyDTO) -> [bool, bool, bool]: - raise NotImplementedError diff --git a/antareslauncher/remote_environnement/remote_environment_with_slurm.py b/antareslauncher/remote_environnement/remote_environment_with_slurm.py index bf3b15c..ab23197 100644 --- a/antareslauncher/remote_environnement/remote_environment_with_slurm.py +++ b/antareslauncher/remote_environnement/remote_environment_with_slurm.py @@ -1,18 +1,13 @@ +import enum import getpass +import re +import shlex import socket +import textwrap import time from pathlib import Path, PurePosixPath from typing import List, Optional -from antareslauncher.remote_environnement.iremote_environment import ( - GetJobStateErrorException, - GetJobStateOutputException, - IRemoteEnvironment, - KillJobErrorException, - NoLaunchScriptFoundException, - NoRemoteBaseDirException, - SubmitJobErrorException, -) from antareslauncher.remote_environnement.slurm_script_features import ( ScriptParametersDTO, SlurmScriptFeatures, @@ -20,14 +15,111 @@ from antareslauncher.remote_environnement.ssh_connection import SshConnection from antareslauncher.study_dto import StudyDTO -SLURM_STATE_FAILED = "FAILED" -SLURM_STATE_TIMEOUT = "TIMEOUT" -SLURM_STATE_CANCELLED = "CANCELLED" -SLURM_STATE_COMPLETED = "COMPLETED" -SLURM_STATE_RUNNING = "RUNNING" +class RemoteEnvBaseError(Exception): + """Base class of the `RemoteEnvironmentWithSlurm` exceptions""" + + +class GetJobStateError(RemoteEnvBaseError): + def __init__(self, job_id: int, job_name: str, reason: str): + msg = ( + f"Unable to retrieve the status of the SLURM job {job_id}" + f" (study job '{job_name})." + f" {reason}" + ) + super().__init__(msg) + + +class JobNotFoundError(RemoteEnvBaseError): + def __init__(self, job_id: int, job_name: str): + msg = ( + f"Unable to retrieve the status of the SLURM job {job_id}" + f" (study job '{job_name}): Job not found." + ) + super().__init__(msg) + + +class NoRemoteBaseDirError(RemoteEnvBaseError): + def __init__(self, remote_base_path: PurePosixPath): + msg = f"Unable to create the remote base directory: '{remote_base_path}" + super().__init__(msg) + + +class NoLaunchScriptFoundError(RemoteEnvBaseError): + def __init__(self, remote_path: str): + msg = f"Launch script not found in remote server: '{remote_path}." + super().__init__(msg) + + +class KillJobError(RemoteEnvBaseError): + def __init__(self, job_id: int, reason: str): + msg = f"Unable to kill the SLURM job {job_id}: {reason}" + super().__init__(msg) + + +class SubmitJobError(RemoteEnvBaseError): + def __init__(self, study_name: str, reason: str): + msg = f"Unable to sumit the Antares Job {study_name} to the SLURM: {reason}" + super().__init__(msg) + + +class JobStateCodes(enum.Enum): + # noinspection SpellCheckingInspection + """ + The `sacct` command returns the status of each task in a column named State or JobState. + The possible values for this column depend on the cluster management system + you are using, but here are some of the most common values: + """ + # Job terminated due to launch failure, typically due to a hardware failure + # (e.g. unable to boot the node or block and the job can not be requeued). + BOOT_FAIL = "BOOT_FAIL" -class RemoteEnvironmentWithSlurm(IRemoteEnvironment): + # Job was explicitly cancelled by the user or system administrator. + # The job may or may not have been initiated. + CANCELLED = "CANCELLED" + + # Job has terminated all processes on all nodes with an exit code of zero. + COMPLETED = "COMPLETED" + + # Job terminated on deadline. + DEADLINE = "DEADLINE" + + # Job terminated with non-zero exit code or other failure condition. + FAILED = "FAILED" + + # Job terminated due to failure of one or more allocated nodes. + NODE_FAIL = "NODE_FAIL" + + # Job experienced out of memory error. + OUT_OF_MEMORY = "OUT_OF_MEMORY" + + # Job is awaiting resource allocation. + PENDING = "PENDING" + + # Job terminated due to preemption. + PREEMPTED = "PREEMPTED" + + # Job currently has an allocation. + RUNNING = "RUNNING" + + # Job was requeued. + REQUEUED = "REQUEUED" + + # Job is about to change size. + RESIZING = "RESIZING" + + # Sibling was removed from cluster due to other cluster starting the job. + REVOKED = "REVOKED" + + # Job has an allocation, but execution has been suspended and + # CPUs have been released for other jobs. + SUSPENDED = "SUSPENDED" + + # Job terminated upon reaching its time limit. + TIMEOUT = "TIMEOUT" + + +class RemoteEnvironmentWithSlurm: """Class that represents the remote environment""" def __init__( @@ -35,33 +127,27 @@ def __init__( _connection: SshConnection, slurm_script_features: SlurmScriptFeatures, ): - super(RemoteEnvironmentWithSlurm, self).__init__(_connection=_connection) + self.connection = _connection self.slurm_script_features = slurm_script_features self.remote_base_path: str = "" self._initialise_remote_path() self._check_remote_script() def _initialise_remote_path(self): - self._set_remote_base_path() - if not self.connection.make_dir(self.remote_base_path): - raise NoRemoteBaseDirException - - def _set_remote_base_path(self): - remote_home_dir = self.connection.home_dir - self.remote_base_path = ( - str(remote_home_dir) - + "/REMOTE_" - + getpass.getuser() - + "_" - + socket.gethostname() + remote_home_dir = PurePosixPath(self.connection.home_dir) + remote_base_path = remote_home_dir.joinpath( + f"REMOTE_{getpass.getuser()}_{socket.gethostname()}" ) + self.remote_base_path = str(remote_base_path) + if not self.connection.make_dir(self.remote_base_path): + raise NoRemoteBaseDirError(remote_base_path) def _check_remote_script(self): remote_antares_script = self.slurm_script_features.solver_script_path if not self.connection.check_file_not_empty(remote_antares_script): - raise NoLaunchScriptFoundException(remote_antares_script) + raise NoLaunchScriptFoundError(remote_antares_script) - def get_queue_info(self): + def get_queue_info(self) -> str: """This function return the information from: squeue -u run-antares Returns: @@ -70,25 +156,23 @@ def get_queue_info(self): username = self.connection.username command = f"squeue -u {username} --Format=name:40,state:12,starttime:22,TimeUsed:12,timelimit:12" output, error = self.connection.execute_command(command) - if error: - return error - else: - return f"{username}@{self.connection.host}\n" + output + return error or f"{username}@{self.connection.host}\n{output}" - def kill_remote_job(self, job_id): + def kill_remote_job(self, job_id: int) -> None: """Kills job with ID Args: - job_id: Id of the job to kill + job_id: ID of the job to kill Raises: KillJobErrorException if the command raises an error """ - + # noinspection SpellCheckingInspection command = f"scancel {job_id}" _, error = self.connection.execute_command(command) if error: - raise KillJobErrorException + reason = f"The command [{command}] failed: {error}" + raise KillJobError(job_id, reason) @staticmethod def convert_time_limit_from_seconds_to_minutes(time_limit_seconds): @@ -102,9 +186,7 @@ def convert_time_limit_from_seconds_to_minutes(time_limit_seconds): """ minimum_duration_in_minutes = 1 time_limit_minutes = int(time_limit_seconds / 60) - if time_limit_minutes < minimum_duration_in_minutes: - time_limit_minutes = minimum_duration_in_minutes - return time_limit_minutes + return max(time_limit_minutes, minimum_duration_in_minutes) def compose_launch_command(self, script_params: ScriptParametersDTO): return self.slurm_script_features.compose_launch_command( @@ -138,100 +220,128 @@ def submit_job(self, my_study: StudyDTO): other_options=my_study.other_options or "", ) command = self.compose_launch_command(script_params) + output, error = self.connection.execute_command(command) if error: - raise SubmitJobErrorException - job_id = self._get_jobid_from_output_of_submit_command(output) - return job_id + reason = f"The command [{command}] failed: {error}" + raise SubmitJobError(my_study.name, reason) - @staticmethod - def _get_jobid_from_output_of_submit_command(output): - job_id = None - # SLURM squeue command returns f'Submitted {job_id}' if successful - stdout_list = str(output).split() - if stdout_list and stdout_list[0] == "Submitted": - job_id = int(stdout_list[-1]) - return job_id - - @staticmethod - def get_advancement_flags_from_state(state): - """Converts the slurm state of the job to 3 boolean values + # should match "Submitted batch job 123456" + if match := re.match(r"Submitted.*?(?P\d+)", output, flags=re.IGNORECASE): + return int(match["job_id"]) - Args: - state: The job state string as obtained from Slurm + reason = ( + f"The command [{command}] return an non-parsable output:" + f"\n{textwrap.indent(output, 'OUTPUT> ')}" + ) + raise SubmitJobError(my_study.name, reason) - Returns: - started, finished, with_error: the booleans representing the advancement of the slurm_job + def get_job_state_flags( + self, + study, + *, + attempts=5, + sleep_time=0.5, + ) -> [bool, bool, bool]: """ - - if state == SLURM_STATE_RUNNING: - started = True - finished = False - with_error = False - elif state == SLURM_STATE_COMPLETED: - started = True - finished = True - with_error = False - elif ( - state.startswith(SLURM_STATE_CANCELLED) - or state.startswith(SLURM_STATE_TIMEOUT) - or state == SLURM_STATE_FAILED - ): - started = True - with_error = True - finished = True - # PENDING - else: - started = False - finished = False - with_error = False - - return started, finished, with_error - - def _check_job_state(self, job_id: int): - """Checks the slurm state of a study + Retrieves the current state of a SLURM job with the given job ID and name. Args: - job_id: The id of the job to be checked + study: The study to check. + attempts: The number of attempts to make to retrieve the job state. + sleep_time: The amount of time to wait between attempts, in seconds. Returns: - The slurm job state string if the server correctly returned id + started, finished, with_error: booleans representing the advancement of the SLURM job Raises: - GetJobStateErrorException if the job_state has not been obtained + GetJobStateErrorException: If the job state cannot be retrieved after + the specified number of attempts. """ - command = self._compose_command_to_get_state_as_one_word(job_id) - max_number_of_tries = 5 - seconds_to_wait = 0.5 - for _ in range(max_number_of_tries): + job_state: JobStateCodes = self._retrieve_job_state( + study.job_id, + study.name, + attempts=attempts, + sleep_time=sleep_time, + ) + return { + # JobStateCodes ------ started, finished, with_error + JobStateCodes.BOOT_FAIL: (False, False, False), + JobStateCodes.CANCELLED: (True, True, True), + JobStateCodes.COMPLETED: (True, True, False), + JobStateCodes.DEADLINE: (True, True, True), # similar to timeout + JobStateCodes.FAILED: (True, True, True), + JobStateCodes.NODE_FAIL: (True, True, True), + JobStateCodes.OUT_OF_MEMORY: (True, True, True), + JobStateCodes.PENDING: (False, False, False), + JobStateCodes.PREEMPTED: (False, False, False), + JobStateCodes.RUNNING: (True, False, False), + JobStateCodes.REQUEUED: (False, False, False), + JobStateCodes.RESIZING: (False, False, False), + JobStateCodes.REVOKED: (False, False, False), + JobStateCodes.SUSPENDED: (True, False, False), + JobStateCodes.TIMEOUT: (True, True, True), + }[job_state] + + def _retrieve_job_state( + self, + job_id: int, + job_name: str, + *, + attempts: int = 5, + sleep_time: float = 0.5, + ) -> JobStateCodes: + # Construct the command line arguments used to check the jobs state. + # See the man page: https://slurm.schedmd.com/sacct.html + # noinspection SpellCheckingInspection + delimiter = "," + # noinspection SpellCheckingInspection + args = [ + "sacct", + f"--jobs={job_id}", + f"--name={job_name}", + "--format=JobID,JobName,State", + "--parsable2", + f"--delimiter={delimiter}", + "--noheader", + ] + command = " ".join(shlex.quote(arg) for arg in args) + + # Makes several attempts to get the job state. + # I don't really know why, but it's better to reproduce the old behavior. + output: Optional[str] + last_error: str = "" + for attempt in range(attempts): output, error = self.connection.execute_command(command) - if error: - raise GetJobStateErrorException - stdout = str(output).split() - if stdout: - return stdout[0] - time.sleep(seconds_to_wait) - - raise GetJobStateOutputException - - @staticmethod - def _compose_command_to_get_state_as_one_word(job_id): - return ( - f"sacct -j {int(job_id)} -n --format=state | head -1 " - + "| awk -F\" \" '{print $1}'" + if output is not None: + break + last_error = error + time.sleep(sleep_time) + else: + reason = ( + f" The command [{command}] failed after {attempts} attempts:" + f" {last_error}" + ) + raise GetJobStateError(job_id, job_name, reason) + + # When the output is empty it mean that the job is not found + if not output.strip(): + return JobStateCodes.PENDING + + # Parse the output to extract the job state. + # The output must be a CSV-like string without header row. + for line in output.splitlines(): + parts = line.split(delimiter) + if len(parts) == 3: + out_job_id, out_job_name, out_state = parts + if out_job_id == str(job_id) and out_job_name == job_name: + return JobStateCodes(out_state) + + reason = ( + f" The command [{command}] return an non-parsable output:" + f"\n{textwrap.indent(output, 'OUTPUT> ')}" ) - - def get_job_state_flags(self, study) -> [bool, bool, bool]: - """Checks the job state of a submitted study and converts it to flags - - Args: - study: The study data transfer object - - Returns: - started, finished, with_error: The booleans representing the advancement of the slurm_job - """ - job_state = self._check_job_state(study.job_id) - return self.get_advancement_flags_from_state(job_state) + raise GetJobStateError(job_id, job_name, reason) def upload_file(self, src): """Uploads a file to the remote server @@ -335,9 +445,9 @@ def clean_remote_server(self, study: StudyDTO): Returns: True if all files have been removed, False otherwise """ - return_flag = False - if not study.remote_server_is_clean: - return_flag = self.remove_remote_final_zipfile( - study - ) & self.remove_input_zipfile(study) - return return_flag + return ( + False + if study.remote_server_is_clean + else self.remove_remote_final_zipfile(study) + & self.remove_input_zipfile(study) + ) diff --git a/antareslauncher/remote_environnement/ssh_connection.py b/antareslauncher/remote_environnement/ssh_connection.py index 6ce7a0b..30ec32f 100644 --- a/antareslauncher/remote_environnement/ssh_connection.py +++ b/antareslauncher/remote_environnement/ssh_connection.py @@ -4,9 +4,8 @@ import socket import stat import time -from os.path import expanduser from pathlib import Path, PurePosixPath -from typing import Tuple, List +from typing import List, Tuple import paramiko @@ -50,37 +49,83 @@ def __init__(self, hostname: str, port: int, username: str): class DownloadMonitor: + """ + A class that monitors the progress of a download. + + Args: + total_size: The total size of the file being downloaded (in bytes). + msg: The message to display while downloading. Defaults to "Downloading...". + logger: A logger object for logging progress messages. Defaults to `None`. + + Attributes: + total_size: The total size of the file being downloaded. + msg: The message to display while downloading. + logger: A logger object for logging progress messages. + """ + def __init__(self, total_size: int, msg: str = "", logger=None) -> None: self.total_size = total_size self.msg = msg or "Downloading..." self.logger = logger or logging.getLogger(__name__) - self._start_time = time.time() - self._size = 0 + # The start time of the download use to calculate ETA + self._start_time: float = time.time() + # The amount of data (in bytes) that has been transferred so far + # for each file + self._transferred: int = 0 + # The total amount of data that has been transferred so far (in bytes) + self._accumulated: int = 0 + # The progress of the download, as a percentage (0-100) self._progress: int = 0 def __call__(self, transferred: int, subtotal: int) -> None: + """ + Called when data is transferred during the download. + Updates the progress and logs a message if progress has changed. + + Args: + transferred: The amount of data transferred in the current transfer (in bytes). + subtotal: The total amount of data transferred so far in this download (in bytes). + """ if not self.total_size: return - self._size += transferred + self._transferred = transferred # Avoid emitting too many messages - rate = self._size / self.total_size + rate = (self._accumulated + self._transferred) / self.total_size if self._progress != int(rate * 10): self._progress = int(rate * 10) self.logger.info(str(self)) - def __str__(self): - rate = self._size / self.total_size - if self._size: + def __str__(self) -> str: + """ + Returns a string representation of the current progress. + """ + total_transferred = self._accumulated + self._transferred + rate = total_transferred / self.total_size + if total_transferred: # Calculate ETA and progress rate # 0 curr_size total_size # |----------->|--------------------------->| # 0 duration total_duration # 0% percent 100% duration = time.time() - self._start_time - eta = int(duration * (self.total_size - self._size) / self._size) + eta = int( + duration * (self.total_size - total_transferred) / total_transferred + ) return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]" return f"{self.msg:<20} ETA: ??? [{rate:.0%}]" + def accumulate(self): + """ + Accumulates the quantity transferred by the previous transfer and + the current transfer. + + This function is used to keep track of the total quantity of data transferred + across multiple transfers. The accumulated quantity is calculated by adding + the quantity transferred in the current transfer to the quantity transferred + in the previous transfer. + """ + self._accumulated += self._transferred + class SshConnection: """Class to _connect to remote server""" @@ -246,6 +291,7 @@ def execute_command(self, command: str): self.logger.info(f"Executing command on remote server: {command}") try: with self.ssh_client() as client: + self.logger.info(f"Running SSH command [{command}]...") stdin, stdout, stderr = client.exec_command(command, timeout=30) output = stdout.read().decode("utf-8") error = stderr.read().decode("utf-8") @@ -342,7 +388,9 @@ def download_files( The paths of the downloaded files on the local filesystem. """ try: - return self._download_files(src_dir, dst_dir, (pattern,) + patterns, remove=remove) + return self._download_files( + src_dir, dst_dir, (pattern,) + patterns, remove=remove + ) except TimeoutError as exc: self.logger.error(f"Timeout: {exc}", exc_info=True) return [] @@ -396,6 +444,7 @@ def _download_files( count = len(files_to_download) for no, filename in enumerate(files_to_download, 1): monitor.msg = f"Downloading '{filename}' [{no}/{count}]..." + monitor.accumulate() src_path = src_dir.joinpath(filename) dst_path = dst_dir.joinpath(filename) sftp.get(str(src_path), str(dst_path), monitor) @@ -468,7 +517,7 @@ def check_file_not_empty(self, file_path): return result_flag def make_dir(self, dir_path): - """Creates a remote directory if it does not exists yet + """Creates a remote directory if it does not exist yet Args: dir_path: Remote path of the directory that will be created @@ -477,7 +526,7 @@ def make_dir(self, dir_path): True if path exists or the directory is successfully created, False otherwise Raises: - IOError if the path exists and it is a file + IOError if the path exists, and it is a file """ try: with self.ssh_client() as client: diff --git a/antareslauncher/use_cases/check_remote_queue/check_queue_controller.py b/antareslauncher/use_cases/check_remote_queue/check_queue_controller.py index 2480ae6..1173964 100644 --- a/antareslauncher/use_cases/check_remote_queue/check_queue_controller.py +++ b/antareslauncher/use_cases/check_remote_queue/check_queue_controller.py @@ -1,9 +1,7 @@ from dataclasses import dataclass from antareslauncher.data_repo.idata_repo import IDataRepo -from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import ( - SlurmQueueShow, -) +from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow from antareslauncher.use_cases.retrieve.state_updater import StateUpdater diff --git a/antareslauncher/use_cases/check_remote_queue/slurm_queue_show.py b/antareslauncher/use_cases/check_remote_queue/slurm_queue_show.py index 7cb3f70..d7e2074 100644 --- a/antareslauncher/use_cases/check_remote_queue/slurm_queue_show.py +++ b/antareslauncher/use_cases/check_remote_queue/slurm_queue_show.py @@ -1,17 +1,17 @@ from dataclasses import dataclass from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) @dataclass class SlurmQueueShow: - env: IRemoteEnvironment + env: RemoteEnvironmentWithSlurm display: IDisplay def run(self): """Displays all the jobs un the slurm queue""" message = "Checking remote server queue\n" + self.env.get_queue_info() - self.display.show_message(message, __name__ + "." + __class__.__name__) + self.display.show_message(message, f"{__name__}.{__class__.__name__}") diff --git a/antareslauncher/use_cases/create_list/study_list_composer.py b/antareslauncher/use_cases/create_list/study_list_composer.py index 4502e95..3b254f5 100644 --- a/antareslauncher/use_cases/create_list/study_list_composer.py +++ b/antareslauncher/use_cases/create_list/study_list_composer.py @@ -5,7 +5,7 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.study_dto import StudyDTO, Modes +from antareslauncher.study_dto import Modes, StudyDTO @dataclass diff --git a/antareslauncher/use_cases/kill_job/job_kill_controller.py b/antareslauncher/use_cases/kill_job/job_kill_controller.py index 3a72fbd..fd7fcad 100644 --- a/antareslauncher/use_cases/kill_job/job_kill_controller.py +++ b/antareslauncher/use_cases/kill_job/job_kill_controller.py @@ -2,14 +2,14 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) @dataclass class JobKillController: - env: IRemoteEnvironment + env: RemoteEnvironmentWithSlurm display: IDisplay repo: IDataRepo diff --git a/antareslauncher/use_cases/launch/launch_controller.py b/antareslauncher/use_cases/launch/launch_controller.py index e86e328..730bfd6 100644 --- a/antareslauncher/use_cases/launch/launch_controller.py +++ b/antareslauncher/use_cases/launch/launch_controller.py @@ -2,15 +2,13 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.launch.study_submitter import StudySubmitter from antareslauncher.use_cases.launch.study_zip_cleaner import StudyZipCleaner -from antareslauncher.use_cases.launch.study_zip_uploader import ( - StudyZipfileUploader, -) +from antareslauncher.use_cases.launch.study_zip_uploader import StudyZipfileUploader from antareslauncher.use_cases.launch.study_zipper import StudyZipper @@ -61,7 +59,7 @@ class LaunchController: def __init__( self, repo: IDataRepo, - env: IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, file_manager: FileManager, display: IDisplay, ): diff --git a/antareslauncher/use_cases/launch/study_submitter.py b/antareslauncher/use_cases/launch/study_submitter.py index 20d0237..91c62b9 100644 --- a/antareslauncher/use_cases/launch/study_submitter.py +++ b/antareslauncher/use_cases/launch/study_submitter.py @@ -2,8 +2,8 @@ from pathlib import Path from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO @@ -13,7 +13,7 @@ class FailedSubmissionException(Exception): class StudySubmitter(object): - def __init__(self, env: IRemoteEnvironment, display: IDisplay): + def __init__(self, env: RemoteEnvironmentWithSlurm, display: IDisplay): self.env = env self.display = display self._current_study: StudyDTO = None diff --git a/antareslauncher/use_cases/launch/study_zip_uploader.py b/antareslauncher/use_cases/launch/study_zip_uploader.py index 09ae90b..b98fabe 100644 --- a/antareslauncher/use_cases/launch/study_zip_uploader.py +++ b/antareslauncher/use_cases/launch/study_zip_uploader.py @@ -2,14 +2,14 @@ from pathlib import Path from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO class StudyZipfileUploader: - def __init__(self, env: IRemoteEnvironment, display: IDisplay): + def __init__(self, env: RemoteEnvironmentWithSlurm, display: IDisplay): self.env = env self.display = display self._current_study: StudyDTO = None diff --git a/antareslauncher/use_cases/retrieve/clean_remote_server.py b/antareslauncher/use_cases/retrieve/clean_remote_server.py index 8f158f1..744e9f2 100644 --- a/antareslauncher/use_cases/retrieve/clean_remote_server.py +++ b/antareslauncher/use_cases/retrieve/clean_remote_server.py @@ -2,7 +2,9 @@ from pathlib import Path from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO @@ -13,7 +15,7 @@ class RemoteServerNotCleanException(Exception): class RemoteServerCleaner: def __init__( self, - env: iremote_environment.IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, display: IDisplay, ): self._display = display diff --git a/antareslauncher/use_cases/retrieve/download_final_zip.py b/antareslauncher/use_cases/retrieve/download_final_zip.py index 854e3e5..ed5c96a 100644 --- a/antareslauncher/use_cases/retrieve/download_final_zip.py +++ b/antareslauncher/use_cases/retrieve/download_final_zip.py @@ -1,7 +1,9 @@ import copy from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO @@ -12,7 +14,7 @@ class FinalZipNotDownloadedException(Exception): class FinalZipDownloader(object): def __init__( self, - env: iremote_environment.IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, display: IDisplay, ): self._env = env diff --git a/antareslauncher/use_cases/retrieve/log_downloader.py b/antareslauncher/use_cases/retrieve/log_downloader.py index 5eb5796..22248d9 100644 --- a/antareslauncher/use_cases/retrieve/log_downloader.py +++ b/antareslauncher/use_cases/retrieve/log_downloader.py @@ -3,14 +3,16 @@ from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO class LogDownloader: def __init__( self, - env: iremote_environment.IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, file_manager: FileManager, display: IDisplay, ): diff --git a/antareslauncher/use_cases/retrieve/retrieve_controller.py b/antareslauncher/use_cases/retrieve/retrieve_controller.py index 258e425..ca67f11 100644 --- a/antareslauncher/use_cases/retrieve/retrieve_controller.py +++ b/antareslauncher/use_cases/retrieve/retrieve_controller.py @@ -2,16 +2,12 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.remote_environnement import iremote_environment -from antareslauncher.use_cases.retrieve.clean_remote_server import ( - RemoteServerCleaner, -) -from antareslauncher.use_cases.retrieve.download_final_zip import ( - FinalZipDownloader, -) -from antareslauncher.use_cases.retrieve.final_zip_extractor import ( - FinalZipExtractor, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) +from antareslauncher.use_cases.retrieve.clean_remote_server import RemoteServerCleaner +from antareslauncher.use_cases.retrieve.download_final_zip import FinalZipDownloader +from antareslauncher.use_cases.retrieve.final_zip_extractor import FinalZipExtractor from antareslauncher.use_cases.retrieve.log_downloader import LogDownloader from antareslauncher.use_cases.retrieve.state_updater import StateUpdater from antareslauncher.use_cases.retrieve.study_retriever import StudyRetriever @@ -21,7 +17,7 @@ class RetrieveController: def __init__( self, repo: IDataRepo, - env: iremote_environment.IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, file_manager: FileManager, display: IDisplay, state_updater: StateUpdater, diff --git a/antareslauncher/use_cases/retrieve/state_updater.py b/antareslauncher/use_cases/retrieve/state_updater.py index 3c2ae43..41b0677 100644 --- a/antareslauncher/use_cases/retrieve/state_updater.py +++ b/antareslauncher/use_cases/retrieve/state_updater.py @@ -2,17 +2,18 @@ from typing import List from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO class StateUpdater: def __init__( self, - env: iremote_environment.IRemoteEnvironment, + env: RemoteEnvironmentWithSlurm, display: IDisplay, ): - self._env = env self._display = display self._current_study: StudyDTO = None diff --git a/antareslauncher/use_cases/retrieve/study_retriever.py b/antareslauncher/use_cases/retrieve/study_retriever.py index 7a24cd0..48fa7c6 100644 --- a/antareslauncher/use_cases/retrieve/study_retriever.py +++ b/antareslauncher/use_cases/retrieve/study_retriever.py @@ -1,14 +1,8 @@ from antareslauncher.data_repo.data_reporter import DataReporter from antareslauncher.study_dto import StudyDTO -from antareslauncher.use_cases.retrieve.clean_remote_server import ( - RemoteServerCleaner, -) -from antareslauncher.use_cases.retrieve.download_final_zip import ( - FinalZipDownloader, -) -from antareslauncher.use_cases.retrieve.final_zip_extractor import ( - FinalZipExtractor, -) +from antareslauncher.use_cases.retrieve.clean_remote_server import RemoteServerCleaner +from antareslauncher.use_cases.retrieve.download_final_zip import FinalZipDownloader +from antareslauncher.use_cases.retrieve.final_zip_extractor import FinalZipExtractor from antareslauncher.use_cases.retrieve.log_downloader import LogDownloader from antareslauncher.use_cases.retrieve.state_updater import StateUpdater diff --git a/tests/integration/test_integration_check_queue_controller.py b/tests/integration/test_integration_check_queue_controller.py index c453efe..91e781f 100644 --- a/tests/integration/test_integration_check_queue_controller.py +++ b/tests/integration/test_integration_check_queue_controller.py @@ -20,7 +20,7 @@ class TestIntegrationCheckQueueController: def setup_method(self): - self.connection_mock = mock.Mock() + self.connection_mock = mock.Mock(home_dir="path/to/home") self.connection_mock.username = "username" self.connection_mock.execute_command = mock.Mock(return_value=("", "")) slurm_script_features = SlurmScriptFeatures("slurm_script_path") diff --git a/tests/integration/test_integration_job_kill_controller.py b/tests/integration/test_integration_job_kill_controller.py index df8eeee..e8c8731 100644 --- a/tests/integration/test_integration_job_kill_controller.py +++ b/tests/integration/test_integration_job_kill_controller.py @@ -16,7 +16,8 @@ class TestIntegrationJobKilController: def setup_method(self): slurm_script_features = SlurmScriptFeatures("slurm_script_path") - env = RemoteEnvironmentWithSlurm(mock.Mock(), slurm_script_features) + connection = mock.Mock(home_dir="path/to/home") + env = RemoteEnvironmentWithSlurm(connection, slurm_script_features) self.job_kill_controller = JobKillController(env, mock.Mock(), repo=mock.Mock()) @pytest.mark.integration_test diff --git a/tests/integration/test_integration_launch_controller.py b/tests/integration/test_integration_launch_controller.py index b7495b7..378e5eb 100644 --- a/tests/integration/test_integration_launch_controller.py +++ b/tests/integration/test_integration_launch_controller.py @@ -19,7 +19,7 @@ class TestIntegrationLaunchController: @pytest.fixture(scope="function") def launch_controller(self): - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") slurm_script_features = SlurmScriptFeatures("slurm_script_path") environment = RemoteEnvironmentWithSlurm(connection, slurm_script_features) study1 = mock.Mock() diff --git a/tests/unit/launcher/test_launch_controller.py b/tests/unit/launcher/test_launch_controller.py index 451db33..13cc4a9 100644 --- a/tests/unit/launcher/test_launch_controller.py +++ b/tests/unit/launcher/test_launch_controller.py @@ -5,6 +5,7 @@ import pytest +import antareslauncher.remote_environnement.remote_environment_with_slurm import antareslauncher.use_cases.launch.study_submitter import antareslauncher.use_cases.launch.study_zip_uploader from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb @@ -12,24 +13,21 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.remote_environnement import iremote_environment -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.launch import launch_controller from antareslauncher.use_cases.launch.launch_controller import StudyLauncher from antareslauncher.use_cases.launch.study_submitter import StudySubmitter from antareslauncher.use_cases.launch.study_zip_cleaner import StudyZipCleaner -from antareslauncher.use_cases.launch.study_zip_uploader import ( - StudyZipfileUploader, -) +from antareslauncher.use_cases.launch.study_zip_uploader import StudyZipfileUploader from antareslauncher.use_cases.launch.study_zipper import StudyZipper class TestStudyLauncher: def setup_method(self): - env = mock.Mock(spec_set=IRemoteEnvironment) + env = mock.Mock(spec_set=RemoteEnvironmentWithSlurm) display = mock.Mock(spec_set=IDisplay) file_manager = mock.Mock(spec_set=FileManager) repo = mock.Mock(spec_set=IDataRepo) @@ -84,7 +82,7 @@ def my_launch_controller(self): expected_study = StudyDTO(path="hello") list_of_studies = [copy.deepcopy(expected_study)] self.data_repo.get_list_of_studies = mock.Mock(return_value=list_of_studies) - remote_env_mock = mock.Mock(spec=iremote_environment.IRemoteEnvironment) + remote_env_mock = mock.Mock(spec=RemoteEnvironmentWithSlurm) file_manager_mock = mock.Mock() my_launcher = launch_controller.LaunchController( self.data_repo, remote_env_mock, file_manager_mock, self.display @@ -96,7 +94,7 @@ def test_with_one_study_the_compressor_is_called_once(self): list_of_studies = [my_study] self.data_repo.get_list_of_studies = mock.Mock(return_value=list_of_studies) - remote_env_mock = mock.Mock(spec=iremote_environment.IRemoteEnvironment) + remote_env_mock = mock.Mock(spec=RemoteEnvironmentWithSlurm) file_manager = mock.Mock(spec_set=FileManager) file_manager.zip_dir_excluding_subdir = mock.Mock() diff --git a/tests/unit/launcher/test_submitter.py b/tests/unit/launcher/test_submitter.py index 735c3c2..fa1384c 100644 --- a/tests/unit/launcher/test_submitter.py +++ b/tests/unit/launcher/test_submitter.py @@ -5,8 +5,8 @@ import antareslauncher.use_cases from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.launch.study_submitter import StudySubmitter @@ -14,7 +14,7 @@ class TestStudySubmitter: def setup_method(self): - self.remote_env = mock.Mock(spec_set=IRemoteEnvironment) + self.remote_env = mock.Mock(spec_set=RemoteEnvironmentWithSlurm) self.display_mock = mock.Mock(spec_set=IDisplay) self.study_submitter = StudySubmitter(self.remote_env, self.display_mock) diff --git a/tests/unit/launcher/test_zip_uploader.py b/tests/unit/launcher/test_zip_uploader.py index a0c300e..5cb013e 100644 --- a/tests/unit/launcher/test_zip_uploader.py +++ b/tests/unit/launcher/test_zip_uploader.py @@ -5,19 +5,19 @@ import pytest from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.launch.study_zip_uploader import ( - StudyZipfileUploader, FailedUploadException, + StudyZipfileUploader, ) class TestZipfileUploader: def setup_method(self): - self.remote_env = mock.Mock(spec_set=IRemoteEnvironment) + self.remote_env = mock.Mock(spec_set=RemoteEnvironmentWithSlurm) self.display_mock = mock.Mock(spec_set=IDisplay) self.study_uploader = StudyZipfileUploader(self.remote_env, self.display_mock) diff --git a/tests/unit/retriever/test_download_final_zip.py b/tests/unit/retriever/test_download_final_zip.py index c8de803..b477c7f 100644 --- a/tests/unit/retriever/test_download_final_zip.py +++ b/tests/unit/retriever/test_download_final_zip.py @@ -6,8 +6,8 @@ import pytest from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.retrieve.download_final_zip import ( @@ -18,7 +18,7 @@ class TestFinalZipDownloader: def setup_method(self): - self.remote_env = mock.Mock(spec_set=IRemoteEnvironment) + self.remote_env = mock.Mock(spec_set=RemoteEnvironmentWithSlurm) self.display_mock = mock.Mock(spec_set=IDisplay) self.final_zip_downloader = FinalZipDownloader( self.remote_env, self.display_mock @@ -103,7 +103,9 @@ def test_remote_env_is_called_if_final_zip_not_yet_downloaded( self, successfully_finished_zip_study ): final_zipfile_path = "results.zip" - self.remote_env.download_final_zip = mock.Mock(return_value=Path(final_zipfile_path)) + self.remote_env.download_final_zip = mock.Mock( + return_value=Path(final_zipfile_path) + ) new_study = self.final_zip_downloader.download(successfully_finished_zip_study) diff --git a/tests/unit/retriever/test_log_downloader.py b/tests/unit/retriever/test_log_downloader.py index 9c1bb26..2f729b8 100644 --- a/tests/unit/retriever/test_log_downloader.py +++ b/tests/unit/retriever/test_log_downloader.py @@ -4,15 +4,18 @@ import pytest +import antareslauncher.remote_environnement.remote_environment_with_slurm from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.retrieve.log_downloader import LogDownloader class TestLogDownloader: def setup_method(self): - self.remote_env_mock = mock.Mock(spec=iremote_environment.IRemoteEnvironment) + self.remote_env_mock = mock.Mock(spec=RemoteEnvironmentWithSlurm) self.file_manager = mock.Mock() self.display_mock = mock.Mock(spec_set=IDisplay) self.log_downloader = LogDownloader( diff --git a/tests/unit/retriever/test_retrieve_controller.py b/tests/unit/retriever/test_retrieve_controller.py index 8a75e18..2a970a5 100644 --- a/tests/unit/retriever/test_retrieve_controller.py +++ b/tests/unit/retriever/test_retrieve_controller.py @@ -5,18 +5,19 @@ import pytest import antareslauncher +import antareslauncher.remote_environnement.remote_environment_with_slurm from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment -from antareslauncher.study_dto import StudyDTO -from antareslauncher.use_cases.retrieve.retrieve_controller import ( - RetrieveController, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) +from antareslauncher.study_dto import StudyDTO +from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController from antareslauncher.use_cases.retrieve.state_updater import StateUpdater class TestRetrieveController: def setup_method(self): - self.remote_env_mock = mock.Mock(spec=iremote_environment.IRemoteEnvironment) + self.remote_env_mock = mock.Mock(spec=RemoteEnvironmentWithSlurm) self.file_manager = mock.Mock() self.data_repo = mock.Mock() self.display = mock.Mock() diff --git a/tests/unit/retriever/test_server_cleaner.py b/tests/unit/retriever/test_server_cleaner.py index 6d021bc..4e0c77f 100644 --- a/tests/unit/retriever/test_server_cleaner.py +++ b/tests/unit/retriever/test_server_cleaner.py @@ -4,8 +4,11 @@ import pytest +import antareslauncher.remote_environnement.remote_environment_with_slurm from antareslauncher.display.idisplay import IDisplay -from antareslauncher.remote_environnement import iremote_environment +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, +) from antareslauncher.study_dto import StudyDTO from antareslauncher.use_cases.retrieve.clean_remote_server import ( RemoteServerCleaner, @@ -15,7 +18,7 @@ class TestServerCleaner: def setup_method(self): - self.remote_env_mock = mock.Mock(spec=iremote_environment.IRemoteEnvironment) + self.remote_env_mock = mock.Mock(spec=RemoteEnvironmentWithSlurm) self.display_mock = mock.Mock(spec_set=IDisplay) self.remote_server_cleaner = RemoteServerCleaner( self.remote_env_mock, self.display_mock diff --git a/tests/unit/retriever/test_study_retriever.py b/tests/unit/retriever/test_study_retriever.py index 32fd776..3caab2d 100644 --- a/tests/unit/retriever/test_study_retriever.py +++ b/tests/unit/retriever/test_study_retriever.py @@ -8,19 +8,13 @@ from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.display.idisplay import IDisplay from antareslauncher.file_manager.file_manager import FileManager -from antareslauncher.remote_environnement.iremote_environment import ( - IRemoteEnvironment, +from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + RemoteEnvironmentWithSlurm, ) from antareslauncher.study_dto import StudyDTO -from antareslauncher.use_cases.retrieve.clean_remote_server import ( - RemoteServerCleaner, -) -from antareslauncher.use_cases.retrieve.download_final_zip import ( - FinalZipDownloader, -) -from antareslauncher.use_cases.retrieve.final_zip_extractor import ( - FinalZipExtractor, -) +from antareslauncher.use_cases.retrieve.clean_remote_server import RemoteServerCleaner +from antareslauncher.use_cases.retrieve.download_final_zip import FinalZipDownloader +from antareslauncher.use_cases.retrieve.final_zip_extractor import FinalZipExtractor from antareslauncher.use_cases.retrieve.log_downloader import LogDownloader from antareslauncher.use_cases.retrieve.state_updater import StateUpdater from antareslauncher.use_cases.retrieve.study_retriever import StudyRetriever @@ -28,7 +22,7 @@ class TestStudyRetriever: def setup_method(self): - env = mock.Mock(spec_set=IRemoteEnvironment) + env = mock.Mock(spec_set=RemoteEnvironmentWithSlurm) display = mock.Mock(spec_set=IDisplay) file_manager = mock.Mock(spec_set=FileManager) repo = mock.Mock(spec_set=IDataRepo) diff --git a/tests/unit/test_remote_environment_with_slurm.py b/tests/unit/test_remote_environment_with_slurm.py index e10f928..2095ff8 100644 --- a/tests/unit/test_remote_environment_with_slurm.py +++ b/tests/unit/test_remote_environment_with_slurm.py @@ -7,16 +7,13 @@ import pytest -from antareslauncher.remote_environnement.iremote_environment import ( - GetJobStateErrorException, - GetJobStateOutputException, - KillJobErrorException, - NoLaunchScriptFoundException, - NoRemoteBaseDirException, - SubmitJobErrorException, -) from antareslauncher.remote_environnement.remote_environment_with_slurm import ( + GetJobStateError, + KillJobError, + NoLaunchScriptFoundError, + NoRemoteBaseDirError, RemoteEnvironmentWithSlurm, + SubmitJobError, ) from antareslauncher.remote_environnement.slurm_script_features import ( ScriptParametersDTO, @@ -51,9 +48,9 @@ def study(self) -> StudyDTO: """Dummy Study Data Transfer Object (DTO)""" return StudyDTO( time_limit=60, - path="study path", + path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f", n_cpu=42, - zipfile_path="zipfile_path", + zipfile_path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f-foo.zip", antares_version="700", local_final_zipfile_path="local_final_zipfile_path", run_mode=Modes.antares, @@ -63,7 +60,7 @@ def study(self) -> StudyDTO: def remote_env(self) -> RemoteEnvironmentWithSlurm: """SLURM remote environment (Mock)""" remote_home_dir = "remote_home_dir" - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") connection.home_dir = remote_home_dir slurm_script_features = SlurmScriptFeatures("slurm_script_path") return RemoteEnvironmentWithSlurm(connection, slurm_script_features) @@ -77,7 +74,7 @@ def test_initialise_remote_path_calls_connection_make_dir_with_correct_arguments remote_base_dir = ( f"{remote_home_dir}/REMOTE_{getpass.getuser()}_{socket.gethostname()}" ) - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") connection.home_dir = remote_home_dir connection.make_dir = mock.Mock(return_value=True) connection.check_file_not_empty = mock.Mock(return_value=True) @@ -92,12 +89,12 @@ def test_when_constructor_is_called_and_remote_base_path_cannot_be_created_then_ self, ): # given - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") slurm_script_features = SlurmScriptFeatures("slurm_script_path") # when connection.make_dir = mock.Mock(return_value=False) # then - with pytest.raises(NoRemoteBaseDirException): + with pytest.raises(NoRemoteBaseDirError): RemoteEnvironmentWithSlurm(connection, slurm_script_features) @pytest.mark.unit_test @@ -105,7 +102,7 @@ def test_when_constructor_is_called_then_connection_check_file_not_empty_is_call self, ): # given - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") connection.make_dir = mock.Mock(return_value=True) connection.check_file_not_empty = mock.Mock(return_value=True) slurm_script_features = SlurmScriptFeatures("slurm_script_path") @@ -121,14 +118,14 @@ def test_when_constructor_is_called_and_connection_check_file_not_empty_is_false ): # given remote_home_dir = "/applications/antares/" - connection = mock.Mock() + connection = mock.Mock(home_dir="path/to/home") connection.home_dir = remote_home_dir connection.make_dir = mock.Mock(return_value=True) slurm_script_features = SlurmScriptFeatures("slurm_script_path") # when connection.check_file_not_empty = mock.Mock(return_value=False) # then - with pytest.raises(NoLaunchScriptFoundException): + with pytest.raises(NoLaunchScriptFoundError): RemoteEnvironmentWithSlurm(connection, slurm_script_features) @pytest.mark.unit_test @@ -183,17 +180,16 @@ def test_when_kill_remote_job_is_called_and_exec_command_returns_error_exception error = "error" remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) # then - with pytest.raises(KillJobErrorException): + with pytest.raises(KillJobError): remote_env.kill_remote_job(42) @pytest.mark.unit_test def test_when_submit_job_is_called_then_execute_command_is_called_with_specific_slurm_command( self, remote_env, study ): - # when - output = "output" - error = None - remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) + # the SSH call output should match "Submitted batch job (?P\d+)" + output = "Submitted batch job 456789\n" + remote_env.connection.execute_command = mock.Mock(return_value=(output, "")) remote_env.submit_job(study) # then script_params = ScriptParametersDTO( @@ -209,7 +205,7 @@ def test_when_submit_job_is_called_then_execute_command_is_called_with_specific_ command = remote_env.slurm_script_features.compose_launch_command( remote_env.remote_base_path, script_params ) - remote_env.connection.execute_command.assert_called_with(command) + remote_env.connection.execute_command.assert_called_once_with(command) @pytest.mark.unit_test def test_when_submit_job_is_called_and_receives_submitted_420_returns_job_id_420( @@ -231,98 +227,88 @@ def test_when_submit_job_is_called_and_receives_error_then_exception_is_raised( error = "error" remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) # then - with pytest.raises(SubmitJobErrorException): + with pytest.raises(SubmitJobError): remote_env.submit_job(study) @pytest.mark.unit_test - def test_when_check_job_state_is_called_then_execute_command_is_called_with_correct_command( - self, remote_env, study - ): - # given - output = "output" - error = "" - study.submitted = True + def test_get_job_state_flags__sacct_bad_output(self, remote_env, study): study.job_id = 42 - remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) - # when - remote_env.get_job_state_flags(study) - # then - # noinspection SpellCheckingInspection - expected_command = ( - f"sacct -j {study.job_id} -n --format=state | head -1 " - + "| awk -F\" \" '{print $1}'" - ) - remote_env.connection.execute_command.assert_called_with(expected_command) - - @pytest.mark.unit_test - def test_given_submitted_study__when_check_job_state_gets_empty_output_it_tries_5_times_then_raises_exception( - self, remote_env, study - ): - # given - output = "" - error = "" - study.submitted = True - study.job_id = 42 - # when - remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) - # then - with pytest.raises(GetJobStateOutputException): + # the output of `sacct` is not: JobID,JobName,State + output = "the sun is shining" + remote_env.connection.execute_command = mock.Mock(return_value=(output, "")) + with pytest.raises(GetJobStateError, match="non-parsable output") as ctx: remote_env.get_job_state_flags(study) - tries_number = remote_env.connection.execute_command.call_count - assert tries_number == 5 + assert output in str(ctx.value) + command = ( + "sacct" + f" --jobs={study.job_id}" + f" --name={study.name}" + " --format=JobID,JobName,State" + " --parsable2" + " --delimiter=," + " --noheader" + ) + remote_env.connection.execute_command.assert_called_once_with(command) @pytest.mark.unit_test - def test_given_a_submitted_study_when_execute_command_returns_an_error_then_an_exception_is_raised( - self, remote_env, study - ): - # given - output = "output" - error = "error" - study.submitted = True + def test_get_job_state_flags__sacct_call_fails(self, remote_env, study): study.job_id = 42 - # when - remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) - # then - with pytest.raises(GetJobStateErrorException): - remote_env.get_job_state_flags(study) + # the `sacct` command fails, output = None (error) + error = "an error occurs" + remote_env.connection.execute_command = mock.Mock(return_value=(None, error)) + with pytest.raises(GetJobStateError, match="an error occurs"): + remote_env.get_job_state_flags(study, attempts=2, sleep_time=0.1) + command = ( + "sacct" + f" --jobs={study.job_id}" + f" --name={study.name}" + " --format=JobID,JobName,State" + " --parsable2" + " --delimiter=," + " --noheader" + ) + remote_env.connection.execute_command.mock_calls = [ + call(command), + call(command), + ] - # noinspection SpellCheckingInspection @pytest.mark.unit_test @pytest.mark.parametrize( - "output,expected_started, expected_finished, expected_with_error", + "state, expected", [ - ("PENDING", False, False, False), - ("RUNNING", True, False, False), - ("CANCELLED BY DUMMY", True, True, True), - ("TIMEOUT", True, True, True), - ("COMPLETED", True, True, False), - ("FAILED DUMMYWORD", True, True, True), + ("", (False, False, False)), + ("PENDING", (False, False, False)), + ("RUNNING", (True, False, False)), + ("CANCELLED", (True, True, True)), + ("TIMEOUT", (True, True, True)), + ("COMPLETED", (True, True, False)), + ("FAILED", (True, True, True)), ], ) - def test_given_state_when_get_job_state_flags_is_called_then_started_and_finished_and_with_error_are_correct( - self, - remote_env, - study, - output, - expected_started, - expected_finished, - expected_with_error, + def test_get_job_state_flags__nominal_case( + self, remote_env, study, state, expected ): - # given - error = "" - study.submitted = True + """ + Check that the "get_job_state_flags" method is correctly returning + the status flags ("started", "finished", and "with_error") + for a SLURM job in a specific state. + """ study.job_id = 42 - remote_env.connection.execute_command = mock.Mock(return_value=(output, error)) - # when - ( - started, - finished, - with_error, - ) = remote_env.get_job_state_flags(study) - # then - assert started is expected_started - assert finished is expected_finished - assert with_error is expected_with_error + # the output of `sacct` should be: JobID,JobName,State + output = f"{study.job_id},{study.name},{state}" if state else "" + remote_env.connection.execute_command = mock.Mock(return_value=(output, "")) + actual = remote_env.get_job_state_flags(study) + assert actual == expected + command = ( + "sacct" + f" --jobs={study.job_id}" + f" --name={study.name}" + " --format=JobID,JobName,State" + " --parsable2" + " --delimiter=," + " --noheader" + ) + remote_env.connection.execute_command.assert_called_once_with(command) @pytest.mark.unit_test @pytest.mark.parametrize( diff --git a/tests/unit/test_ssh_connection.py b/tests/unit/test_ssh_connection.py index 52fc416..4086db3 100644 --- a/tests/unit/test_ssh_connection.py +++ b/tests/unit/test_ssh_connection.py @@ -7,13 +7,12 @@ import paramiko import pytest -from paramiko.sftp_attr import SFTPAttributes - from antareslauncher.remote_environnement.ssh_connection import ( + ConnectionFailedException, DownloadMonitor, SshConnection, - ConnectionFailedException, ) +from paramiko.sftp_attr import SFTPAttributes LOGGER = DownloadMonitor.__module__ @@ -27,16 +26,25 @@ def test_download_monitor__null_size(self, caplog): assert not caplog.text def test_download_monitor__nominal(self, caplog): - total_size = 1000 + """Simulate the downloading of two files""" + sizes = 1000, 1256 # two different sizes + total_size = sum(sizes) monitor = DownloadMonitor(total_size, msg="Downloading 'foo'") with caplog.at_level(level=logging.INFO, logger=LOGGER): - for _ in range(0, total_size, 250): - monitor(250, 0) - time.sleep(0.01) + for size in sizes: + monitor.accumulate() + for transferred in range(250, size + 1, 250): + monitor(transferred, 0) + time.sleep(0.01) assert caplog.messages == [ - "Downloading 'foo' ETA: 0s [25%]", - "Downloading 'foo' ETA: 0s [50%]", - "Downloading 'foo' ETA: 0s [75%]", + "Downloading 'foo' ETA: 0s [11%]", + "Downloading 'foo' ETA: 0s [22%]", + "Downloading 'foo' ETA: 0s [33%]", + "Downloading 'foo' ETA: 0s [44%]", + "Downloading 'foo' ETA: 0s [55%]", + "Downloading 'foo' ETA: 0s [66%]", + "Downloading 'foo' ETA: 0s [78%]", + "Downloading 'foo' ETA: 0s [89%]", "Downloading 'foo' ETA: 0s [100%]", ]