From 42ccb60791ec863d0f50a267895272a3a22541d5 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Thu, 18 Aug 2022 11:50:09 +0200 Subject: [PATCH] Support custom cache for OAuth2 tokens --- README.md | 32 +++++++++++++++++++++++++ tests/unit/test_dbapi.py | 51 ++++++++++++++++++++++++++++++++++++++-- trino/auth.py | 26 +++++++++++++------- 3 files changed, 99 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index d45c6554..c0940d08 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,38 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins ) ``` +A custom caching implementation can be provided by creating a class implementing the `trino.auth.OAuth2TokenCache` abstract class and adding it as in `OAuth2Authentication(cache=my_custom_cache_impl)`. The custom caching implementation enables usage in multi-user environments (notebooks, web applications) in combination with a custom `redirect_auth_url_handler` as explained above. + +```python +from typing import Optional + +from trino.auth import OAuth2Authentication, OAuth2TokenCache +from trino.dbapi import connect + + +class MyCustomCacheImpl(OAuth2TokenCache): + def get_token_from_cache(self, host: str) -> Optional[str]: + # Retrieve your cached token from a distributed system + # and return it + pass + + def store_token_to_cache(self, host: str, token: str) -> None: + # Store your cached token in a distributed system + pass + + +def my_custom_redirect_handler(url: str) -> None: + # Ensure the url is opened by the user that should perform the authentication + pass + +conn = connect( + user="", + auth=OAuth2Authentication(cache=MyCustomCacheImpl(), redirect_auth_url_handler=my_custom_redirect_handler), + http_scheme="https", + ... +) +``` + ### Certificate authentication `CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key. diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 7b1c72c2..a8964d75 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -11,7 +11,7 @@ # limitations under the License. import threading import uuid -from unittest.mock import patch +from unittest.mock import patch, MagicMock import httpretty from httpretty import httprettified @@ -20,7 +20,7 @@ from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \ GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS from trino import constants -from trino.auth import OAuth2Authentication +from trino.auth import OAuth2Authentication, OAuth2TokenCache from trino.dbapi import connect @@ -107,6 +107,53 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): assert len(_get_token_requests(challenge_id)) == 2 +@httprettified +def test_custom_token_cache_is_invoked(sample_post_response_data): + host = "coordinator" + token = str(uuid.uuid4()) + challenge_id = str(uuid.uuid4()) + + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" + + post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + + # bind post statement + httpretty.register_uri( + method=httpretty.POST, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + body=post_statement_callback) + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + redirect_handler = RedirectHandler() + + custom_cache = MagicMock(OAuth2TokenCache) + custom_cache.get_token_from_cache = MagicMock(side_effect=[None, token, token, token]) + custom_cache.store_token_to_cache = MagicMock() + + with connect( + host, + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler, cache=custom_cache), + http_scheme=constants.HTTPS + ) as conn: + conn.cursor().execute("SELECT 1") + conn.cursor().execute("SELECT 2") + conn.cursor().execute("SELECT 3") + + assert len(_get_token_requests(challenge_id)) == 1 + custom_cache.get_token_from_cache.assert_called_with(host) + assert custom_cache.get_token_from_cache.call_count == 4 + custom_cache.store_token_to_cache.assert_called_with(host, token) + assert custom_cache.store_token_to_cache.call_count == 1 + + @httprettified def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data): token = str(uuid.uuid4()) diff --git a/trino/auth.py b/trino/auth.py index e6b4f04c..b1bf3e50 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -202,7 +202,7 @@ def __call__(self, url: str): handler(url) -class _OAuth2TokenCache(metaclass=abc.ABCMeta): +class OAuth2TokenCache(metaclass=abc.ABCMeta): """ Abstract class for OAuth token cache, inherit from this class to implement your own token cache. """ @@ -216,7 +216,7 @@ def store_token_to_cache(self, host: str, token: str) -> None: pass -class _OAuth2TokenInMemoryCache(_OAuth2TokenCache): +class _OAuth2TokenInMemoryCache(OAuth2TokenCache): """ In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache. """ @@ -231,7 +231,7 @@ def store_token_to_cache(self, host: str, token: str) -> None: self._cache[host] = token -class _OAuth2KeyRingTokenCache(_OAuth2TokenCache): +class _OAuth2KeyRingTokenCache(OAuth2TokenCache): """ Keyring Token Cache implementation """ @@ -272,10 +272,9 @@ class _OAuth2TokenBearer(AuthBase): MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) - def __init__(self, redirect_auth_url_handler: Callable[[str], None]): + def __init__(self, redirect_auth_url_handler: Callable[[str], None], custom_cache: Optional[OAuth2TokenCache]): self._redirect_auth_url = redirect_auth_url_handler - keyring_cache = _OAuth2KeyRingTokenCache() - self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache() + self._token_cache = self._setup_cache(custom_cache) self._token_lock = threading.Lock() self._inside_oauth_attempt_lock = threading.Lock() self._inside_oauth_attempt_blocker = threading.Event() @@ -291,6 +290,17 @@ def __call__(self, r): return r + def _setup_cache(self, custom_cache): + if custom_cache is not None: + if not isinstance(custom_cache, OAuth2TokenCache): + raise exceptions.TrinoAuthError("Custom cache does not implement `trino.auth.OAuth2TokenCache` " + "interface") + return custom_cache + keyring_cache = _OAuth2KeyRingTokenCache() + if keyring_cache.is_keyring_available(): + return keyring_cache + return _OAuth2TokenInMemoryCache() + def _authenticate(self, response, **kwargs): if not 400 <= response.status_code < 500: return response @@ -396,9 +406,9 @@ class OAuth2Authentication(Authentication): def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([ WebBrowserRedirectHandler(), ConsoleRedirectHandler() - ])): + ]), cache: Optional[OAuth2TokenCache] = None): self._redirect_auth_url = redirect_auth_url_handler - self._bearer = _OAuth2TokenBearer(self._redirect_auth_url) + self._bearer = _OAuth2TokenBearer(self._redirect_auth_url, custom_cache=cache) def set_http_session(self, http_session): http_session.auth = self._bearer