Skip to content

Commit

Permalink
Update authorized_clients handling to partially restore backwards-c…
Browse files Browse the repository at this point in the history
…ompat.
  • Loading branch information
NeonDaniel committed Nov 5, 2024
1 parent 84e31a6 commit 2cbfad0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
2 changes: 2 additions & 0 deletions neon_hana/app/routers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@

@util_route.get("/client_ip", response_class=PlainTextResponse)
async def api_client_ip(request: Request) -> str:
# Validation will fail, but this increments the rate-limiting
client_manager.validate_auth("", request.client.host)
return request.client.host


@util_route.get("/headers")
async def api_headers(request: Request):
# Validation will fail, but this increments the rate-limiting
client_manager.validate_auth("", request.client.host)
return request.headers
48 changes: 32 additions & 16 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ def __init__(self, config: dict,
mq_connector: Optional[MQServiceManager] = None):
self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage())

# TODO: Is `authorized_clients` useful to track?
# Keep a dict of `client_id` to auth tokens that have authenticated to
# this instance
self.authorized_clients: Dict[str, HanaToken] = dict()
self._authorized_clients: Dict[str, AuthenticationResponse] = dict()
self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24)
self._refresh_token_lifetime = config.get("refresh_token_ttl",
3600 * 24 * 90)
Expand All @@ -79,6 +78,16 @@ def __init__(self, config: dict,
self._stream_check_lock = Lock()
self._mq_connector = mq_connector

@property
def authorized_clients(self) -> Dict[str, AuthenticationResponse]:
"""
Dict of `client_id` to `AuthenticationResponse` objects for clients
known by this instance. NOTE: Refresh tokens are not reliably stored
here and should never be retrievable after generation for security.
"""
# TODO: Is `authorized_clients` useful to track?
return self._authorized_clients

def _create_tokens(self,
user_id: str,
client_id: str,
Expand All @@ -92,9 +101,9 @@ def _create_tokens(self,
expiration_timestamp = creation_timestamp + self._access_token_lifetime
refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime
permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST,
diana=AccessRoles.GUEST,
node=AccessRoles.GUEST,
llm=AccessRoles.GUEST)
diana=AccessRoles.GUEST,
node=AccessRoles.GUEST,
llm=AccessRoles.GUEST)
token_name = token_name or kwargs.get("name") or \
datetime.fromtimestamp(creation_timestamp).isoformat()
access_token_data = HanaToken(iss=self._jwt_issuer,
Expand Down Expand Up @@ -174,9 +183,9 @@ def check_auth_request(self, client_id: str, username: str,
@param origin_ip: Origin IP address of request
@return: response tokens, permissions, and other metadata
"""
# if client_id in self.authorized_clients:
# print(f"Using cached client: {self.authorized_clients[client_id]}")
# return self.authorized_clients[client_id]
if client_id in self.authorized_clients:
print(f"Using cached client: {self.authorized_clients[client_id]}")
return self.authorized_clients[client_id]

ratelimit_id = f"auth{origin_ip}"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
Expand Down Expand Up @@ -208,13 +217,15 @@ def check_auth_request(self, client_id: str, username: str,
"token_name": token_name,
"last_refresh_timestamp": create_time}
access, refresh, config = self._create_tokens(**encode_data)
self.authorized_clients[client_id] = config

auth_response = AuthenticationResponse(username=user.username,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)
self.authorized_clients[client_id] = auth_response
self._add_token_to_userdb(user, config)
return AuthenticationResponse(username=user.username,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)
return auth_response

def check_refresh_request(self, access_token: str, refresh_token: str,
client_id: str) -> AuthenticationResponse:
Expand Down Expand Up @@ -263,11 +274,14 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
else:
username = token_data.sub
access, refresh, config = self._create_tokens(**encode_data)
return AuthenticationResponse(username=username,

auth_response = AuthenticationResponse(username=username,
client_id=client_id,
access_token=access,
refresh_token=refresh,
expiration=config.refresh_expiration_timestamp)
self._authorized_clients[client_id] = auth_response
return auth_response

def _add_token_to_userdb(self, user: User, new_token: TokenConfig):
if self._mq_connector is None:
Expand Down Expand Up @@ -310,7 +324,9 @@ def validate_auth(self, token: str, origin_ip: str) -> bool:
if auth.exp < time():
self.authorized_clients.pop(auth.client_id, None)
return False
self.authorized_clients[auth.client_id] = auth
self.authorized_clients[auth.client_id] = AuthenticationResponse(
username=auth.sub, client_id=auth.client_id, access_token=token,
refresh_token="", expiration=auth.exp)
return True
except DecodeError:
# Invalid token supplied
Expand Down
8 changes: 4 additions & 4 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ def test_check_auth_request(self):

# Check simple auth
auth_resp_1 = self.client_manager.check_auth_request(**request_1)
# self.assertEqual(self.client_manager.authorized_clients[client_1],
# auth_resp_1.access_token)
self.assertEqual(self.client_manager.authorized_clients[client_1],
auth_resp_1)
self.assertEqual(auth_resp_1.username, 'guest')
self.assertEqual(auth_resp_1.client_id, client_1)

# Check auth from different client
auth_resp_2 = self.client_manager.check_auth_request(**request_2)
self.assertNotEquals(auth_resp_1, auth_resp_2)
# self.assertEqual(self.client_manager.authorized_clients[client_2],
# auth_resp_2.access_token)
self.assertEqual(self.client_manager.authorized_clients[client_2],
auth_resp_2)
self.assertEqual(auth_resp_2.username, 'guest')
self.assertEqual(auth_resp_2.client_id, client_2)

Expand Down

0 comments on commit 2cbfad0

Please sign in to comment.