diff --git a/ipfshttpclient/http_requests.py b/ipfshttpclient/http_requests.py index f7f10b1a..a41dfb49 100644 --- a/ipfshttpclient/http_requests.py +++ b/ipfshttpclient/http_requests.py @@ -5,7 +5,7 @@ import math import http.client -import os +import requests import typing as ty import urllib.parse @@ -19,13 +19,8 @@ addr_t, auth_t, cookies_t, headers_t, params_t, reqdata_sync_t, timeout_t, Closable, ) - -PATCH_REQUESTS = (os.environ.get("PY_IPFS_HTTP_CLIENT_PATCH_REQUESTS", "yes").lower() - not in ("false", "no")) -if PATCH_REQUESTS: - from . import requests_wrapper as requests -elif not ty.TYPE_CHECKING: # pragma: no cover (always enabled in production) - import requests +from .requests_wrapper import PATCH_REQUESTS +from .requests_wrapper import wrapped_session def map_args_to_requests( @@ -70,7 +65,7 @@ def map_args_to_requests( return kwargs -class ClientSync(ClientSyncBase[requests.Session]): # type: ignore[name-defined] +class ClientSync(ClientSyncBase[requests.Session]): __slots__ = ("_base_url", "_default_timeout", "_request_proxies", "_session_props") #_base_url: str #_default_timeout: timeout_t @@ -92,7 +87,8 @@ def _init(self, addr: addr_t, base: str, *, # type: ignore[no-any-unimported] params=params, ) self._default_timeout = timeout - if PATCH_REQUESTS: # pragma: no branch (always enabled in production) + + if PATCH_REQUESTS: self._session_props["family"] = family # Ensure that no proxy lookups are done for the UDS pseudo-hostname @@ -106,8 +102,8 @@ def _init(self, addr: addr_t, base: str, *, # type: ignore[no-any-unimported] "no_proxy": urllib.parse.quote(uds_path, safe=""), } - def _make_session(self) -> requests.Session: # type: ignore[name-defined] - session = requests.Session() # type: ignore[attr-defined] + def _make_session(self) -> requests.Session: + session = wrapped_session() try: for name, value in self._session_props.items(): setattr(session, name, value) @@ -118,10 +114,10 @@ def _make_session(self) -> requests.Session: # type: ignore[name-defined] session.close() raise - def _do_raise_for_status(self, response: requests.Request) -> None: # type: ignore[name-defined] + def _do_raise_for_status(self, response: requests.Response) -> None: try: response.raise_for_status() - except requests.exceptions.HTTPError as error: # type: ignore[attr-defined] + except requests.exceptions.HTTPError as error: content = [] try: decoder = encoding.get_encoding("json") @@ -155,7 +151,7 @@ def _request( path = path[1:] url = urllib.parse.urljoin(self._base_url, path) - + try: # Determine session object to use closables, session = self._access_session() @@ -172,13 +168,15 @@ def _request( timeout=(timeout if timeout is not None else self._default_timeout), ), proxies=self._request_proxies, - data=data, + # requests.Session.request does not accept Optional[Iterator[bytes]], + # but we seem to be passing that here. + data=data, # type: ignore[arg-type] stream=True, ) closables.insert(0, res) except (requests.ConnectTimeout, requests.Timeout) as error: # type: ignore[attr-defined] raise exceptions.TimeoutError(error) from error - except requests.ConnectionError as error: # type: ignore[attr-defined] + except requests.ConnectionError as error: # Report protocol violations separately # # This used to happen because requests wouldn't catch @@ -196,8 +194,8 @@ def _request( # (optionally incorporating the response message, if available) self._do_raise_for_status(res) - return closables, res.iter_content(chunk_size=chunk_size) + return closables, (portion for portion in res.iter_content(chunk_size=chunk_size)) except: for closable in closables: closable.close() - raise \ No newline at end of file + raise diff --git a/ipfshttpclient/requests_wrapper.py b/ipfshttpclient/requests_wrapper.py index 90213d1a..024b5feb 100644 --- a/ipfshttpclient/requests_wrapper.py +++ b/ipfshttpclient/requests_wrapper.py @@ -1,18 +1,20 @@ -# type: ignore """Exposes the full ``requests`` HTTP library API, while adding an extra ``family`` parameter to all HTTP request operations that may be used to restrict the address family used when resolving a domain-name to an IP address. """ + +import os import socket import urllib.parse import requests import requests.adapters -import urllib3 -import urllib3.connection -import urllib3.exceptions -import urllib3.poolmanager -import urllib3.util.connection +import typing as ty +import urllib3 # type: ignore[import] +import urllib3.connection # type: ignore[import] +import urllib3.exceptions # type: ignore[import] +import urllib3.poolmanager # type: ignore[import] +import urllib3.util.connection # type: ignore[import] AF2NAME = { int(socket.AF_INET): "ip4", @@ -22,6 +24,11 @@ AF2NAME[int(socket.AF_UNIX)] = "unix" NAME2AF = {name: af for af, name in AF2NAME.items()} +PATCH_REQUESTS = ( + os.environ.get("PY_IPFS_HTTP_CLIENT_PATCH_REQUESTS", "yes").lower() + not in ("false", "no") +) + # This function is copied from urllib3/util/connection.py (that in turn copied # it from socket.py in the Python 2.7 standard library test suite) and accepts @@ -31,9 +38,13 @@ # The entire remainder of this file after this only exists to ensure that this # `family` parameter is exposed all the way up to request's `Session` interface, # storing it as part of the URL scheme while traversing most of the layers. -def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None, socket_options=None, - family=socket.AF_UNSPEC): +def create_connection( + address: ty.Tuple[str, int], + timeout: int = socket._GLOBAL_DEFAULT_TIMEOUT, # type: ignore[attr-defined] + source_address: ty.Optional[ty.Union[ty.Tuple[str, int], str, bytes]] = None, + socket_options: ty.Optional[ty.List[ty.Tuple[int, int, ty.Union[int, bytes]]]] = None, + family: int = socket.AF_UNSPEC +) -> socket.socket: host, port = address if host.startswith('['): host = host.strip('[]') @@ -44,7 +55,29 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, # Extension for Unix domain sockets if hasattr(socket, "AF_UNIX") and family == socket.AF_UNIX: - gai_result = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", host)] + gai_result: ty.Union[ + ty.List[ + ty.Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + str + ] + ], + ty.List[ + ty.Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + ty.Union[ + ty.Tuple[str, int], + ty.Tuple[str, int, int, int] + ] + ] + ] + ] = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", host)] else: gai_result = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) @@ -59,7 +92,7 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, for opt in socket_options: sock.setsockopt(*opt) - if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # type: ignore[attr-defined] sock.settimeout(timeout) if source_address: sock.bind(source_address) @@ -69,7 +102,6 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, err = e if sock is not None: sock.close() - sock = None if err is not None: raise err @@ -79,8 +111,11 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, # Override the `urllib3` low-level Connection objects that do the actual work # of speaking HTTP -def _kw_scheme_to_family(kw, base_scheme): - family = socket.AF_UNSPEC +def _kw_scheme_to_family( + kw: ty.Dict[str, ty.Any], + base_scheme: str +) -> ty.Union[int, socket.AddressFamily]: + family: ty.Union[int, socket.AddressFamily] = socket.AF_UNSPEC scheme = kw.pop("scheme", None) if isinstance(scheme, str): parts = scheme.rsplit("+", 1) @@ -90,26 +125,26 @@ def _kw_scheme_to_family(kw, base_scheme): class ConnectionOverrideMixin: - def _new_conn(self): + def _new_conn(self) -> socket.socket: extra_kw = { - "family": self.family + "family": self.family # type: ignore[attr-defined] } - if self.source_address: - extra_kw['source_address'] = self.source_address + if self.source_address: # type: ignore[attr-defined] + extra_kw['source_address'] = self.source_address # type: ignore[attr-defined] - if self.socket_options: - extra_kw['socket_options'] = self.socket_options + if self.socket_options: # type: ignore[attr-defined] + extra_kw['socket_options'] = self.socket_options # type: ignore[attr-defined] try: - dns_host = getattr(self, "_dns_host", self.host) + dns_host = getattr(self, "_dns_host", self.host) # type: ignore[attr-defined] if hasattr(socket, "AF_UNIX") and extra_kw["family"] == socket.AF_UNIX: dns_host = urllib.parse.unquote(dns_host) conn = create_connection( - (dns_host, self.port), self.timeout, **extra_kw) + (dns_host, self.port), self.timeout, **extra_kw) # type: ignore[attr-defined] except socket.timeout: raise urllib3.exceptions.ConnectTimeoutError( self, "Connection to %s timed out. (connect timeout=%s)" % - (self.host, self.timeout)) + (self.host, self.timeout)) # type: ignore[attr-defined] except OSError as e: raise urllib3.exceptions.NewConnectionError( self, "Failed to establish a new connection: %s" % e) @@ -117,39 +152,45 @@ def _new_conn(self): return conn -class HTTPConnection(ConnectionOverrideMixin, urllib3.connection.HTTPConnection): - def __init__(self, *args, **kw): +class HTTPConnection( + ConnectionOverrideMixin, + urllib3.connection.HTTPConnection # type: ignore[misc,no-any-unimported] +): + def __init__(self, *args, **kw) -> None: # type: ignore[no-untyped-def] self.family = _kw_scheme_to_family(kw, "http") super().__init__(*args, **kw) -class HTTPSConnection(ConnectionOverrideMixin, urllib3.connection.HTTPSConnection): - def __init__(self, *args, **kw): +class HTTPSConnection( + ConnectionOverrideMixin, + urllib3.connection.HTTPSConnection # type: ignore[misc,no-any-unimported] +): + def __init__(self, *args, **kw) -> None: # type: ignore[no-untyped-def] self.family = _kw_scheme_to_family(kw, "https") super().__init__(*args, **kw) # Override the higher-level `urllib3` ConnectionPool objects that instantiate # one or more Connection objects and dispatch work between them -class HTTPConnectionPool(urllib3.HTTPConnectionPool): +class HTTPConnectionPool(urllib3.HTTPConnectionPool): # type: ignore[misc,no-any-unimported] ConnectionCls = HTTPConnection -class HTTPSConnectionPool(urllib3.HTTPSConnectionPool): +class HTTPSConnectionPool(urllib3.HTTPSConnectionPool): # type: ignore[misc,no-any-unimported] ConnectionCls = HTTPSConnection # Override the highest-level `urllib3` PoolManager to also properly support the # address family extended scheme values in URLs and pass these scheme values on # to the individual ConnectionPool objects -class PoolManager(urllib3.PoolManager): - def __init__(self, *args, **kwargs): +class PoolManager(urllib3.PoolManager): # type: ignore[misc,no-any-unimported] + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) # Additionally to adding our variant of the usual HTTP and HTTPS # pool classes, also add these for some variants of the default schemes # that are limited to some specific address family only - self.pool_classes_by_scheme = {} + self.pool_classes_by_scheme: ty.Dict[str, type] = {} for scheme, ConnectionPool in (("http", HTTPConnectionPool), ("https", HTTPSConnectionPool)): self.pool_classes_by_scheme[scheme] = ConnectionPool for name in AF2NAME.values(): @@ -159,8 +200,14 @@ def __init__(self, *args, **kwargs): # These next two are only required to ensure that our custom `scheme` values # will be passed down to the `*ConnectionPool`s and finally to the actual # `*Connection`s as parameter - def _new_pool(self, scheme, host, port, request_context=None): - # Copied from `urllib3` to *not* surpress the `scheme` parameter + def _new_pool( + self, + scheme: str, + host: str, + port: int, + request_context: ty.Optional[ty.Dict[str, ty.Any]] = None + ) -> ty.Union[HTTPConnectionPool, HTTPSConnectionPool]: + # Copied from `urllib3` to *not* suppress the `scheme` parameter pool_cls = self.pool_classes_by_scheme[scheme] if request_context is None: request_context = self.connection_pool_kw.copy() @@ -172,16 +219,26 @@ def _new_pool(self, scheme, host, port, request_context=None): for kw in urllib3.poolmanager.SSL_KEYWORDS: request_context.pop(kw, None) - return pool_cls(host, port, **request_context) - - def connection_from_pool_key(self, pool_key, request_context=None): + return ty.cast( + ty.Union[HTTPConnectionPool, HTTPSConnectionPool], + pool_cls(host, port, **request_context) + ) + + def connection_from_pool_key( + self, + pool_key: str, + request_context: ty.Optional[ty.Dict[str, ty.Any]] = None + ) -> ty.Union[HTTPConnectionPool, HTTPSConnectionPool]: # Copied from `urllib3` so that we continue to ensure that this will # call `_new_pool` with self.pools.lock: - pool = self.pools.get(pool_key) + pool: ty.Union[HTTPConnectionPool, HTTPSConnectionPool] = self.pools.get(pool_key) if pool: return pool + if request_context is None: + raise ValueError('request_context required') + scheme = request_context['scheme'] host = request_context['host'] port = request_context['port'] @@ -193,7 +250,13 @@ def connection_from_pool_key(self, pool_key, request_context=None): # Override the lower-level `requests` adapter that invokes the `urllib3` # PoolManager objects class HTTPAdapter(requests.adapters.HTTPAdapter): - def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): + def init_poolmanager( # type: ignore[no-untyped-def] + self, + connections: int, + maxsize: int, + block: bool = False, + **pool_kwargs + ) -> None: # save these values for pickling (copied from `requests`) self._pool_connections = connections self._pool_maxsize = maxsize @@ -206,9 +269,9 @@ def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): # Override the highest-level `requests` Session object to accept the `family` # parameter for any request and encode its value as part of the URL scheme # when passing it down to the adapter -class Session(requests.Session): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class PatchedSession(requests.Session): + def __init__(self) -> None: + super().__init__() self.family = socket.AF_UNSPEC # Additionally to mounting our variant of the usual HTTP and HTTPS @@ -220,63 +283,32 @@ def __init__(self, *args, **kwargs): for name in AF2NAME.values(): self.mount("{0}+{1}://".format(scheme, name), adapter) - def request(self, method, url, *args, **kwargs): - family = kwargs.pop("family", self.family) - if family != socket.AF_UNSPEC: + @staticmethod + def replace_scheme(family: int, url: str) -> str: + if family == socket.AF_UNSPEC: + return url + else: # Inject provided address family value as extension to scheme - url = urllib.parse.urlparse(url) - url = url._replace(scheme="{0}+{1}".format(url.scheme, AF2NAME[int(family)])) - url = url.geturl() - return super().request(method, url, *args, **kwargs) - - -session = Session - - -# Import other `requests` stuff to make the top-level API of this more compatible -from requests import ( - __title__, __description__, __url__, __version__, __build__, __author__, - __author_email__, __license__, __copyright__, __cake__, - - exceptions, utils, packages, codes, - Request, Response, PreparedRequest, - RequestException, Timeout, URLRequired, TooManyRedirects, HTTPError, - ConnectionError, FileModeWarning, ConnectTimeout, ReadTimeout -) - - -# Re-implement the top-level “session-less” API -def request(method, url, **kwargs): - with Session() as session: - return session.request(method=method, url=url, **kwargs) - - -def get(url, params=None, **kwargs): - kwargs.setdefault('allow_redirects', True) - return request('get', url, params=params, **kwargs) - - -def options(url, **kwargs): - kwargs.setdefault('allow_redirects', True) - return request('options', url, **kwargs) - - -def head(url, **kwargs): - kwargs.setdefault('allow_redirects', False) - return request('head', url, **kwargs) - - -def post(url, data=None, json=None, **kwargs): - return request('post', url, data=data, json=json, **kwargs) - - -def put(url, data=None, **kwargs): - return request('put', url, data=data, **kwargs) + parsed = urllib.parse.urlparse(url) + parsed = parsed._replace(scheme="{0}+{1}".format(parsed.scheme, AF2NAME[int(family)])) + return parsed.geturl() + + def request( # type: ignore[override,no-untyped-def] + self, + method: str, + url: str, + *args, + **kwargs + ) -> requests.Response: + family = kwargs.pop("family", self.family) + url = self.replace_scheme(family, url) -def patch(url, data=None, **kwargs): - return request('patch', url, data=data, **kwargs) + return super().request(method, url, *args, **kwargs) -def delete(url, **kwargs): - return request('delete', url, **kwargs) +def wrapped_session() -> requests.Session: + if PATCH_REQUESTS: + return PatchedSession() + else: + return requests.Session() diff --git a/tox.ini b/tox.ini index 56584c6e..d04a4381 100644 --- a/tox.ini +++ b/tox.ini @@ -63,9 +63,11 @@ skip_install = true deps = mypy ~= 0.812 pytest ~= 6.2 - {[testenv:py3-httpx]deps-exclusive} + types-requests + urllib3 ~= 1.26.4 + {[testenv:py3-httpx]deps-exclusive} commands = - mypy --config-file=tox.ini {posargs} -p ipfshttpclient + mypy --config-file=tox.ini {posargs} -p ipfshttpclient --install-types --non-interactive # Pass down TERM environment variable to allow mypy output to be colorized # See: https://github.com/tox-dev/tox/issues/1441