Skip to content

Commit

Permalink
Update auth tests to address code changes
Browse files Browse the repository at this point in the history
Fix error in token generation logic
Update exception handling for proper JWTs
  • Loading branch information
NeonDaniel committed Nov 5, 2024
1 parent e5b3f7d commit b113af1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 103 deletions.
26 changes: 19 additions & 7 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import Dict, Optional
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jwt import DecodeError
from jwt import DecodeError, ExpiredSignatureError
from ovos_utils import LOG
from token_throttler import TokenThrottler, TokenBucket
from token_throttler.storage import RuntimeStorage
Expand Down Expand Up @@ -84,9 +84,11 @@ def _create_tokens(self,
client_id: str,
token_name: Optional[str] = None,
permissions: Optional[PermissionsConfig] = None,
**kwargs) -> (HanaToken, HanaToken, TokenConfig):
**kwargs) -> (str, str, TokenConfig):
token_id = str(uuid4())
creation_timestamp = round(time())
# Subtract a second from creation so the token may be used immediately
# upon return
creation_timestamp = round(time()) - 1
expiration_timestamp = creation_timestamp + self._access_token_lifetime
refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime
permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST,
Expand All @@ -111,7 +113,10 @@ def _create_tokens(self,
client_id=client_id,
roles=permissions.to_roles(),
purpose="refresh")

access_token = jwt.encode(access_token_data.model_dump(),
self._access_secret, self._jwt_algo)
refresh_token = jwt.encode(refresh_token_data.model_dump(),
self._refresh_secret, self._jwt_algo)
token_config = TokenConfig(token_name=token_name,
token_id=token_id,
user_id=user_id,
Expand All @@ -120,7 +125,7 @@ def _create_tokens(self,
refresh_expiration_timestamp=refresh_expiration_timestamp,
creation_timestamp=creation_timestamp,
last_refresh_timestamp=creation_timestamp)
return access_token_data, refresh_token_data, token_config
return access_token, refresh_token, token_config

def check_connect_stream(self) -> bool:
"""
Expand Down Expand Up @@ -220,10 +225,14 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
self._jwt_algo))
token_data = HanaToken(**jwt.decode(access_token,
self._access_secret,
self._jwt_algo))
self._jwt_algo,
leeway=self._refresh_token_lifetime))
except DecodeError:
raise HTTPException(status_code=400,
detail="Invalid refresh token supplied")
detail="Invalid token supplied")
except ExpiredSignatureError:
raise HTTPException(status_code=401,
detail="Refresh token is expired")
if refresh_data.jti != token_data.jti + ".refresh":
raise HTTPException(status_code=403,
detail="Refresh and access token mismatch")
Expand Down Expand Up @@ -306,6 +315,9 @@ def validate_auth(self, token: str, origin_ip: str) -> bool:
except DecodeError:
# Invalid token supplied
pass
except ExpiredSignatureError:
# Expired token
pass
return False


Expand Down
129 changes: 33 additions & 96 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import unittest
from time import time
from time import time, sleep
from uuid import uuid4

from fastapi import HTTPException
Expand All @@ -34,7 +34,7 @@
class TestClientManager(unittest.TestCase):
from neon_hana.auth.client_manager import ClientManager
client_manager = ClientManager({"access_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b",
"refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b",
"refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391ba800",
"disable_auth": False})

def test_check_auth_request(self):
Expand Down Expand Up @@ -67,20 +67,24 @@ def test_check_auth_request(self):
self.client_manager.check_auth_request(**request_2))

def test_validate_auth(self):
# Test valid client
valid_client = str(uuid4())
invalid_client = str(uuid4())
auth_response = self.client_manager.check_auth_request(
username="valid", client_id=valid_client)['access_token']

username="valid", client_id=valid_client).access_token
self.assertTrue(self.client_manager.validate_auth(auth_response,
"127.0.0.1"))

# Unauthenticated client fails
invalid_client = str(uuid4())
self.assertFalse(self.client_manager.validate_auth(invalid_client,
"127.0.0.1"))
# TODO: Update token data
expired_token = self.client_manager._create_tokens(
{"client_id": invalid_client, "username": "test",
"password": "test", "expire": time(),
"permissions": {}})['access_token']
# Test expired token fails auth
self.client_manager._access_token_lifetime = 1
self.client_manager._refresh_token_lifetime = 1
expired_token, _, _ = self.client_manager._create_tokens(
user_id=str(uuid4()),
client_id=str(uuid4()))
sleep(1)
self.assertFalse(self.client_manager.validate_auth(expired_token,
"127.0.0.1"))

Expand All @@ -93,118 +97,51 @@ def test_validate_auth(self):

def test_check_refresh_request(self):
valid_client = str(uuid4())
# TODO: Update token data
tokens = self.client_manager._create_tokens({"client_id": valid_client,
"username": "test",
"password": "test",
"expire": time(),
"permissions": {}})
self.assertEqual(tokens['client_id'], valid_client)
self.client_manager._access_token_lifetime = 60
self.client_manager._refresh_token_lifetime = 3600
access, refresh, config = self.client_manager._create_tokens(
user_id=str(uuid4()), client_id=valid_client)
access2, refresh2, config2 = self.client_manager._create_tokens(
user_id=str(uuid4()), client_id=str(uuid4()))
self.assertEqual(config.client_id, valid_client)

# Test invalid refresh token
with self.assertRaises(HTTPException) as e:
self.client_manager.check_refresh_request(tokens['access_token'],
valid_client,
self.client_manager.check_refresh_request(access, access,
valid_client)
self.assertEqual(e.exception.status_code, 400)

# Test incorrect access token
with self.assertRaises(HTTPException) as e:
self.client_manager.check_refresh_request(tokens['refresh_token'],
tokens['refresh_token'],
self.client_manager.check_refresh_request(access2, refresh,
valid_client)
self.assertEqual(e.exception.status_code, 403)

# Test invalid client_id
with self.assertRaises(HTTPException) as e:
self.client_manager.check_refresh_request(tokens['access_token'],
tokens['refresh_token'],
self.client_manager.check_refresh_request(access, refresh,
str(uuid4()))
self.assertEqual(e.exception.status_code, 403)

# Test valid refresh
valid_refresh = self.client_manager.check_refresh_request(
tokens['access_token'], tokens['refresh_token'],
tokens['client_id'])
self.assertEqual(valid_refresh['client_id'], tokens['client_id'])
self.assertNotEqual(valid_refresh['access_token'],
tokens['access_token'])
self.assertNotEqual(valid_refresh['refresh_token'],
tokens['refresh_token'])
access, refresh, config.client_id)
self.assertEqual(valid_refresh.client_id, config.client_id)
self.assertNotEqual(valid_refresh.access_token, access)
self.assertNotEqual(valid_refresh.refresh_token, refresh)

# Test expired refresh token
real_refresh = self.client_manager._refresh_token_lifetime
self.client_manager._refresh_token_lifetime = 0
# TODO: Update token data
tokens = self.client_manager._create_tokens({"client_id": valid_client,
"username": "test",
"password": "test",
"expire": time(),
"permissions": {}})

access, refresh, config = self.client_manager._create_tokens(
user_id=str(uuid4()), client_id=valid_client)
with self.assertRaises(HTTPException) as e:
self.client_manager.check_refresh_request(tokens['access_token'],
tokens['refresh_token'],
tokens['client_id'])
self.client_manager.check_refresh_request(access, refresh,
config.client_id)
self.assertEqual(e.exception.status_code, 401)
self.client_manager._refresh_token_lifetime = real_refresh

def test_get_permissions(self):
from neon_hana.auth.permissions import ClientPermissions

node_user = "node_test"
rest_user = "rest_user"
self.client_manager._node_username = node_user
self.client_manager._node_password = node_user

rest_resp = self.client_manager.check_auth_request(rest_user, rest_user)
node_resp = self.client_manager.check_auth_request(node_user, node_user,
node_user)
node_fail = self.client_manager.check_auth_request("node_fail",
node_user, rest_user)

rest_cid = rest_resp['client_id']
node_cid = node_resp['client_id']
fail_cid = node_fail['client_id']

permissive = ClientPermissions(True, True, True)
no_node = ClientPermissions(True, True, False)
no_perms = ClientPermissions(False, False, False)

# Auth disabled, returns all True
self.client_manager._disable_auth = True
self.assertEqual(self.client_manager.get_permissions(rest_cid),
permissive)
self.assertEqual(self.client_manager.get_permissions(node_cid),
permissive)
self.assertEqual(self.client_manager.get_permissions(rest_cid),
permissive)
self.assertEqual(self.client_manager.get_permissions(fail_cid),
permissive)
self.assertEqual(self.client_manager.get_permissions("fake_user"),
permissive)

# Auth enabled
self.client_manager._disable_auth = False
self.assertEqual(self.client_manager.get_permissions(rest_cid), no_node)
self.assertEqual(self.client_manager.get_permissions(node_cid),
permissive)
self.assertEqual(self.client_manager.get_permissions(fail_cid), no_node)
self.assertEqual(self.client_manager.get_permissions("fake_user"),
no_perms)

def test_client_permissions(self):
from neon_hana.auth.permissions import ClientPermissions
default_perms = ClientPermissions()
restricted_perms = ClientPermissions(False, False, False)
permissive_perms = ClientPermissions(True, True, True)
self.assertIsInstance(default_perms.as_dict(), dict)
for v in default_perms.as_dict().values():
self.assertIsInstance(v, bool)
self.assertIsInstance(restricted_perms.as_dict(), dict)
self.assertFalse(any([v for v in restricted_perms.as_dict().values()]))
self.assertIsInstance(permissive_perms.as_dict(), dict)
self.assertTrue(all([v for v in permissive_perms.as_dict().values()]))

def test_stream_connections(self):
# Test configured maximum
self.client_manager._max_streaming_clients = 1
Expand Down

0 comments on commit b113af1

Please sign in to comment.