diff --git a/neon_hana/app/routers/auth.py b/neon_hana/app/routers/auth.py index 605db9b..9660953 100644 --- a/neon_hana/app/routers/auth.py +++ b/neon_hana/app/routers/auth.py @@ -24,11 +24,12 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends +from neon_data_models.models.user.database import PermissionsConfig +from neon_data_models.models.user import User -from neon_hana.app.dependencies import client_manager +from neon_hana.app.dependencies import client_manager, jwt_bearer from neon_hana.schema.auth_requests import * -from neon_data_models.models.user import User auth_route = APIRouter(prefix="/auth", tags=["authentication"]) @@ -51,3 +52,8 @@ async def register_user(register_request: RegistrationRequest, request: Request) -> User: return client_manager.check_registration_request(**dict(register_request), origin_ip=request.client.host) + + +@auth_route.post("/permissions") +async def check_permissions(token: str = Depends(jwt_bearer)) -> PermissionsConfig: + return client_manager.get_token_permissions(token) diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index e3bdb64..9e81ded 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -344,11 +344,20 @@ def get_token_data(self, token: str) -> HanaToken: """ Extract the user_id from a JWT string @param token: JWT to parse - @retrun: user_id associated with token + @return: user_id associated with token """ return HanaToken(**jwt.decode(token, self._access_secret, self._jwt_algo)) + def get_token_permissions(self, token: str) -> PermissionsConfig: + """ + Get a PermissionsConfig object from a JWT Token + @param token: JWT to parse + @return: PermissionsConfig object representing the token permissions + """ + roles = self.get_token_data(token).roles + return PermissionsConfig.from_roles(roles) + def validate_auth(self, token: str, origin_ip: str) -> bool: ratelimit_id = f"{origin_ip}-total" if not self.rate_limiter.get_all_buckets(ratelimit_id): diff --git a/tests/test_auth.py b/tests/test_auth.py index b5270d1..cfbc877 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -29,6 +29,10 @@ from uuid import uuid4 from fastapi import HTTPException +from jwt import DecodeError + +from neon_data_models.enum import AccessRoles +from neon_data_models.models.user.database import PermissionsConfig class TestClientManager(unittest.TestCase): @@ -173,3 +177,23 @@ def test_stream_connections(self): self.client_manager._max_streaming_clients = False self.assertTrue(self.client_manager.check_connect_stream()) self.assertEqual(self.client_manager._connected_streams, 5) + + def test_get_token_permissions(self): + permissions = PermissionsConfig(core=AccessRoles.USER, + diana=AccessRoles.USER, + node=AccessRoles.USER, + llm=AccessRoles.USER, + users=AccessRoles.USER) + valid_token, _, _ = self.client_manager._create_tokens("test_user", + "test_client", + "test_name", + permissions) + + # Valid token decodes + self.assertEqual(self.client_manager.get_token_permissions(valid_token), + permissions) + + # Invalid token raises exception + with self.assertRaises(DecodeError): + self.client_manager.get_token_permissions("invalid_token_string") +