Skip to content

Commit

Permalink
feat: add parameter guards
Browse files Browse the repository at this point in the history
  • Loading branch information
ctran88 committed Dec 9, 2024
1 parent a22485f commit 4671f8c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
17 changes: 10 additions & 7 deletions passageidentity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from typing import TYPE_CHECKING

import jwt
import jwt.algorithms
import jwt as pyjwt

from passageidentity.errors import PassageError
from passageidentity.helper import fetch_app
Expand All @@ -26,7 +25,7 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:
"""Initialize the Auth class with the app ID and request headers."""
self.app_id = app_id
self.request_headers = request_headers
self.jwks = jwt.PyJWKClient(
self.jwks = pyjwt.PyJWKClient(
f"https://auth.passage.id/v1/apps/{self.app_id}/.well-known/jwks.json",
# must set a user agent to avoid 403 from CF
headers={"User-Agent": "passageidentity/python"},
Expand All @@ -35,13 +34,17 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:

self.magic_links_api = MagicLinksApi()

def validate_jwt(self, token: str) -> str:
def validate_jwt(self, jwt: str) -> str:
"""Verify the JWT and return the user ID for the authenticated user, or throw a PassageError."""
if not jwt:
msg = "jwt is required."
raise ValueError(msg)

try:
kid = jwt.get_unverified_header(token)["kid"]
kid = pyjwt.get_unverified_header(jwt)["kid"]
public_key = self.jwks.get_signing_key(kid)
claims = jwt.decode(
token,
claims = pyjwt.decode(
jwt,
public_key,
audience=[self.app_id] if self.app["hosted"] else self.app["auth_origin"],
algorithms=["RS256"],
Expand Down
12 changes: 9 additions & 3 deletions passageidentity/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import typing_extensions

Expand All @@ -19,6 +19,7 @@
)
from .openapi_client.models import (
CreateMagicLinkRequest,
CreateUserRequest,
MagicLinkType,
)

Expand All @@ -27,7 +28,6 @@

from .openapi_client.models import (
AppInfo,
CreateUserRequest,
UpdateUserRequest,
UserInfo,
WebAuthnDevices,
Expand Down Expand Up @@ -243,4 +243,10 @@ def createUser( # noqa: N802
msg = "either phone or email must be provided to create the user"
raise PassageError(msg)

return self.user.create(userAttributes)
user_args = (
cast(CreateUserRequest, CreateUserRequest.from_dict(userAttributes))
if isinstance(userAttributes, dict)
else userAttributes
)

return self.user.create(user_args)
44 changes: 44 additions & 0 deletions passageidentity/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:

def get(self, user_id: str) -> PassageUser:
"""Get a user's object using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.get_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -49,6 +53,10 @@ def get(self, user_id: str) -> PassageUser:

def get_by_identifier(self, identifier: str) -> PassageUser:
"""Get a user's object using their user identifier."""
if not identifier:
msg = "identifier is required."
raise ValueError(msg)

try:
users = self.users_api.list_paginated_users(
self.app_id,
Expand All @@ -68,6 +76,10 @@ def get_by_identifier(self, identifier: str) -> PassageUser:

def activate(self, user_id: str) -> PassageUser:
"""Activate a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.activate_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -76,6 +88,10 @@ def activate(self, user_id: str) -> PassageUser:

def deactivate(self, user_id: str) -> PassageUser:
"""Deactivate a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.deactivate_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -84,6 +100,10 @@ def deactivate(self, user_id: str) -> PassageUser:

def update(self, user_id: str, args: UpdateUserArgs) -> PassageUser:
"""Update a user."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.update_user(self.app_id, user_id, args, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -92,6 +112,10 @@ def update(self, user_id: str, args: UpdateUserArgs) -> PassageUser:

def create(self, args: CreateUserArgs) -> PassageUser:
"""Create a user."""
if not args.email and not args.phone:
msg = "At least one of args.email or args.phone is required."
raise ValueError(msg)

try:
return self.users_api.create_user(self.app_id, args, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -100,6 +124,10 @@ def create(self, args: CreateUserArgs) -> PassageUser:

def delete(self, user_id: str) -> None:
"""Delete a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
self.users_api.delete_user(self.app_id, user_id, _headers=self.request_headers)
except ApiException as e:
Expand All @@ -108,6 +136,10 @@ def delete(self, user_id: str) -> None:

def list_devices(self, user_id: str) -> list[WebAuthnDevices]:
"""Get a user's devices using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.user_devices_api.list_user_devices(self.app_id, user_id, _headers=self.request_headers).devices
except ApiException as e:
Expand All @@ -116,6 +148,14 @@ def list_devices(self, user_id: str) -> list[WebAuthnDevices]:

def revoke_device(self, user_id: str, device_id: str) -> None:
"""Revoke a user's device using their user ID and the device ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

if not device_id:
msg = "device_id is required."
raise ValueError(msg)

try:
self.user_devices_api.delete_user_devices(self.app_id, user_id, device_id, _headers=self.request_headers)
except ApiException as e:
Expand All @@ -124,6 +164,10 @@ def revoke_device(self, user_id: str, device_id: str) -> None:

def revoke_refresh_tokens(self, user_id: str) -> None:
"""Revokes all of a user's Refresh Tokens using their User ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
self.tokens_api.revoke_user_refresh_tokens(self.app_id, user_id, _headers=self.request_headers)
except ApiException as e:
Expand Down
4 changes: 2 additions & 2 deletions tests/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def test_get_by_identifier_valid_upper_case() -> None:
psg = Passage(PASSAGE_APP_ID, PASSAGE_API_KEY)

email = f.email()
new_user = cast(UserInfo, psg.user.create({"email": email})) # type: ignore[arg-type]
new_user = cast(UserInfo, psg.createUser({"email": email})) # type: ignore[arg-type]
assert new_user.email == email

user_by_identifier = cast(UserInfo, psg.user.get_by_identifier(email.upper()))
user_by_identifier = cast(UserInfo, psg.getUserByIdentifier(email.upper()))
assert user_by_identifier.id == new_user.id

user = cast(UserInfo, psg.user.get(new_user.id))
Expand Down

0 comments on commit 4671f8c

Please sign in to comment.