From 8bda8194a00da4aa0efd76646993d6a70e3902c4 Mon Sep 17 00:00:00 2001 From: brassy endomorph Date: Wed, 25 Sep 2024 07:56:19 +0000 Subject: [PATCH] added migration testing for existing migrations --- hushline/__init__.py | 6 +- hushline/db.py | 2 + migrations/env.py | 7 +- pyproject.toml | 76 ++-- tests/conftest.py | 11 +- tests/helpers.py | 52 +++ tests/migrations/revision_46aedec8fd9b.py | 497 ++++++++++++++++++++++ tests/test_migrations.py | 86 ++++ 8 files changed, 679 insertions(+), 58 deletions(-) create mode 100644 tests/helpers.py create mode 100644 tests/migrations/revision_46aedec8fd9b.py create mode 100644 tests/test_migrations.py diff --git a/hushline/__init__.py b/hushline/__init__.py index 912a5571..519d9fd9 100644 --- a/hushline/__init__.py +++ b/hushline/__init__.py @@ -4,13 +4,12 @@ from typing import Any from flask import Flask, flash, redirect, request, session, url_for -from flask_migrate import Migrate from jinja2 import StrictUndefined from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.wrappers.response import Response from . import admin, routes, settings -from .db import db +from .db import db, migrate from .model import User from .version import __version__ @@ -70,9 +69,8 @@ def create_app() -> Flask: app.logger.info("Development environment detected, enabling jinja2.StrictUndefined") app.jinja_env.undefined = StrictUndefined - # Run migrations db.init_app(app) - Migrate(app, db) + migrate.init_app(app, db) routes.init_app(app) for module in [admin, settings]: diff --git a/hushline/db.py b/hushline/db.py index bc079ed2..0b8e290d 100644 --- a/hushline/db.py +++ b/hushline/db.py @@ -1,3 +1,4 @@ +from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import MetaData @@ -12,3 +13,4 @@ ) db = SQLAlchemy(metadata=metadata) +migrate = Migrate() diff --git a/migrations/env.py b/migrations/env.py index c4c00993..8995f659 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -15,12 +15,7 @@ def get_engine(): - try: - # this works with Flask-SQLAlchemy<3 and Alchemical - return current_app.extensions["migrate"].db.get_engine() - except (TypeError, AttributeError): - # this works with Flask-SQLAlchemy>=3 - return current_app.extensions["migrate"].db.engine + return current_app.extensions["migrate"].db.engine def get_engine_url(): diff --git a/pyproject.toml b/pyproject.toml index 96047116..519f4ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,10 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] pythonpath = "." +filterwarnings = [ + # passlib unmaintianed, see: https://github.com/scidsg/hushline/issues/553 + "ignore:.*'crypt' is deprecated.*:DeprecationWarning", +] [tool.ruff] line-length = 100 @@ -49,56 +53,40 @@ indent-width = 4 [tool.ruff.lint] select = [ - # pycodestyle errors - "E", - # pyflakes - "F", - # isort - "I", - # flake8-gettext - "INT", - # flake8-pie - "PIE", - # pylint - "PL", - # flake8-pytest-style - "PT", - # flake8-pyi - "PYI", - # flake8-return - "RET", - # flake8-bandit - "S", - # flake8-simplify - "SIM", - # pyupgrade - "UP", - # pycodestyle warnings - "W", - # Unused noqa directive - "RUF100", + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "INT", # flake8-gettext + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PYI", # flake8-pyi + "RET", # flake8-return + "S", # flake8-bandit + "SIM", # flake8-simplify + "UP", # pyupgrade + "W", # pycodestyle warnings + "RUF100", # Unused noqa directive ] ignore = [ - # https://docs.astral.sh/ruff/rules/too-many-statements/ - "PLR0915", - # https://docs.astral.sh/ruff/rules/too-many-return-statements/ - "PLR0911", - # https://docs.astral.sh/ruff/rules/too-many-branches/ - "PLR0912", + "PLR0911", # too-many-return-statements + "PLR0912", # too-many-branches + "PLR0915", # too-many-statements ] [tool.ruff.lint.per-file-ignores] +"migrations/versions/*.py" = [ + "I001", # unsorted-imports + "S608", # hardcoded-sql-expression +] "tests/*.py" = [ - # https://docs.astral.sh/ruff/rules/assert/ - "S101", - "S105", # hardcoded password - # https://docs.astral.sh/ruff/rules/magic-value-comparison/ - "PLR2004", + "PLR2004", # magic-value-comparison + "S101", # assert + "S105", # hardcoded-password-string + "S311", # suspicious-non-cryptographic-random-usage ] -"migrations/versions/*.py" = [ - # https://docs.astral.sh/ruff/rules/unsorted-imports/ - "I001", - "S608", # sql injection via string-based query construction +"tests/migrations/*.py" = [ + "S608", # hardcoded-sql-expression ] [tool.mypy] @@ -108,4 +96,4 @@ no_implicit_optional = true disallow_untyped_defs = true disallow_incomplete_defs = true warn_unused_configs = true -exclude = '^migrations/env\.py$|^migrations/versions/.*\.py$' +exclude = "^migrations/env\\.py$" diff --git a/tests/conftest.py b/tests/conftest.py index 8f4ee069..e1d6fdf5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,7 +39,7 @@ def pytest_addoption(parser: Parser) -> None: def random_name(size: int) -> str: - return "".join([random.choice(string.ascii_lowercase) for _ in range(size)]) # noqa: S311 + return "".join([random.choice(string.ascii_lowercase) for _ in range(size)]) @contextmanager @@ -112,8 +112,7 @@ def init_db_via_alembic(db_uri: str) -> None: @pytest.fixture() -def database(_db_template: None) -> str: - """A clean Postgres database from the template with DDLs applied""" +def database(request: pytest.FixtureRequest, _db_template: None) -> str: db_name = random_name(16) conn_str = CONN_FMT_STR.format(database="hushline") engine = create_engine(conn_str) @@ -121,7 +120,11 @@ def database(_db_template: None) -> str: session = sessionmaker(bind=engine)() - sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}") + if request.module.__name__ == "test_migrations": + # don't use the template when testing migrations. we want a blank db + sql = text(f"CREATE DATABASE {db_name}") + else: + sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}") session.execute(sql) # aggressively terminate all connections diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..61c8f503 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,52 @@ +import random +import string +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar + +T = TypeVar("T") + + +def one_of(xs: Sequence[T], predicate: Callable[[T], bool]) -> T: + matches = [x for x in xs if predicate(x)] + match len(matches): + case 1: + return matches[0] + case 0: + raise ValueError("No matches") + case _: + raise ValueError(f"Too many matches: {matches}") + + +def random_bool() -> bool: + return bool(random.getrandbits(1)) + + +def random_optional_bool() -> Optional[bool]: + if random_bool(): + return None + return random_bool() + + +def random_string(length: int) -> str: + return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) + + +def random_optional_string(length: int) -> Optional[str]: + if random_bool(): + return None + return random_string(length) + + +def format_param_dict(params: Mapping[str, Any]) -> Tuple[str, str]: + return (", ".join(params.keys()), ", ".join(f":{x}" for x in params)) + + +class Missing: + def __eq__(self, other: object) -> bool: + return False + + def __ne__(self, other: object) -> bool: + return True + + +assert Missing() != Missing() +assert Missing() != Missing() diff --git a/tests/migrations/revision_46aedec8fd9b.py b/tests/migrations/revision_46aedec8fd9b.py new file mode 100644 index 00000000..05dc0bcc --- /dev/null +++ b/tests/migrations/revision_46aedec8fd9b.py @@ -0,0 +1,497 @@ +import random +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from hushline.db import db + +from ..helpers import ( + Missing, + format_param_dict, + one_of, + random_bool, + random_optional_bool, + random_optional_string, + random_string, +) + + +@dataclass(frozen=True) +class OldSecondaryUser: + id: int + user_id: int + username: str + display_name: Optional[str] + + +@dataclass(frozen=True) +class OldMessage: + id: int + user_id: int + secondary_user_id: Optional[int] + content: str + + +@dataclass(frozen=True) +class OldUser: + id: int + primary_username: str + display_name: Optional[str] + bio: Optional[str] + show_in_directory: Optional[bool] + is_admin: bool + is_verified: bool + + secondary_usernames: List[OldSecondaryUser] + messages: List[OldMessage] + + extra_field_label1: Optional[str] = None + extra_field_label2: Optional[str] = None + extra_field_label3: Optional[str] = None + extra_field_label4: Optional[str] = None + + extra_field_value1: Optional[str] = None + extra_field_value2: Optional[str] = None + extra_field_value3: Optional[str] = None + extra_field_value4: Optional[str] = None + + extra_field_verified1: Optional[bool] = None + extra_field_verified2: Optional[bool] = None + extra_field_verified3: Optional[bool] = None + extra_field_verified4: Optional[bool] = None + + +@dataclass(frozen=True) +class NewMessage: + id: int + username_id: int + content: str + + +@dataclass(frozen=True) +class NewUsername: + id: int + user_id: int + username: str + display_name: Optional[str] + bio: Optional[str] + show_in_directory: Optional[bool] + is_primary: bool + is_verified: bool + + messages: List[NewMessage] + + extra_field_label1: Optional[str] = None + extra_field_label2: Optional[str] = None + extra_field_label3: Optional[str] = None + extra_field_label4: Optional[str] = None + + extra_field_value1: Optional[str] = None + extra_field_value2: Optional[str] = None + extra_field_value3: Optional[str] = None + extra_field_value4: Optional[str] = None + + extra_field_verified1: Optional[bool] = None + extra_field_verified2: Optional[bool] = None + extra_field_verified3: Optional[bool] = None + extra_field_verified4: Optional[bool] = None + + +@dataclass(frozen=True) +class NewUser: + id: int + is_admin: bool + + usernames: List[NewUsername] + + +class UpgradeTester: + def __init__(self) -> None: + self.old_users: List[OldUser] = [] + + def load_data(self) -> None: + for user_idx in range(12): + user_params: Dict[str, Any] = { + "primary_username": f"user_{random_string(10)}", + "display_name": random_optional_string(10), + "is_verified": random_bool(), + "is_admin": random_bool(), + "show_in_directory": random_bool(), + "bio": random_optional_string(10), + } + + for i in range(1, 5): + if random_bool(): + user_params[f"extra_field_label{i}"] = random_string(10) + user_params[f"extra_field_value{i}"] = random_string(10) + user_params[f"extra_field_verified{i}"] = random_optional_bool() + + columns, param_args = format_param_dict(user_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO users (password_hash, {columns}) + VALUES ('$scrypt$', {param_args}) + RETURNING id + """ + ), + user_params, + ), + ) + + user_id = result[0][0] + secondary_usernames = [] + messages = [] + + # make 0, 1, or 2 secondary usernames + for second_idx in range(user_idx % 3): + secondary_params: Dict[str, Any] = { + "user_id": user_id, + "username": random_string(10), + "display_name": random_optional_string(10), + } + columns, param_args = format_param_dict(secondary_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO secondary_usernames ({columns}) + VALUES ({param_args}) + RETURNING id + """ + ), + secondary_params, + ) + ) + secondary_usernames.append(OldSecondaryUser(id=result[0][0], **secondary_params)) + + for _ in range(10): + msg_params: Dict[str, Any] = { + "content": random_string(10), + "secondary_user_id": random.choice(secondary_usernames).id + if secondary_usernames and random_bool() + else None, + } + columns, param_args = format_param_dict(msg_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO message (user_id, {columns}) + VALUES (:user_id, {param_args}) + RETURNING id + """ + ), + params=dict(user_id=user_id, **msg_params), + ) + ) + messages.append(OldMessage(id=result[0][0], user_id=user_id, **msg_params)) + + db.session.commit() + self.old_users.append( + OldUser( + id=user_id, + secondary_usernames=secondary_usernames, + messages=messages, + **user_params, + ) + ) + + assert self.old_users # sensiblity check + + def check_upgrade(self) -> None: + messages_by_username_id = defaultdict(list) + results = db.session.execute(db.text("SELECT * FROM message")) + for result in results: + result_dict = result._asdict() + msg = NewMessage(**result_dict) + messages_by_username_id[msg.username_id].append(msg) + + usernames_by_user_id = defaultdict(list) + results = db.session.execute(db.text("SELECT * FROM usernames")) + for result in results: + result_dict = result._asdict() + username = NewUsername( + **result_dict, messages=messages_by_username_id[result_dict["id"]] + ) + usernames_by_user_id[username.user_id].append(username) + + new_users = [] + results = db.session.execute(db.text("SELECT id, is_admin FROM users")) + for result in results: + result_dict = result._asdict() + user_id = result_dict["id"] + + messages = [] + for username in usernames_by_user_id[user_id]: + messages.extend(messages_by_username_id[username.id]) + + user = NewUser( + id=user_id, + is_admin=result_dict["is_admin"], + usernames=usernames_by_user_id[user_id], + ) + new_users.append(user) + + # sensible quick checks first: + # users equal + assert len(new_users) == len(self.old_users) + # usernames = users + secondaries + assert sum(len(x) for x in usernames_by_user_id.values()) == len(self.old_users) + sum( + len(x.secondary_usernames) for x in self.old_users + ) + # messages equal + assert sum(len(y.messages) for x in new_users for y in x.usernames) == sum( + len(x.messages) for x in self.old_users + ) + + for old_user in self.old_users: + new_user = one_of(new_users, lambda x: x.id == old_user.id) + new_username = one_of(usernames_by_user_id[old_user.id], lambda x: x.is_primary) + assert new_username.user_id == old_user.id + assert new_username.username == old_user.primary_username + + attrs = ["bio", "is_verified", "show_in_directory", "display_name"] + for i in range(1, 5): + attrs.append(f"extra_field_label{i}") + attrs.append(f"extra_field_value{i}") + attrs.append(f"extra_field_verified{i}") + + for attr in attrs: + assert getattr(new_username, attr, Missing()) == getattr(old_user, attr, Missing()) + + # check that all secondary usernames transferred + new_secondaries = [x for x in usernames_by_user_id[old_user.id] if not x.is_primary] + for old_secondary in old_user.secondary_usernames: + new_secondary = one_of( + new_secondaries, + lambda x: x.username == old_secondary.username and not x.is_primary, + ) + + for attr in ["username", "display_name"]: + assert getattr(new_secondary, attr, Missing()) == getattr( + old_secondary, attr, Missing() + ) + + # check that all messages updated correctly + for old_message in old_user.messages: + new_message_matches = [ + y for x in new_user.usernames for y in x.messages if y.id == old_message.id + ] + assert len(new_message_matches) == 1 + new_message = new_message_matches[0] + + assert new_message.username_id in [x.id for x in new_user.usernames] + assert new_message.content == old_message.content + if old_message.secondary_user_id: + assert one_of( + old_user.secondary_usernames, + lambda x: x.id == old_message.secondary_user_id, + ) + + +class DowngradeTester: + def __init__(self) -> None: + self.new_users: List[NewUser] = [] + + def load_data(self) -> None: + for user_idx in range(12): + usernames = [] + + user_params: Dict[str, Any] = { + "is_admin": random_bool(), + } + + columns, param_args = format_param_dict(user_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO users (password_hash, {columns}) + VALUES ('$scrypt$', {param_args}) + RETURNING id + """ + ), + user_params, + ), + ) + + user_id = result[0][0] + + # make 1 primary and 0, 1, or 2 aliases + for username_idx in range(user_idx % 3 + 1): + messages = [] + + username_params: Dict[str, Any] = { + "user_id": user_id, + "username": random_string(20), + "display_name": random_optional_string(10), + "is_primary": username_idx == 0, + "is_verified": random_bool(), + "show_in_directory": random_bool(), + "bio": random_optional_string(10), + } + + for i in range(1, 5): + if random_bool(): + username_params[f"extra_field_label{i}"] = random_string(10) + username_params[f"extra_field_value{i}"] = random_string(10) + username_params[f"extra_field_verified{i}"] = random_optional_bool() + + columns, param_args = format_param_dict(username_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO usernames ({columns}) + VALUES ({param_args}) + RETURNING id + """ + ), + username_params, + ), + ) + + username_params.pop("user_id") + username_id = result[0][0] + + for _ in range(5): + msg_params: Dict[str, Any] = { + "username_id": username_id, + "content": random_string(10), + } + + columns, param_args = format_param_dict(msg_params) + result = list( + db.session.execute( + db.text( + f""" + INSERT INTO message ({columns}) + VALUES ({param_args}) + RETURNING id + """ + ), + msg_params, + ) + ) + + messages.append( + NewMessage( + id=result[0][0], + username_id=username_id, + content=msg_params["content"], + ) + ) + + usernames.append( + NewUsername( + id=username_id, + user_id=user_id, + **username_params, + messages=messages, + ) + ) + + self.new_users.append( + NewUser( + id=user_id, + is_admin=user_params["is_admin"], + usernames=usernames, + ), + ) + + db.session.commit() + + def check_downgrade(self) -> None: + assert self.new_users + + old_secondaries_by_user_id = defaultdict(list) + results = db.session.execute(db.text("SELECT * FROM secondary_usernames")) + for result in results: + result_dict = result._asdict() + old_secondaries_by_user_id[result_dict["user_id"]].append( + OldSecondaryUser(**result_dict) + ) + + old_messages_by_user_id = defaultdict(list) + results = db.session.execute(db.text("SELECT * FROM message")) + for result in results: + result_dict = result._asdict() + old_messages_by_user_id[result_dict["user_id"]].append(OldMessage(**result_dict)) + + old_users = [] + results = db.session.execute(db.text("SELECT * from users")) + skip_keys = ["password_hash", "totp_secret", "email", "pgp_key"] + for result in results: + result_dict = { + k: v + for k, v in result._asdict().items() + if k not in skip_keys and not k.startswith("smtp_") + } + old_users.append( + OldUser( + secondary_usernames=old_secondaries_by_user_id[result_dict["id"]], + messages=old_messages_by_user_id[result_dict["id"]], + **result_dict, + ) + ) + + # sensibility checks first: + # users equal + assert len(old_users) == len(self.new_users) + # users + secondaries = usernames + assert len(old_users) + sum(len(x.secondary_usernames) for x in old_users) == sum( + len(x.usernames) for x in self.new_users + ) + # messages equal + assert sum(len(x.messages) for x in old_users) == sum( + len(y.messages) for x in self.new_users for y in x.usernames + ) + + for new_user in self.new_users: + old_user = one_of(old_users, lambda x: x.id == new_user.id) + + for new_username in new_user.usernames: + if new_username.is_primary: + # only primary usernames retain their fields + attrs = ["bio", "is_verified", "show_in_directory", "display_name"] + for i in range(1, 5): + attrs.append(f"extra_field_label{i}") + attrs.append(f"extra_field_value{i}") + attrs.append(f"extra_field_verified{i}") + + for attr in attrs: + assert getattr(old_user, attr, Missing()) == getattr( + new_username, attr, Missing() + ) + else: + # only secondary usernames will have a match in the downgraded + # secondary_usernames table + old_secondary = one_of( + old_user.secondary_usernames, lambda x: x.username == new_username.username + ) + assert old_secondary.user_id == new_username.user_id + assert old_secondary.display_name == new_username.display_name + + for new_msg in new_username.messages: + old_msg = one_of(old_user.messages, lambda x: x.id == new_msg.id) + assert old_msg.content == new_msg.content + assert old_msg.user_id == old_user.id + + if new_username.is_primary: + assert old_msg.secondary_user_id is None + else: + # inserts back into the secondary_usernames table aren't deterministic, + # so we can't rely on user_id's being equal. only usernames. + assert ( + len( + [ + y.username + for x in self.new_users + for y in x.usernames + if y.username == old_secondary.username + ] + ) + == 1 + ) diff --git a/tests/test_migrations.py b/tests/test_migrations.py new file mode 100644 index 00000000..3e794437 --- /dev/null +++ b/tests/test_migrations.py @@ -0,0 +1,86 @@ +""" +This module dynamically generates test cases from the revisions directory. +To create new test modules, look at the "revision_tests" directory for examples. +""" + +from pathlib import Path +from typing import Sequence + +import pytest +from alembic import command +from alembic.script import ScriptDirectory +from flask import Flask + +from hushline.db import db, migrate + +REVISIONS_ROOT = Path(__file__).parent.parent / "migrations" +assert REVISIONS_ROOT.exists() +assert REVISIONS_ROOT.is_dir() + + +FIRST_TESTABLE_REVISION = "46aedec8fd9b" +SKIPPABLE_REVISIONS = [ + "5ffe5a5c8e9a", # only renames indices and tables, no data changed +] + + +def list_revisions() -> Sequence[str]: + script_dir = ScriptDirectory(REVISIONS_ROOT) + revisions = list(script_dir.walk_revisions()) + revisions.reverse() + return [x.module.revision for x in revisions] + + +def list_testable_revisions() -> Sequence[str]: + idx = ALL_REVISIONS.index(FIRST_TESTABLE_REVISION) + assert idx >= 0 + return [rev for rev in ALL_REVISIONS[idx:] if rev not in SKIPPABLE_REVISIONS] + + +ALL_REVISIONS: Sequence[str] = list_revisions() +TESTABLE_REVISIONS: Sequence[str] = list_testable_revisions() + +del list_revisions, list_testable_revisions + + +def test_linear_revision_history(app: Flask) -> None: + script_dir = ScriptDirectory.from_config(migrate.get_config()) + + bases = script_dir.get_bases() + assert len(bases) == 1, f"Multiple bases found: {bases}" + assert bases[0] == ALL_REVISIONS[0] + + heads = script_dir.get_heads() + assert len(heads) == 1, f"Multiple heads found: {heads}" + assert heads[0] == ALL_REVISIONS[-1] + + +@pytest.mark.parametrize("revision", TESTABLE_REVISIONS) +def test_upgrade_with_data(revision: str, app: Flask) -> None: + previous_revision = ALL_REVISIONS[ALL_REVISIONS.index(revision) - 2] + cfg = migrate.get_config() + command.upgrade(cfg, previous_revision) + + mod = __import__(f"tests.migrations.revision_{revision}", fromlist=["UpgradeTester"]) + upgrade_tester = mod.UpgradeTester() + + upgrade_tester.load_data() + db.session.close() + + command.upgrade(cfg, revision) + upgrade_tester.check_upgrade() + + +@pytest.mark.parametrize("revision", TESTABLE_REVISIONS) +def test_downgrade_with_data(revision: str, app: Flask) -> None: + cfg = migrate.get_config() + command.upgrade(cfg, revision) + + mod = __import__(f"tests.migrations.revision_{revision}", fromlist=["DowngradeTester"]) + downgrade_tester = mod.DowngradeTester() + + downgrade_tester.load_data() + db.session.close() + + command.downgrade(cfg, "-1") + downgrade_tester.check_downgrade()