Skip to content

Commit

Permalink
server/customer: harden the rules for linking customers with users
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed Jan 2, 2025
1 parent 886e814 commit 2700bf3
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Make UserCustomer.customer_id unique
Revision ID: 58d5e316549f
Revises: 6a6e872cbea5
Create Date: 2025-01-02 16:56:16.433658
"""

import sqlalchemy as sa
from alembic import op

# Polar Custom Imports

# revision identifiers, used by Alembic.
revision = "58d5e316549f"
down_revision = "6a6e872cbea5"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
"user_customers_user_id_customer_id_key", "user_customers", type_="unique"
)
op.create_unique_constraint(
op.f("user_customers_customer_id_key"), "user_customers", ["customer_id"]
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
op.f("user_customers_customer_id_key"), "user_customers", type_="unique"
)
op.create_unique_constraint(
"user_customers_user_id_customer_id_key",
"user_customers",
["user_id", "customer_id"],
)
# ### end Alembic commands ###
6 changes: 5 additions & 1 deletion server/polar/checkout/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,11 @@ async def _create_or_update_customer(
await session.flush()

if is_direct_user(auth_subject):
await customer_service.link_user(session, customer, auth_subject.subject)
user = auth_subject.subject
if user.email_verified and user.email.lower() == customer.email.lower():
await customer_service.link_user(
session, customer, auth_subject.subject
)

return customer

Expand Down
2 changes: 1 addition & 1 deletion server/polar/customer/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ async def link_user(
user_id=user.id, customer_id=customer.id
)
insert_statement = insert_statement.on_conflict_do_nothing(
index_elements=["user_id", "customer_id"]
index_elements=["customer_id"]
)
await session.execute(insert_statement)

Expand Down
8 changes: 5 additions & 3 deletions server/polar/models/user_customer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from uuid import UUID

from sqlalchemy import ForeignKey, UniqueConstraint, Uuid
from sqlalchemy import ForeignKey, Uuid
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from polar.kit.db.models.base import RecordModel
Expand All @@ -11,13 +11,15 @@

class UserCustomer(RecordModel):
__tablename__ = "user_customers"
__table_args__ = (UniqueConstraint("user_id", "customer_id"),)

user_id: Mapped[UUID] = mapped_column(
Uuid, ForeignKey("users.id", ondelete="cascade"), nullable=False
)
customer_id: Mapped[UUID] = mapped_column(
Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False
Uuid,
ForeignKey("customers.id", ondelete="cascade"),
nullable=False,
unique=True, # A customer can only be associated with one user
)

@declared_attr
Expand Down
4 changes: 2 additions & 2 deletions server/polar/user/service/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ async def link_customers(self, session: AsyncSession, user: User) -> None:
Customer.id,
func.uuid_generate_v4(),
func.now(),
).where(Customer.email == user.email),
).where(func.lower(Customer.email) == user.email.lower()),
)
.on_conflict_do_nothing(index_elements=["user_id", "customer_id"])
.on_conflict_do_nothing(index_elements=["customer_id"])
)
await session.execute(statement)

Expand Down
3 changes: 2 additions & 1 deletion server/polar/user/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ async def user_on_after_signup(
if user is None:
raise UserDoesNotExist(user_id)

await user_service.link_customers(session, user)
if user.email_verified:
await user_service.link_customers(session, user)
38 changes: 37 additions & 1 deletion server/tests/checkout/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ async def test_valid_stripe_existing_customer_email(
stripe_service_mock.update_customer.assert_called_once()

@pytest.mark.auth(AuthSubjectFixture(subject="user_second"))
async def test_link_customer_to_authenticated_user(
async def test_link_customer_to_authenticated_user_different_email(
self,
stripe_service_mock: MagicMock,
session: AsyncSession,
Expand Down Expand Up @@ -2248,6 +2248,42 @@ async def test_link_customer_to_authenticated_user(
),
)

assert checkout.customer is not None
linked_customer = await customer_service.get_by_id_and_user(
session, checkout.customer.id, auth_subject.subject
)
assert linked_customer is None

@pytest.mark.auth(AuthSubjectFixture(subject="user_second"))
async def test_link_customer_to_authenticated_same_email(
self,
stripe_service_mock: MagicMock,
session: AsyncSession,
locker: Locker,
auth_subject: AuthSubject[User],
checkout_one_time_fixed: Checkout,
) -> None:
stripe_service_mock.create_customer.return_value = SimpleNamespace(
id="STRIPE_CUSTOMER_ID"
)
stripe_service_mock.create_payment_intent.return_value = SimpleNamespace(
client_secret="CLIENT_SECRET", status="succeeded"
)
checkout = await checkout_service.confirm(
session,
locker,
auth_subject,
checkout_one_time_fixed,
CheckoutConfirmStripe.model_validate(
{
"confirmation_token_id": "CONFIRMATION_TOKEN_ID",
"customer_name": "Customer Name",
"customer_email": auth_subject.subject.email,
"customer_billing_address": {"country": "FR"},
}
),
)

assert checkout.customer is not None
linked_customer = await customer_service.get_by_id_and_user(
session, checkout.customer.id, auth_subject.subject
Expand Down
25 changes: 10 additions & 15 deletions server/tests/fixtures/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,15 @@ async def user_github_oauth(
return await create_user_github_oauth(save_fixture, user)


@pytest_asyncio.fixture
async def user(
save_fixture: SaveFixture,
) -> User:
return await create_user(save_fixture)


async def create_user(
save_fixture: SaveFixture, stripe_customer_id: str | None = None
save_fixture: SaveFixture,
stripe_customer_id: str | None = None,
email_verified: bool = True,
) -> User:
user = User(
id=uuid.uuid4(),
email=rstr("test") + "@example.com",
email_verified=email_verified,
avatar_url="https://avatars.githubusercontent.com/u/47952?v=4",
oauth_accounts=[],
stripe_customer_id=stripe_customer_id,
Expand All @@ -310,15 +306,14 @@ async def create_user(
return user


@pytest_asyncio.fixture
async def user(save_fixture: SaveFixture) -> User:
return await create_user(save_fixture)


@pytest_asyncio.fixture
async def user_second(save_fixture: SaveFixture) -> User:
user = User(
id=uuid.uuid4(),
email=rstr("test") + "@example.com",
avatar_url="https://avatars.githubusercontent.com/u/47952?v=4",
)
await save_fixture(user)
return user
return await create_user(save_fixture)


@pytest_asyncio.fixture
Expand Down

0 comments on commit 2700bf3

Please sign in to comment.