generated from cheshire-cat-ai/plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcatcloak.py
146 lines (117 loc) · 5.82 KB
/
catcloak.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from cat.mad_hatter.decorators import hook
from cat.factory.auth_handler import AuthHandlerConfig
from cat.factory.custom_auth_handler import BaseAuthHandler
from cat.auth.permissions import (
AuthPermission, AuthResource, AuthUserInfo, get_base_permissions
)
from cat.log import log
from pydantic import ConfigDict, Field
from typing import List, Type, Dict, Any, Literal
from keycloak import KeycloakOpenID
from cachetools import TTLCache
from time import time
@hook(priority=0)
def factory_allowed_auth_handlers(allowed: List[AuthHandlerConfig], cat) -> List:
allowed.append(KeycloakAuthHandlerConfig)
return allowed
class KeycloakAuthHandler(BaseAuthHandler):
def __init__(self, **config):
self.user_mapping = config.get("user_mapping", {})
self.permission_mapping = config.get("permission_mapping", {})
self.keycloak_openid = KeycloakOpenID(
server_url=config["server_url"],
client_id=config["client_id"],
realm_name=config["realm"],
client_secret_key=config["client_secret"]
)
self.kc_permissions = {}
self.token_cache = TTLCache(maxsize=1000, ttl=300)
self.user_info_cache = TTLCache(maxsize=1000, ttl=300)
def authorize_user_from_jwt(
self, token: str, auth_resource: AuthResource, auth_permission: AuthPermission
) -> AuthUserInfo | None:
try:
if token in self.token_cache:
token_info, expiration = self.token_cache[token]
if time() < expiration:
user_info = self.user_info_cache.get(token)
if user_info and self.has_permission(user_info, auth_resource, auth_permission):
return user_info
token_info = self.keycloak_openid.decode_token(token)
expiration = token_info['exp']
self.token_cache[token] = (token_info, expiration)
user_info = self.map_user_data(token_info)
self.map_permissions(token_info, user_info)
log.debug(f"User info: {user_info}")
self.user_info_cache[token] = user_info
if not self.permission_mapping:
user_info.permissions = get_base_permissions()
return user_info
if self.has_permission(user_info, auth_resource, auth_permission):
return user_info
return None
except Exception as e:
log.error(f"Error processing token: {e}")
return None
def authorize_user_from_key(
self, protocol: Literal["http", "websocket"], user_id: str, api_key: str, auth_resource: AuthResource,
auth_permission: AuthPermission
) -> AuthUserInfo | None:
log.warning("KeycloakAuthHandler does not support API keys.")
return None
def map_user_data(self, token_info: Dict[str, Any]) -> AuthUserInfo:
extra = {key: self.get_nested_value(token_info, path)
for key, path in self.user_mapping.items()
if key not in ["id", "name", "roles"]}
return AuthUserInfo(
id=self.get_nested_value(token_info, self.user_mapping.get("id", "sub")),
name=self.get_nested_value(token_info, self.user_mapping.get("name", "preferred_username")),
extra=extra
)
def map_permissions(self, token_info: Dict[str, Any], user_info: AuthUserInfo):
roles_path = self.user_mapping.get("roles", "realm_access.roles")
kc_roles = self.get_nested_value(token_info, roles_path) or []
roles_key = tuple(sorted(kc_roles))
if roles_key in self.kc_permissions:
user_info.permissions = self.kc_permissions[roles_key]
return
permissions = {}
for role in kc_roles:
if role in self.permission_mapping:
for resource, perms in self.permission_mapping[role].items():
if resource not in permissions:
permissions[resource] = set()
permissions[resource].update(perms)
permissions = {resource: list(perms) for resource, perms in permissions.items()}
self.kc_permissions[roles_key] = permissions
user_info.permissions = permissions
def has_permission(self, user_info: AuthUserInfo, auth_resource: AuthResource,
auth_permission: AuthPermission) -> bool:
user_permissions = user_info.permissions.get(auth_resource.value, [])
if auth_permission.value not in user_permissions:
log.error(
f"User {user_info.id} does not have permission to access {auth_resource.value} with {auth_permission.value}")
return False
return True
@staticmethod
def get_nested_value(data: Dict[str, Any], path: str) -> Any:
for key in path.split('.'):
if isinstance(data, dict):
data = data.get(key)
return data
class KeycloakAuthHandlerConfig(AuthHandlerConfig):
_pyclass: Type = KeycloakAuthHandler
server_url: str = Field(..., description="The URL of the Keycloak server.")
realm: str = Field(..., description="The realm to use.")
client_id: str = Field(..., description="The client ID to use.")
client_secret: str = Field(..., description="The client secret to use.")
user_mapping: Dict[str, str] = Field(..., description="The mapping of user data from the token to the user model.",
extra={"type": "TextArea"})
permission_mapping: Dict[str, Any] = Field(..., description="The mapping of Keycloak roles to Cat permissions.",
extra={"type": "TextArea"})
model_config = ConfigDict(
json_schema_extra={
"humanReadableName": "Keycloak Auth Handler",
"description": "Delegate auth to a Keycloak instance."
}
)