Skip to content

Commit

Permalink
✨ New product-bound login: user must have access to target product 🚨 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov authored Jan 28, 2024
1 parent 3693767 commit 4fa4917
Show file tree
Hide file tree
Showing 38 changed files with 823 additions and 384 deletions.
4 changes: 2 additions & 2 deletions api/specs/web-server/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from models_library.generics import Envelope
from pydantic import BaseModel, Field, confloat
from simcore_service_webserver._meta import API_VTAG
from simcore_service_webserver.login.handlers_2fa import Resend2faBody
from simcore_service_webserver.login.handlers_auth import (
from simcore_service_webserver.login._auth_handlers import (
LoginBody,
LoginNextPage,
LoginTwoFactorAuthBody,
LogoutBody,
)
from simcore_service_webserver.login.handlers_2fa import Resend2faBody
from simcore_service_webserver.login.handlers_change import (
ChangeEmailBody,
ChangePasswordBody,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _compute_hash(password: str) -> str:
return hashlib.sha224(password.encode("ascii")).hexdigest()


DEFAULT_PASSWORD = "password-with-at-least-12-characters"
_DEFAULT_HASH = _compute_hash(DEFAULT_PASSWORD)
DEFAULT_TEST_PASSWORD = "password-with-at-least-12-characters" # noqa: S105
_DEFAULT_HASH = _compute_hash(DEFAULT_TEST_PASSWORD)


def random_user(
Expand Down
91 changes: 71 additions & 20 deletions packages/pytest-simcore/src/pytest_simcore/helpers/utils_login.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import re
from datetime import datetime
from typing import TypedDict
from typing import Any, TypedDict

from aiohttp import web
from aiohttp.test_utils import TestClient
from models_library.users import UserID
from simcore_service_webserver.db.models import UserRole, UserStatus
from simcore_service_webserver.groups.api import auto_add_user_to_product_group
from simcore_service_webserver.login._constants import MSG_LOGGED_IN
from simcore_service_webserver.login._registration import create_invitation_token
from simcore_service_webserver.login.storage import AsyncpgStorage, get_plugin_storage
from simcore_service_webserver.products.api import list_products
from simcore_service_webserver.security.api import clean_auth_policy_cache
from yarl import URL

from .rawdata_fakers import DEFAULT_FAKER, DEFAULT_PASSWORD, random_user
from .rawdata_fakers import DEFAULT_FAKER, DEFAULT_TEST_PASSWORD, random_user
from .utils_assert import assert_status


Expand Down Expand Up @@ -55,31 +58,73 @@ def parse_link(text):
return URL(link).path


async def _insert_fake_user(db: AsyncpgStorage, data=None) -> UserInfoDict:
"""Creates a fake user and inserts it in the users table in the database"""
async def _create_user(app: web.Application, data=None) -> UserInfoDict:
db: AsyncpgStorage = get_plugin_storage(app)

# create
data = data or {}
data.setdefault(
"password", DEFAULT_PASSWORD
) # Password must be at least 12 characters long
data.setdefault("status", UserStatus.ACTIVE.name)
data.setdefault("role", UserRole.USER.name)
params = random_user(**data)
data.setdefault("password", DEFAULT_TEST_PASSWORD)
user = await db.create_user(random_user(**data))

# get
user = await db.get_user({"id": user["id"]})
assert "first_name" in user
assert "last_name" in user

# adds extras
extras = {"raw_password": data["password"]}

return UserInfoDict(
**{
key: user[key]
for key in [
"id",
"name",
"email",
"primary_gid",
"status",
"role",
"created_at",
"password_hash",
"first_name",
"last_name",
]
},
**extras,
)


async def _register_user_in_default_product(app: web.Application, user_id: UserID):
products = list_products(app)
assert products
product_name = products[0].name

return await auto_add_user_to_product_group(app, user_id, product_name=product_name)


user = await db.create_user(params)
user["raw_password"] = data["password"]
user.setdefault("first_name", None)
user.setdefault("last_name", None)
async def _create_account(
app: web.Application,
user_data: dict[str, Any] | None = None,
) -> UserInfoDict:
# users, groups in db
user = await _create_user(app, user_data)
# user has default product
await _register_user_in_default_product(app, user_id=user["id"])
return user


async def log_client_in(
client: TestClient, user_data=None, *, enable_check=True
client: TestClient,
user_data: dict[str, Any] | None = None,
*,
enable_check=True,
) -> UserInfoDict:
# creates user directly in db
assert client.app
db: AsyncpgStorage = get_plugin_storage(client.app)

user = await _insert_fake_user(db, user_data)
# create account
user = await _create_account(client.app, user_data=user_data)

# login
url = client.app.router["auth_login"].url_for()
Expand All @@ -98,14 +143,19 @@ async def log_client_in(


class NewUser:
def __init__(self, params=None, app: web.Application | None = None):
def __init__(
self,
params: dict[str, Any] | None = None,
app: web.Application | None = None,
):
self.params = params
self.user = None
assert app
self.db = get_plugin_storage(app)
self.app = app

async def __aenter__(self) -> UserInfoDict:
self.user = await _insert_fake_user(self.db, self.params)
self.user = await _create_account(self.app, self.params)
return self.user

async def __aexit__(self, *args):
Expand All @@ -117,6 +167,7 @@ def __init__(self, client: TestClient, params=None, *, check_if_succeeds=True):
super().__init__(params, client.app)
self.client = client
self.enable_check = check_if_succeeds
assert self.client.app

async def __aenter__(self) -> UserInfoDict:
self.user = await log_client_in(
Expand All @@ -125,6 +176,7 @@ async def __aenter__(self) -> UserInfoDict:
return self.user

async def __aexit__(self, *args):
assert self.client.app
# NOTE: cache key is based on an email. If the email is
# reused during the test, then it creates quite some noise
await clean_auth_policy_cache(self.client.app)
Expand All @@ -151,8 +203,7 @@ def __init__(
async def __aenter__(self) -> "NewInvitation":
# creates host user
assert self.client.app
db: AsyncpgStorage = get_plugin_storage(self.client.app)
self.user = await _insert_fake_user(db, self.params)
self.user = await _create_user(self.client.app, self.params)

self.confirmation = await create_invitation_token(
self.db,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from simcore_postgres_database.utils_users import UsersRepo

from ..db.plugin import get_database_engine
from ..groups.api import is_user_by_email_in_group
from ..products.api import Product
from ..security.api import check_password, encrypt_password
from ._constants import MSG_UNKNOWN_EMAIL, MSG_WRONG_PASSWORD
Expand All @@ -25,7 +26,7 @@ async def create_user(
email: str,
password: str,
status: UserStatus,
expires_at: datetime | None
expires_at: datetime | None,
) -> dict:

async with get_database_engine(app).acquire() as conn:
Expand All @@ -39,7 +40,7 @@ async def create_user(
return dict(user.items())


async def check_authorized_user_or_raise(
async def check_authorized_user_credentials_or_raise(
user: dict,
password: str,
product: Product,
Expand All @@ -56,5 +57,23 @@ async def check_authorized_user_or_raise(
raise web.HTTPUnauthorized(
reason=MSG_WRONG_PASSWORD, content_type=MIMETYPE_APPLICATION_JSON
)

return user


async def check_authorized_user_in_product_or_raise(
app: web.Application,
*,
user: dict,
product: Product,
) -> None:
"""Checks whether user is registered in this product"""
email = user.get("email", "").lower()
product_group_id = product.group_id
assert product_group_id is not None # nosec

if product_group_id is not None and not await is_user_by_email_in_group(
app, user_email=email, group_id=product_group_id
):
raise web.HTTPUnauthorized(
reason=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
mask_phone_number,
send_sms_code,
)
from ._auth_api import check_authorized_user_or_raise, get_user_by_email
from ._auth_api import (
check_authorized_user_credentials_or_raise,
check_authorized_user_in_product_or_raise,
get_user_by_email,
)
from ._constants import (
CODE_2FA_CODE_REQUIRED,
CODE_PHONE_NUMBER_REQUIRED,
Expand Down Expand Up @@ -92,18 +96,23 @@ async def login(request: web.Request):
)
login_ = await parse_request_body_as(LoginBody, request)

user = await check_authorized_user_or_raise(
# auth user and has access to product
user = await check_authorized_user_credentials_or_raise(
user=await get_user_by_email(request.app, email=login_.email),
password=login_.password.get_secret_value(),
product=product,
)
await check_authorized_user_in_product_or_raise(
request.app, user=user, product=product
)

# Some roles have login privileges
skip_2fa: bool = UserRole(user["role"]) == UserRole.TESTER
if skip_2fa or not settings.LOGIN_2FA_REQUIRED:
return await login_granted_response(request, user=user)

# no phone
# 2FA login (continuation)
# check phone
if not user["phone"]:
return envelope_response(
# LoginNextPage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._constants import MSG_LOGGED_IN
from .utils import flash_response

log = logging.getLogger(__name__)
_logger = logging.getLogger(__name__)


async def login_granted_response(
Expand All @@ -29,7 +29,7 @@ async def login_granted_response(
user_id = user.get("id")

with log_context(
log,
_logger,
logging.INFO,
"login of user_id=%s with %s",
f"{user_id}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import inspect

from aiohttp import web
from aiohttp_security.api import check_authorized
from servicelib.aiohttp.typing_extension import HandlerAnyReturn
from servicelib.request_keys import RQT_USERID_KEY

from ..products.api import get_product_name
from ..security.api import AuthContextDict, check_user_authorized, check_user_permission


def login_required(handler: HandlerAnyReturn) -> HandlerAnyReturn:
"""Decorator that restrict access only for authorized users
"""Decorator that restrict access only for authorized users with permissions to access a given product
- User is considered authorized if check_authorized(request) raises no exception
- If authorized, it injects user_id in request[RQT_USERID_KEY]
Expand Down Expand Up @@ -42,12 +44,23 @@ async def get_foo(request: web.Request):
async def _wrapper(request: web.Request):
"""
Raises:
HTTPUnauthorized: if request authorization check fails
HTTPUnauthorized: if unauthorized user
HTTPForbidden: if user not allowed in product
"""
# WARNING: note that check_authorized is patched in some tests.
# Careful when changing the function signature
request[RQT_USERID_KEY] = await check_authorized(request)
user_id = await check_user_authorized(request)

await check_user_permission(
request,
"product",
context=AuthContextDict(
product_name=get_product_name(request),
authorized_uid=user_id,
),
)

request[RQT_USERID_KEY] = user_id
return await handler(request)

return _wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def register(request: web.Request):
# get authorized user or create new
user = await _auth_api.get_user_by_email(request.app, email=registration.email)
if user:
await _auth_api.check_authorized_user_or_raise(
await _auth_api.check_authorized_user_credentials_or_raise(
user,
password=registration.password.get_secret_value(),
product=product,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from ..redis import setup_redis
from ..rest.plugin import setup_rest
from . import (
_auth_handlers,
_registration_handlers,
handlers_2fa,
handlers_auth,
handlers_change,
handlers_confirmation,
handlers_registration,
Expand Down Expand Up @@ -157,7 +157,7 @@ def setup_login(app: web.Application):

# routes

app.router.add_routes(handlers_auth.routes)
app.router.add_routes(_auth_handlers.routes)
app.router.add_routes(handlers_confirmation.routes)
app.router.add_routes(handlers_registration.routes)
app.router.add_routes(_registration_handlers.routes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def discover_product_middleware(request: web.Request, handler: Handler):
if (
request.path.startswith(f"/{API_VTAG}")
or request.path == "/static-frontend-data.json"
or request.path == "/socket.io/"
):
product_name = (
_discover_product_by_request_header(request)
Expand All @@ -69,10 +70,8 @@ async def discover_product_middleware(request: web.Request, handler: Handler):

request[RQ_PRODUCT_KEY] = product_name

assert ( # nosec
request.get(RQ_PRODUCT_KEY) is not None
or request.path == "/socket.io/"
or request.path.startswith("/dev/doc")
assert request.get(RQ_PRODUCT_KEY) is not None or request.path.startswith( # nosec
"/dev/doc"
)

return await handler(request)
Loading

0 comments on commit 4fa4917

Please sign in to comment.