diff --git a/backend/metering_billing/permissions.py b/backend/metering_billing/permissions.py
index 331bdb24e..5875776c8 100644
--- a/backend/metering_billing/permissions.py
+++ b/backend/metering_billing/permissions.py
@@ -28,6 +28,7 @@ def has_permission(self, request, view):
org = request.organization
if org is None and request.user.is_authenticated:
org = request.user.organization
+ request.organization = org
return org is not None
def has_object_permission(self, request, view, obj):
diff --git a/backend/metering_billing/serializers/model_serializers.py b/backend/metering_billing/serializers/model_serializers.py
index 6eee10dcc..86a5998a4 100644
--- a/backend/metering_billing/serializers/model_serializers.py
+++ b/backend/metering_billing/serializers/model_serializers.py
@@ -304,19 +304,17 @@ class Meta(SubscriptionCustomerSummarySerializer.Meta):
class CustomerWithRevenueSerializer(serializers.ModelSerializer):
class Meta:
model = Customer
- fields = ("customer_id", "total_amount_due", "next_amount_due")
+ fields = (
+ "customer_id",
+ "total_amount_due",
+ )
total_amount_due = serializers.SerializerMethodField()
- next_amount_due = serializers.SerializerMethodField()
def get_total_amount_due(self, obj) -> float:
total_amount_due = float(self.context.get("total_amount_due"))
return total_amount_due
- def get_next_amount_due(self, obj) -> float:
- next_amount_due = float(self.context.get("next_amount_due"))
- return next_amount_due
-
class CustomerSerializer(serializers.ModelSerializer):
class Meta:
@@ -1441,6 +1439,30 @@ class Meta(SubscriptionRecordSerializer.Meta):
)
+class LightweightPlanVersionSerializer(PlanVersionSerializer):
+ class Meta(PlanVersionSerializer.Meta):
+ model = PlanVersion
+ fields = ("plan_id", "plan_name", "version_id")
+
+ plan_name = serializers.CharField(read_only=True, source="plan.plan_name")
+ plan_id = serializers.CharField(read_only=True, source="plan.plan_id")
+
+
+class LightweightSubscriptionRecordSerializer(SubscriptionRecordSerializer):
+ class Meta(SubscriptionRecordSerializer.Meta):
+ model = SubscriptionRecord
+ fields = tuple(
+ set(SubscriptionRecordSerializer.Meta.fields).union(set(["plan_detail"]))
+ )
+
+ plan_detail = LightweightPlanVersionSerializer(
+ source="billing_plan", read_only=True
+ )
+ subscription_filters = SubscriptionCategoricalFilterSerializer(
+ source="filters", many=True, read_only=True
+ )
+
+
class SubscriptionSerializer(serializers.ModelSerializer):
class Meta:
model = Subscription
@@ -1458,9 +1480,14 @@ class Meta:
customer = ShortCustomerSerializer(read_only=True)
plans = serializers.SerializerMethodField()
- def get_plans(self, obj) -> SubscriptionRecordDetailSerializer(many=True):
- sub_records = obj.get_subscription_records()
- data = SubscriptionRecordDetailSerializer(sub_records, many=True).data
+ def get_plans(self, obj) -> LightweightSubscriptionRecordSerializer(many=True):
+ sub_records = obj.get_subscription_records().prefetch_related(
+ "billing_plan",
+ "filters",
+ "billing_plan__plan",
+ "billing_plan__pricing_unit",
+ )
+ data = LightweightSubscriptionRecordSerializer(sub_records, many=True).data
return data
@@ -1869,6 +1896,20 @@ class Meta:
line_items = InvoiceLineItemSerializer(many=True, read_only=True)
+class LightweightInvoiceSerializer(InvoiceSerializer):
+ class Meta(InvoiceSerializer.Meta):
+ fields = tuple(
+ set(InvoiceSerializer.Meta.fields)
+ - set(
+ [
+ "line_items",
+ "customer",
+ "subscription",
+ ]
+ )
+ )
+
+
class InvoiceListFilterSerializer(serializers.Serializer):
customer_id = serializers.CharField(required=False)
payment_status = serializers.MultipleChoiceField(
@@ -1986,7 +2027,6 @@ class Meta:
"customer_name",
"invoices",
"total_amount_due",
- "next_amount_due",
"subscription",
"integrations",
"default_currency",
@@ -1995,7 +2035,6 @@ class Meta:
subscription = serializers.SerializerMethodField(allow_null=True)
invoices = serializers.SerializerMethodField()
total_amount_due = serializers.SerializerMethodField()
- next_amount_due = serializers.SerializerMethodField()
default_currency = PricingUnitSerializer()
def get_subscription(self, obj) -> SubscriptionSerializer:
@@ -2005,15 +2044,14 @@ def get_subscription(self, obj) -> SubscriptionSerializer:
else:
return SubscriptionSerializer(sub_obj).data
- def get_invoices(self, obj) -> InvoiceSerializer(many=True):
- timeline = self.context.get("invoices")
- timeline = InvoiceSerializer(timeline, many=True).data
+ def get_invoices(self, obj) -> LightweightInvoiceSerializer(many=True):
+ timeline = obj.invoices.filter(
+ ~Q(payment_status=INVOICE_STATUS.DRAFT),
+ organization=self.context.get("organization"),
+ ).order_by("-issue_date")
+ timeline = LightweightInvoiceSerializer(timeline, many=True).data
return timeline
def get_total_amount_due(self, obj) -> float:
- total_amount_due = float(self.context.get("total_amount_due"))
+ total_amount_due = float(obj.get_outstanding_revenue())
return total_amount_due
-
- def get_next_amount_due(self, obj) -> float:
- next_amount_due = float(self.context.get("next_amount_due"))
- return next_amount_due
diff --git a/backend/metering_billing/tests/test_billable_metric.py b/backend/metering_billing/tests/test_billable_metric.py
index 289de81fa..b93e0bc82 100644
--- a/backend/metering_billing/tests/test_billable_metric.py
+++ b/backend/metering_billing/tests/test_billable_metric.py
@@ -19,6 +19,7 @@
PlanVersion,
PriceTier,
SubscriptionRecord,
+ User,
)
from metering_billing.utils import now_utc
from metering_billing.utils.enums import (
@@ -30,6 +31,7 @@
METRIC_TYPE,
NUMERIC_FILTER_OPERATORS,
PLAN_DURATION,
+ PLAN_STATUS,
PLAN_VERSION_STATUS,
PRICE_TIER_TYPE,
USAGE_CALC_GRANULARITY,
@@ -158,34 +160,6 @@ def test_session_auth_can_create_billable_metric_nonempty_before(
== num_billable_metrics + 1
)
- def test_user_org_and_api_key_different_reject_creation(
- self,
- billable_metric_test_common_setup,
- insert_billable_metric_payload,
- get_billable_metrics_in_org,
- ):
- # covers user_org_and_api_key_org_different = True
- num_billable_metrics = 3
- setup_dict = billable_metric_test_common_setup(
- num_billable_metrics=num_billable_metrics,
- auth_method="both",
- user_org_and_api_key_org_different=True,
- )
-
- response = setup_dict["client"].post(
- reverse("metric-list"),
- data=json.dumps(insert_billable_metric_payload, cls=DjangoJSONEncoder),
- content_type="application/json",
- )
-
- assert response.status_code == status.HTTP_401_UNAUTHORIZED
- assert (
- len(get_billable_metrics_in_org(setup_dict["org"])) == num_billable_metrics
- )
- assert (
- len(get_billable_metrics_in_org(setup_dict["org2"])) == num_billable_metrics
- )
-
def test_billable_metric_exists_reject_creation(
self,
billable_metric_test_common_setup,
@@ -261,6 +235,7 @@ def test_cant_archive_with_active_plan_version(
description="test_plan for testing",
flat_rate=30.0,
plan=plan,
+ status=PLAN_VERSION_STATUS.ACTIVE,
)
plan.display_version = billing_plan
plan.save()
@@ -290,6 +265,15 @@ def test_cant_archive_with_active_plan_version(
setup_dict["billing_plan"] = billing_plan
payload = {"status": METRIC_STATUS.ARCHIVED}
+ assert billing_plan.status == PLAN_VERSION_STATUS.ACTIVE
+ assert billing_plan.plan.status == PLAN_STATUS.ACTIVE
+ assert billing_plan.plan_components.count() == 3
+ all_pcs = billing_plan.plan_components.all()
+ assert (
+ all_pcs[0].billable_metric == metric_set[0]
+ or all_pcs[1].billable_metric == metric_set[0]
+ or all_pcs[2].billable_metric == metric_set[0]
+ )
response = setup_dict["client"].patch(
reverse("metric-detail", kwargs={"metric_id": metric_set[0].metric_id}),
data=json.dumps(payload, cls=DjangoJSONEncoder),
@@ -1310,7 +1294,7 @@ def test_proration_and_metric_granularity_sub_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
- auth_method="session_auth",
+ auth_method="api_key",
user_org_and_api_key_org_different=False,
)
billable_metric = Metric.objects.create(
@@ -1424,7 +1408,7 @@ def test_metric_granularity_daily_proration_smaller_than_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
- auth_method="session_auth",
+ auth_method="api_key",
user_org_and_api_key_org_different=False,
)
billable_metric = Metric.objects.create(
@@ -1531,7 +1515,7 @@ def test_metric_granularity_greater_than_daily_proration_smaller_than_day(
num_billable_metrics = 0
setup_dict = billable_metric_test_common_setup(
num_billable_metrics=num_billable_metrics,
- auth_method="session_auth",
+ auth_method="api_key",
user_org_and_api_key_org_different=False,
)
plan = setup_dict["plan"]
diff --git a/backend/metering_billing/tests/test_customer.py b/backend/metering_billing/tests/test_customer.py
index 8e34d0761..47baffa56 100644
--- a/backend/metering_billing/tests/test_customer.py
+++ b/backend/metering_billing/tests/test_customer.py
@@ -97,22 +97,6 @@ def test_session_auth_can_access_customers_multiple(
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == num_customers
- def test_user_org_and_api_key_different_reject_access(
- self, customer_test_common_setup
- ):
- # covers user_org_and_api_key_org_different = true
- num_customers = 3
- setup_dict = customer_test_common_setup(
- num_customers=num_customers,
- auth_method="both",
- user_org_and_api_key_org_different=True,
- )
-
- payload = {}
- response = setup_dict["client"].get(reverse("customer-list"), payload)
-
- assert response.status_code == status.HTTP_401_UNAUTHORIZED
-
@pytest.fixture
def insert_customer_payload():
@@ -183,27 +167,6 @@ def test_session_auth_can_create_customer_nonempty_before(
assert len(response.data) > 0
assert len(get_customers_in_org(setup_dict["org"])) == num_customers + 1
- def test_user_org_and_api_key_different_reject_creation(
- self, customer_test_common_setup, insert_customer_payload, get_customers_in_org
- ):
- # covers user_org_and_api_key_org_different = True
- num_customers = 3
- setup_dict = customer_test_common_setup(
- num_customers=num_customers,
- auth_method="both",
- user_org_and_api_key_org_different=True,
- )
-
- response = setup_dict["client"].post(
- reverse("customer-list"),
- data=json.dumps(insert_customer_payload, cls=DjangoJSONEncoder),
- content_type="application/json",
- )
-
- assert response.status_code == status.HTTP_401_UNAUTHORIZED
- assert len(get_customers_in_org(setup_dict["org"])) == num_customers
- assert len(get_customers_in_org(setup_dict["org2"])) == num_customers
-
def test_customer_id_already_exists_within_org_reject_creation(
self, customer_test_common_setup, insert_customer_payload, get_customers_in_org
):
diff --git a/backend/metering_billing/tests/test_draft_invoices.py b/backend/metering_billing/tests/test_draft_invoices.py
index d154e1aa9..2583f22d8 100644
--- a/backend/metering_billing/tests/test_draft_invoices.py
+++ b/backend/metering_billing/tests/test_draft_invoices.py
@@ -127,7 +127,7 @@ def do_draft_invoice_test_common_setup(*, auth_method):
@pytest.mark.django_db(transaction=True)
class TestGenerateInvoice:
def test_generate_invoice(self, draft_invoice_test_common_setup):
- setup_dict = draft_invoice_test_common_setup(auth_method="session_auth")
+ setup_dict = draft_invoice_test_common_setup(auth_method="api_key")
active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
@@ -140,7 +140,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup):
).count()
payload = {"customer_id": setup_dict["customer"].customer_id}
response = setup_dict["client"].get(reverse("draft_invoice"), payload)
-
+ print(response.data)
assert response.status_code == status.HTTP_200_OK
after_active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
@@ -156,7 +156,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup):
def test_generate_invoice_with_price_adjustments(
self, draft_invoice_test_common_setup
):
- setup_dict = draft_invoice_test_common_setup(auth_method="session_auth")
+ setup_dict = draft_invoice_test_common_setup(auth_method="api_key")
active_subscriptions = Subscription.objects.active().filter(
organization=setup_dict["org"],
diff --git a/backend/metering_billing/tests/test_integrations.py b/backend/metering_billing/tests/test_integrations.py
index 7b8533e84..e89ee0e2c 100644
--- a/backend/metering_billing/tests/test_integrations.py
+++ b/backend/metering_billing/tests/test_integrations.py
@@ -228,7 +228,7 @@ def test_stripe_end_to_end(self, integration_test_common_setup):
)
# now lets test out the replace now
- Subscription.objects.filter(
+ SubscriptionRecord.objects.filter(
organization=setup_dict["org"],
start_date__gte=now_utc(),
billing_plan=setup_dict["plan"].display_version,
@@ -237,12 +237,10 @@ def test_stripe_end_to_end(self, integration_test_common_setup):
time.sleep(10)
stripe_sub = stripe.Subscription.retrieve(stripe_sub.id)
assert (
- Subscription.objects.active()
- .filter(
+ SubscriptionRecord.objects.filter(
organization=setup_dict["org"],
billing_plan=setup_dict["plan"].display_version,
- )
- .count()
+ ).count()
== 1
)
assert stripe_sub.status == "canceled"
diff --git a/backend/metering_billing/tests/test_subscription.py b/backend/metering_billing/tests/test_subscription.py
index 891159024..e6603bcb4 100644
--- a/backend/metering_billing/tests/test_subscription.py
+++ b/backend/metering_billing/tests/test_subscription.py
@@ -202,26 +202,6 @@ def test_session_auth_can_create_subscription_nonempty_before(
== num_subscription_records_before + 1
)
- def test_user_org_and_api_key_different_reject_creation(
- self, subscription_test_common_setup, get_subscriptions_in_org
- ):
- # covers user_org_and_api_key_org_different = True
- num_subscriptions = 1
- setup_dict = subscription_test_common_setup(
- num_subscriptions=num_subscriptions,
- auth_method="both",
- user_org_and_api_key_org_different=True,
- )
-
- response = setup_dict["client"].post(
- reverse("subscription-plans"),
- data=json.dumps(setup_dict["payload"], cls=DjangoJSONEncoder),
- content_type="application/json",
- )
-
- assert response.status_code == status.HTTP_401_UNAUTHORIZED
- assert len(get_subscriptions_in_org(setup_dict["org"])) == num_subscriptions
-
@pytest.mark.django_db(transaction=True)
class TestUpdateSub:
diff --git a/backend/metering_billing/views/model_views.py b/backend/metering_billing/views/model_views.py
index 3f11a9434..ab4e742ee 100644
--- a/backend/metering_billing/views/model_views.py
+++ b/backend/metering_billing/views/model_views.py
@@ -140,17 +140,17 @@ class APITokenViewSet(
lookup_field = "prefix"
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return APIToken.objects.filter(organization=organization)
def get_serializer_context(self):
context = super(APITokenViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
def perform_create(self, serializer):
- organization = parse_organization(self.request)
+ organization = self.request.organization
api_key, key = serializer.save(organization=organization)
return api_key, key
@@ -236,18 +236,18 @@ class WebhookViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
}
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return WebhookEndpoint.objects.filter(organization=organization)
def get_serializer_context(self):
context = super(WebhookViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
def perform_create(self, serializer):
try:
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
except ValueError as e:
raise ServerError(e)
except IntegrityError as e:
@@ -269,7 +269,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -307,7 +307,7 @@ class EventViewSet(
def get_queryset(self):
now = now_utc()
- organization = parse_organization(self.request)
+ organization = self.request.organization
return (
super()
.get_queryset()
@@ -316,7 +316,7 @@ def get_queryset(self):
def get_serializer_context(self):
context = super(EventViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -331,17 +331,17 @@ class UserViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
http_method_names = ["get", "post", "head"]
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return User.objects.filter(organization=organization)
def get_serializer_context(self):
context = super(UserViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
class CustomerViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
@@ -353,23 +353,10 @@ class CustomerViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
http_method_names = ["get", "post", "head", "patch"]
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
qs = Customer.objects.filter(organization=organization)
if self.action == "retrieve":
- qs = qs.prefetch_related(
- Prefetch(
- "subscription_records",
- queryset=SubscriptionRecord.objects.filter(
- organization=organization,
- status=SUBSCRIPTION_STATUS.ACTIVE,
- ),
- ),
- Prefetch(
- "subscription_records__billing_plan",
- queryset=PlanVersion.objects.filter(organization=organization),
- to_attr="billing_plans",
- ),
- )
+ qs = qs.prefetch_related("subscriptions", "invoices")
return qs
def get_serializer_class(self):
@@ -381,7 +368,7 @@ def get_serializer_class(self):
def perform_create(self, serializer):
try:
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
except IntegrityError as e:
cause = e.__cause__
if "unique_email" in str(cause):
@@ -392,24 +379,7 @@ def perform_create(self, serializer):
def get_serializer_context(self):
context = super(CustomerViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
- context.update({"organization": organization})
- if self.action == "retrieve":
- customer = self.get_object()
- total_amount_due = customer.get_outstanding_revenue()
- next_amount_due = customer.get_active_sub_drafts_revenue()
- invoices = Invoice.objects.filter(
- ~Q(payment_status=INVOICE_STATUS.DRAFT),
- organization=organization,
- customer=customer,
- ).order_by("-issue_date")
- context.update(
- {
- "total_amount_due": total_amount_due,
- "invoices": invoices,
- "next_amount_due": next_amount_due,
- }
- )
+ context.update({"organization": self.request.organization})
return context
def dispatch(self, request, *args, **kwargs):
@@ -419,7 +389,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization or self.request.user.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -445,7 +415,7 @@ class MetricViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
}
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return Metric.objects.filter(
organization=organization, status=METRIC_STATUS.ACTIVE
)
@@ -457,7 +427,7 @@ def get_serializer_class(self):
def get_serializer_context(self):
context = super(MetricViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -468,7 +438,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -482,7 +452,7 @@ def dispatch(self, request, *args, **kwargs):
def perform_create(self, serializer):
try:
- instance = serializer.save(organization=parse_organization(self.request))
+ instance = serializer.save(organization=self.request.organization)
except IntegrityError as e:
cause = e.__cause__
if "unique_org_metric_id" in str(cause):
@@ -511,12 +481,12 @@ class FeatureViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
}
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return Feature.objects.filter(organization=organization)
def get_serializer_context(self):
context = super(FeatureViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -527,7 +497,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -540,7 +510,7 @@ def dispatch(self, request, *args, **kwargs):
return response
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
class PlanVersionViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
@@ -566,7 +536,7 @@ def get_serializer_class(self):
return PlanVersionSerializer
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
qs = PlanVersion.objects.filter(
organization=organization,
)
@@ -574,7 +544,7 @@ def get_queryset(self):
def get_serializer_context(self):
context = super(PlanVersionViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
if self.request.user.is_authenticated:
user = self.request.user
else:
@@ -589,7 +559,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -607,7 +577,7 @@ def perform_create(self, serializer):
else:
user = None
instance = serializer.save(
- organization=parse_organization(self.request), created_by=user
+ organization=self.request.organization, created_by=user
)
# if user:
# action.send(
@@ -655,7 +625,7 @@ class PlanViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
}
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
qs = Plan.objects.filter(organization=organization, status=PLAN_STATUS.ACTIVE)
if self.action == "retrieve":
qs = qs.prefetch_related(
@@ -683,7 +653,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -704,7 +674,7 @@ def get_serializer_class(self):
def get_serializer_context(self):
context = super(PlanViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
if self.request.user.is_authenticated:
user = self.request.user
else:
@@ -718,7 +688,7 @@ def perform_create(self, serializer):
else:
user = None
instance = serializer.save(
- organization=parse_organization(self.request), created_by=user
+ organization=self.request.organization, created_by=user
)
# if user:
# action.send(
@@ -751,7 +721,7 @@ class SubscriptionViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
def get_serializer_context(self):
context = super(SubscriptionViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -766,7 +736,7 @@ def get_serializer_class(self):
return SubscriptionSerializer
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
# need for: list, update_plans, cancel_plans
if self.action == "list":
args = []
@@ -813,7 +783,7 @@ def get_queryset(self):
args.append(Q(customer=serializer.validated_data["customer"]))
if serializer.validated_data.get("plan_id"):
args.append(Q(billing_plan__plan=serializer.validated_data["plan"]))
- organization = parse_organization(self.request)
+ organization = self.request.organization
args.append(Q(organization=organization))
qs = (
SubscriptionRecord.objects.filter(*args)
@@ -869,7 +839,7 @@ def update(self, request, *args, **kwargs):
@action(detail=False, methods=["post"])
def plans(self, request, *args, **kwargs):
# run checks to make sure it's valid
- organization = parse_organization(self.request)
+ organization = self.request.organization
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
plan_name = serializer.validated_data["billing_plan"].plan.plan_name
@@ -955,7 +925,7 @@ def plans(self, request, *args, **kwargs):
)
def cancel_plans(self, request, *args, **kwargs):
qs = self.get_queryset()
- organization = parse_organization(self.request)
+ organization = self.request.organization
serializer = self.get_serializer(data=self.request.query_params)
serializer.is_valid(raise_exception=True)
flat_fee_behavior = serializer.validated_data["flat_fee_behavior"]
@@ -1000,7 +970,7 @@ def cancel_plans(self, request, *args, **kwargs):
)
def update_plans(self, request, *args, **kwargs):
qs = self.get_queryset()
- organization = parse_organization(self.request)
+ organization = self.request.organization
original_qs = list(copy.copy(qs).values_list("pk", flat=True))
if qs.count() == 0:
raise NotFoundException("Subscription matching the given filters not found")
@@ -1090,7 +1060,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -1127,7 +1097,7 @@ class InvoiceViewSet(PermissionPolicyMixin, viewsets.ModelViewSet):
def get_queryset(self):
args = [
~Q(payment_status=INVOICE_STATUS.DRAFT),
- Q(organization=parse_organization(self.request)),
+ Q(organization=self.request.organization),
]
if self.action == "list":
args = []
@@ -1150,7 +1120,7 @@ def get_serializer_class(self):
def get_serializer_context(self):
context = super(InvoiceViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1161,7 +1131,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -1205,11 +1175,11 @@ def get_serializer_class(self):
return BacktestCreateSerializer
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return Backtest.objects.filter(organization=organization)
def perform_create(self, serializer):
- backtest_obj = serializer.save(organization=parse_organization(self.request))
+ backtest_obj = serializer.save(organization=self.request.organization)
bt_id = backtest_obj.backtest_id
run_backtest.delay(bt_id)
@@ -1220,7 +1190,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -1234,7 +1204,7 @@ def dispatch(self, request, *args, **kwargs):
def get_serializer_context(self):
context = super(BacktestViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1253,11 +1223,11 @@ class ProductViewSet(viewsets.ModelViewSet):
]
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return Product.objects.filter(organization=organization)
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
def dispatch(self, request, *args, **kwargs):
response = super().dispatch(request, *args, **kwargs)
@@ -1266,7 +1236,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -1280,7 +1250,7 @@ def dispatch(self, request, *args, **kwargs):
def get_serializer_context(self):
context = super(ProductViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1304,7 +1274,7 @@ class ActionViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
]
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return (
super()
.get_queryset()
@@ -1329,7 +1299,7 @@ class ExternalPlanLinkViewSet(viewsets.ModelViewSet):
http_method_names = ["post", "head", "delete"]
def get_queryset(self):
- filter_kwargs = {"organization": parse_organization(self.request)}
+ filter_kwargs = {"organization": self.request.organization}
source = self.request.query_params.get("source")
if source:
filter_kwargs["source"] = source
@@ -1342,7 +1312,7 @@ def dispatch(self, request, *args, **kwargs):
username = self.request.user.username
except:
username = None
- organization = parse_organization(self.request)
+ organization = self.request.organization
posthog.capture(
POSTHOG_PERSON
if POSTHOG_PERSON
@@ -1355,11 +1325,11 @@ def dispatch(self, request, *args, **kwargs):
return response
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
def get_serializer_context(self):
context = super(ExternalPlanLinkViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1388,7 +1358,7 @@ class OrganizationSettingViewSet(viewsets.ModelViewSet):
lookup_field = "setting_id"
def get_queryset(self):
- filter_kwargs = {"organization": parse_organization(self.request)}
+ filter_kwargs = {"organization": self.request.organization}
setting_name = self.request.query_params.get("setting_name")
if setting_name:
filter_kwargs["setting_name"] = setting_name
@@ -1410,15 +1380,15 @@ class PricingUnitViewSet(
http_method_names = ["get", "post", "head"]
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return PricingUnit.objects.filter(organization=organization)
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
def get_serializer_context(self):
context = super(PricingUnitViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1442,7 +1412,7 @@ class OrganizationViewSet(
lookup_field = "organization_id"
def get_queryset(self):
- organization = parse_organization(self.request)
+ organization = self.request.organization
return Organization.objects.filter(pk=organization.pk)
def get_object(self):
@@ -1457,7 +1427,7 @@ def get_serializer_class(self):
def get_serializer_context(self):
context = super(OrganizationViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
@@ -1484,7 +1454,7 @@ class CustomerBalanceAdjustmentViewSet(
lookup_field = "adjustment_id"
def get_queryset(self):
- filter_kwargs = {"organization": parse_organization(self.request)}
+ filter_kwargs = {"organization": self.request.organization}
customer_id = self.request.query_params.get("customer_id")
if customer_id:
filter_kwargs["customer__customer_id"] = customer_id
@@ -1492,12 +1462,12 @@ def get_queryset(self):
def get_serializer_context(self):
context = super(CustomerBalanceAdjustmentViewSet, self).get_serializer_context()
- organization = parse_organization(self.request)
+ organization = self.request.organization
context.update({"organization": organization})
return context
def perform_create(self, serializer):
- serializer.save(organization=parse_organization(self.request))
+ serializer.save(organization=self.request.organization)
@extend_schema(
parameters=[
diff --git a/backend/metering_billing/views/organization_views.py b/backend/metering_billing/views/organization_views.py
index 912c0484a..599551d7c 100644
--- a/backend/metering_billing/views/organization_views.py
+++ b/backend/metering_billing/views/organization_views.py
@@ -14,21 +14,6 @@
POSTHOG_PERSON = settings.POSTHOG_PERSON
DEFAULT_FROM_EMAIL = settings.DEFAULT_FROM_EMAIL
-# class OrganizationView(APIView):
-# permission_classes = [IsAuthenticated]
-
-# def get(self, request, format=None):
-# """
-# Get the current settings for the organization.
-# """
-# organization = parse_organization(request)
-# OrganizationSerializer(organization).data
-# team_members = organization.org_users.all().values_list("email", flat=True)
-# return Response(
-# {"organization": organization.company_name, "team_members": team_members},
-# status=status.HTTP_200_OK,
-# )
-
class InviteView(APIView):
permission_classes = [IsAuthenticated & ValidOrganization]
@@ -36,7 +21,7 @@ class InviteView(APIView):
def post(self, request, *args, **kwargs):
email = request.data.get("email", None)
user = request.user
- organization = parse_organization(request)
+ organization = request.organization
token_object, created = OrganizationInviteToken.objects.get_or_create(
organization=organization, email=email, defaults={"user": user}
diff --git a/backend/metering_billing/views/payment_provider_views.py b/backend/metering_billing/views/payment_provider_views.py
index eba86100f..2f5aa806d 100644
--- a/backend/metering_billing/views/payment_provider_views.py
+++ b/backend/metering_billing/views/payment_provider_views.py
@@ -20,7 +20,7 @@ class PaymentProviderView(APIView):
responses={200: SinglePaymentProviderSerializer(many=True)},
)
def get(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
response = []
for payment_processor_name, pp_obj in PAYMENT_PROVIDER_MAP.items():
pp_response = {
@@ -38,7 +38,7 @@ def get(self, request, format=None):
responses={200: PaymentProviderPostResponseSerializer},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
# parse outer level request
serializer = PaymentProviderPostRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
diff --git a/backend/metering_billing/views/views.py b/backend/metering_billing/views/views.py
index ff5d9d671..349a747d1 100644
--- a/backend/metering_billing/views/views.py
+++ b/backend/metering_billing/views/views.py
@@ -53,7 +53,7 @@ def get(self, request, format=None):
"""
Returns the revenue for an organization in a given time period.
"""
- organization = parse_organization(request)
+ organization = request.organization
serializer = PeriodComparisonRequestSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
p1_start, p1_end, p2_start, p2_end = [
@@ -135,7 +135,8 @@ def get(self, request, format=None):
"""
Returns the revenue for an organization in a given time period.
"""
- organization = parse_organization(request)
+ organization = request.organization
+ organization = request.organization
serializer = CostAnalysisRequestSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
start_date, end_date, customer_id = [
@@ -244,7 +245,7 @@ class PeriodSubscriptionsView(APIView):
responses={200: PeriodSubscriptionsResponseSerializer},
)
def get(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
serializer = PeriodComparisonRequestSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
p1_start, p1_end, p2_start, p2_end = [
@@ -297,7 +298,7 @@ def get(self, request, format=None):
"""
Return current usage for a customer during a given billing period.
"""
- organization = parse_organization(request)
+ organization = request.organization
serializer = PeriodMetricUsageRequestSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
q_start, q_end, top_n = [
@@ -404,7 +405,7 @@ def get(self, request, format=None):
"""
Revokes the current API key and returns a new one.
"""
- organization = parse_organization(request)
+ organization = request.organization
tk = APIToken.objects.filter(organization=organization).first()
if tk:
cache.delete(tk.prefix)
@@ -433,7 +434,7 @@ def get(self, request, format=None):
"""
Get the current settings for the organization.
"""
- organization = parse_organization(request)
+ organization = request.organization
return Response(
{"organization": organization.company_name}, status=status.HTTP_200_OK
)
@@ -449,7 +450,7 @@ def get(self, request, format=None):
"""
Get the current settings for the organization.
"""
- organization = parse_organization(request)
+ organization = request.organization
customers = Customer.objects.filter(organization=organization).prefetch_related(
Prefetch(
"subscription_records",
@@ -479,17 +480,15 @@ def get(self, request, format=None):
"""
Return current usage for a customer during a given billing period.
"""
- organization = parse_organization(request)
+ organization = request.organization
customers = Customer.objects.filter(organization=organization)
cust = []
for customer in customers:
total_amount_due = customer.get_outstanding_revenue()
- next_amount_due = customer.get_active_sub_drafts_revenue()
serializer = CustomerWithRevenueSerializer(
customer,
context={
"total_amount_due": total_amount_due,
- "next_amount_due": next_amount_due,
},
)
cust.append(serializer.data)
@@ -513,7 +512,7 @@ def get(self, request, format=None):
"""
Pagination-enabled endpoint for retrieving an organization's event stream.
"""
- organization = parse_organization(request)
+ organization = request.organization
serializer = DraftInvoiceRequestSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
try:
@@ -807,7 +806,7 @@ class ImportCustomersView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
source = request.data["source"]
if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]:
raise ExternalConnectionInvalid(f"Invalid source: {source}")
@@ -853,7 +852,7 @@ class ImportPaymentObjectsView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
source = request.data["source"]
if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]:
raise ExternalConnectionInvalid(f"Invalid source: {source}")
@@ -901,7 +900,7 @@ class TransferSubscriptionsView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
source = request.data["source"]
if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]:
raise ExternalConnectionInvalid(f"Invalid source: {source}")
@@ -943,7 +942,7 @@ class ExperimentalToActiveView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
serializer = ExperimentalToActiveRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
billing_plan = serializer.validated_data["version_id"]
@@ -994,7 +993,7 @@ class PlansByNumCustomersView(APIView):
},
)
def get(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
plans = (
SubscriptionRecord.objects.filter(
organization=organization, status=SUBSCRIPTION_STATUS.ACTIVE
@@ -1048,7 +1047,7 @@ class CustomerBatchCreateView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
serializer = CustomerSerializer(
data=request.data["customers"],
many=True,
@@ -1140,7 +1139,7 @@ class ConfirmIdemsReceivedView(APIView):
},
)
def post(self, request, format=None):
- organization = parse_organization(request)
+ organization = request.organization
if request.data.get("idempotency_ids") is None:
return Response(
{
diff --git a/frontend/src/components/Customers/CustomerInfo.tsx b/frontend/src/components/Customers/CustomerInfo.tsx
index f81dd4278..c58cb8b9c 100644
--- a/frontend/src/components/Customers/CustomerInfo.tsx
+++ b/frontend/src/components/Customers/CustomerInfo.tsx
@@ -1,7 +1,9 @@
// @ts-ignore
import React, { FC, useEffect } from "react";
import { Column } from "@ant-design/plots";
+import { useQueryClient } from "react-query";
import { Select, Tag } from "antd";
+import { DraftInvoiceType, LineItem } from "../../types/invoice-type";
// @ts-ignore
import dayjs from "dayjs";
import LoadingSpinner from "../LoadingSpinner";
@@ -14,6 +16,12 @@ const CustomerInfoView: FC
Amount Due On Next Invoice: {data?.default_currency?.symbol} - {data.next_amount_due.toFixed(2)} + {invoiceData?.invoices[0].cost_due.toFixed(2)}