Skip to content

Commit ead8f86

Browse files
committed
Allow to disable ssh agent usage
1 parent 20b45e0 commit ead8f86

File tree

5 files changed

+188
-83
lines changed

5 files changed

+188
-83
lines changed

doc/source/SSHClient.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ API: SSHClient and SSHAuth.
1010
1111
SSHClient helper.
1212

13-
.. py:method:: __init__(host, port=22, username=None, password=None, *, auth=None, verbose=True, ssh_config=None, ssh_auth_map=None, sock=None, keepalive=1)
13+
.. py:method:: __init__(host, port=22, username=None, password=None, *, auth=None, verbose=True, ssh_config=None, ssh_auth_map=None, sock=None, keepalive=1, allow_ssh_agent=True)
1414
1515
:param host: remote hostname
1616
:type host: ``str``
@@ -32,6 +32,8 @@ API: SSHClient and SSHAuth.
3232
:type sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None
3333
:param keepalive: keepalive period
3434
:type keepalive: int | bool
35+
:param allow_ssh_agent: use SSH Agent if available
36+
:type allow_ssh_agent: bool
3537

3638
.. note:: auth has priority over username/password/private_keys
3739
.. note::
@@ -49,6 +51,7 @@ API: SSHClient and SSHAuth.
4951
.. versionchanged:: 7.0.0 private_keys is removed
5052
.. versionchanged:: 7.0.0 keepalive_mode is removed
5153
.. versionchanged:: 7.4.0 return of keepalive_mode to prevent mix with keepalive period. Default is `False`
54+
.. versionchanged:: 8.0.0 expose SSH Agent usage override
5255

5356
.. py:attribute:: log_mask_re
5457
@@ -102,6 +105,11 @@ API: SSHClient and SSHAuth.
102105
``int | bool``
103106
Keepalive period for connection object.
104107

108+
.. py:attribute:: use_ssh_agent
109+
110+
``bool``
111+
Use SSH Agent if available.
112+
105113
.. py:method:: close()
106114
107115
Close connection

exec_helpers/_ssh_base.py

+92-65
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ class SSHClientBase(api.ExecHelper):
454454
:type sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None
455455
:param keepalive: keepalive period
456456
:type keepalive: int | bool
457+
:param allow_ssh_agent: use SSH Agent if available
458+
:type allow_ssh_agent: bool
457459
458460
.. note:: auth has priority over username/password/private_keys
459461
.. note::
@@ -471,6 +473,7 @@ class SSHClientBase(api.ExecHelper):
471473
.. versionchanged:: 7.0.0 private_keys is removed
472474
.. versionchanged:: 7.0.0 keepalive_mode is removed
473475
.. versionchanged:: 7.4.0 return of keepalive_mode to prevent mix with keepalive period. Default is `False`
476+
.. versionchanged:: 8.0.0 expose SSH Agent usage override
474477
"""
475478

476479
__slots__ = (
@@ -486,6 +489,7 @@ class SSHClientBase(api.ExecHelper):
486489
"__ssh_config",
487490
"__sock",
488491
"__conn_chain",
492+
"__allow_agent",
489493
)
490494

491495
def __hash__(self) -> int:
@@ -509,8 +513,19 @@ def __init__(
509513
ssh_auth_map: dict[str, ssh_auth.SSHAuth] | ssh_auth.SSHAuthMapping | None = None,
510514
sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None = None,
511515
keepalive: KeepAlivePeriodT = 1,
516+
allow_ssh_agent: bool = True,
512517
) -> None:
513518
"""Main SSH Client helper."""
519+
self.__sudo_mode = False
520+
self.__keepalive_period: int = int(keepalive)
521+
self.__keepalive_mode = False
522+
self.__verbose: bool = verbose
523+
self.__sock = sock
524+
525+
self.__ssh: paramiko.SSHClient
526+
self.__sftp: paramiko.SFTPClient | None = None
527+
self.__allow_agent = allow_ssh_agent
528+
514529
# Init ssh config. It's main source for connection parameters
515530
if isinstance(ssh_config, _ssh_helpers.HostsSSHConfigs):
516531
self.__ssh_config: _ssh_helpers.HostsSSHConfigs = ssh_config
@@ -533,35 +548,25 @@ def __init__(
533548
if self.hostname not in self.__auth_mapping and host in self.__auth_mapping:
534549
self.__auth_mapping[self.hostname] = self.__auth_mapping[host]
535550

536-
self.__sudo_mode = False
537-
self.__keepalive_period: int = int(keepalive)
538-
self.__keepalive_mode = False
539-
self.__verbose: bool = verbose
540-
self.__sock = sock
541-
542-
self.__ssh: paramiko.SSHClient
543-
self.__sftp: paramiko.SFTPClient | None = None
544-
545551
# Rebuild SSHAuth object if required.
546552
# Priority: auth > credentials > auth mapping
547-
if auth is not None:
548-
self.__auth_mapping[self.hostname] = real_auth = copy.copy(auth)
549-
elif self.hostname not in self.__auth_mapping or any((username, password)):
550-
self.__auth_mapping[self.hostname] = real_auth = ssh_auth.SSHAuth(
551-
username=username if username is not None else config.user,
552-
password=password,
553-
key_filename=config.identityfile,
554-
)
555-
else:
556-
real_auth = self.__auth_mapping[self.hostname]
553+
real_auth = self.__handle_explicit_auth(
554+
username=username,
555+
config_username=config.user,
556+
password=password,
557+
auth=auth,
558+
key_filename=config.identityfile,
559+
)
557560

558561
# Init super with host and real port and username
559562
mod_name = "exec_helpers" if self.__module__.startswith("exec_helpers") else self.__module__
560563
log_username: str = real_auth.username if real_auth.username is not None else getpass.getuser()
561564

562565
super().__init__(
563-
logger=logging.getLogger(f"{mod_name}.{self.__class__.__name__}").getChild(
564-
f"({log_username}@{host}:{self.port})"
566+
logger=logging.getLogger(
567+
f"{mod_name}.{self.__class__.__name__}",
568+
).getChild(
569+
f"({log_username}@{host}:{self.port})",
565570
)
566571
)
567572

@@ -577,6 +582,26 @@ def __init__(
577582

578583
self.__connect()
579584

585+
def __handle_explicit_auth(
586+
self,
587+
*,
588+
username: str | None,
589+
config_username: str | None,
590+
password: str | None,
591+
auth: ssh_auth.SSHAuth | None,
592+
key_filename: Iterable[str] | None,
593+
) -> ssh_auth.SSHAuth:
594+
if auth is not None:
595+
self.__auth_mapping[self.hostname] = auth
596+
elif self.hostname not in self.__auth_mapping or any((username, password)):
597+
self.__auth_mapping[self.hostname] = ssh_auth.SSHAuth(
598+
username=username if username is not None else config_username,
599+
password=password,
600+
key_filename=key_filename,
601+
)
602+
603+
return self.__auth_mapping[self.hostname]
604+
580605
def __rebuild_ssh_config(self) -> None:
581606
"""Rebuild main ssh config from available information."""
582607
self.__ssh_config[self.hostname] = self.__ssh_config[self.hostname].overridden_by(
@@ -598,7 +623,11 @@ def __build_connection_chain(self) -> list[tuple[_ssh_helpers.SSHConfig, ssh_aut
598623

599624
config = self.ssh_config[self.hostname]
600625
default_auth = ssh_auth.SSHAuth(username=config.user, key_filename=config.identityfile)
601-
auth = self.__auth_mapping.get_with_alt_hostname(config.hostname, self.hostname, default=default_auth)
626+
auth = self.__auth_mapping.get_with_alt_hostname(
627+
config.hostname,
628+
self.hostname,
629+
default=default_auth,
630+
)
602631
conn_chain.append((config, auth))
603632

604633
while config.proxyjump is not None:
@@ -621,6 +650,15 @@ def auth(self) -> ssh_auth.SSHAuth:
621650
"""
622651
return self.__auth_mapping[self.hostname]
623652

653+
@property
654+
def allow_ssh_agent(self) -> bool:
655+
"""Use SSH Agent if available.
656+
657+
:return: SSH Agent usage allowed
658+
:rtype: bool
659+
"""
660+
return self.__allow_agent
661+
624662
@property
625663
def hostname(self) -> str:
626664
"""Connected remote host name.
@@ -714,16 +752,15 @@ def __connect(self) -> None:
714752
"""Main method for connection open."""
715753
with self.lock:
716754
if self.__sock is not None:
717-
sock = self.__sock
718-
719755
self.__ssh = paramiko.SSHClient()
720756
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
721757
self.auth.connect(
722758
client=self.__ssh,
723759
hostname=self.hostname,
724760
port=self.port,
725761
log=self.__verbose,
726-
sock=sock,
762+
sock=self.__sock,
763+
allow_ssh_agent=self.allow_ssh_agent,
727764
)
728765
else:
729766
self.__ssh = self.__get_client()
@@ -745,15 +782,14 @@ def __get_client(self) -> paramiko.SSHClient:
745782
last_ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
746783

747784
config, auth = self.__conn_chain[0]
748-
if config.proxycommand:
749-
auth.connect(
750-
last_ssh_client,
751-
hostname=config.hostname,
752-
port=config.port or 22,
753-
sock=paramiko.ProxyCommand(config.proxycommand),
754-
)
755-
else:
756-
auth.connect(last_ssh_client, hostname=config.hostname, port=config.port or 22)
785+
786+
auth.connect(
787+
last_ssh_client,
788+
hostname=config.hostname,
789+
port=config.port or 22,
790+
sock=paramiko.ProxyCommand(config.proxycommand) if config.proxycommand else None,
791+
allow_ssh_agent=self.allow_ssh_agent,
792+
)
757793

758794
for config, auth in self.__conn_chain[1:]: # start has another logic, so do it out of cycle
759795
ssh = paramiko.SSHClient()
@@ -768,7 +804,13 @@ def __get_client(self) -> paramiko.SSHClient:
768804
dest_addr=(config.hostname, config.port or 22),
769805
src_addr=(config.proxyjump, 0),
770806
)
771-
auth.connect(ssh, hostname=config.hostname, port=config.port or 22, sock=sock)
807+
auth.connect(
808+
ssh,
809+
hostname=config.hostname,
810+
port=config.port or 22,
811+
sock=sock,
812+
allow_ssh_agent=self.allow_ssh_agent,
813+
)
772814
last_ssh_client = ssh
773815
continue
774816

@@ -1421,33 +1463,6 @@ def check_stderr(
14211463
**kwargs,
14221464
)
14231465

1424-
def _get_proxy_channel(
1425-
self,
1426-
port: int | None,
1427-
ssh_config: _ssh_helpers.SSHConfig,
1428-
) -> paramiko.Channel:
1429-
"""Get ssh proxy channel.
1430-
1431-
:param port: target port
1432-
:type port: int | None
1433-
:param ssh_config: pre-parsed ssh config
1434-
:type ssh_config: SSHConfig
1435-
:return: ssh channel for usage as socket for new connection over it
1436-
:rtype: paramiko.Channel
1437-
1438-
.. versionadded:: 6.0.0
1439-
"""
1440-
if port is not None:
1441-
dest_port: int = port
1442-
else:
1443-
dest_port = ssh_config.port if ssh_config.port is not None else 22
1444-
1445-
return self._ssh_transport.open_channel(
1446-
kind="direct-tcpip",
1447-
dest_addr=(ssh_config.hostname, dest_port),
1448-
src_addr=(self.hostname, 0),
1449-
)
1450-
14511466
def proxy_to(
14521467
self,
14531468
host: str,
@@ -1498,13 +1513,25 @@ def proxy_to(
14981513
else:
14991514
parsed_ssh_config = _ssh_helpers.parse_ssh_config(ssh_config, host)
15001515

1501-
hostname = parsed_ssh_config[host].hostname
1516+
host_config = parsed_ssh_config[host]
1517+
1518+
if port is not None:
1519+
dest_port: int = port
1520+
elif host_config.port is not None:
1521+
dest_port = host_config.port
1522+
else:
1523+
dest_port = 22
1524+
1525+
sock: paramiko.Channel = self._ssh_transport.open_channel(
1526+
kind="direct-tcpip",
1527+
dest_addr=(host_config.hostname, dest_port),
1528+
src_addr=(self.hostname, 0),
1529+
)
15021530

1503-
sock: paramiko.Channel = self._get_proxy_channel(port=port, ssh_config=parsed_ssh_config[hostname])
15041531
cls: type[Self] = self.__class__
15051532
return cls(
15061533
host=host,
1507-
port=port,
1534+
port=dest_port,
15081535
username=username,
15091536
password=password,
15101537
auth=auth,

exec_helpers/ssh_auth.py

+4
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def connect(
163163
log: bool = True,
164164
*,
165165
sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None = None,
166+
allow_ssh_agent: bool = True,
166167
) -> None:
167168
"""Connect SSH client object using credentials.
168169
@@ -176,6 +177,8 @@ def connect(
176177
:type log: bool
177178
:param sock: socket for connection. Useful for ssh proxies support
178179
:type sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None
180+
:param allow_ssh_agent: use SSH Agent if available
181+
:type allow_ssh_agent: bool
179182
:raises PasswordRequiredException: No password has been set, but required.
180183
:raises AuthenticationException: Authentication failed.
181184
"""
@@ -196,6 +199,7 @@ def connect(
196199
username=self.username,
197200
password=self.__password,
198201
key_filename=self.__key_filename, # type: ignore[arg-type] # types verified by not signature
202+
allow_agent=allow_ssh_agent,
199203
**kwargs,
200204
)
201205
if index != self.__key_index:

test/test_ssh_client_init_basic.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,21 @@ def __iter__(self):
5656
"username_password": {"host": host, "username": "user", "password": "password"},
5757
"auth": {
5858
"host": host,
59-
"auth": exec_helpers.SSHAuth(username="user", password="password", key=gen_private_keys(1).pop()),
59+
"auth": exec_helpers.SSHAuth(
60+
username="user",
61+
password="password",
62+
key=gen_private_keys(1).pop(),
63+
),
6064
},
6165
"auth_break": {
6266
"host": host,
6367
"username": "Invalid",
6468
"password": "Invalid",
65-
"auth": exec_helpers.SSHAuth(username="user", password="password", key=gen_private_keys(1).pop()),
69+
"auth": exec_helpers.SSHAuth(
70+
username="user",
71+
password="password",
72+
key=gen_private_keys(1).pop(),
73+
),
6674
},
6775
}
6876

@@ -90,7 +98,12 @@ def run_parameters(request):
9098
return configs[request.param]
9199

92100

93-
def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_auth_logger):
101+
def test_init_base(
102+
paramiko_ssh_client,
103+
auto_add_policy,
104+
run_parameters,
105+
ssh_auth_logger,
106+
):
94107
"""Parametrized validation of SSH client initialization."""
95108
# Helper code
96109
_ssh = mock.call
@@ -117,6 +130,7 @@ def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_aut
117130
port=port,
118131
username=username,
119132
key_filename=(),
133+
allow_agent=True,
120134
),
121135
_ssh.get_transport(),
122136
_ssh.get_transport().set_keepalive(1),

0 commit comments

Comments
 (0)