diff --git a/sshfs/pools/base.py b/sshfs/pools/base.py index 698f2e7..89ced8c 100644 --- a/sshfs/pools/base.py +++ b/sshfs/pools/base.py @@ -1,5 +1,6 @@ import asyncio from contextlib import AsyncExitStack, suppress +from typing import Optional from asyncssh.misc import ChannelOpenError @@ -15,9 +16,10 @@ def __init__( self, client, *, - max_channels=None, - timeout=MAX_TIMEOUT, - unsafe_terminate=True, + max_channels: Optional[int] = None, + timeout: int = MAX_TIMEOUT, + unsafe_terminate: bool = True, + sftp_client_kwargs: Optional[dict] = None, **kwargs, ): self.client = client @@ -38,6 +40,8 @@ def __init__( self.unsafe_terminate = unsafe_terminate self._stack = AsyncExitStack() + self.sftp_client_kwargs = sftp_client_kwargs or {} + async def _maybe_new_channel(self): # If there is no hard limit or the limit is not hit yet # try to create a new channel @@ -47,7 +51,7 @@ async def _maybe_new_channel(self): ): try: return await self._stack.enter_async_context( - self.client.start_sftp_client() + self.client.start_sftp_client(**self.sftp_client_kwargs) ) except ChannelOpenError: # If we can't create any more channels, then change diff --git a/sshfs/spec.py b/sshfs/spec.py index 6dba100..d5791a9 100644 --- a/sshfs/spec.py +++ b/sshfs/spec.py @@ -5,6 +5,7 @@ import weakref from contextlib import AsyncExitStack, suppress from datetime import datetime +from typing import Optional import asyncssh from asyncssh.sftp import SFTPOpUnsupported @@ -34,6 +35,7 @@ def __init__( host, *, pool_type=SFTPSoftChannelPool, + sftp_client_kwargs: Optional[dict] = None, **kwargs, ): """ @@ -45,15 +47,20 @@ def __init__( SSH host to connect. **kwargs: Any Any option that will be passed to either the top level - `AsyncFileSystem` or the `asyncssh.connect`. + `AsyncFileSystem` (e.g. timeout) + or the `asyncssh.connect`. pool_type: sshfs.pools.base.BaseSFTPChannelPool Pool manager to use (when doing concurrent operations together, pool managers offer the flexibility of prioritizing channels and deciding which to use). + sftp_client_kwargs: Optional[dict] + Parameters to pass to asyncssh.SSHClientConnection.start_sftp_client method + (e.g. env, send_env, path_encoding, path_errors, sftp_version). """ super().__init__(self, **kwargs) + _timeout = kwargs.pop("timeout", None) max_sessions = kwargs.pop("max_sessions", _DEFAULT_MAX_SESSIONS) if max_sessions <= _SHELL_CHANNELS: raise ValueError( @@ -61,6 +68,7 @@ def __init__( ) _client_args = kwargs.copy() _client_args.setdefault("known_hosts", None) + sftp_client_kwargs = sftp_client_kwargs or {} self._stack = AsyncExitStack() self.active_executors = 0 @@ -68,7 +76,9 @@ def __init__( host, pool_type, max_sftp_channels=max_sessions - _SHELL_CHANNELS, - **_client_args, + timeout=_timeout, # goes to sync_wrapper + connect_args=_client_args, # for asyncssh.connect + sftp_client_kwargs=sftp_client_kwargs, # for asyncssh.SSHClientConnection.start_sftp_client ) weakref.finalize( self, sync, self.loop, self._finalize, self._pool, self._stack @@ -89,13 +99,22 @@ def _get_kwargs_from_urls(urlpath): @wrap_exceptions async def _connect( - self, host, pool_type, max_sftp_channels, **client_args + self, + host, + pool_type, + max_sftp_channels, + connect_args, + sftp_client_kwargs, ): self._client_lock = asyncio.Semaphore(_SHELL_CHANNELS) - _raw_client = asyncssh.connect(host, **client_args) + _raw_client = asyncssh.connect(host, **connect_args) client = await self._stack.enter_async_context(_raw_client) - pool = pool_type(client, max_channels=max_sftp_channels) + pool = pool_type( + client, + max_channels=max_sftp_channels, + sftp_client_kwargs=sftp_client_kwargs, + ) return client, pool connect = sync_wrapper(_connect)