Skip to content

Commit

Permalink
feat: bulk pricing (#715)
Browse files Browse the repository at this point in the history
Co-authored-by: Soham Parekh <[email protected]>
Co-authored-by: mnida <[email protected]>
Co-authored-by: Diego Escobedo <[email protected]>
  • Loading branch information
4 people authored Mar 26, 2023
1 parent c3492f2 commit 01b39a5
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 4.0.5 on 2023-03-26 00:33

from django.db import migrations, models
import django.db.models.expressions


class Migration(migrations.Migration):

dependencies = [
('metering_billing', '0241_alter_historicalorganization_payment_grace_period_and_more'),
]

operations = [
migrations.AddField(
model_name='plancomponent',
name='bulk_pricing_enabled',
field=models.BooleanField(default=False),
),
migrations.AddConstraint(
model_name='pricetier',
constraint=models.CheckConstraint(check=models.Q(('range_end__gte', django.db.models.expressions.F('range_start')), ('range_end__isnull', True), _connector='OR'), name='price_tier_type_valid'),
),
migrations.AddConstraint(
model_name='pricetier',
constraint=models.UniqueConstraint(fields=('organization', 'plan_component', 'range_start'), name='unique_price_tier'),
),
]
99 changes: 81 additions & 18 deletions backend/metering_billing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
from django.db.models.constraints import CheckConstraint, UniqueConstraint
from django.db.models.functions import Cast, Coalesce
from django.utils.translation import gettext_lazy as _
from rest_framework_api_key.models import AbstractAPIKey
from simple_history.models import HistoricalRecords
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
from svix.internal.openapi_client.models.http_error import HttpError
from svix.internal.openapi_client.models.http_validation_error import (
HTTPValidationError,
)
from timezone_field import TimeZoneField

from metering_billing.exceptions.exceptions import (
ExternalConnectionFailure,
NotEditable,
Expand Down Expand Up @@ -72,14 +81,6 @@
WEBHOOK_TRIGGER_EVENTS,
)
from metering_billing.webhooks import invoice_paid_webhook, usage_alert_webhook
from rest_framework_api_key.models import AbstractAPIKey
from simple_history.models import HistoricalRecords
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
from svix.internal.openapi_client.models.http_error import HttpError
from svix.internal.openapi_client.models.http_validation_error import (
HTTPValidationError,
)
from timezone_field import TimeZoneField

logger = logging.getLogger("django.server")
META = settings.META
Expand Down Expand Up @@ -1449,7 +1450,48 @@ class BatchRoundingType(models.IntegerChoices):
null=True,
)

def calculate_revenue(self, usage: float, prev_tier_end=False):
class Meta:
constraints = [
models.CheckConstraint(
check=Q(range_end__gte=F("range_start")) | Q(range_end__isnull=True),
name="price_tier_type_valid",
),
models.UniqueConstraint(
fields=["organization", "plan_component", "range_start"],
name="unique_price_tier",
),
]

def save(self, *args, **kwargs):
new = self._state.adding is True
ranges = [
(tier["range_start"], tier["range_end"])
for tier in self.plan_component.tiers.order_by("range_start").values(
"range_start", "range_end"
)
]
if new:
ranges = sorted(
ranges + [(self.range_start, self.range_end)], key=lambda x: x[0]
)
for i, (start, end) in enumerate(ranges):
if i == 0:
if start != 0:
raise ValidationError("First tier must start at 0")
else:
diff = start - ranges[i - 1][1]
if diff != Decimal(0) and diff != Decimal(1):
raise ValidationError(
"Tier ranges must be continuous or separated by 1"
)
if i != len(ranges) - 1:
if end is None:
raise ValidationError("Only last tier can be open ended")
super().save(*args, **kwargs)

def calculate_revenue(
self, usage: float, prev_tier_end=False, bulk_pricing_enabled=False
):
# if division_factor is None:
# division_factor = len(usage_dict)
revenue = 0
Expand All @@ -1458,21 +1500,37 @@ def calculate_revenue(self, usage: float, prev_tier_end=False):
)
# for usage in usage_dict.values():
usage = convert_to_decimal(usage)
usage_in_range = (
self.range_start <= usage
if discontinuous_range
else self.range_start < usage or self.range_start == 0
)

if (
bulk_pricing_enabled
and self.range_end is not None
and self.range_end <= usage
):
return revenue

if bulk_pricing_enabled:
usage_in_range = self.range_start <= usage
else:
usage_in_range = (
self.range_start <= usage
if discontinuous_range
else self.range_start < usage or self.range_start == 0
)
if usage_in_range:
if self.type == PriceTier.PriceTierType.FLAT:
revenue += self.cost_per_batch
elif self.type == PriceTier.PriceTierType.PER_UNIT:
if self.range_end is not None:
return revenue

if self.type == PriceTier.PriceTierType.PER_UNIT:
if bulk_pricing_enabled:
billable_units = usage
elif self.range_end is not None:
billable_units = min(
usage - self.range_start, self.range_end - self.range_start
)
else:
billable_units = usage - self.range_start

if discontinuous_range:
billable_units += 1
billable_batches = billable_units / self.metric_units_per_batch
Expand Down Expand Up @@ -1574,6 +1632,7 @@ class IntervalLengthType(models.IntegerChoices):
related_name="component",
null=True,
)
bulk_pricing_enabled = models.BooleanField(default=False)

def __str__(self):
return str(self.billable_metric)
Expand Down Expand Up @@ -1720,10 +1779,14 @@ def tier_rating_function(self, usage_qty):
# this is for determining whether this is a continuous or discontinuous range
prev_tier_end = tiers[i - 1].range_end
tier_revenue = tier.calculate_revenue(
usage_qty, prev_tier_end=prev_tier_end
usage_qty,
prev_tier_end=prev_tier_end,
bulk_pricing_enabled=self.bulk_pricing_enabled,
)
else:
tier_revenue = tier.calculate_revenue(usage_qty)
tier_revenue = tier.calculate_revenue(
usage_qty, bulk_pricing_enabled=self.bulk_pricing_enabled
)
revenue += tier_revenue
revenue = convert_to_decimal(revenue)
return revenue
Expand Down
9 changes: 6 additions & 3 deletions backend/metering_billing/serializers/model_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ class Meta:
"reset_interval_unit",
"reset_interval_count",
"prepaid_charge",
"bulk_pricing_enabled",
)
extra_kwargs = {
"metric_id": {"required": True, "write_only": True},
Expand All @@ -987,6 +988,7 @@ class Meta:
"reset_interval_unit": {"required": False},
"reset_interval_count": {"required": False},
"prepaid_charge": {"required": False},
"bulk_pricing_enabled": {"required": False, "default": False},
}

metric_id = SlugRelatedFieldWithOrganization(
Expand Down Expand Up @@ -1019,9 +1021,10 @@ def validate(self, data):
x["range_end"] for x in tiers_sorted[:-1]
), "All tiers must have an end, last one is the only one allowed to have open end"
for i, tier in enumerate(tiers_sorted[:-1]):
assert tiers_sorted[i + 1]["range_start"] - tier[
"range_end"
] <= Decimal(1), "All tiers must be contiguous"
diff = tiers_sorted[i + 1]["range_start"] - tier["range_end"]
assert diff == Decimal(1) or diff == Decimal(
0
), "Tier ranges must be continuous or separated by 1"
except AssertionError as e:
raise serializers.ValidationError(str(e))
data["invoicing_interval_unit"] = PlanComponent.convert_length_label_to_value(
Expand Down
185 changes: 185 additions & 0 deletions backend/metering_billing/tests/test_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import itertools
from datetime import timedelta
from decimal import Decimal

import pytest
from model_bakery import baker
from rest_framework.test import APIClient

from metering_billing.models import Event, Metric, PlanComponent, PriceTier
from metering_billing.utils import now_utc


@pytest.fixture
def components_test_common_setup(
generate_org_and_api_key,
add_users_to_org,
api_client_with_api_key_auth,
add_customers_to_org,
add_product_to_org,
add_plan_to_product,
add_plan_version_to_plan,
add_subscription_record_to_org,
):
def do_components_test_common_setup(*, auth_method):
setup_dict = {}
# set up organizations and api keys
org, key = generate_org_and_api_key()
org2, key2 = generate_org_and_api_key()
setup_dict = {
"org": org,
"key": key,
"org2": org2,
"key2": key2,
}
org.subscription_filter_keys = ["email"]
org.save()
# set up the client with the appropriate api key spec
if auth_method == "api_key":
client = api_client_with_api_key_auth(key)
elif auth_method == "session_auth":
client = APIClient()
(user,) = add_users_to_org(org, n=1)
client.force_authenticate(user=user)
setup_dict["user"] = user
else:
client = api_client_with_api_key_auth(key)
(user,) = add_users_to_org(org, n=1)
client.force_authenticate(user=user)
setup_dict["user"] = user
setup_dict["client"] = client
(customer,) = add_customers_to_org(org, n=1)
setup_dict["customer"] = customer
event_properties = (
{"num_characters": 350, "peak_bandwith": 65},
{"num_characters": 125, "peak_bandwith": 148},
{"num_characters": 543, "peak_bandwith": 16},
)
baker.make(
Event,
organization=org,
event_name="email_sent",
time_created=now_utc() - timedelta(days=1),
properties=itertools.cycle(event_properties),
_quantity=3,
)
metric_set = baker.make(
Metric,
billable_metric_name=itertools.cycle(
["Email Character Count", "Peak Bandwith", "Email Count"]
),
organization=org,
event_name="email_sent",
property_name=itertools.cycle(["num_characters", "peak_bandwith", ""]),
usage_aggregation_type=itertools.cycle(["sum", "max", "count"]),
_quantity=3,
)
for metric in metric_set:
metric.provision_materialized_views()
setup_dict["metrics"] = metric_set
product = add_product_to_org(org)
plan = add_plan_to_product(product)
plan_version = add_plan_version_to_plan(plan)
for i, (fmu, cpb, mupb) in enumerate(
zip([50, 0, 1], [0.01, 0.05, 2], [1, 1, 1])
):
pc = PlanComponent.objects.create(
plan_version=plan_version,
billable_metric=metric_set[i],
)
start = 0
if fmu > 0:
PriceTier.objects.create(
plan_component=pc,
type=PriceTier.PriceTierType.FREE,
range_start=0,
range_end=fmu,
)
start = fmu
PriceTier.objects.create(
plan_component=pc,
type=PriceTier.PriceTierType.PER_UNIT,
range_start=start,
cost_per_batch=cpb,
metric_units_per_batch=mupb,
)
setup_dict["billing_plan"] = plan_version
subscription_record = add_subscription_record_to_org(
org, plan_version, customer, now_utc() - timedelta(days=3)
)
setup_dict["subscription_record"] = subscription_record

return setup_dict

return do_components_test_common_setup


@pytest.mark.django_db(transaction=True)
class TestBulkPricing:
def test_bulk_pricing(self, components_test_common_setup):
setup_dict = components_test_common_setup(auth_method="api_key")
metric = setup_dict["metrics"][0]
component = setup_dict["billing_plan"].plan_components.get(
billable_metric=metric
)
assert component.tiers.count() == 2
assert component.tiers.filter(type=PriceTier.PriceTierType.FREE).count() == 1
assert (
component.tiers.filter(type=PriceTier.PriceTierType.PER_UNIT).count() == 1
)

revenue_no_bulk = component.tier_rating_function(100)
# we have 50 free, then 1 cent per unit
assert revenue_no_bulk == Decimal("0.50")

# now we convert the component to bulk pricing
component.bulk_pricing_enabled = True
component.save()
revenue_bulk = component.tier_rating_function(100)
# everything charged at 1 cent per unit
assert revenue_bulk == Decimal("1.00")

def test_bulk_pricing_edge_case(self, components_test_common_setup):
setup_dict = components_test_common_setup(auth_method="api_key")
metric = setup_dict["metrics"][0]
component = setup_dict["billing_plan"].plan_components.get(
billable_metric=metric
)
assert component.tiers.count() == 2
assert component.tiers.filter(type=PriceTier.PriceTierType.FREE).count() == 1
assert (
component.tiers.filter(type=PriceTier.PriceTierType.PER_UNIT).count() == 1
)

# first, lets add 2 more tiers to the component so we can test this edge case
old_last_pt = component.tiers.order_by("range_start").last()
old_last_pt.range_end = 100
old_last_pt.save()
PriceTier.objects.create(
plan_component=component,
type=PriceTier.PriceTierType.PER_UNIT,
range_start=100,
range_end=200,
cost_per_batch=0.05,
metric_units_per_batch=1,
)
PriceTier.objects.create(
plan_component=component,
type=PriceTier.PriceTierType.PER_UNIT,
range_start=200,
range_end=300,
cost_per_batch=0.10,
metric_units_per_batch=1,
)
# this means 0-50 free, 50-100 1 cent per unit, 100-200 5 cents per unit, 200-300 10 cents per unit

revenue_no_bulk = component.tier_rating_function(200)
# we have 50 free, then 50 at 1 cent per unit, then 100 at 5 cents per unit, for a total of 5.50
assert revenue_no_bulk == Decimal("5.50")

# now we convert the component to bulk pricing
component.bulk_pricing_enabled = True
component.save()
revenue_bulk = component.tier_rating_function(200)
# everything charged at 10 cents per unit
assert revenue_bulk == Decimal("20.00")
Loading

0 comments on commit 01b39a5

Please sign in to comment.