From 6c8d3e881733a9ed137634c6109956af36250490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 17 Dec 2024 15:35:14 +0100 Subject: [PATCH] server/checkout: allow to pass customer_metadata that'll be copied over to the created customer --- ...-17-1756_add_checkout_customer_metadata.py | 43 ++++++++++++++ server/polar/checkout/schemas.py | 15 +++++ server/polar/checkout/service.py | 5 ++ server/polar/kit/metadata.py | 36 +++++++----- server/polar/models/checkout.py | 3 +- server/tests/checkout/test_service.py | 58 +++++++++++++++++++ server/tests/fixtures/random_objects.py | 4 ++ 7 files changed, 148 insertions(+), 16 deletions(-) create mode 100644 server/migrations/versions/2024-12-17-1756_add_checkout_customer_metadata.py diff --git a/server/migrations/versions/2024-12-17-1756_add_checkout_customer_metadata.py b/server/migrations/versions/2024-12-17-1756_add_checkout_customer_metadata.py new file mode 100644 index 0000000000..9050a759dd --- /dev/null +++ b/server/migrations/versions/2024-12-17-1756_add_checkout_customer_metadata.py @@ -0,0 +1,43 @@ +"""Add Checkout.customer_metadata + +Revision ID: cb9906114207 +Revises: a3e70f4c4e1e +Create Date: 2024-12-17 17:56:12.495724 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# Polar Custom Imports + +# revision identifiers, used by Alembic. +revision = "cb9906114207" +down_revision = "a3e70f4c4e1e" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "checkouts", + sa.Column( + "customer_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + ) + + op.execute( + "UPDATE checkouts SET customer_metadata = '{}' WHERE customer_metadata IS NULL" + ) + + op.alter_column("checkouts", "customer_metadata", nullable=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("checkouts", "customer_metadata") + # ### end Alembic commands ### diff --git a/server/polar/checkout/schemas.py b/server/polar/checkout/schemas.py index a6896e96c1..73811fc2b1 100644 --- a/server/polar/checkout/schemas.py +++ b/server/polar/checkout/schemas.py @@ -26,6 +26,8 @@ from polar.enums import PaymentProcessor from polar.kit.address import Address from polar.kit.metadata import ( + METADATA_DESCRIPTION, + MetadataField, MetadataInputMixin, MetadataOutputMixin, OptionalMetadataInputMixin, @@ -101,6 +103,12 @@ "If you apply a discount through `discount_id`, it'll still be applied, " "but the customer won't be able to change it." ) +_customer_metadata_description = METADATA_DESCRIPTION.format( + heading=( + "Key-value object allowing you to store additional information " + "that'll be copied to the created customer." + ) +) class CheckoutCreateBase(CustomFieldDataInputMixin, MetadataInputMixin, Schema): @@ -134,6 +142,9 @@ class CheckoutCreateBase(CustomFieldDataInputMixin, MetadataInputMixin, Schema): customer_ip_address: CustomerIPAddress | None = None customer_billing_address: CustomerBillingAddress | None = None customer_tax_id: Annotated[str | None, EmptyStrToNoneValidator] = None + customer_metadata: MetadataField = Field( + default_factory=dict, description=_customer_metadata_description + ) subscription_id: UUID4 | None = Field( default=None, description=( @@ -217,6 +228,9 @@ class CheckoutUpdate(OptionalMetadataInputMixin, CheckoutUpdateBase): default=None, description=_allow_discount_codes_description ) customer_ip_address: CustomerIPAddress | None = None + customer_metadata: MetadataField | None = Field( + default=None, description=_customer_metadata_description + ) success_url: SuccessURL = None embed_origin: EmbedOrigin = None @@ -406,6 +420,7 @@ class Checkout(MetadataOutputMixin, CheckoutBase): discount: CheckoutDiscount | None subscription_id: UUID4 | None attached_custom_fields: list[AttachedCustomField] + customer_metadata: dict[str, str | int | bool] class CheckoutPublic(CheckoutBase): diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index 39fc01ee77..b3909aac4d 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -1599,6 +1599,7 @@ async def _create_or_update_customer( billing_address=checkout.customer_billing_address, tax_id=checkout.customer_tax_id, organization=checkout.organization, + user_metadata={}, ) stripe_customer_id = customer.stripe_customer_id @@ -1628,6 +1629,10 @@ async def _create_or_update_customer( **update_params, ) customer.stripe_customer_id = stripe_customer_id + customer.user_metadata = { + **customer.user_metadata, + **checkout.customer_metadata, + } session.add(customer) await session.flush() diff --git a/server/polar/kit/metadata.py b/server/polar/kit/metadata.py index ebb27dfc50..ba385e2d67 100644 --- a/server/polar/kit/metadata.py +++ b/server/polar/kit/metadata.py @@ -5,11 +5,13 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column +MetadataColumn = Annotated[ + dict[str, Any], mapped_column(JSONB, nullable=False, default=dict) +] + class MetadataMixin: - user_metadata: Mapped[dict[str, Any]] = mapped_column( - JSONB, nullable=False, default=dict - ) + user_metadata: Mapped[MetadataColumn] _MAXIMUM_KEYS = 50 @@ -28,9 +30,10 @@ class MetadataMixin: ), ] _MetadataValue = _MetadataValueString | int | bool -_description = inspect.cleandoc( + +METADATA_DESCRIPTION = inspect.cleandoc( f""" - Key-value object allowing you to store additional information. + {{heading}} The key must be a string with a maximum length of **{_MAXIMUM_KEY_LENGTH} characters**. The value must be either: @@ -42,23 +45,26 @@ class MetadataMixin: You can store up to **{_MAXIMUM_KEYS} key-value pairs**. """ ) +_description = METADATA_DESCRIPTION.format( + heading="Key-value object allowing you to store additional information." +) + + +MetadataField = Annotated[ + dict[_MetadataKey, _MetadataValue], + Field(max_length=_MAXIMUM_KEYS, description=_description), +] class MetadataInputMixin(BaseModel): - metadata: dict[_MetadataKey, _MetadataValue] = Field( - default_factory=dict, - max_length=_MAXIMUM_KEYS, - description=_description, - serialization_alias="user_metadata", + metadata: MetadataField = Field( + default_factory=dict, serialization_alias="user_metadata" ) class OptionalMetadataInputMixin(BaseModel): - metadata: dict[_MetadataKey, _MetadataValue] | None = Field( - default=None, - max_length=_MAXIMUM_KEYS, - description=_description, - serialization_alias="user_metadata", + metadata: MetadataField | None = Field( + default=None, serialization_alias="user_metadata" ) diff --git a/server/polar/models/checkout.py b/server/polar/models/checkout.py index b682152117..163be14c8d 100644 --- a/server/polar/models/checkout.py +++ b/server/polar/models/checkout.py @@ -25,7 +25,7 @@ from polar.enums import PaymentProcessor from polar.kit.address import Address, AddressType from polar.kit.db.models import RecordModel -from polar.kit.metadata import MetadataMixin +from polar.kit.metadata import MetadataColumn, MetadataMixin from polar.kit.tax import TaxID, TaxIDType from polar.kit.utils import utc_now @@ -135,6 +135,7 @@ def customer(cls) -> Mapped[Customer | None]: customer_tax_id: Mapped[TaxID | None] = mapped_column( TaxIDType, nullable=True, default=None ) + customer_metadata: Mapped[MetadataColumn] subscription_id: Mapped[UUID | None] = mapped_column( Uuid, ForeignKey("subscriptions.id", ondelete="set null"), nullable=True diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index e6daf74c03..4cadbb6be4 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -1037,6 +1037,29 @@ async def test_valid_customer( assert checkout.customer_billing_address == customer.billing_address assert checkout.customer_tax_id == customer.tax_id + @pytest.mark.auth( + AuthSubjectFixture(subject="user"), + AuthSubjectFixture(subject="organization"), + ) + async def test_customer_metadata( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + product_one_time: Product, + user_organization: UserOrganization, + ) -> None: + checkout = await checkout_service.create( + session, + CheckoutProductCreate( + payment_processor=PaymentProcessor.stripe, + product_id=product_one_time.id, + customer_metadata={"key": "value"}, + ), + auth_subject, + ) + + assert checkout.customer_metadata == {"key": "value"} + @pytest.mark.asyncio @pytest.mark.skip_db_asserts @@ -1696,6 +1719,21 @@ async def test_valid_metadata( assert checkout.user_metadata == {"key": "value"} + async def test_valid_customer_metadata( + self, + session: AsyncSession, + checkout_one_time_free: Checkout, + ) -> None: + checkout = await checkout_service.update( + session, + checkout_one_time_free, + CheckoutUpdate( + customer_metadata={"key": "value"}, + ), + ) + + assert checkout.customer_metadata == {"key": "value"} + @pytest.mark.parametrize( "custom_field_data", ( @@ -1952,6 +1990,7 @@ async def test_calculate_tax_error( ) async def test_valid_stripe( self, + save_fixture: SaveFixture, customer_billing_address: dict[str, str], expected_tax_metadata: dict[str, str], stripe_service_mock: MagicMock, @@ -1960,6 +1999,9 @@ async def test_valid_stripe( auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, ) -> None: + checkout_one_time_fixed.customer_metadata = {"key": "value"} + await save_fixture(checkout_one_time_fixed) + stripe_service_mock.create_customer.return_value = SimpleNamespace( id="STRIPE_CUSTOMER_ID" ) @@ -1997,6 +2039,9 @@ async def test_valid_stripe( **expected_tax_metadata, } + assert checkout.customer is not None + assert checkout.customer.user_metadata == {"key": "value"} + assert checkout.customer_session_token is not None customer_session = await customer_session_service.get_by_token( session, checkout.customer_session_token @@ -2125,9 +2170,11 @@ async def test_valid_stripe_existing_customer( save_fixture, organization=organization, stripe_customer_id="CHECKOUT_CUSTOMER_ID", + user_metadata={"key": "value"}, ) checkout_one_time_fixed.customer = customer checkout_one_time_fixed.customer_email = customer.email + checkout_one_time_fixed.customer_metadata = {"key": "updated", "key2": "value2"} await save_fixture(checkout_one_time_fixed) stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( @@ -2151,8 +2198,12 @@ async def test_valid_stripe_existing_customer( assert checkout.status == CheckoutStatus.confirmed stripe_service_mock.update_customer.assert_called_once() + assert checkout.customer is not None + assert checkout.customer.user_metadata == {"key": "updated", "key2": "value2"} + async def test_valid_stripe_existing_customer_email( self, + save_fixture: SaveFixture, stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, @@ -2160,6 +2211,11 @@ async def test_valid_stripe_existing_customer_email( checkout_one_time_fixed: Checkout, customer: Customer, ) -> None: + customer.user_metadata = {"key": "value"} + await save_fixture(customer) + + checkout_one_time_fixed.customer_metadata = {"key": "updated", "key2": "value2"} + stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( client_secret="CLIENT_SECRET", status="succeeded" ) @@ -2180,7 +2236,9 @@ async def test_valid_stripe_existing_customer_email( ) assert checkout.status == CheckoutStatus.confirmed + assert checkout.customer is not None assert checkout.customer == customer + assert checkout.customer.user_metadata == {"key": "updated", "key2": "value2"} stripe_service_mock.update_customer.assert_called_once() @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 2fa79a6f05..e728b8fa4d 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -841,6 +841,7 @@ async def create_customer( email_verified: bool = False, name: str = "Customer", stripe_customer_id: str = "STRIPE_CUSTOMER_ID", + user_metadata: dict[str, Any] = {}, ) -> Customer: customer = Customer( email=email, @@ -848,6 +849,7 @@ async def create_customer( name=name, stripe_customer_id=stripe_customer_id, organization=organization, + user_metadata=user_metadata, ) await save_fixture(customer) return customer @@ -1130,6 +1132,7 @@ async def create_checkout( expires_at: datetime | None = None, client_secret: str | None = None, user_metadata: dict[str, Any] = {}, + customer_metadata: dict[str, Any] = {}, payment_processor_metadata: dict[str, Any] = {}, amount: int | None = None, tax_amount: int | None = None, @@ -1156,6 +1159,7 @@ async def create_checkout( "CHECKOUT_CLIENT_SECRET", ), user_metadata=user_metadata, + customer_metadata=customer_metadata, payment_processor_metadata=payment_processor_metadata, amount=amount, tax_amount=tax_amount,