Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand connection retry logic #69

Merged
merged 22 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Changed

- Expand connection retry logic to cover more cases
- Move exec script into a separate file

## [0.23.0] - 2023-10-20

### Changed
Expand Down
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ignore:
# This script is read into a string and formatted. Never executed directly.
- covalent_ssh_plugin/exec.py
46 changes: 46 additions & 0 deletions covalent_ssh_plugin/exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Load task `fn` from pickle. Run it. Save the result.
"""
import os
import sys
from pathlib import Path

result = None
exception = None

# NOTE: Paths must be substituted-in here by the executor.
remote_result_file = Path("{remote_result_file}").resolve()
remote_function_file = Path("{remote_function_file}").resolve()
current_remote_workdir = Path("{current_remote_workdir}").resolve()

try:
# Make sure cloudpickle is available.
import cloudpickle as pickle
except Exception as e:
import pickle

with open(remote_result_file, "wb") as f_out:
pickle.dump((None, e), f_out)
sys.exit(1) # Error.

current_dir = os.getcwd()

# Read the function object and arguments from pickle file.
with open(remote_function_file, "rb") as f_in:
fn, args, kwargs = pickle.load(f_in)

try:
# Execute the task `fn` inside the remote workdir.
current_remote_workdir.mkdir(parents=True, exist_ok=True)
os.chdir(current_remote_workdir)

result = fn(*args, **kwargs)

except Exception as e:
exception = e
finally:
os.chdir(current_dir)

# Save the result to pickle file.
with open(remote_result_file, "wb") as f_out:
pickle.dump((result, exception), f_out)
151 changes: 81 additions & 70 deletions covalent_ssh_plugin/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class SSHExecutor(RemoteExecutor):
then the execution is run on the local machine.
remote_workdir: The working directory on the remote server used for storing files produced from workflows.
create_unique_workdir: Whether to create unique sub-directories for each node / task / electron.
poll_freq: Number of seconds to wait for before retrying the result poll
do_cleanup: Whether to delete all the intermediate files or not
poll_freq: Number of seconds to wait for before retrying the result poll.
do_cleanup: Delete all the intermediate files or not if True.
max_connection_attempts: Maximum number of attempts to establish SSH connection.
retry_wait_time: Time to wait (in seconds) before reattempting connection.
"""

def __init__(
Expand All @@ -85,6 +87,8 @@ def __init__(
poll_freq: int = 15,
do_cleanup: bool = True,
retry_connect: bool = True,
max_connection_attempts: int = 5,
retry_wait_time: int = 5,
) -> None:

remote_cache = (
Expand Down Expand Up @@ -113,6 +117,8 @@ def __init__(

self.do_cleanup = do_cleanup
self.retry_connect = retry_connect
self.max_connection_attempts = max_connection_attempts
self.retry_wait_time = retry_wait_time

ssh_key_file = ssh_key_file or get_config("executors.ssh.ssh_key_file")
self.ssh_key_file = str(Path(ssh_key_file).expanduser().resolve())
Expand All @@ -124,7 +130,7 @@ def _write_function_files(
args: list,
kwargs: dict,
current_remote_workdir: str = ".",
) -> None:
) -> Tuple[str, str, str, str, str]:
"""
Helper function to pickle the function to be executed to file, and write the
python script which calls the function.
Expand All @@ -143,53 +149,26 @@ def _write_function_files(
with open(function_file, "wb") as f_out:
pickle.dump((fn, args, kwargs), f_out)
remote_function_file = os.path.join(self.remote_cache, f"function_{operation_id}.pkl")
remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl")

# Write the code that the remote server will use to execute the function.

message = f"Function file names:\nLocal function file: {function_file}\n"
message += f"Remote function file: {remote_function_file}"
app_log.debug(message)

remote_result_file = os.path.join(self.remote_cache, f"result_{operation_id}.pkl")
exec_script = "\n".join(
[
"import os",
"import sys",
"from pathlib import Path",
"",
"result = None",
"exception = None",
"",
"try:",
" import cloudpickle as pickle",
"except Exception as e:",
" import pickle",
f" with open('{remote_result_file}','wb') as f_out:",
" pickle.dump((None, e), f_out)",
" exit()",
"",
f"with open('{remote_function_file}', 'rb') as f_in:",
" fn, args, kwargs = pickle.load(f_in)",
" current_dir = os.getcwd()",
" try:",
f" Path('{current_remote_workdir}').mkdir(parents=True, exist_ok=True)",
f" os.chdir('{current_remote_workdir}')",
" result = fn(*args, **kwargs)",
" except Exception as e:",
" exception = e",
" finally:",
" os.chdir(current_dir)",
"",
"",
f"with open('{remote_result_file}','wb') as f_out:",
" pickle.dump((result, exception), f_out)",
"",
]
)
exec_blank = Path(__file__).parent / "exec.py"
script_file = os.path.join(self.cache_dir, f"exec_{operation_id}.py")
remote_script_file = os.path.join(self.remote_cache, f"exec_{operation_id}.py")
with open(script_file, "w") as f_out:
f_out.write(exec_script)

with open(exec_blank, "r", encoding="utf-8") as f_blank:
exec_script = f_blank.read().format(
remote_result_file=remote_result_file,
remote_function_file=remote_function_file,
current_remote_workdir=current_remote_workdir,
)
with open(script_file, "w", encoding="utf-8") as f_out:
f_out.write(exec_script)

return (
function_file,
Expand Down Expand Up @@ -228,9 +207,9 @@ def _on_ssh_fail(
app_log.error(message)
raise RuntimeError(message)

async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]:
async def _client_connect(self) -> Tuple[bool, Optional[asyncssh.SSHClientConnection]]:
"""
Helper function for connecting to the remote host through the paramiko module.
Attempts connection to the remote host.

Args:
None
Expand All @@ -241,35 +220,65 @@ async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]:

ssh_success = False
conn = None
if os.path.exists(self.ssh_key_file):
retries = 6 if self.retry_connect else 1
for _ in range(retries):
try:
conn = await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)

ssh_success = True
except (socket.gaierror, ValueError, TimeoutError, ConnectionRefusedError) as e:
app_log.error(e)
if not os.path.exists(self.ssh_key_file):
message = f"SSH key file {self.ssh_key_file} does not exist."
app_log.error(message)
raise RuntimeError(message)

if conn is not None:
break
try:
conn = await self._attempt_client_connect()
ssh_success = conn is not None
except (socket.gaierror, ValueError, TimeoutError) as e:
app_log.error(e)

await asyncio.sleep(5)
return ssh_success, conn

if conn is None and not self.run_local_on_ssh_fail:
raise RuntimeError("Could not connect to remote host.")
async def _attempt_client_connect(self) -> Optional[asyncssh.SSHClientConnection]:
"""
Helper function that catches specific errors and retries connecting to the remote host.

else:
message = f"no SSH key file found at {self.ssh_key_file}. Cannot connect to host."
app_log.error(message)
raise RuntimeError(message)
Args:
max_attempts: Gives up after this many attempts.

return ssh_success, conn
Returns:
An `SSHClientConnection` object if successful, None otherwise.
"""

# Retry connecting if any of these errors happen:
_retry_errs = (
ConnectionRefusedError,
OSError, # e.g. Network unreachable
)

address = f"{self.username}@{self.hostname}"
attempt_max = self.max_connection_attempts

attempt = 0
while attempt < attempt_max:

try:
# Exit here if the connection is successful.
return await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)
except _retry_errs as err:

if not self.retry_connect:
app_log.error(f"{err} ({address} | retry disabled).")
raise err

app_log.warning(f"{err} ({address} | retry {attempt+1}/{attempt_max})")
await asyncio.sleep(self.retry_wait_time)

finally:
attempt += 1

# Failed to connect to client.
return None
kessler-frost marked this conversation as resolved.
Show resolved Hide resolved

async def cleanup(
self,
Expand All @@ -290,7 +299,7 @@ async def cleanup(
script_file: Path to the script file to be deleted locally
result_file: Path to the result file to be deleted locally
remote_function_file: Path to the function file to be deleted on remote
remote_script_file: Path to the sccript file to be deleted on remote
remote_script_file: Path to the script file to be deleted on remote
remote_result_file: Path to the result file to be deleted on remote

Returns:
Expand Down Expand Up @@ -540,9 +549,11 @@ async def run(
app_log.debug("Running function file in remote machine...")
result = await self.submit_task(conn, remote_script_file)

if result_err := result.stderr.strip():
app_log.warning(result_err)
return self._on_ssh_fail(function, args, kwargs, result_err)
if result.exit_status != 0:
message = result.stderr.strip()
message = message or f"Task exited with nonzero exit status {result.exit_status}."
app_log.warning(message)
return self._on_ssh_fail(function, args, kwargs, message)

if not await self._poll_task(conn, remote_result_file):
message = (
Expand Down
20 changes: 20 additions & 0 deletions tests/functional_tests/basic_workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def basic_workflow(a, b):
print(result)

assert status == str(ct.status.COMPLETED)


@pytest.mark.functional_tests
def test_basic_workflow_failure():
@ct.electron(executor="ssh")
def join_words(a, b):
raise Exception(f"{', '.join([a, b])} -- but something went wrong!")

@ct.lattice
def basic_workflow_that_will_fail(a, b):
return join_words(a, b)

# Dispatch the workflow
dispatch_id = ct.dispatch(basic_workflow_that_will_fail)("Hello", "World")
result = ct.get_result(dispatch_id=dispatch_id, wait=True)
status = str(result.status)

print(result)

assert status == str(ct.status.FAILED)
Loading
Loading