Skip to content

Commit

Permalink
v1.3.2
Browse files Browse the repository at this point in the history
Merge pull request #70 from AntaresSimulatorTeam/hotfix/v1.3.2
  • Loading branch information
laurent-laporte-pro authored Apr 11, 2024
2 parents 516f94d + a9324f7 commit e4c3d44
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 67 deletions.
15 changes: 15 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ npx auto-changelog -l false --hide-empty-releases -v v1.3.1 -o CHANGES.out.md
```
-->

## [1.3.2] - 2024-04-11

### Build

- build: add a script to bump the version

### Changed

- feat(ssh): add retry loop around SSH Exceptions [`#68`](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/68)

### Fixes

- fix(retriever): avoid infinite loop if sbatch command fails [`#69`](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/69)


## [1.3.1] - 2023-09-26

### Changed
Expand Down
4 changes: 2 additions & 2 deletions antareslauncher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

# Standard project metadata

__version__ = "1.3.1"
__version__ = "1.3.2"
__author__ = "RTE, Antares Web Team"
__date__ = "2023-09-26"
__date__ = "2024-04-11"
# noinspection SpellCheckingInspection
__credits__ = "(c) Réseau de Transport de l’Électricité (RTE)"

Expand Down
197 changes: 134 additions & 63 deletions antareslauncher/remote_environnement/ssh_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import fnmatch
import functools
import logging
import socket
import stat
Expand All @@ -14,14 +15,56 @@
LocalPath = Path


def retry(
exception: t.Type[Exception],
*exceptions: t.Type[Exception],
delay_sec: float = 5,
max_retry: int = 5,
msg_fmt: str = "Retrying in {delay_sec} seconds...",
):
"""
Decorator to retry a function call if it raises an exception.
Args:
exception: The exception to catch.
exceptions: Additional exceptions to catch.
delay_sec: The delay (in seconds) between each retry.
max_retry: The maximum number of retries.
msg_fmt: The message to display when retrying, with the following format keys:
- delay_sec: The delay (in seconds) between each retry.
- remaining: The number of remaining retries.
Returns:
The decorated function.
"""

def decorator(func): # type: ignore
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore
for attempt in range(max_retry):
try:
return func(*args, **kwargs)
except (exception, *exceptions):
logger = logging.getLogger(__name__)
remaining = max_retry - attempt - 1
logger.warning(msg_fmt.format(delay_sec=delay_sec, remaining=remaining))
time.sleep(delay_sec)
# Last attempt
return func(*args, **kwargs)

return wrapper

return decorator


class SshConnectionError(Exception):
"""
SSH Connection Error
"""


class InvalidConfigError(SshConnectionError):
def __init__(self, config, msg=""):
def __init__(self, config: t.Mapping[str, t.Any], msg: str = ""):
err_msg = f"Invalid configuration error {config}"
if msg:
err_msg += f": {msg}"
Expand Down Expand Up @@ -106,7 +149,7 @@ def __str__(self) -> str:
return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]"
return f"{self.msg:<20} ETA: ??? [{rate:.0%}]"

def accumulate(self):
def accumulate(self) -> None:
"""
Accumulates the quantity transferred by the previous transfer and
the current transfer.
Expand Down Expand Up @@ -151,7 +194,7 @@ def __init__(self, config: t.Mapping[str, t.Any]):
self.initialize_home_dir()
self.logger.info(f"Connection created with host = {self.host} and username = {self.username}")

def _init_public_key(self, key_file_name, key_password):
def _init_public_key(self, key_file_name: str, key_password: str) -> bool:
"""Initialises self.private_key
Args:
Expand Down Expand Up @@ -234,7 +277,7 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]:
self.logger.exception(f"paramiko.AuthenticationException: {paramiko.AuthenticationException}")
raise ConnectionFailedException(self.host, self.port, self.username) from e
except paramiko.SSHException as e:
self.logger.exception(f"paramiko.SSHException: {paramiko.SSHException}")
self.logger.exception(f"Paramiko SSH Exception: {e!r}")
raise ConnectionFailedException(self.host, self.port, self.username) from e
except socket.timeout as e:
self.logger.exception(f"socket.timeout: {socket.timeout}")
Expand All @@ -247,47 +290,75 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]:
finally:
client.close()

def execute_command(self, command: str):
"""Executes a command on the remote host. Puts stderr and stdout in
self.ssh_error and self.ssh_output respectively
def execute_command(self, command: str) -> t.Tuple[t.Optional[str], str]:
"""
Runs an SSH command with a retry logic.
If it encounters an SSH Exception, it's going to sleep for 5 seconds.
The command will then be re-executed a maximum of 5 times.
It allows us to wait for the connection to be re-established.
This way, we avoid having a simulation failure due to an SSH error.
Args:
command: String containing the command that will be executed through the ssh connection
Returns:
output: The standard output of the command
error: The standard error of the command
"""
output = None

try:
with self.ssh_client() as client:
# fmt: off
self.logger.info(f"Running SSH command [{command}]...")
stdin, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}")
# fmt: on
output, error = self._exec_command(command)
except socket.timeout:
error = f"SSH command timed out: [{command}]"
self.logger.error(error)
except paramiko.SSHException as e:
error = f"SSH command failed to execute [{command}]: {e}"
self.logger.error(error)
except ConnectionFailedException as e:
error = f"SSH connection failed: {e}"

if error:
self.logger.error(error)

return output, error

@retry(
socket.timeout,
paramiko.SSHException,
ConnectionFailedException,
delay_sec=5,
max_retry=5,
msg_fmt=(
"An SSH Error occurred, so the command did not succeed."
" The command will be re-executed {remaining} times until it succeeds."
" Retrying in {delay_sec} seconds..."
),
)
def _exec_command(self, command: str) -> t.Tuple[str, str]:
"""
Executes a command on the remote host.
Args:
command: String containing the command that will be executed through the ssh connection
Returns:
output: The standard output of the command
error: The standard error of the command
"""
with self.ssh_client() as client:
self.logger.info(f"Running SSH command [{command}]...")
_, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}")
return output, error

def upload_file(self, src: str, dst: str):
"""Uploads a file to a remote server via sftp protocol
Args:
src: Local file to upload
dst: Remote directory where the file will be uploaded
Returns:
Expand All @@ -300,18 +371,18 @@ def upload_file(self, src: str, dst: str):
sftp_client = client.open_sftp()
sftp_client.put(src, dst)
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.debug(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except IOError:
self.logger.debug("IO Error", exc_info=True)
except IOError as e:
self.logger.debug(f"IO Error: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def download_file(self, src: str, dst: str):
def download_file(self, src: str, dst: str) -> bool:
"""Downloads a file from a remote server via sftp protocol
Args:
Expand All @@ -329,21 +400,21 @@ def download_file(self, src: str, dst: str):
sftp_client.get(src, dst)
sftp_client.close()
result_flag = True
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.error(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def download_files(
self,
src_dir: RemotePath,
dst_dir: LocalPath,
pattern: str,
*patterns: str,
remove: bool = True,
self,
src_dir: RemotePath,
dst_dir: LocalPath,
pattern: str,
*patterns: str,
remove: bool = True,
) -> t.Sequence[LocalPath]:
"""
Download files matching the specified patterns from the remote
Expand All @@ -369,20 +440,20 @@ def download_files(
except TimeoutError as exc:
self.logger.error(f"Timeout: {exc}", exc_info=True)
return []
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.error(f"Paramiko SSH Exception: {e!r}", exc_info=True)
return []
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
return []

def _download_files(
self,
src_dir: RemotePath,
dst_dir: LocalPath,
patterns: t.Tuple[str, ...],
*,
remove: bool = True,
self,
src_dir: RemotePath,
dst_dir: LocalPath,
patterns: t.Tuple[str, ...],
*,
remove: bool = True,
) -> t.Sequence[LocalPath]:
"""
Download files matching the specified patterns from the remote
Expand Down Expand Up @@ -447,12 +518,12 @@ def check_remote_dir_exists(self, dir_path):
if stat.S_ISDIR(sftp_stat.st_mode):
result_flag = True
else:
raise IOError
raise IOError(f"Not a directory: '{dir_path}'")
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -482,8 +553,8 @@ def check_file_not_empty(self, file_path):
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -512,11 +583,11 @@ def make_dir(self, dir_path):
result_flag = True
finally:
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSHException", exc_info=True)
except paramiko.SSHException as e:
self.logger.debug(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -548,8 +619,8 @@ def remove_file(self, file_path):
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -581,15 +652,15 @@ def remove_dir(self, dir_path):
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def test_connection(self):
def test_connection(self) -> bool:
try:
with self.ssh_client():
return True
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
return False
2 changes: 1 addition & 1 deletion antareslauncher/use_cases/retrieve/state_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(self, study: StudyDTO) -> None:
Args:
study: The study data transfer object
"""
if not study.done:
if not study.done and not study.with_error:
# set current study job state flags
if study.job_id:
s, f, e = self._env.get_job_state_flags(study)
Expand Down
Loading

0 comments on commit e4c3d44

Please sign in to comment.