Skip to content

Commit

Permalink
added test for stripe tables
Browse files Browse the repository at this point in the history
  • Loading branch information
brassy-endomorph committed Oct 5, 2024
1 parent 2166520 commit 01e111d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 16 deletions.
42 changes: 26 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,49 @@ def init_db_via_alembic(db_uri: str) -> None:
db.session.connection().connection.invalidate() # type: ignore


def populate_db(session: Session) -> None:
"""Populate the DB with common objects required for the app to function at all"""
free_tier = Tier(name="Free", monthly_amount=0)
business_tier = Tier(name="Business", monthly_amount=2000)
business_tier.stripe_product_id = "prod_123"
business_tier.stripe_price_id = "price_123"
session.add(free_tier)
session.add(business_tier)
session.commit()


@pytest.fixture()
def database(request: pytest.FixtureRequest, _db_template: None) -> str:
db_name = random_name(16)
conn_str = CONN_FMT_STR.format(database="hushline")
engine = create_engine(conn_str)
engine = engine.execution_options(isolation_level="AUTOCOMMIT")

session = sessionmaker(bind=engine)()

if request.module.__name__ == "test_migrations":
# don't use the template when testing migrations. we want a blank db
sql = text(f"CREATE DATABASE {db_name}")
session.execute(text(f"CREATE DATABASE {db_name}"))
else:
sql = text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}")
session.execute(sql)
session.execute(text(f"CREATE DATABASE {db_name} WITH TEMPLATE {TEMPLATE_DB_NAME}"))

# aggressively terminate all connections
session.close()
session.connection().connection.invalidate()
engine.dispose()

if request.module.__name__ != "test_migrations":
conn_str = CONN_FMT_STR.format(database=db_name)
engine = create_engine(conn_str)
engine = engine.execution_options(isolation_level="AUTOCOMMIT")
session = sessionmaker(bind=engine)()

populate_db(session)

# aggressively terminate all connections
session.close()
session.connection().connection.invalidate()
engine.dispose()

print(f"Postgres DB: {db_name}, template: {TEMPLATE_DB_NAME}") # to help with debugging tests

return db_name
Expand All @@ -157,18 +179,6 @@ def app(database: str) -> Generator[Flask, None, None]:
app.config["PREFERRED_URL_SCHEME"] = "http"

with app.app_context():
db.create_all()

# Create the default tiers
# (this happens in the migrations, but migrations don't run in the tests)
free_tier = Tier(name="Free", monthly_amount=0)
business_tier = Tier(name="Business", monthly_amount=2000)
business_tier.stripe_product_id = "prod_123"
business_tier.stripe_price_id = "price_123"
db.session.add(free_tier)
db.session.add(business_tier)
db.session.commit()

yield app


Expand Down
75 changes: 75 additions & 0 deletions tests/migrations/revision_be0744a5679f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Any, Dict

from hushline.db import db

from ..helpers import ( # type: ignore[misc]
format_param_dict,
random_bool,
)


class UpgradeTester:
def __init__(self) -> None:
self.old_users_count = 10

def load_data(self) -> None:
for user_idx in range(self.old_users_count):
db.session.execute(
db.text(
"""
INSERT INTO users (password_hash, is_admin)
VALUES ('$scrypt$', :is_admin)
"""
),
dict(is_admin=random_bool()),
)

db.session.commit()

def check_upgrade(self) -> None:
new_user_count = db.session.execute(db.text("SELECT count(*) FROM users")).scalar()
# just make sure nothing weird happened where users got dropped
assert new_user_count == self.old_users_count


class DowngradeTester:
def __init__(self) -> None:
self.new_user_count = 10

def load_data(self) -> None:
for i in range(1, self.new_user_count + 1):
params: Dict[str, Any] = {
"id": i,
"name": f"tier_{i}",
"monthly_amount": i * 100,
"stripe_product_id": f"prod_{i}",
"stripe_price_id": f"price_{i}",
}

columns, param_args = format_param_dict(params)
db.session.execute(
db.text(f"INSERT INTO tiers ({columns}) VALUES ({param_args})"), params
)

params = {
"id": i,
"is_admin": False,
"password_hash": "$scrypt$",
"tier_id": i,
"stripe_customer_id": f"cust_{i}",
"stripe_subscription_id": f"sub_{i}",
"stripe_subscription_cancel_at_period_end": False,
}

columns, param_args = format_param_dict(params)
db.session.execute(
db.text(f"INSERT INTO users ({columns}) VALUES ({param_args})"),
params,
)

db.session.commit()

def check_downgrade(self) -> None:
old_user_count = db.session.execute(db.text("SELECT count(*) FROM users")).scalar()
# just make sure nothing weird happened where users got dropped
assert old_user_count == self.new_user_count

0 comments on commit 01e111d

Please sign in to comment.