diff --git a/juju/machine.py b/juju/machine.py index 60554fd09..b0d77ef65 100644 --- a/juju/machine.py +++ b/juju/machine.py @@ -3,14 +3,15 @@ import ipaddress import logging +import typing import pyrfc3339 -from . import model, tag, jasyncio +from . import jasyncio, model, tag from .annotationhelper import _get_annotations, _set_annotations from .client import client from .errors import JujuError -from juju.utils import juju_ssh_key_paths +from juju.utils import juju_ssh_key_paths, block_until log = logging.getLogger(__name__) @@ -70,7 +71,7 @@ def _format_addr(self, addr): return fmt.format(ipaddr) async def scp_to(self, source, destination, user='ubuntu', proxy=False, - scp_opts=''): + scp_opts='', wait_for_active=False, timeout=None): """Transfer files to this machine. :param str source: Local path of file(s) to transfer @@ -79,10 +80,13 @@ async def scp_to(self, source, destination, user='ubuntu', proxy=False, :param bool proxy: Proxy through the Juju API server :param scp_opts: Additional options to the `scp` command :type scp_opts: str or list + :param bool wait_for_active: Wait until the machine is ready to take in ssh commands. + :param int timeout: Time in seconds to wait until the machine becomes ready. """ if proxy: raise NotImplementedError('proxy option is not implemented') - + if wait_for_active: + await block_until(lambda: self.addresses, timeout=timeout) try: # if dns_name is an IP address format it appropriately address = self._format_addr(self.dns_name) @@ -93,7 +97,7 @@ async def scp_to(self, source, destination, user='ubuntu', proxy=False, await self._scp(source, destination, scp_opts) async def scp_from(self, source, destination, user='ubuntu', proxy=False, - scp_opts=''): + scp_opts='', wait_for_active=False, timeout=None): """Transfer files from this machine. :param str source: Remote path of file(s) to transfer @@ -102,10 +106,13 @@ async def scp_from(self, source, destination, user='ubuntu', proxy=False, :param bool proxy: Proxy through the Juju API server :param scp_opts: Additional options to the `scp` command :type scp_opts: str or list + :param bool wait_for_active: Wait until the machine is ready to take in ssh commands. + :param int timeout: Time in seconds to wait until the machine becomes ready. """ if proxy: raise NotImplementedError('proxy option is not implemented') - + if wait_for_active: + await block_until(lambda: self.addresses, timeout=timeout) try: # if dns_name is an IP address format it appropriately address = self._format_addr(self.dns_name) @@ -129,23 +136,37 @@ async def _scp(self, source, destination, scp_opts): ] cmd.extend(scp_opts.split() if isinstance(scp_opts, str) else scp_opts) cmd.extend([source, destination]) - process = await jasyncio.create_subprocess_exec(*cmd) - await process.wait() + # There's a bit of a gap between the time that the machine is assigned an IP and the ssh + # service is up and listening, which creates a race for the ssh command. So we retry a + # couple of times until either we run out of attempts, or the ssh command succeeds to + # mitigate that effect. + # TODO (cderici): refactor the ssh and scp subcommand processing into a single method. + retry_backoff = 2 + retries = 10 + for _ in range(retries): + process = await jasyncio.create_subprocess_exec(*cmd) + await process.wait() + if process.returncode == 0: + break + await jasyncio.sleep(retry_backoff) if process.returncode != 0: - raise JujuError("command failed: %s" % cmd) + raise JujuError(f"command failed after {retries} attempts: {cmd}") async def ssh( - self, command, user='ubuntu', proxy=False, ssh_opts=None): + self, command, user='ubuntu', proxy=False, ssh_opts=None, wait_for_active=False, timeout=None): """Execute a command over SSH on this machine. :param str command: Command to execute :param str user: Remote username :param bool proxy: Proxy through the Juju API server :param str ssh_opts: Additional options to the `ssh` command - + :param bool wait_for_active: Wait until the machine is ready to take in ssh commands. + :param int timeout: Time in seconds to wait until the machine becomes ready. """ if proxy: raise NotImplementedError('proxy option is not implemented') + if wait_for_active: + await block_until(lambda: self.addresses, timeout=timeout) address = self.dns_name destination = "{}@{}".format(user, address) _, id_path = juju_ssh_key_paths() @@ -159,14 +180,32 @@ async def ssh( if ssh_opts: cmd.extend(ssh_opts.split() if isinstance(ssh_opts, str) else ssh_opts) cmd.extend([command]) - process = await jasyncio.create_subprocess_exec( - *cmd, stdout=jasyncio.subprocess.PIPE, stderr=jasyncio.subprocess.PIPE) - stdout, stderr = await process.communicate() + + # There's a bit of a gap between the time that the machine is assigned an IP and the ssh + # service is up and listening, which creates a race for the ssh command. So we retry a + # couple of times until either we run out of attempts, or the ssh command succeeds to + # mitigate that effect. + retry_backoff = 2 + retries = 10 + for _ in range(retries): + process = await jasyncio.create_subprocess_exec( + *cmd, stdout=jasyncio.subprocess.PIPE, stderr=jasyncio.subprocess.PIPE) + stdout, stderr = await process.communicate() + if process.returncode == 0: + break + await jasyncio.sleep(retry_backoff) if process.returncode != 0: - raise JujuError("command failed: %s with %s" % (cmd, stderr.decode())) + raise JujuError(f"command failed: {cmd} after {retries} attempts, with {stderr.decode()}") # stdout is a bytes-like object, returning a string might be more useful return stdout.decode() + @property + def addresses(self) -> typing.List[str]: + """Returns the machine addresses. + + """ + return self.safe_data['addresses'] or [] + @property def agent_status(self): """Returns the current Juju agent status string. @@ -221,11 +260,10 @@ def dns_name(self): May return None if no suitable address is found. """ - addresses = self.safe_data['addresses'] or [] ordered_addresses = [] ordered_scopes = ['public', 'local-cloud', 'local-fan'] for scope in ordered_scopes: - for address in addresses: + for address in self.addresses: if scope == address['scope']: ordered_addresses.append(address) for address in ordered_addresses: diff --git a/tests/integration/test_machine.py b/tests/integration/test_machine.py index f0f82c188..1a6804152 100644 --- a/tests/integration/test_machine.py +++ b/tests/integration/test_machine.py @@ -6,6 +6,7 @@ import pytest from .. import base +from juju.machine import Machine @base.bootstrapped @@ -36,3 +37,12 @@ async def test_status(): machine.status_message.lower() == 'running' and machine.agent_status == 'started')), timeout=480) + + +@base.bootstrapped +async def test_machine_ssh(): + async with base.CleanModel() as model: + machine: Machine = await model.add_machine() + out = await machine.ssh("echo hello world!", wait_for_active=True) + + assert out == "hello world!\n"