Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow passing kwargs to sftp client #41

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sshfs/pools/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from contextlib import AsyncExitStack, suppress
from typing import Optional

from asyncssh.misc import ChannelOpenError

Expand All @@ -15,9 +16,10 @@ def __init__(
self,
client,
*,
max_channels=None,
timeout=MAX_TIMEOUT,
unsafe_terminate=True,
max_channels: Optional[int] = None,
timeout: int = MAX_TIMEOUT,
unsafe_terminate: bool = True,
sftp_client_kwargs: Optional[dict] = None,
**kwargs,
):
self.client = client
Expand All @@ -38,6 +40,8 @@ def __init__(
self.unsafe_terminate = unsafe_terminate
self._stack = AsyncExitStack()

self.sftp_client_kwargs = sftp_client_kwargs or {}

async def _maybe_new_channel(self):
# If there is no hard limit or the limit is not hit yet
# try to create a new channel
Expand All @@ -47,7 +51,7 @@ async def _maybe_new_channel(self):
):
try:
return await self._stack.enter_async_context(
self.client.start_sftp_client()
self.client.start_sftp_client(**self.sftp_client_kwargs)
)
except ChannelOpenError:
# If we can't create any more channels, then change
Expand Down
29 changes: 24 additions & 5 deletions sshfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import weakref
from contextlib import AsyncExitStack, suppress
from datetime import datetime
from typing import Optional

import asyncssh
from asyncssh.sftp import SFTPOpUnsupported
Expand Down Expand Up @@ -34,6 +35,7 @@ def __init__(
host,
*,
pool_type=SFTPSoftChannelPool,
sftp_client_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Expand All @@ -45,30 +47,38 @@ def __init__(
SSH host to connect.
**kwargs: Any
Any option that will be passed to either the top level
`AsyncFileSystem` or the `asyncssh.connect`.
`AsyncFileSystem` (e.g. timeout)
or the `asyncssh.connect`.
pool_type: sshfs.pools.base.BaseSFTPChannelPool
Pool manager to use (when doing concurrent operations together,
pool managers offer the flexibility of prioritizing channels
and deciding which to use).
sftp_client_kwargs: Optional[dict]
Parameters to pass to asyncssh.SSHClientConnection.start_sftp_client method
(e.g. env, send_env, path_encoding, path_errors, sftp_version).
"""

super().__init__(self, **kwargs)

_timeout = kwargs.pop("timeout", None)
max_sessions = kwargs.pop("max_sessions", _DEFAULT_MAX_SESSIONS)
if max_sessions <= _SHELL_CHANNELS:
raise ValueError(
f"max_sessions must be greater than {_SHELL_CHANNELS}"
)
_client_args = kwargs.copy()
_client_args.setdefault("known_hosts", None)
sftp_client_kwargs = sftp_client_kwargs or {}

self._stack = AsyncExitStack()
self.active_executors = 0
self._client, self._pool = self.connect(
host,
pool_type,
max_sftp_channels=max_sessions - _SHELL_CHANNELS,
**_client_args,
timeout=_timeout, # goes to sync_wrapper
connect_args=_client_args, # for asyncssh.connect
sftp_client_kwargs=sftp_client_kwargs, # for asyncssh.SSHClientConnection.start_sftp_client
)
weakref.finalize(
self, sync, self.loop, self._finalize, self._pool, self._stack
Expand All @@ -89,13 +99,22 @@ def _get_kwargs_from_urls(urlpath):

@wrap_exceptions
async def _connect(
self, host, pool_type, max_sftp_channels, **client_args
self,
host,
pool_type,
max_sftp_channels,
connect_args,
sftp_client_kwargs,
):
self._client_lock = asyncio.Semaphore(_SHELL_CHANNELS)

_raw_client = asyncssh.connect(host, **client_args)
_raw_client = asyncssh.connect(host, **connect_args)
client = await self._stack.enter_async_context(_raw_client)
pool = pool_type(client, max_channels=max_sftp_channels)
pool = pool_type(
client,
max_channels=max_sftp_channels,
sftp_client_kwargs=sftp_client_kwargs,
)
return client, pool

connect = sync_wrapper(_connect)
Expand Down
Loading