Skip to content

Commit

Permalink
server/user: link existing customers by email on signup
Browse files Browse the repository at this point in the history
Fix #4646
  • Loading branch information
frankie567 committed Dec 16, 2024
1 parent 0f38839 commit 2acabde
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
28 changes: 25 additions & 3 deletions server/polar/user/service/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from uuid import UUID

import structlog
from sqlalchemy import func
from sqlalchemy import func, literal, select
from sqlalchemy.dialects.postgresql import insert

from polar.account.service import account as account_service
from polar.authz.service import AccessType, Authz
from polar.exceptions import PolarError
from polar.kit.services import ResourceService
from polar.logging import Logger
from polar.models import OAuthAccount, User
from polar.models import Customer, OAuthAccount, User, UserCustomer
from polar.models.user import OAuthPlatform
from polar.postgres import AsyncSession, sql
from polar.user.schemas.user import UserSignupAttribution
Expand Down Expand Up @@ -100,7 +101,7 @@ async def create_by_email(
)

session.add(user)
await session.commit()
await session.flush()

log.info("user.create", user_id=user.id, email=email)

Expand All @@ -122,5 +123,26 @@ async def set_account(
await session.commit()
return user

async def link_customers(self, session: AsyncSession, user: User) -> None:
statement = (
insert(UserCustomer)
.from_select(
[
UserCustomer.user_id,
UserCustomer.customer_id,
UserCustomer.id,
UserCustomer.created_at,
],
select(
literal(user.id),
Customer.id,
func.uuid_generate_v4(),
func.now(),
).where(Customer.email == user.email),
)
.on_conflict_do_nothing(index_elements=["user_id", "customer_id"])
)
await session.execute(statement)


user = UserService(User)
2 changes: 2 additions & 0 deletions server/polar/user/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ async def user_on_after_signup(

if user is None:
raise UserDoesNotExist(user_id)

await user_service.link_customers(session, user)
1 change: 1 addition & 0 deletions server/tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def initialize_test_database(worker_id: str) -> AsyncIterator[None]:

async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS citext"))
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"'))
await conn.run_sync(Model.metadata.create_all)
await engine.dispose()

Expand Down
51 changes: 51 additions & 0 deletions server/tests/user/service/test_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from sqlalchemy import select

from polar.models import Organization, UserCustomer
from polar.postgres import AsyncSession
from polar.user.service.user import user as user_service
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import (
create_customer,
create_organization,
create_user,
)


@pytest.mark.asyncio
@pytest.mark.skip_db_asserts
async def test_link_customers(
save_fixture: SaveFixture,
session: AsyncSession,
organization: Organization,
organization_second: Organization,
) -> None:
user = await create_user(save_fixture)

customer1 = await create_customer(
save_fixture, organization=organization, email=user.email
)
user_customer1 = UserCustomer(user=user, customer=customer1)
await save_fixture(user_customer1)

customer2 = await create_customer(
save_fixture, organization=organization_second, email=user.email
)

organization_third = await create_organization(save_fixture)
customer3 = await create_customer(
save_fixture, organization=organization_third, email=user.email
)

await user_service.link_customers(session, user)

user_customer_statement = select(UserCustomer).where(
UserCustomer.user_id == user.id
)
result = await session.execute(user_customer_statement)
user_customers = result.scalars().all()

assert len(user_customers) == 3
assert user_customers[0].customer_id == customer1.id
assert user_customers[1].customer_id == customer2.id
assert user_customers[2].customer_id == customer3.id

0 comments on commit 2acabde

Please sign in to comment.