Skip to content

Commit

Permalink
Fix response codes, labels, add user binding and user logout.
Browse files Browse the repository at this point in the history
  • Loading branch information
GrahamDumpleton committed Aug 25, 2024
1 parent 7366a46 commit 0d8ec41
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ spec:
password:
type: string
minLength: 8
user:
type: string
roles:
type: array
items:
Expand Down
17 changes: 15 additions & 2 deletions lookup-service/service/caches/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,32 @@ class ClientConfig:

name: str
uid: str
issue: int
password: str
user: str
tenants: List[str]
roles: List[str]

@property
def identity(self) -> str:
"""Return the identity of the client."""

return f"client@educates:{self.uid}#{self.issue}"

def revoke_tokens(self) -> None:
"""Revoke all tokens issued to the client."""

self.issue += 1

def check_password(self, password: str) -> bool:
"""Checks the password provided against the client's password."""

return self.password == password

def validate_identity(self, uid: str) -> bool:
def validate_identity(self, identity: str) -> bool:
"""Validate the identity provided against the client's identity."""

return self.uid == uid
return self.identity == identity

def has_required_role(self, *roles: str) -> Set:
"""Check if the client has any of the roles provided. We return back a
Expand Down
4 changes: 2 additions & 2 deletions lookup-service/service/caches/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class ClusterConfig:

name: str
uid: str
labels: Dict[str, str]
labels: List[Dict[str, str]]
kubeconfig: Dict[str, Any]
portals: Dict[str, "TrainingPortal"]

def __init__(
self, name: str, uid: str, labels: Dict[str, str], kubeconfig: Dict[str, Any]
self, name: str, uid: str, labels: List[Dict[str, str]], kubeconfig: Dict[str, Any]
):
self.name = name
self.uid = uid
Expand Down
4 changes: 2 additions & 2 deletions lookup-service/service/caches/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_client(self, name: str) -> "ClientConfig":
return self.clients.get(name)

def authenticate_client(self, name: str, password: str) -> str | None:
"""Validate a client's credentials. Returning the uid of the client if
"""Validate a client's credentials. Returning the the client if
the credentials are valid."""

client = self.get_client(name)
Expand All @@ -51,7 +51,7 @@ def authenticate_client(self, name: str, password: str) -> str | None:
return

if client.check_password(password):
return client.uid
return client


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions lookup-service/service/caches/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class WorkshopEnvironment:
workshop: str
title: str
description: str
labels: Dict[str, str]
labels: List[Dict[str, str]]
capacity: int
reserved: int
allocated: int
Expand All @@ -43,7 +43,7 @@ def __init__(
workshop: str,
title: str,
description: str,
labels: Dict[str, str],
labels: List[Dict[str, str]],
capacity: int,
reserved: int,
allocated: int,
Expand Down
4 changes: 2 additions & 2 deletions lookup-service/service/caches/portals.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TrainingPortal:
name: str
uid: str
generation: int
labels: Dict[Tuple[str, str], str]
labels: List[Dict[str, str]]
url: str
credentials: PortalCredentials
phase: str
Expand All @@ -49,7 +49,7 @@ def __init__(
name: str,
uid: str,
generation: int,
labels: Dict[str, str],
labels: List[Dict[str, str]],
url: str,
credentials: PortalCredentials,
phase: str,
Expand Down
3 changes: 3 additions & 0 deletions lookup-service/service/handlers/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def clientconfigs_update(

client_uid = xgetattr(meta, "uid")
client_password = xgetattr(spec, "client.password")
client_user = xgetattr(spec, "user")
client_tenants = xgetattr(spec, "tenants", [])
client_roles = xgetattr(spec, "roles", [])

Expand All @@ -42,7 +43,9 @@ def clientconfigs_update(
ClientConfig(
name=client_name,
uid=client_uid,
issue=1,
password=client_password,
user=client_user,
tenants=client_tenants,
roles=client_roles,
)
Expand Down
23 changes: 13 additions & 10 deletions lookup-service/service/handlers/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def clusterconfigs_update(
ClusterConfig(
name=name,
uid=uid,
labels=xgetattr(spec, "labels", {}),
labels=xgetattr(spec, "labels", []),
kubeconfig=kubeconfig,
)
)
Expand All @@ -164,7 +164,7 @@ def clusterconfigs_update(
generation,
)

cluster_config.labels = xgetattr(spec, "labels", {})
cluster_config.labels = xgetattr(spec, "labels", [])
cluster_config.kubeconfig = kubeconfig


Expand Down Expand Up @@ -250,7 +250,7 @@ async def trainingportals_event(event: kopf.RawEvent, **_):
name=portal_name,
uid=portal_uid,
generation=xgetattr(metadata, "generation"),
labels=xgetattr(spec, "portal.labels", {}),
labels=xgetattr(spec, "portal.labels", []),
url=xgetattr(status, "educates.url"),
phase=xgetattr(status, "educates.phase"),
credentials=credentials,
Expand All @@ -271,7 +271,7 @@ async def trainingportals_event(event: kopf.RawEvent, **_):

portal_state.uid = portal_uid
portal_state.generation = xgetattr(metadata, "generation")
portal_state.labels = xgetattr(spec, "portal.labels", {})
portal_state.labels = xgetattr(spec, "portal.labels", [])
portal_state.url = xgetattr(status, "educates.url")
portal_state.phase = xgetattr(status, "educates.phase")
portal_state.credentials = credentials
Expand Down Expand Up @@ -358,7 +358,7 @@ async def workshopenvironments_event(event: kopf.RawEvent, **_):
name=portal_name,
uid=portal_uid,
generation=0,
labels={},
labels=[],
url="",
phase="Unknown",
credentials=PortalCredentials(
Expand Down Expand Up @@ -393,7 +393,7 @@ async def workshopenvironments_event(event: kopf.RawEvent, **_):
workshop=workshop_name,
title=xgetattr(workshop_spec, "title"),
description=xgetattr(workshop_spec, "description"),
labels=xgetattr(workshop_spec, "labels", {}),
labels=xgetattr(workshop_spec, "labels", []),
capacity=xgetattr(status, "educates.capacity", 0),
reserved=xgetattr(status, "educates.reserved", 0),
allocated=0,
Expand All @@ -416,7 +416,7 @@ async def workshopenvironments_event(event: kopf.RawEvent, **_):
environment_state.description = xgetattr(
workshop_spec, "description"
)
environment_state.labels = xgetattr(workshop_spec, "labels", {})
environment_state.labels = xgetattr(workshop_spec, "labels", [])

environment_state.phase = xgetattr(status, "educates.phase")

Expand All @@ -442,6 +442,7 @@ async def workshopsessions_event(event: kopf.RawEvent, **_):

body = xgetattr(event, "object", {})
metadata = xgetattr(body, "metadata", {})
spec = xgetattr(body, "spec", {})
status = xgetattr(body, "status", {})

portal_name = xgetattr(metadata, "labels", {}).get(
Expand All @@ -458,6 +459,8 @@ async def workshopsessions_event(event: kopf.RawEvent, **_):
"training.educates.dev/environment.uid"
)

workshop_name = xgetattr(spec, "workshop.name")

session_name = xgetattr(metadata, "name")

with synchronized(self.cluster_config):
Expand Down Expand Up @@ -531,7 +534,7 @@ async def workshopsessions_event(event: kopf.RawEvent, **_):
name=portal_name,
uid=portal_uid,
generation=0,
labels={},
labels=[],
url="",
phase="Unknown",
credentials=PortalCredentials(
Expand Down Expand Up @@ -561,10 +564,10 @@ async def workshopsessions_event(event: kopf.RawEvent, **_):
name=environment_name,
uid=environment_uid,
generation=0,
workshop="",
workshop=workshop_name,
title="",
description="",
labels={},
labels=[],
capacity=0,
reserved=0,
allocated=0,
Expand Down
56 changes: 48 additions & 8 deletions lookup-service/service/routes/authnz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from aiohttp import web

from ..config import jwt_token_secret
from ..caches.clients import ClientConfig

TOKEN_EXPIRATION = 72 # Expiration in hours.


def generate_client_token(username: str, uid: str) -> dict:
def generate_login_response(client: ClientConfig) -> dict:
"""Generate a JWT token for the client. The token will be set to expire and
will need to be renewed. The token will contain the username and the unique
identifier for the client."""
Expand All @@ -25,7 +26,7 @@ def generate_client_token(username: str, uid: str) -> dict:
)

jwt_token = jwt.encode(
{"sub": username, "jti": uid, "exp": expires_at},
{"sub": client.name, "jti": client.identity, "exp": expires_at},
jwt_token_secret(),
algorithm="HS256",
)
Expand All @@ -34,6 +35,8 @@ def generate_client_token(username: str, uid: str) -> dict:
"access_token": jwt_token,
"token_type": "Bearer",
"expires_at": expires_at,
"roles": client.roles,
"tenants": client.tenants,
}


Expand Down Expand Up @@ -77,7 +80,7 @@ async def jwt_token_middleware(
except jwt.ExpiredSignatureError:
return web.Response(text="JWT token has expired", status=401)
except jwt.InvalidTokenError:
return web.Response(text="JWT token is invalid", status=400)
return web.Response(text="JWT token is invalid", status=401)

# Store the decoded token in the request object for later use.

Expand Down Expand Up @@ -152,7 +155,7 @@ async def wrapper(request: web.Request) -> web.Response:
return decorator


async def api_login_handler(request: web.Request) -> web.Response:
async def api_auth_login(request: web.Request) -> web.Response:
"""Login handler for accessing the web application. Validates the username
and password provided in the request and returns a JWT token if the
credentials are valid."""
Expand All @@ -175,22 +178,59 @@ async def api_login_handler(request: web.Request) -> web.Response:
service_state = request.app["service_state"]
client_database = service_state.client_database

uid = client_database.authenticate_client(username, password)
client = client_database.authenticate_client(username, password)

if not uid:
if not client:
return web.Response(text="Invalid username/password", status=401)

# Generate a JWT token for the user and return it. The response is
# bundle with the token type and expiration time so they can be used
# by the client without needing to parse the actual JWT token.

token = generate_client_token(username, uid)
token = generate_login_response(client)

return web.json_response(token)


async def api_auth_logout(request: web.Request) -> web.Response:
"""Logout handler for the web application. The client will be logged out
and the JWT token will be invalidated."""

# Check if the decoded JWT token is present in the request object.

if "jwt_token" not in request:
return web.Response(text="JWT token not supplied", status=400)

decoded_token = request["jwt_token"]

# Check the client database for the client by the name of the client
# taken from the JWT token subject. Then check if the identity of the
# client is still the same as the one recorded in the JWT token.

service_state = request.app["service_state"]
client_database = service_state.client_database

client = client_database.get_client(decoded_token["sub"])

if not client:
return web.Response(text="Client details not found", status=401)

if not client.validate_identity(decoded_token["jti"]):
return web.Response(text="Client identity does not match", status=401)

# Revoke the tokens issued to the client.

client.revoke_tokens()

return web.json_response({})

# Set up the middleware and routes for the authentication and authorization.

middlewares = [jwt_token_middleware]

routes = [web.post("/login", api_login_handler)]
routes = [
web.post("/login", api_auth_login),
web.post("/auth/login", api_auth_login),
web.post("/auth/logout", api_auth_logout),
web.get("/auth/verify", login_required(lambda r: web.json_response({}))),
]
Loading

0 comments on commit 0d8ec41

Please sign in to comment.