Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom cache for OAuth2 tokens #225

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,38 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
)
```

A custom caching implementation can be provided by creating a class implementing the `trino.auth.OAuth2TokenCache` abstract class and adding it as in `OAuth2Authentication(cache=my_custom_cache_impl)`. The custom caching implementation enables usage in multi-user environments (notebooks, web applications) in combination with a custom `redirect_auth_url_handler` as explained above.

```python
from typing import Optional

from trino.auth import OAuth2Authentication, OAuth2TokenCache
from trino.dbapi import connect


class MyCustomCacheImpl(OAuth2TokenCache):
def get_token_from_cache(self, host: str) -> Optional[str]:
# Retrieve your cached token from a distributed system
# and return it
pass

def store_token_to_cache(self, host: str, token: str) -> None:
# Store your cached token in a distributed system
pass


def my_custom_redirect_handler(url: str) -> None:
# Ensure the url is opened by the user that should perform the authentication
pass

conn = connect(
user="<username>",
auth=OAuth2Authentication(cache=MyCustomCacheImpl(), redirect_auth_url_handler=my_custom_redirect_handler),
http_scheme="https",
...
)
```

### Certificate authentication

`CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key.
Expand Down
51 changes: 49 additions & 2 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
import threading
import uuid
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import httpretty
from httpretty import httprettified
Expand All @@ -20,7 +20,7 @@
from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \
GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS
from trino import constants
from trino.auth import OAuth2Authentication
from trino.auth import OAuth2Authentication, OAuth2TokenCache
from trino.dbapi import connect


Expand Down Expand Up @@ -107,6 +107,53 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
assert len(_get_token_requests(challenge_id)) == 2


@httprettified
def test_custom_token_cache_is_invoked(sample_post_response_data):
host = "coordinator"
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)

redirect_handler = RedirectHandler()

custom_cache = MagicMock(OAuth2TokenCache)
custom_cache.get_token_from_cache = MagicMock(side_effect=[None, token, token, token])
custom_cache.store_token_to_cache = MagicMock()

with connect(
host,
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler, cache=custom_cache),
http_scheme=constants.HTTPS
) as conn:
conn.cursor().execute("SELECT 1")
conn.cursor().execute("SELECT 2")
conn.cursor().execute("SELECT 3")

assert len(_get_token_requests(challenge_id)) == 1
custom_cache.get_token_from_cache.assert_called_with(host)
assert custom_cache.get_token_from_cache.call_count == 4
custom_cache.store_token_to_cache.assert_called_with(host, token)
assert custom_cache.store_token_to_cache.call_count == 1


@httprettified
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data):
token = str(uuid.uuid4())
Expand Down
26 changes: 18 additions & 8 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __call__(self, url: str):
handler(url)


class _OAuth2TokenCache(metaclass=abc.ABCMeta):
class OAuth2TokenCache(metaclass=abc.ABCMeta):
"""
Abstract class for OAuth token cache, inherit from this class to implement your own token cache.
"""
Expand All @@ -216,7 +216,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
pass


class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
class _OAuth2TokenInMemoryCache(OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
"""
Expand All @@ -231,7 +231,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
self._cache[host] = token


class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
class _OAuth2KeyRingTokenCache(OAuth2TokenCache):
"""
Keyring Token Cache implementation
"""
Expand Down Expand Up @@ -272,10 +272,9 @@ class _OAuth2TokenBearer(AuthBase):
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)

def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
def __init__(self, redirect_auth_url_handler: Callable[[str], None], custom_cache: Optional[OAuth2TokenCache]):
self._redirect_auth_url = redirect_auth_url_handler
keyring_cache = _OAuth2KeyRingTokenCache()
self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
self._token_cache = self._setup_cache(custom_cache)
self._token_lock = threading.Lock()
self._inside_oauth_attempt_lock = threading.Lock()
self._inside_oauth_attempt_blocker = threading.Event()
Expand All @@ -291,6 +290,17 @@ def __call__(self, r):

return r

def _setup_cache(self, custom_cache):
if custom_cache is not None:
if not isinstance(custom_cache, OAuth2TokenCache):
raise exceptions.TrinoAuthError("Custom cache does not implement `trino.auth.OAuth2TokenCache` "
"interface")
return custom_cache
keyring_cache = _OAuth2KeyRingTokenCache()
if keyring_cache.is_keyring_available():
return keyring_cache
return _OAuth2TokenInMemoryCache()

def _authenticate(self, response, **kwargs):
if not 400 <= response.status_code < 500:
return response
Expand Down Expand Up @@ -396,9 +406,9 @@ class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([
WebBrowserRedirectHandler(),
ConsoleRedirectHandler()
])):
]), cache: Optional[OAuth2TokenCache] = None):
self._redirect_auth_url = redirect_auth_url_handler
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url, custom_cache=cache)

def set_http_session(self, http_session):
http_session.auth = self._bearer
Expand Down