Skip to content

Commit

Permalink
feat: updates create_magic_link method to use stronger typed paramete…
Browse files Browse the repository at this point in the history
…rs (#114)

Co-authored-by: Bert Ramirez <[email protected]>
  • Loading branch information
ctran88 and bertrmz authored Dec 11, 2024
1 parent 5153486 commit 83cd628
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 47 deletions.
3 changes: 3 additions & 0 deletions passageidentity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Initializes the Passage identity package."""

from .errors import PassageError
from .models import MagicLinkArgs, MagicLinkOptions
from .passage import Passage

__all__ = [
"MagicLinkArgs",
"MagicLinkOptions",
"Passage",
"PassageError",
]
67 changes: 38 additions & 29 deletions passageidentity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@

from typing import TYPE_CHECKING

import jwt
import jwt.algorithms
import jwt as pyjwt

from passageidentity.errors import PassageError
from passageidentity.helper import fetch_app
from passageidentity.models.magic_link_args import MagicLinkWithEmailArgs, MagicLinkWithPhoneArgs, MagicLinkWithUserArgs
from passageidentity.openapi_client.api.magic_links_api import MagicLinksApi
from passageidentity.openapi_client.exceptions import ApiException
from passageidentity.openapi_client.models.create_magic_link_request import CreateMagicLinkRequest
from passageidentity.openapi_client.models.magic_link_channel import MagicLinkChannel

if TYPE_CHECKING:
from passageidentity.models.magic_link_args import MagicLinkArgs
from passageidentity.models.magic_link_options import MagicLinkOptions
from passageidentity.openapi_client.models.magic_link import MagicLink

CreateMagicLinkArgs = CreateMagicLinkRequest


class Auth:
"""Auth class for handling operations to authenticate and validate JWTs."""
Expand All @@ -26,7 +27,7 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:
"""Initialize the Auth class with the app ID and request headers."""
self.app_id = app_id
self.request_headers = request_headers
self.jwks = jwt.PyJWKClient(
self.jwks = pyjwt.PyJWKClient(
f"https://auth.passage.id/v1/apps/{self.app_id}/.well-known/jwks.json",
# must set a user agent to avoid 403 from CF
headers={"User-Agent": "passageidentity/python"},
Expand All @@ -35,13 +36,17 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:

self.magic_links_api = MagicLinksApi()

def validate_jwt(self, token: str) -> str:
def validate_jwt(self, jwt: str) -> str:
"""Verify the JWT and return the user ID for the authenticated user, or throw a PassageError."""
if not jwt:
msg = "jwt is required."
raise ValueError(msg)

try:
kid = jwt.get_unverified_header(token)["kid"]
kid = pyjwt.get_unverified_header(jwt)["kid"]
public_key = self.jwks.get_signing_key(kid)
claims = jwt.decode(
token,
claims = pyjwt.decode(
jwt,
public_key,
audience=[self.app_id] if self.app["hosted"] else self.app["auth_origin"],
algorithms=["RS256"],
Expand All @@ -52,31 +57,35 @@ def validate_jwt(self, token: str) -> str:
msg = f"JWT is not valid: {e}"
raise PassageError(msg) from e

def create_magic_link(self, args: CreateMagicLinkArgs) -> MagicLink:
def create_magic_link(self, args: MagicLinkArgs, options: MagicLinkOptions | None = None) -> MagicLink:
"""Create a Magic Link for your app."""
magic_link_req = {}
args_dict = args.to_dict() if isinstance(args, CreateMagicLinkRequest) else args

magic_link_req["user_id"] = args_dict.get("user_id") or ""
magic_link_req["email"] = args_dict.get("email") or ""
magic_link_req["phone"] = args_dict.get("phone") or ""

magic_link_req["language"] = args_dict.get("language") or ""
magic_link_req["magic_link_path"] = args_dict.get("magic_link_path") or ""
magic_link_req["redirect_url"] = args_dict.get("redirect_url") or ""
magic_link_req["send"] = args_dict.get("send") or False
magic_link_req["ttl"] = args_dict.get("ttl") or 0
magic_link_req["type"] = args_dict.get("type") or "login"

if args_dict.get("email"):
magic_link_req["channel"] = args_dict.get("channel") or "email"
elif args_dict.get("phone"):
magic_link_req["channel"] = args_dict.get("channel") or "phone"
payload = CreateMagicLinkRequest()
payload.type = args.type
payload.send = args.send

if isinstance(args, MagicLinkWithEmailArgs):
payload.email = args.email
payload.channel = MagicLinkChannel.EMAIL
elif isinstance(args, MagicLinkWithPhoneArgs):
payload.phone = args.phone
payload.channel = MagicLinkChannel.PHONE
elif isinstance(args, MagicLinkWithUserArgs):
payload.user_id = args.user_id
payload.channel = args.channel
else:
msg = "args must be an instance of MagicLinkArgs"
raise TypeError(msg)

if options:
payload.language = options.language
payload.magic_link_path = options.magic_link_path
payload.redirect_url = options.redirect_url
payload.ttl = options.ttl

try:
return self.magic_links_api.create_magic_link(
self.app_id,
magic_link_req, # type: ignore[arg-type]
payload,
_headers=self.request_headers,
).magic_link
except ApiException as e:
Expand Down
2 changes: 2 additions & 0 deletions passageidentity/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
)
from passageidentity.models.update_passkey_auth_method import UpdatePasskeysAuthMethod
from passageidentity.models.update_otp_auth_method import UpdateOtpAuthMethod
from passageidentity.models.magic_link_args import MagicLinkArgs
from passageidentity.models.magic_link_options import MagicLinkOptions
35 changes: 35 additions & 0 deletions passageidentity/models/magic_link_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Defines required arguments for creating a Magic Link."""

from typing import Union

from passageidentity.openapi_client.models.magic_link_channel import MagicLinkChannel
from passageidentity.openapi_client.models.magic_link_type import MagicLinkType


class MagicLinkArgsBase:
"""Base class for MagicLinkArgs."""

type: MagicLinkType
send: bool


class MagicLinkWithEmailArgs(MagicLinkArgsBase):
"""Arguments for creating a Magic Link with an email."""

email: str


class MagicLinkWithPhoneArgs(MagicLinkArgsBase):
"""Arguments for creating a Magic Link with a phone number."""

phone: str


class MagicLinkWithUserArgs(MagicLinkArgsBase):
"""Arguments for creating a Magic Link with a user ID."""

user_id: str
channel: MagicLinkChannel


MagicLinkArgs = Union[MagicLinkWithEmailArgs, MagicLinkWithPhoneArgs, MagicLinkWithUserArgs]
12 changes: 12 additions & 0 deletions passageidentity/models/magic_link_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Defines options for creating a Magic Link."""

from __future__ import annotations


class MagicLinkOptions:
"""Options for creating a Magic Link."""

language: str | None
magic_link_path: str | None
redirect_url: str | None
ttl: int | None
49 changes: 43 additions & 6 deletions passageidentity/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,32 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import typing_extensions

from passageidentity.auth import Auth
from passageidentity.errors import PassageError
from passageidentity.helper import get_auth_token_from_request
from passageidentity.models.magic_link_args import MagicLinkWithEmailArgs, MagicLinkWithPhoneArgs, MagicLinkWithUserArgs
from passageidentity.models.magic_link_options import MagicLinkOptions
from passageidentity.openapi_client.models.magic_link_channel import MagicLinkChannel
from passageidentity.user import User

from .openapi_client.api import (
AppsApi,
)
from .openapi_client.models import (
CreateMagicLinkRequest,
CreateUserRequest,
MagicLinkType,
)

if TYPE_CHECKING:
from requests.sessions import Request

from .openapi_client.models import (
AppInfo,
CreateMagicLinkRequest,
CreateUserRequest,
MagicLinkType,
UpdateUserRequest,
UserInfo,
WebAuthnDevices,
Expand Down Expand Up @@ -104,7 +109,33 @@ def createMagicLink( # noqa: N802
msg = "No Passage API key provided."
raise PassageError(msg)

return self.auth.create_magic_link(magicLinkAttributes) # type: ignore[attr-defined]
magic_link_attrs_dict = (
magicLinkAttributes.to_dict()
if isinstance(magicLinkAttributes, CreateMagicLinkRequest)
else magicLinkAttributes
)

if "email" in magic_link_attrs_dict:
args = MagicLinkWithEmailArgs()
args.email = magic_link_attrs_dict["email"]
elif "phone" in magic_link_attrs_dict:
args = MagicLinkWithPhoneArgs()
args.phone = magic_link_attrs_dict["phone"]
elif "user_id" in magic_link_attrs_dict:
args = MagicLinkWithUserArgs()
args.user_id = magic_link_attrs_dict["user_id"]
args.channel = magic_link_attrs_dict.get("channel") or MagicLinkChannel.EMAIL

args.send = magic_link_attrs_dict.get("send") or False
args.type = magic_link_attrs_dict.get("type") or MagicLinkType.LOGIN

options = MagicLinkOptions()
options.language = magic_link_attrs_dict.get("language")
options.magic_link_path = magic_link_attrs_dict.get("magic_link_path")
options.redirect_url = magic_link_attrs_dict.get("redirect_url")
options.ttl = magic_link_attrs_dict.get("ttl")

return self.auth.create_magic_link(args, options) # type: ignore[attr-defined]

@typing_extensions.deprecated("Passage.getApp() will be removed without replacement.")
def getApp(self) -> AppInfo | PassageError: # noqa: N802
Expand Down Expand Up @@ -212,4 +243,10 @@ def createUser( # noqa: N802
msg = "either phone or email must be provided to create the user"
raise PassageError(msg)

return self.user.create(userAttributes)
user_args = (
cast(CreateUserRequest, CreateUserRequest.from_dict(userAttributes))
if isinstance(userAttributes, dict)
else userAttributes
)

return self.user.create(user_args)
48 changes: 46 additions & 2 deletions passageidentity/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self, app_id: str, request_headers: dict[str, str]) -> None:

def get(self, user_id: str) -> PassageUser:
"""Get a user's object using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.get_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -49,6 +53,10 @@ def get(self, user_id: str) -> PassageUser:

def get_by_identifier(self, identifier: str) -> PassageUser:
"""Get a user's object using their user identifier."""
if not identifier:
msg = "identifier is required."
raise ValueError(msg)

try:
users = self.users_api.list_paginated_users(
self.app_id,
Expand All @@ -68,6 +76,10 @@ def get_by_identifier(self, identifier: str) -> PassageUser:

def activate(self, user_id: str) -> PassageUser:
"""Activate a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.activate_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -76,22 +88,34 @@ def activate(self, user_id: str) -> PassageUser:

def deactivate(self, user_id: str) -> PassageUser:
"""Deactivate a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.deactivate_user(self.app_id, user_id, _headers=self.request_headers).user
except ApiException as e:
msg = "Could not deactivate user"
raise PassageError.from_response_error(e, msg) from e

def update(self, user_id: str, args: UpdateUserArgs) -> PassageUser:
def update(self, user_id: str, options: UpdateUserArgs) -> PassageUser:
"""Update a user."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.users_api.update_user(self.app_id, user_id, args, _headers=self.request_headers).user
return self.users_api.update_user(self.app_id, user_id, options, _headers=self.request_headers).user
except ApiException as e:
msg = "Could not update user"
raise PassageError.from_response_error(e, msg) from e

def create(self, args: CreateUserArgs) -> PassageUser:
"""Create a user."""
if not args.email and not args.phone:
msg = "At least one of args.email or args.phone is required."
raise ValueError(msg)

try:
return self.users_api.create_user(self.app_id, args, _headers=self.request_headers).user
except ApiException as e:
Expand All @@ -100,6 +124,10 @@ def create(self, args: CreateUserArgs) -> PassageUser:

def delete(self, user_id: str) -> None:
"""Delete a user using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
self.users_api.delete_user(self.app_id, user_id, _headers=self.request_headers)
except ApiException as e:
Expand All @@ -108,6 +136,10 @@ def delete(self, user_id: str) -> None:

def list_devices(self, user_id: str) -> list[WebAuthnDevices]:
"""Get a user's devices using their user ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
return self.user_devices_api.list_user_devices(self.app_id, user_id, _headers=self.request_headers).devices
except ApiException as e:
Expand All @@ -116,6 +148,14 @@ def list_devices(self, user_id: str) -> list[WebAuthnDevices]:

def revoke_device(self, user_id: str, device_id: str) -> None:
"""Revoke a user's device using their user ID and the device ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

if not device_id:
msg = "device_id is required."
raise ValueError(msg)

try:
self.user_devices_api.delete_user_devices(self.app_id, user_id, device_id, _headers=self.request_headers)
except ApiException as e:
Expand All @@ -124,6 +164,10 @@ def revoke_device(self, user_id: str, device_id: str) -> None:

def revoke_refresh_tokens(self, user_id: str) -> None:
"""Revokes all of a user's Refresh Tokens using their User ID."""
if not user_id:
msg = "user_id is required."
raise ValueError(msg)

try:
self.tokens_api.revoke_user_refresh_tokens(self.app_id, user_id, _headers=self.request_headers)
except ApiException as e:
Expand Down
Loading

0 comments on commit 83cd628

Please sign in to comment.