Skip to content

Commit

Permalink
added migration testing for existing migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
brassy-endomorph committed Oct 5, 2024
1 parent 5ca1c7a commit 2ea86e8
Show file tree
Hide file tree
Showing 7 changed files with 680 additions and 52 deletions.
6 changes: 2 additions & 4 deletions hushline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

from flask import Flask, flash, redirect, request, session, url_for
from flask.cli import AppGroup
from flask_migrate import Migrate
from jinja2 import StrictUndefined
from sqlalchemy.exc import ProgrammingError
from werkzeug.middleware.proxy_fix import ProxyFix
from werkzeug.wrappers.response import Response

from . import admin, premium, routes, settings
from .db import db
from .db import db, migrate
from .model import HostOrganization, Tier, User
from .version import __version__

Expand Down Expand Up @@ -79,9 +78,8 @@ def create_app() -> Flask:
app.logger.info("Development environment detected, enabling jinja2.StrictUndefined")
app.jinja_env.undefined = StrictUndefined

# Run migrations
db.init_app(app)
Migrate(app, db)
migrate.init_app(app, db)

# Initialize Stripe
if app.config["STRIPE_SECRET_KEY"]:
Expand Down
2 changes: 2 additions & 0 deletions hushline/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData

Expand All @@ -12,3 +13,4 @@
)

db = SQLAlchemy(metadata=metadata)
migrate = Migrate()
76 changes: 32 additions & 44 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,63 +45,51 @@ build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
pythonpath = "."
filterwarnings = [
# passlib unmaintianed, see: https://github.com/scidsg/hushline/issues/553
"ignore:.*'crypt' is deprecated.*:DeprecationWarning",
]

[tool.ruff]
line-length = 100
indent-width = 4

[tool.ruff.lint]
select = [
# pycodestyle errors
"E",
# pyflakes
"F",
# isort
"I",
# flake8-gettext
"INT",
# flake8-pie
"PIE",
# pylint
"PL",
# flake8-pytest-style
"PT",
# flake8-pyi
"PYI",
# flake8-return
"RET",
# flake8-bandit
"S",
# flake8-simplify
"SIM",
# pyupgrade
"UP",
# pycodestyle warnings
"W",
# Unused noqa directive
"RUF100",
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"INT", # flake8-gettext
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flake8-pyi
"RET", # flake8-return
"S", # flake8-bandit
"SIM", # flake8-simplify
"UP", # pyupgrade
"W", # pycodestyle warnings
"RUF100", # Unused noqa directive
]
ignore = [
# https://docs.astral.sh/ruff/rules/too-many-statements/
"PLR0915",
# https://docs.astral.sh/ruff/rules/too-many-return-statements/
"PLR0911",
# https://docs.astral.sh/ruff/rules/too-many-branches/
"PLR0912",
"PLR0911", # too-many-return-statements
"PLR0912", # too-many-branches
"PLR0915", # too-many-statements
]

[tool.ruff.lint.per-file-ignores]
"migrations/versions/*.py" = [
"I001", # unsorted-imports
"S608", # hardcoded-sql-expression
]
"tests/*.py" = [
# https://docs.astral.sh/ruff/rules/assert/
"S101",
"S105", # hardcoded password
# https://docs.astral.sh/ruff/rules/magic-value-comparison/
"PLR2004",
"PLR2004", # magic-value-comparison
"S101", # assert
"S105", # hardcoded-password-string
"S311", # suspicious-non-cryptographic-random-usage
]
"migrations/versions/*.py" = [
# https://docs.astral.sh/ruff/rules/unsorted-imports/
"I001",
"S608", # sql injection via string-based query construction
"tests/migrations/*.py" = [
"S608", # hardcoded-sql-expression
]

[tool.mypy]
Expand All @@ -111,4 +99,4 @@ no_implicit_optional = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_unused_configs = true
exclude = '^migrations/env\.py$|^migrations/versions/.*\.py$'
exclude = "^migrations/env\\.py$"
11 changes: 7 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pytest_addoption(parser: Parser) -> None:


def random_name(size: int) -> str:
return "".join([random.choice(string.ascii_lowercase) for _ in range(size)]) # noqa: S311
return "".join([random.choice(string.ascii_lowercase) for _ in range(size)])


@contextmanager
Expand Down Expand Up @@ -112,16 +112,19 @@ def init_db_via_alembic(db_uri: str) -> None:


@pytest.fixture()
def database(_db_template: None) -> str:
"""A clean Postgres database from the template with DDLs applied"""
def database(request: pytest.FixtureRequest, _db_template: None) -> str:
db_name = random_name(16)
conn_str = CONN_FMT_STR.format(database="hushline")
engine = create_engine(conn_str)
engine = engine.execution_options(isolation_level="AUTOCOMMIT")

session = sessionmaker(bind=engine)()

sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}")
if request.module.__name__ == "test_migrations":
# don't use the template when testing migrations. we want a blank db
sql = text(f"CREATE DATABASE {db_name}")
else:
sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}")
session.execute(sql)

# aggressively terminate all connections
Expand Down
52 changes: 52 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import random
import string
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar

T = TypeVar("T")


def one_of(xs: Sequence[T], predicate: Callable[[T], bool]) -> T:
matches = [x for x in xs if predicate(x)]
match len(matches):
case 1:
return matches[0]
case 0:
raise ValueError("No matches")
case _:
raise ValueError(f"Too many matches: {matches}")


def random_bool() -> bool:
return bool(random.getrandbits(1))


def random_optional_bool() -> Optional[bool]:
if random_bool():
return None
return random_bool()


def random_string(length: int) -> str:
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))


def random_optional_string(length: int) -> Optional[str]:
if random_bool():
return None
return random_string(length)


def format_param_dict(params: Mapping[str, Any]) -> Tuple[str, str]:
return (", ".join(params.keys()), ", ".join(f":{x}" for x in params))


class Missing:
def __eq__(self, other: object) -> bool:
return False

def __ne__(self, other: object) -> bool:
return True


assert Missing() != Missing()
assert Missing() != Missing()
Loading

0 comments on commit 2ea86e8

Please sign in to comment.