Skip to content

Commit

Permalink
optimize return values for frontend calls (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
diego-escobedo authored Dec 16, 2022
1 parent 7b39d6c commit 7848241
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 247 deletions.
1 change: 1 addition & 0 deletions backend/metering_billing/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def has_permission(self, request, view):
org = request.organization
if org is None and request.user.is_authenticated:
org = request.user.organization
request.organization = org
return org is not None

def has_object_permission(self, request, view, obj):
Expand Down
76 changes: 57 additions & 19 deletions backend/metering_billing/serializers/model_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,17 @@ class Meta(SubscriptionCustomerSummarySerializer.Meta):
class CustomerWithRevenueSerializer(serializers.ModelSerializer):
class Meta:
model = Customer
fields = ("customer_id", "total_amount_due", "next_amount_due")
fields = (
"customer_id",
"total_amount_due",
)

total_amount_due = serializers.SerializerMethodField()
next_amount_due = serializers.SerializerMethodField()

def get_total_amount_due(self, obj) -> float:
total_amount_due = float(self.context.get("total_amount_due"))
return total_amount_due

def get_next_amount_due(self, obj) -> float:
next_amount_due = float(self.context.get("next_amount_due"))
return next_amount_due


class CustomerSerializer(serializers.ModelSerializer):
class Meta:
Expand Down Expand Up @@ -1441,6 +1439,30 @@ class Meta(SubscriptionRecordSerializer.Meta):
)


class LightweightPlanVersionSerializer(PlanVersionSerializer):
class Meta(PlanVersionSerializer.Meta):
model = PlanVersion
fields = ("plan_id", "plan_name", "version_id")

plan_name = serializers.CharField(read_only=True, source="plan.plan_name")
plan_id = serializers.CharField(read_only=True, source="plan.plan_id")


class LightweightSubscriptionRecordSerializer(SubscriptionRecordSerializer):
class Meta(SubscriptionRecordSerializer.Meta):
model = SubscriptionRecord
fields = tuple(
set(SubscriptionRecordSerializer.Meta.fields).union(set(["plan_detail"]))
)

plan_detail = LightweightPlanVersionSerializer(
source="billing_plan", read_only=True
)
subscription_filters = SubscriptionCategoricalFilterSerializer(
source="filters", many=True, read_only=True
)


class SubscriptionSerializer(serializers.ModelSerializer):
class Meta:
model = Subscription
Expand All @@ -1458,9 +1480,14 @@ class Meta:
customer = ShortCustomerSerializer(read_only=True)
plans = serializers.SerializerMethodField()

def get_plans(self, obj) -> SubscriptionRecordDetailSerializer(many=True):
sub_records = obj.get_subscription_records()
data = SubscriptionRecordDetailSerializer(sub_records, many=True).data
def get_plans(self, obj) -> LightweightSubscriptionRecordSerializer(many=True):
sub_records = obj.get_subscription_records().prefetch_related(
"billing_plan",
"filters",
"billing_plan__plan",
"billing_plan__pricing_unit",
)
data = LightweightSubscriptionRecordSerializer(sub_records, many=True).data
return data


Expand Down Expand Up @@ -1869,6 +1896,20 @@ class Meta:
line_items = InvoiceLineItemSerializer(many=True, read_only=True)


class LightweightInvoiceSerializer(InvoiceSerializer):
class Meta(InvoiceSerializer.Meta):
fields = tuple(
set(InvoiceSerializer.Meta.fields)
- set(
[
"line_items",
"customer",
"subscription",
]
)
)


class InvoiceListFilterSerializer(serializers.Serializer):
customer_id = serializers.CharField(required=False)
payment_status = serializers.MultipleChoiceField(
Expand Down Expand Up @@ -1986,7 +2027,6 @@ class Meta:
"customer_name",
"invoices",
"total_amount_due",
"next_amount_due",
"subscription",
"integrations",
"default_currency",
Expand All @@ -1995,7 +2035,6 @@ class Meta:
subscription = serializers.SerializerMethodField(allow_null=True)
invoices = serializers.SerializerMethodField()
total_amount_due = serializers.SerializerMethodField()
next_amount_due = serializers.SerializerMethodField()
default_currency = PricingUnitSerializer()

def get_subscription(self, obj) -> SubscriptionSerializer:
Expand All @@ -2005,15 +2044,14 @@ def get_subscription(self, obj) -> SubscriptionSerializer:
else:
return SubscriptionSerializer(sub_obj).data

def get_invoices(self, obj) -> InvoiceSerializer(many=True):
timeline = self.context.get("invoices")
timeline = InvoiceSerializer(timeline, many=True).data
def get_invoices(self, obj) -> LightweightInvoiceSerializer(many=True):
timeline = obj.invoices.filter(
~Q(payment_status=INVOICE_STATUS.DRAFT),
organization=self.context.get("organization"),
).order_by("-issue_date")
timeline = LightweightInvoiceSerializer(timeline, many=True).data
return timeline

def get_total_amount_due(self, obj) -> float:
total_amount_due = float(self.context.get("total_amount_due"))
total_amount_due = float(obj.get_outstanding_revenue())
return total_amount_due

def get_next_amount_due(self, obj) -> float:
next_amount_due = float(self.context.get("next_amount_due"))
return next_amount_due
46 changes: 15 additions & 31 deletions backend/metering_billing/tests/test_billable_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PlanVersion,
PriceTier,
SubscriptionRecord,
User,
)
from metering_billing.utils import now_utc
from metering_billing.utils.enums import (
Expand All @@ -30,6 +31,7 @@
METRIC_TYPE,
NUMERIC_FILTER_OPERATORS,
PLAN_DURATION,
PLAN_STATUS,
PLAN_VERSION_STATUS,
PRICE_TIER_TYPE,
USAGE_CALC_GRANULARITY,
Expand Down Expand Up @@ -158,34 +160,6 @@ def test_session_auth_can_create_billable_metric_nonempty_before(
== num_billable_metrics + 1
)

def test_user_org_and_api_key_different_reject_creation(
self,
billable_metric_test_common_setup,
insert_billable_metric_payload,
get_billable_metrics_in_org,
):
# covers user_org_and_api_key_org_different = True
num_billable_metrics = 3
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
auth_method="both",
user_org_and_api_key_org_different=True,
)

response = setup_dict["client"].post(
reverse("metric-list"),
data=json.dumps(insert_billable_metric_payload, cls=DjangoJSONEncoder),
content_type="application/json",
)

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert (
len(get_billable_metrics_in_org(setup_dict["org"])) == num_billable_metrics
)
assert (
len(get_billable_metrics_in_org(setup_dict["org2"])) == num_billable_metrics
)

def test_billable_metric_exists_reject_creation(
self,
billable_metric_test_common_setup,
Expand Down Expand Up @@ -261,6 +235,7 @@ def test_cant_archive_with_active_plan_version(
description="test_plan for testing",
flat_rate=30.0,
plan=plan,
status=PLAN_VERSION_STATUS.ACTIVE,
)
plan.display_version = billing_plan
plan.save()
Expand Down Expand Up @@ -290,6 +265,15 @@ def test_cant_archive_with_active_plan_version(
setup_dict["billing_plan"] = billing_plan

payload = {"status": METRIC_STATUS.ARCHIVED}
assert billing_plan.status == PLAN_VERSION_STATUS.ACTIVE
assert billing_plan.plan.status == PLAN_STATUS.ACTIVE
assert billing_plan.plan_components.count() == 3
all_pcs = billing_plan.plan_components.all()
assert (
all_pcs[0].billable_metric == metric_set[0]
or all_pcs[1].billable_metric == metric_set[0]
or all_pcs[2].billable_metric == metric_set[0]
)
response = setup_dict["client"].patch(
reverse("metric-detail", kwargs={"metric_id": metric_set[0].metric_id}),
data=json.dumps(payload, cls=DjangoJSONEncoder),
Expand Down Expand Up @@ -1310,7 +1294,7 @@ def test_proration_and_metric_granularity_sub_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
auth_method="session_auth",
auth_method="api_key",
user_org_and_api_key_org_different=False,
)
billable_metric = Metric.objects.create(
Expand Down Expand Up @@ -1424,7 +1408,7 @@ def test_metric_granularity_daily_proration_smaller_than_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
auth_method="session_auth",
auth_method="api_key",
user_org_and_api_key_org_different=False,
)
billable_metric = Metric.objects.create(
Expand Down Expand Up @@ -1531,7 +1515,7 @@ def test_metric_granularity_greater_than_daily_proration_smaller_than_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
auth_method="session_auth",
auth_method="api_key",
user_org_and_api_key_org_different=False,
)
plan = setup_dict["plan"]
Expand Down
37 changes: 0 additions & 37 deletions backend/metering_billing/tests/test_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,6 @@ def test_session_auth_can_access_customers_multiple(
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == num_customers

def test_user_org_and_api_key_different_reject_access(
self, customer_test_common_setup
):
# covers user_org_and_api_key_org_different = true
num_customers = 3
setup_dict = customer_test_common_setup(
num_customers=num_customers,
auth_method="both",
user_org_and_api_key_org_different=True,
)

payload = {}
response = setup_dict["client"].get(reverse("customer-list"), payload)

assert response.status_code == status.HTTP_401_UNAUTHORIZED


@pytest.fixture
def insert_customer_payload():
Expand Down Expand Up @@ -183,27 +167,6 @@ def test_session_auth_can_create_customer_nonempty_before(
assert len(response.data) > 0
assert len(get_customers_in_org(setup_dict["org"])) == num_customers + 1

def test_user_org_and_api_key_different_reject_creation(
self, customer_test_common_setup, insert_customer_payload, get_customers_in_org
):
# covers user_org_and_api_key_org_different = True
num_customers = 3
setup_dict = customer_test_common_setup(
num_customers=num_customers,
auth_method="both",
user_org_and_api_key_org_different=True,
)

response = setup_dict["client"].post(
reverse("customer-list"),
data=json.dumps(insert_customer_payload, cls=DjangoJSONEncoder),
content_type="application/json",
)

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert len(get_customers_in_org(setup_dict["org"])) == num_customers
assert len(get_customers_in_org(setup_dict["org2"])) == num_customers

def test_customer_id_already_exists_within_org_reject_creation(
self, customer_test_common_setup, insert_customer_payload, get_customers_in_org
):
Expand Down
6 changes: 3 additions & 3 deletions backend/metering_billing/tests/test_draft_invoices.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def do_draft_invoice_test_common_setup(*, auth_method):
@pytest.mark.django_db(transaction=True)
class TestGenerateInvoice:
def test_generate_invoice(self, draft_invoice_test_common_setup):
setup_dict = draft_invoice_test_common_setup(auth_method="session_auth")
setup_dict = draft_invoice_test_common_setup(auth_method="api_key")

active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
Expand All @@ -140,7 +140,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup):
).count()
payload = {"customer_id": setup_dict["customer"].customer_id}
response = setup_dict["client"].get(reverse("draft_invoice"), payload)

print(response.data)
assert response.status_code == status.HTTP_200_OK
after_active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
Expand All @@ -156,7 +156,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup):
def test_generate_invoice_with_price_adjustments(
self, draft_invoice_test_common_setup
):
setup_dict = draft_invoice_test_common_setup(auth_method="session_auth")
setup_dict = draft_invoice_test_common_setup(auth_method="api_key")

active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
Expand Down
8 changes: 3 additions & 5 deletions backend/metering_billing/tests/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_stripe_end_to_end(self, integration_test_common_setup):
)

# now lets test out the replace now
Subscription.objects.filter(
SubscriptionRecord.objects.filter(
organization=setup_dict["org"],
start_date__gte=now_utc(),
billing_plan=setup_dict["plan"].display_version,
Expand All @@ -237,12 +237,10 @@ def test_stripe_end_to_end(self, integration_test_common_setup):
time.sleep(10)
stripe_sub = stripe.Subscription.retrieve(stripe_sub.id)
assert (
Subscription.objects.active()
.filter(
SubscriptionRecord.objects.filter(
organization=setup_dict["org"],
billing_plan=setup_dict["plan"].display_version,
)
.count()
).count()
== 1
)
assert stripe_sub.status == "canceled"
Expand Down
20 changes: 0 additions & 20 deletions backend/metering_billing/tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,6 @@ def test_session_auth_can_create_subscription_nonempty_before(
== num_subscription_records_before + 1
)

def test_user_org_and_api_key_different_reject_creation(
self, subscription_test_common_setup, get_subscriptions_in_org
):
# covers user_org_and_api_key_org_different = True
num_subscriptions = 1
setup_dict = subscription_test_common_setup(
num_subscriptions=num_subscriptions,
auth_method="both",
user_org_and_api_key_org_different=True,
)

response = setup_dict["client"].post(
reverse("subscription-plans"),
data=json.dumps(setup_dict["payload"], cls=DjangoJSONEncoder),
content_type="application/json",
)

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert len(get_subscriptions_in_org(setup_dict["org"])) == num_subscriptions


@pytest.mark.django_db(transaction=True)
class TestUpdateSub:
Expand Down
Loading

0 comments on commit 7848241

Please sign in to comment.