Skip to content

Commit

Permalink
refactor: make dvc-ssh compatible with asyncssh>=2.19.0 (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Feb 5, 2025
1 parent cdefaf5 commit e452803
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 150 deletions.
61 changes: 20 additions & 41 deletions dvc_ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
import threading
from typing import ClassVar

from funcy import memoize, silent, wrap_prop, wrap_with
from funcy import memoize, wrap_prop, wrap_with

from dvc.utils.objects import cached_property
from dvc_objects.fs.base import FileSystem
from dvc_objects.fs.utils import as_atomic

DEFAULT_PORT = 22


@wrap_with(threading.Lock())
@memoize
def ask_password(host, user, port, desc):
prompt = f"Enter a {desc} for"
if user:
prompt += f" {user}@{host}"
else:
prompt += f" {host}"
if port:
prompt += f":{port}"
prompt += ":\n"

try:
return getpass.getpass(
f"Enter a {desc} for " f"host '{host}' port '{port}' user '{user}':\n"
)
return getpass.getpass(prompt)
except EOFError:
return None

Expand All @@ -41,8 +46,6 @@ def unstrip_protocol(self, path: str) -> str:
return f"ssh://{host}:{port}/{path}"

def _prepare_credentials(self, **config):
from sshfs.config import parse_config

from .client import InteractiveSSHClient

self.CAN_TRAVERSE = True
Expand All @@ -51,44 +54,25 @@ def _prepare_credentials(self, **config):
login_info["client_factory"] = config.get(
"client_factory", InteractiveSSHClient
)
try:
user_ssh_config = parse_config(host=config["host"])
except FileNotFoundError:
user_ssh_config = {}

login_info["host"] = user_ssh_config.get("Hostname", config["host"])

login_info["username"] = (
config.get("user")
or config.get("username")
or user_ssh_config.get("User")
or getpass.getuser()
)
login_info["port"] = (
config.get("port")
or silent(int)(user_ssh_config.get("Port"))
or DEFAULT_PORT
)
login_info["host"] = config["host"]
if username := (config.get("user") or config.get("username")):
login_info["username"] = username
if port := config.get("port"):
login_info["port"] = port

for option in ("password", "passphrase"):
login_info[option] = config.get(option)

if config.get(f"ask_{option}") and login_info[option] is None:
login_info[option] = ask_password(
login_info["host"],
login_info["username"],
login_info["port"],
login_info.get("username"),
login_info.get("port"),
option,
)

raw_keys = []
if config.get("keyfile"):
raw_keys.append(config.get("keyfile"))
elif user_ssh_config.get("IdentityFile"):
raw_keys.extend(user_ssh_config.get("IdentityFile"))

if raw_keys:
login_info["client_keys"] = [os.path.expanduser(key) for key in raw_keys]
if keyfile := config.get("keyfile"):
login_info["client_keys"] = [os.path.expanduser(keyfile)]

login_info["timeout"] = config.get("timeout", 1800)

Expand All @@ -108,11 +92,6 @@ def _prepare_credentials(self, **config):

login_info["gss_auth"] = config.get("gss_auth", False)
login_info["agent_forwarding"] = config.get("agent_forwarding", True)
login_info["proxy_command"] = user_ssh_config.get("ProxyCommand")

# We are going to automatically add stuff to known_hosts
# something like paramiko's AutoAddPolicy()
login_info["known_hosts"] = None

if "max_sessions" in config:
login_info["max_sessions"] = config["max_sessions"]
Expand Down
120 changes: 11 additions & 109 deletions dvc_ssh/tests/test_fs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# pylint: disable=W0212,W0613
import getpass
import os
from unittest.mock import mock_open, patch

import pytest

from dvc_ssh import DEFAULT_PORT, SSHFileSystem
from dvc_ssh import SSHFileSystem


def test_get_kwargs_from_urls():
Expand Down Expand Up @@ -67,97 +65,24 @@ def test_passphrase(mocker, password, passphrase):
assert connect.call_args[1]["passphrase"] == passphrase


mock_ssh_config = """
Host example.com
User ubuntu
HostName 1.2.3.4
Port 1234
IdentityFile ~/.ssh/not_default.key
"""


@pytest.mark.parametrize(
"config,expected_host",
[
({"host": "example.com"}, "1.2.3.4"),
({"host": "not_in_ssh_config.com"}, "not_in_ssh_config.com"),
],
)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_host_override_from_config(mock_file, config, expected_host):
fs = SSHFileSystem(**config)
assert fs.fs_args["host"] == expected_host


@pytest.mark.parametrize(
"config,expected_user",
[
({"host": "example.com", "user": "test1"}, "test1"),
({"host": "example.com"}, "ubuntu"),
({"host": "not_in_ssh_config.com", "user": "test1"}, "test1"),
({"host": "not_in_ssh_config.com"}, getpass.getuser()),
],
)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_user(mock_file, config, expected_user):
fs = SSHFileSystem(**config)
assert fs.fs_args["username"] == expected_user
def test_ssh_user():
fs = SSHFileSystem(host="example.com", user="test")
assert fs.fs_args["username"] == "test"


@pytest.mark.parametrize(
"config,expected_port",
[
({"host": "example.com"}, 1234),
({"host": "example.com", "port": 4321}, 4321),
({"host": "not_in_ssh_config.com"}, DEFAULT_PORT),
({"host": "not_in_ssh_config.com", "port": 2222}, 2222),
],
)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_port(mock_file, config, expected_port):
fs = SSHFileSystem(**config)
assert fs.fs_args["port"] == expected_port
def test_ssh_port():
fs = SSHFileSystem(host="example.com", port=4321)
assert fs.fs_args["port"] == 4321


@pytest.mark.parametrize(
"config,expected_keyfile",
[
(
{"host": "example.com", "keyfile": "dvc_config.key"},
["dvc_config.key"],
),
(
{"host": "example.com"},
["~/.ssh/not_default.key"],
),
(
{
"host": "not_in_ssh_config.com",
"keyfile": "dvc_config.key",
},
["dvc_config.key"],
),
({"host": "not_in_ssh_config.com"}, None),
({"host": "example.com", "keyfile": "dvc_config.key"}, ["dvc_config.key"]),
({"host": "example.com"}, None),
],
)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_keyfile(mock_file, config, expected_keyfile):
def test_ssh_keyfile(config, expected_keyfile):
fs = SSHFileSystem(**config)
expected_keyfiles = (
[os.path.expanduser(path) for path in expected_keyfile]
Expand All @@ -167,24 +92,6 @@ def test_ssh_keyfile(mock_file, config, expected_keyfile):
assert fs.fs_args.get("client_keys") == expected_keyfiles


mock_ssh_multi_key_config = """
IdentityFile file_1
Host example.com
IdentityFile file_2
"""


@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_multi_key_config,
)
def test_ssh_multi_identity_files(mock_file):
fs = SSHFileSystem(host="example.com")
assert fs.fs_args.get("client_keys") == ["file_1", "file_2"]


@pytest.mark.parametrize(
"config,expected_gss_auth",
[
Expand All @@ -193,11 +100,6 @@ def test_ssh_multi_identity_files(mock_file):
({"host": "not_in_ssh_config.com"}, False),
],
)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_gss_auth(mock_file, config, expected_gss_auth):
def test_ssh_gss_auth(config, expected_gss_auth):
fs = SSHFileSystem(**config)
assert fs.fs_args["gss_auth"] == expected_gss_auth

0 comments on commit e452803

Please sign in to comment.