Skip to content

Commit

Permalink
Merge pull request #618 from scidsg/auto-testing-of-migrations
Browse files Browse the repository at this point in the history
Auto testing of migrations
  • Loading branch information
micahflee authored Oct 7, 2024
2 parents 8a06f46 + e8952a9 commit 74aede1
Show file tree
Hide file tree
Showing 9 changed files with 802 additions and 66 deletions.
6 changes: 2 additions & 4 deletions hushline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from flask import Flask, flash, redirect, request, session, url_for
from flask.cli import AppGroup
from flask_migrate import Migrate
from jinja2 import StrictUndefined
from werkzeug.middleware.proxy_fix import ProxyFix
from werkzeug.wrappers.response import Response

from . import admin, premium, routes, settings
from .db import db
from .db import db, migrate
from .model import HostOrganization, Tier, User
from .version import __version__

Expand Down Expand Up @@ -78,9 +77,8 @@ def create_app() -> Flask:
app.logger.info("Development environment detected, enabling jinja2.StrictUndefined")
app.jinja_env.undefined = StrictUndefined

# Run migrations
db.init_app(app)
Migrate(app, db)
migrate.init_app(app, db)

# Initialize Stripe
if app.config["STRIPE_SECRET_KEY"]:
Expand Down
2 changes: 2 additions & 0 deletions hushline/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData

Expand All @@ -12,3 +13,4 @@
)

db = SQLAlchemy(metadata=metadata)
migrate = Migrate()
76 changes: 32 additions & 44 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,63 +45,51 @@ build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
pythonpath = "."
filterwarnings = [
# passlib unmaintianed, see: https://github.com/scidsg/hushline/issues/553
"ignore:.*'crypt' is deprecated.*:DeprecationWarning",
]

[tool.ruff]
line-length = 100
indent-width = 4

[tool.ruff.lint]
select = [
# pycodestyle errors
"E",
# pyflakes
"F",
# isort
"I",
# flake8-gettext
"INT",
# flake8-pie
"PIE",
# pylint
"PL",
# flake8-pytest-style
"PT",
# flake8-pyi
"PYI",
# flake8-return
"RET",
# flake8-bandit
"S",
# flake8-simplify
"SIM",
# pyupgrade
"UP",
# pycodestyle warnings
"W",
# Unused noqa directive
"RUF100",
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"INT", # flake8-gettext
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flake8-pyi
"RET", # flake8-return
"S", # flake8-bandit
"SIM", # flake8-simplify
"UP", # pyupgrade
"W", # pycodestyle warnings
"RUF100", # Unused noqa directive
]
ignore = [
# https://docs.astral.sh/ruff/rules/too-many-statements/
"PLR0915",
# https://docs.astral.sh/ruff/rules/too-many-return-statements/
"PLR0911",
# https://docs.astral.sh/ruff/rules/too-many-branches/
"PLR0912",
"PLR0911", # too-many-return-statements
"PLR0912", # too-many-branches
"PLR0915", # too-many-statements
]

[tool.ruff.lint.per-file-ignores]
"migrations/versions/*.py" = [
"I001", # unsorted-imports
"S608", # hardcoded-sql-expression
]
"tests/*.py" = [
# https://docs.astral.sh/ruff/rules/assert/
"S101",
"S105", # hardcoded password
# https://docs.astral.sh/ruff/rules/magic-value-comparison/
"PLR2004",
"PLR2004", # magic-value-comparison
"S101", # assert
"S105", # hardcoded-password-string
"S311", # suspicious-non-cryptographic-random-usage
]
"migrations/versions/*.py" = [
# https://docs.astral.sh/ruff/rules/unsorted-imports/
"I001",
"S608", # sql injection via string-based query construction
"tests/migrations/*.py" = [
"S608", # hardcoded-sql-expression
]

[tool.mypy]
Expand All @@ -111,4 +99,4 @@ no_implicit_optional = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_unused_configs = true
exclude = '^migrations/env\.py$|^migrations/versions/.*\.py$'
exclude = "^migrations/env\\.py$"
48 changes: 30 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pytest_addoption(parser: Parser) -> None:


def random_name(size: int) -> str:
return "".join([random.choice(string.ascii_lowercase) for _ in range(size)]) # noqa: S311
return "".join([random.choice(string.ascii_lowercase) for _ in range(size)])


@contextmanager
Expand Down Expand Up @@ -111,24 +111,48 @@ def init_db_via_alembic(db_uri: str) -> None:
db.session.connection().connection.invalidate() # type: ignore


def populate_db(session: Session) -> None:
"""Populate the DB with common objects required for the app to function at all"""
free_tier = Tier(name="Free", monthly_amount=0)
business_tier = Tier(name="Business", monthly_amount=2000)
business_tier.stripe_product_id = "prod_123"
business_tier.stripe_price_id = "price_123"
session.add(free_tier)
session.add(business_tier)
session.commit()


@pytest.fixture()
def database(_db_template: None) -> str:
"""A clean Postgres database from the template with DDLs applied"""
def database(request: pytest.FixtureRequest, _db_template: None) -> str:
db_name = random_name(16)
conn_str = CONN_FMT_STR.format(database="hushline")
engine = create_engine(conn_str)
engine = engine.execution_options(isolation_level="AUTOCOMMIT")

session = sessionmaker(bind=engine)()

sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}")
session.execute(sql)
if request.module.__name__ == "test_migrations":
# don't use the template when testing migrations. we want a blank db
session.execute(text(f"CREATE DATABASE {db_name}"))
else:
session.execute(text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}"))

# aggressively terminate all connections
session.close()
session.connection().connection.invalidate()
engine.dispose()

if request.module.__name__ != "test_migrations":
conn_str = CONN_FMT_STR.format(database=db_name)
engine = create_engine(conn_str)
session = sessionmaker(bind=engine)()

populate_db(session)

# aggressively terminate all connections
session.close()
session.connection().connection.invalidate()
engine.dispose()

print(f"Postgres DB: {db_name}, template: {TEMPLATE_DB_NAME}") # to help with debugging tests

return db_name
Expand All @@ -154,18 +178,6 @@ def app(database: str) -> Generator[Flask, None, None]:
app.config["PREFERRED_URL_SCHEME"] = "http"

with app.app_context():
db.create_all()

# Create the default tiers
# (this happens in the migrations, but migrations don't run in the tests)
free_tier = Tier(name="Free", monthly_amount=0)
business_tier = Tier(name="Business", monthly_amount=2000)
business_tier.stripe_product_id = "prod_123"
business_tier.stripe_price_id = "price_123"
db.session.add(free_tier)
db.session.add(business_tier)
db.session.commit()

yield app


Expand Down
53 changes: 53 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import random
import string
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar

T = TypeVar("T")


def one_of(xs: Sequence[T], predicate: Callable[[T], bool]) -> T:
matches = [x for x in xs if predicate(x)]
match len(matches):
case 1:
return matches[0]
case 0:
raise ValueError("No matches")
case _:
raise ValueError(f"Too many matches: {matches}")


def random_bool() -> bool:
return bool(random.getrandbits(1))


def random_optional_bool() -> Optional[bool]:
if random_bool():
return None
return random_bool()


def random_string(length: int) -> str:
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))


def random_optional_string(length: int) -> Optional[str]:
if random_bool():
return None
return random_string(length)


def format_param_dict(params: Mapping[str, Any]) -> Tuple[str, str]:
return (", ".join(params.keys()), ", ".join(f":{x}" for x in params))


class Missing:
def __eq__(self, other: object) -> bool:
return False

def __ne__(self, other: object) -> bool:
return True


# ridiculous formatting because `ruff` won't allow `not (x == y)`
assert (Missing() == Missing()) ^ bool("x")
assert Missing() != Missing()
Loading

0 comments on commit 74aede1

Please sign in to comment.