From b113af1434552da118eefeaf7a69b65c4b6dcbe9 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 4 Nov 2024 18:58:34 -0800 Subject: [PATCH] Update auth tests to address code changes Fix error in token generation logic Update exception handling for proper JWTs --- neon_hana/auth/client_manager.py | 26 +++++-- tests/test_auth.py | 129 ++++++++----------------------- 2 files changed, 52 insertions(+), 103 deletions(-) diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 9dcf22e..286f431 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -33,7 +33,7 @@ from typing import Dict, Optional from fastapi import Request, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from jwt import DecodeError +from jwt import DecodeError, ExpiredSignatureError from ovos_utils import LOG from token_throttler import TokenThrottler, TokenBucket from token_throttler.storage import RuntimeStorage @@ -84,9 +84,11 @@ def _create_tokens(self, client_id: str, token_name: Optional[str] = None, permissions: Optional[PermissionsConfig] = None, - **kwargs) -> (HanaToken, HanaToken, TokenConfig): + **kwargs) -> (str, str, TokenConfig): token_id = str(uuid4()) - creation_timestamp = round(time()) + # Subtract a second from creation so the token may be used immediately + # upon return + creation_timestamp = round(time()) - 1 expiration_timestamp = creation_timestamp + self._access_token_lifetime refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST, @@ -111,7 +113,10 @@ def _create_tokens(self, client_id=client_id, roles=permissions.to_roles(), purpose="refresh") - + access_token = jwt.encode(access_token_data.model_dump(), + self._access_secret, self._jwt_algo) + refresh_token = jwt.encode(refresh_token_data.model_dump(), + self._refresh_secret, self._jwt_algo) token_config = TokenConfig(token_name=token_name, token_id=token_id, user_id=user_id, @@ -120,7 +125,7 @@ def _create_tokens(self, refresh_expiration_timestamp=refresh_expiration_timestamp, creation_timestamp=creation_timestamp, last_refresh_timestamp=creation_timestamp) - return access_token_data, refresh_token_data, token_config + return access_token, refresh_token, token_config def check_connect_stream(self) -> bool: """ @@ -220,10 +225,14 @@ def check_refresh_request(self, access_token: str, refresh_token: str, self._jwt_algo)) token_data = HanaToken(**jwt.decode(access_token, self._access_secret, - self._jwt_algo)) + self._jwt_algo, + leeway=self._refresh_token_lifetime)) except DecodeError: raise HTTPException(status_code=400, - detail="Invalid refresh token supplied") + detail="Invalid token supplied") + except ExpiredSignatureError: + raise HTTPException(status_code=401, + detail="Refresh token is expired") if refresh_data.jti != token_data.jti + ".refresh": raise HTTPException(status_code=403, detail="Refresh and access token mismatch") @@ -306,6 +315,9 @@ def validate_auth(self, token: str, origin_ip: str) -> bool: except DecodeError: # Invalid token supplied pass + except ExpiredSignatureError: + # Expired token + pass return False diff --git a/tests/test_auth.py b/tests/test_auth.py index 73b0ae8..655c73d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import unittest -from time import time +from time import time, sleep from uuid import uuid4 from fastapi import HTTPException @@ -34,7 +34,7 @@ class TestClientManager(unittest.TestCase): from neon_hana.auth.client_manager import ClientManager client_manager = ClientManager({"access_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b", - "refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b", + "refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391ba800", "disable_auth": False}) def test_check_auth_request(self): @@ -67,20 +67,24 @@ def test_check_auth_request(self): self.client_manager.check_auth_request(**request_2)) def test_validate_auth(self): + # Test valid client valid_client = str(uuid4()) - invalid_client = str(uuid4()) auth_response = self.client_manager.check_auth_request( - username="valid", client_id=valid_client)['access_token'] - + username="valid", client_id=valid_client).access_token self.assertTrue(self.client_manager.validate_auth(auth_response, "127.0.0.1")) + + # Unauthenticated client fails + invalid_client = str(uuid4()) self.assertFalse(self.client_manager.validate_auth(invalid_client, "127.0.0.1")) - # TODO: Update token data - expired_token = self.client_manager._create_tokens( - {"client_id": invalid_client, "username": "test", - "password": "test", "expire": time(), - "permissions": {}})['access_token'] + # Test expired token fails auth + self.client_manager._access_token_lifetime = 1 + self.client_manager._refresh_token_lifetime = 1 + expired_token, _, _ = self.client_manager._create_tokens( + user_id=str(uuid4()), + client_id=str(uuid4())) + sleep(1) self.assertFalse(self.client_manager.validate_auth(expired_token, "127.0.0.1")) @@ -93,118 +97,51 @@ def test_validate_auth(self): def test_check_refresh_request(self): valid_client = str(uuid4()) - # TODO: Update token data - tokens = self.client_manager._create_tokens({"client_id": valid_client, - "username": "test", - "password": "test", - "expire": time(), - "permissions": {}}) - self.assertEqual(tokens['client_id'], valid_client) + self.client_manager._access_token_lifetime = 60 + self.client_manager._refresh_token_lifetime = 3600 + access, refresh, config = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=valid_client) + access2, refresh2, config2 = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=str(uuid4())) + self.assertEqual(config.client_id, valid_client) # Test invalid refresh token with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - valid_client, + self.client_manager.check_refresh_request(access, access, valid_client) self.assertEqual(e.exception.status_code, 400) # Test incorrect access token with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['refresh_token'], - tokens['refresh_token'], + self.client_manager.check_refresh_request(access2, refresh, valid_client) self.assertEqual(e.exception.status_code, 403) # Test invalid client_id with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - tokens['refresh_token'], + self.client_manager.check_refresh_request(access, refresh, str(uuid4())) self.assertEqual(e.exception.status_code, 403) # Test valid refresh valid_refresh = self.client_manager.check_refresh_request( - tokens['access_token'], tokens['refresh_token'], - tokens['client_id']) - self.assertEqual(valid_refresh['client_id'], tokens['client_id']) - self.assertNotEqual(valid_refresh['access_token'], - tokens['access_token']) - self.assertNotEqual(valid_refresh['refresh_token'], - tokens['refresh_token']) + access, refresh, config.client_id) + self.assertEqual(valid_refresh.client_id, config.client_id) + self.assertNotEqual(valid_refresh.access_token, access) + self.assertNotEqual(valid_refresh.refresh_token, refresh) # Test expired refresh token real_refresh = self.client_manager._refresh_token_lifetime self.client_manager._refresh_token_lifetime = 0 - # TODO: Update token data - tokens = self.client_manager._create_tokens({"client_id": valid_client, - "username": "test", - "password": "test", - "expire": time(), - "permissions": {}}) + + access, refresh, config = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=valid_client) with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - tokens['refresh_token'], - tokens['client_id']) + self.client_manager.check_refresh_request(access, refresh, + config.client_id) self.assertEqual(e.exception.status_code, 401) self.client_manager._refresh_token_lifetime = real_refresh - def test_get_permissions(self): - from neon_hana.auth.permissions import ClientPermissions - - node_user = "node_test" - rest_user = "rest_user" - self.client_manager._node_username = node_user - self.client_manager._node_password = node_user - - rest_resp = self.client_manager.check_auth_request(rest_user, rest_user) - node_resp = self.client_manager.check_auth_request(node_user, node_user, - node_user) - node_fail = self.client_manager.check_auth_request("node_fail", - node_user, rest_user) - - rest_cid = rest_resp['client_id'] - node_cid = node_resp['client_id'] - fail_cid = node_fail['client_id'] - - permissive = ClientPermissions(True, True, True) - no_node = ClientPermissions(True, True, False) - no_perms = ClientPermissions(False, False, False) - - # Auth disabled, returns all True - self.client_manager._disable_auth = True - self.assertEqual(self.client_manager.get_permissions(rest_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(node_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(rest_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(fail_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions("fake_user"), - permissive) - - # Auth enabled - self.client_manager._disable_auth = False - self.assertEqual(self.client_manager.get_permissions(rest_cid), no_node) - self.assertEqual(self.client_manager.get_permissions(node_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(fail_cid), no_node) - self.assertEqual(self.client_manager.get_permissions("fake_user"), - no_perms) - - def test_client_permissions(self): - from neon_hana.auth.permissions import ClientPermissions - default_perms = ClientPermissions() - restricted_perms = ClientPermissions(False, False, False) - permissive_perms = ClientPermissions(True, True, True) - self.assertIsInstance(default_perms.as_dict(), dict) - for v in default_perms.as_dict().values(): - self.assertIsInstance(v, bool) - self.assertIsInstance(restricted_perms.as_dict(), dict) - self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) - self.assertIsInstance(permissive_perms.as_dict(), dict) - self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) - def test_stream_connections(self): # Test configured maximum self.client_manager._max_streaming_clients = 1