diff --git a/backend/metering_billing/permissions.py b/backend/metering_billing/permissions.py index 331bdb24e..5875776c8 100644 --- a/backend/metering_billing/permissions.py +++ b/backend/metering_billing/permissions.py @@ -28,6 +28,7 @@ def has_permission(self, request, view): org = request.organization if org is None and request.user.is_authenticated: org = request.user.organization + request.organization = org return org is not None def has_object_permission(self, request, view, obj): diff --git a/backend/metering_billing/serializers/model_serializers.py b/backend/metering_billing/serializers/model_serializers.py index 6eee10dcc..86a5998a4 100644 --- a/backend/metering_billing/serializers/model_serializers.py +++ b/backend/metering_billing/serializers/model_serializers.py @@ -304,19 +304,17 @@ class Meta(SubscriptionCustomerSummarySerializer.Meta): class CustomerWithRevenueSerializer(serializers.ModelSerializer): class Meta: model = Customer - fields = ("customer_id", "total_amount_due", "next_amount_due") + fields = ( + "customer_id", + "total_amount_due", + ) total_amount_due = serializers.SerializerMethodField() - next_amount_due = serializers.SerializerMethodField() def get_total_amount_due(self, obj) -> float: total_amount_due = float(self.context.get("total_amount_due")) return total_amount_due - def get_next_amount_due(self, obj) -> float: - next_amount_due = float(self.context.get("next_amount_due")) - return next_amount_due - class CustomerSerializer(serializers.ModelSerializer): class Meta: @@ -1441,6 +1439,30 @@ class Meta(SubscriptionRecordSerializer.Meta): ) +class LightweightPlanVersionSerializer(PlanVersionSerializer): + class Meta(PlanVersionSerializer.Meta): + model = PlanVersion + fields = ("plan_id", "plan_name", "version_id") + + plan_name = serializers.CharField(read_only=True, source="plan.plan_name") + plan_id = serializers.CharField(read_only=True, source="plan.plan_id") + + +class LightweightSubscriptionRecordSerializer(SubscriptionRecordSerializer): + class Meta(SubscriptionRecordSerializer.Meta): + model = SubscriptionRecord + fields = tuple( + set(SubscriptionRecordSerializer.Meta.fields).union(set(["plan_detail"])) + ) + + plan_detail = LightweightPlanVersionSerializer( + source="billing_plan", read_only=True + ) + subscription_filters = SubscriptionCategoricalFilterSerializer( + source="filters", many=True, read_only=True + ) + + class SubscriptionSerializer(serializers.ModelSerializer): class Meta: model = Subscription @@ -1458,9 +1480,14 @@ class Meta: customer = ShortCustomerSerializer(read_only=True) plans = serializers.SerializerMethodField() - def get_plans(self, obj) -> SubscriptionRecordDetailSerializer(many=True): - sub_records = obj.get_subscription_records() - data = SubscriptionRecordDetailSerializer(sub_records, many=True).data + def get_plans(self, obj) -> LightweightSubscriptionRecordSerializer(many=True): + sub_records = obj.get_subscription_records().prefetch_related( + "billing_plan", + "filters", + "billing_plan__plan", + "billing_plan__pricing_unit", + ) + data = LightweightSubscriptionRecordSerializer(sub_records, many=True).data return data @@ -1869,6 +1896,20 @@ class Meta: line_items = InvoiceLineItemSerializer(many=True, read_only=True) +class LightweightInvoiceSerializer(InvoiceSerializer): + class Meta(InvoiceSerializer.Meta): + fields = tuple( + set(InvoiceSerializer.Meta.fields) + - set( + [ + "line_items", + "customer", + "subscription", + ] + ) + ) + + class InvoiceListFilterSerializer(serializers.Serializer): customer_id = serializers.CharField(required=False) payment_status = serializers.MultipleChoiceField( @@ -1986,7 +2027,6 @@ class Meta: "customer_name", "invoices", "total_amount_due", - "next_amount_due", "subscription", "integrations", "default_currency", @@ -1995,7 +2035,6 @@ class Meta: subscription = serializers.SerializerMethodField(allow_null=True) invoices = serializers.SerializerMethodField() total_amount_due = serializers.SerializerMethodField() - next_amount_due = serializers.SerializerMethodField() default_currency = PricingUnitSerializer() def get_subscription(self, obj) -> SubscriptionSerializer: @@ -2005,15 +2044,14 @@ def get_subscription(self, obj) -> SubscriptionSerializer: else: return SubscriptionSerializer(sub_obj).data - def get_invoices(self, obj) -> InvoiceSerializer(many=True): - timeline = self.context.get("invoices") - timeline = InvoiceSerializer(timeline, many=True).data + def get_invoices(self, obj) -> LightweightInvoiceSerializer(many=True): + timeline = obj.invoices.filter( + ~Q(payment_status=INVOICE_STATUS.DRAFT), + organization=self.context.get("organization"), + ).order_by("-issue_date") + timeline = LightweightInvoiceSerializer(timeline, many=True).data return timeline def get_total_amount_due(self, obj) -> float: - total_amount_due = float(self.context.get("total_amount_due")) + total_amount_due = float(obj.get_outstanding_revenue()) return total_amount_due - - def get_next_amount_due(self, obj) -> float: - next_amount_due = float(self.context.get("next_amount_due")) - return next_amount_due diff --git a/backend/metering_billing/tests/test_billable_metric.py b/backend/metering_billing/tests/test_billable_metric.py index 289de81fa..b93e0bc82 100644 --- a/backend/metering_billing/tests/test_billable_metric.py +++ b/backend/metering_billing/tests/test_billable_metric.py @@ -19,6 +19,7 @@ PlanVersion, PriceTier, SubscriptionRecord, + User, ) from metering_billing.utils import now_utc from metering_billing.utils.enums import ( @@ -30,6 +31,7 @@ METRIC_TYPE, NUMERIC_FILTER_OPERATORS, PLAN_DURATION, + PLAN_STATUS, PLAN_VERSION_STATUS, PRICE_TIER_TYPE, USAGE_CALC_GRANULARITY, @@ -158,34 +160,6 @@ def test_session_auth_can_create_billable_metric_nonempty_before( == num_billable_metrics + 1 ) - def test_user_org_and_api_key_different_reject_creation( - self, - billable_metric_test_common_setup, - insert_billable_metric_payload, - get_billable_metrics_in_org, - ): - # covers user_org_and_api_key_org_different = True - num_billable_metrics = 3 - setup_dict = billable_metric_test_common_setup( - num_billable_metrics=num_billable_metrics, - auth_method="both", - user_org_and_api_key_org_different=True, - ) - - response = setup_dict["client"].post( - reverse("metric-list"), - data=json.dumps(insert_billable_metric_payload, cls=DjangoJSONEncoder), - content_type="application/json", - ) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - assert ( - len(get_billable_metrics_in_org(setup_dict["org"])) == num_billable_metrics - ) - assert ( - len(get_billable_metrics_in_org(setup_dict["org2"])) == num_billable_metrics - ) - def test_billable_metric_exists_reject_creation( self, billable_metric_test_common_setup, @@ -261,6 +235,7 @@ def test_cant_archive_with_active_plan_version( description="test_plan for testing", flat_rate=30.0, plan=plan, + status=PLAN_VERSION_STATUS.ACTIVE, ) plan.display_version = billing_plan plan.save() @@ -290,6 +265,15 @@ def test_cant_archive_with_active_plan_version( setup_dict["billing_plan"] = billing_plan payload = {"status": METRIC_STATUS.ARCHIVED} + assert billing_plan.status == PLAN_VERSION_STATUS.ACTIVE + assert billing_plan.plan.status == PLAN_STATUS.ACTIVE + assert billing_plan.plan_components.count() == 3 + all_pcs = billing_plan.plan_components.all() + assert ( + all_pcs[0].billable_metric == metric_set[0] + or all_pcs[1].billable_metric == metric_set[0] + or all_pcs[2].billable_metric == metric_set[0] + ) response = setup_dict["client"].patch( reverse("metric-detail", kwargs={"metric_id": metric_set[0].metric_id}), data=json.dumps(payload, cls=DjangoJSONEncoder), @@ -1310,7 +1294,7 @@ def test_proration_and_metric_granularity_sub_day( num_billable_metrics = 0 setup_dict = billable_metric_test_common_setup( num_billable_metrics=num_billable_metrics, - auth_method="session_auth", + auth_method="api_key", user_org_and_api_key_org_different=False, ) billable_metric = Metric.objects.create( @@ -1424,7 +1408,7 @@ def test_metric_granularity_daily_proration_smaller_than_day( num_billable_metrics = 0 setup_dict = billable_metric_test_common_setup( num_billable_metrics=num_billable_metrics, - auth_method="session_auth", + auth_method="api_key", user_org_and_api_key_org_different=False, ) billable_metric = Metric.objects.create( @@ -1531,7 +1515,7 @@ def test_metric_granularity_greater_than_daily_proration_smaller_than_day( num_billable_metrics = 0 setup_dict = billable_metric_test_common_setup( num_billable_metrics=num_billable_metrics, - auth_method="session_auth", + auth_method="api_key", user_org_and_api_key_org_different=False, ) plan = setup_dict["plan"] diff --git a/backend/metering_billing/tests/test_customer.py b/backend/metering_billing/tests/test_customer.py index 8e34d0761..47baffa56 100644 --- a/backend/metering_billing/tests/test_customer.py +++ b/backend/metering_billing/tests/test_customer.py @@ -97,22 +97,6 @@ def test_session_auth_can_access_customers_multiple( assert response.status_code == status.HTTP_200_OK assert len(response.data) == num_customers - def test_user_org_and_api_key_different_reject_access( - self, customer_test_common_setup - ): - # covers user_org_and_api_key_org_different = true - num_customers = 3 - setup_dict = customer_test_common_setup( - num_customers=num_customers, - auth_method="both", - user_org_and_api_key_org_different=True, - ) - - payload = {} - response = setup_dict["client"].get(reverse("customer-list"), payload) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - @pytest.fixture def insert_customer_payload(): @@ -183,27 +167,6 @@ def test_session_auth_can_create_customer_nonempty_before( assert len(response.data) > 0 assert len(get_customers_in_org(setup_dict["org"])) == num_customers + 1 - def test_user_org_and_api_key_different_reject_creation( - self, customer_test_common_setup, insert_customer_payload, get_customers_in_org - ): - # covers user_org_and_api_key_org_different = True - num_customers = 3 - setup_dict = customer_test_common_setup( - num_customers=num_customers, - auth_method="both", - user_org_and_api_key_org_different=True, - ) - - response = setup_dict["client"].post( - reverse("customer-list"), - data=json.dumps(insert_customer_payload, cls=DjangoJSONEncoder), - content_type="application/json", - ) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - assert len(get_customers_in_org(setup_dict["org"])) == num_customers - assert len(get_customers_in_org(setup_dict["org2"])) == num_customers - def test_customer_id_already_exists_within_org_reject_creation( self, customer_test_common_setup, insert_customer_payload, get_customers_in_org ): diff --git a/backend/metering_billing/tests/test_draft_invoices.py b/backend/metering_billing/tests/test_draft_invoices.py index d154e1aa9..2583f22d8 100644 --- a/backend/metering_billing/tests/test_draft_invoices.py +++ b/backend/metering_billing/tests/test_draft_invoices.py @@ -127,7 +127,7 @@ def do_draft_invoice_test_common_setup(*, auth_method): @pytest.mark.django_db(transaction=True) class TestGenerateInvoice: def test_generate_invoice(self, draft_invoice_test_common_setup): - setup_dict = draft_invoice_test_common_setup(auth_method="session_auth") + setup_dict = draft_invoice_test_common_setup(auth_method="api_key") active_subscriptions = Subscription.objects.active().filter( organization=setup_dict["org"], @@ -140,7 +140,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup): ).count() payload = {"customer_id": setup_dict["customer"].customer_id} response = setup_dict["client"].get(reverse("draft_invoice"), payload) - + print(response.data) assert response.status_code == status.HTTP_200_OK after_active_subscriptions = Subscription.objects.active().filter( organization=setup_dict["org"], @@ -156,7 +156,7 @@ def test_generate_invoice(self, draft_invoice_test_common_setup): def test_generate_invoice_with_price_adjustments( self, draft_invoice_test_common_setup ): - setup_dict = draft_invoice_test_common_setup(auth_method="session_auth") + setup_dict = draft_invoice_test_common_setup(auth_method="api_key") active_subscriptions = Subscription.objects.active().filter( organization=setup_dict["org"], diff --git a/backend/metering_billing/tests/test_integrations.py b/backend/metering_billing/tests/test_integrations.py index 7b8533e84..e89ee0e2c 100644 --- a/backend/metering_billing/tests/test_integrations.py +++ b/backend/metering_billing/tests/test_integrations.py @@ -228,7 +228,7 @@ def test_stripe_end_to_end(self, integration_test_common_setup): ) # now lets test out the replace now - Subscription.objects.filter( + SubscriptionRecord.objects.filter( organization=setup_dict["org"], start_date__gte=now_utc(), billing_plan=setup_dict["plan"].display_version, @@ -237,12 +237,10 @@ def test_stripe_end_to_end(self, integration_test_common_setup): time.sleep(10) stripe_sub = stripe.Subscription.retrieve(stripe_sub.id) assert ( - Subscription.objects.active() - .filter( + SubscriptionRecord.objects.filter( organization=setup_dict["org"], billing_plan=setup_dict["plan"].display_version, - ) - .count() + ).count() == 1 ) assert stripe_sub.status == "canceled" diff --git a/backend/metering_billing/tests/test_subscription.py b/backend/metering_billing/tests/test_subscription.py index 891159024..e6603bcb4 100644 --- a/backend/metering_billing/tests/test_subscription.py +++ b/backend/metering_billing/tests/test_subscription.py @@ -202,26 +202,6 @@ def test_session_auth_can_create_subscription_nonempty_before( == num_subscription_records_before + 1 ) - def test_user_org_and_api_key_different_reject_creation( - self, subscription_test_common_setup, get_subscriptions_in_org - ): - # covers user_org_and_api_key_org_different = True - num_subscriptions = 1 - setup_dict = subscription_test_common_setup( - num_subscriptions=num_subscriptions, - auth_method="both", - user_org_and_api_key_org_different=True, - ) - - response = setup_dict["client"].post( - reverse("subscription-plans"), - data=json.dumps(setup_dict["payload"], cls=DjangoJSONEncoder), - content_type="application/json", - ) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - assert len(get_subscriptions_in_org(setup_dict["org"])) == num_subscriptions - @pytest.mark.django_db(transaction=True) class TestUpdateSub: diff --git a/backend/metering_billing/views/model_views.py b/backend/metering_billing/views/model_views.py index 3f11a9434..ab4e742ee 100644 --- a/backend/metering_billing/views/model_views.py +++ b/backend/metering_billing/views/model_views.py @@ -140,17 +140,17 @@ class APITokenViewSet( lookup_field = "prefix" def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return APIToken.objects.filter(organization=organization) def get_serializer_context(self): context = super(APITokenViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context def perform_create(self, serializer): - organization = parse_organization(self.request) + organization = self.request.organization api_key, key = serializer.save(organization=organization) return api_key, key @@ -236,18 +236,18 @@ class WebhookViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): } def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return WebhookEndpoint.objects.filter(organization=organization) def get_serializer_context(self): context = super(WebhookViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context def perform_create(self, serializer): try: - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) except ValueError as e: raise ServerError(e) except IntegrityError as e: @@ -269,7 +269,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -307,7 +307,7 @@ class EventViewSet( def get_queryset(self): now = now_utc() - organization = parse_organization(self.request) + organization = self.request.organization return ( super() .get_queryset() @@ -316,7 +316,7 @@ def get_queryset(self): def get_serializer_context(self): context = super(EventViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -331,17 +331,17 @@ class UserViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): http_method_names = ["get", "post", "head"] def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return User.objects.filter(organization=organization) def get_serializer_context(self): context = super(UserViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) class CustomerViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): @@ -353,23 +353,10 @@ class CustomerViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): http_method_names = ["get", "post", "head", "patch"] def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization qs = Customer.objects.filter(organization=organization) if self.action == "retrieve": - qs = qs.prefetch_related( - Prefetch( - "subscription_records", - queryset=SubscriptionRecord.objects.filter( - organization=organization, - status=SUBSCRIPTION_STATUS.ACTIVE, - ), - ), - Prefetch( - "subscription_records__billing_plan", - queryset=PlanVersion.objects.filter(organization=organization), - to_attr="billing_plans", - ), - ) + qs = qs.prefetch_related("subscriptions", "invoices") return qs def get_serializer_class(self): @@ -381,7 +368,7 @@ def get_serializer_class(self): def perform_create(self, serializer): try: - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) except IntegrityError as e: cause = e.__cause__ if "unique_email" in str(cause): @@ -392,24 +379,7 @@ def perform_create(self, serializer): def get_serializer_context(self): context = super(CustomerViewSet, self).get_serializer_context() - organization = parse_organization(self.request) - context.update({"organization": organization}) - if self.action == "retrieve": - customer = self.get_object() - total_amount_due = customer.get_outstanding_revenue() - next_amount_due = customer.get_active_sub_drafts_revenue() - invoices = Invoice.objects.filter( - ~Q(payment_status=INVOICE_STATUS.DRAFT), - organization=organization, - customer=customer, - ).order_by("-issue_date") - context.update( - { - "total_amount_due": total_amount_due, - "invoices": invoices, - "next_amount_due": next_amount_due, - } - ) + context.update({"organization": self.request.organization}) return context def dispatch(self, request, *args, **kwargs): @@ -419,7 +389,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization or self.request.user.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -445,7 +415,7 @@ class MetricViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): } def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return Metric.objects.filter( organization=organization, status=METRIC_STATUS.ACTIVE ) @@ -457,7 +427,7 @@ def get_serializer_class(self): def get_serializer_context(self): context = super(MetricViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -468,7 +438,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -482,7 +452,7 @@ def dispatch(self, request, *args, **kwargs): def perform_create(self, serializer): try: - instance = serializer.save(organization=parse_organization(self.request)) + instance = serializer.save(organization=self.request.organization) except IntegrityError as e: cause = e.__cause__ if "unique_org_metric_id" in str(cause): @@ -511,12 +481,12 @@ class FeatureViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): } def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return Feature.objects.filter(organization=organization) def get_serializer_context(self): context = super(FeatureViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -527,7 +497,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -540,7 +510,7 @@ def dispatch(self, request, *args, **kwargs): return response def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) class PlanVersionViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): @@ -566,7 +536,7 @@ def get_serializer_class(self): return PlanVersionSerializer def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization qs = PlanVersion.objects.filter( organization=organization, ) @@ -574,7 +544,7 @@ def get_queryset(self): def get_serializer_context(self): context = super(PlanVersionViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization if self.request.user.is_authenticated: user = self.request.user else: @@ -589,7 +559,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -607,7 +577,7 @@ def perform_create(self, serializer): else: user = None instance = serializer.save( - organization=parse_organization(self.request), created_by=user + organization=self.request.organization, created_by=user ) # if user: # action.send( @@ -655,7 +625,7 @@ class PlanViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): } def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization qs = Plan.objects.filter(organization=organization, status=PLAN_STATUS.ACTIVE) if self.action == "retrieve": qs = qs.prefetch_related( @@ -683,7 +653,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -704,7 +674,7 @@ def get_serializer_class(self): def get_serializer_context(self): context = super(PlanViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization if self.request.user.is_authenticated: user = self.request.user else: @@ -718,7 +688,7 @@ def perform_create(self, serializer): else: user = None instance = serializer.save( - organization=parse_organization(self.request), created_by=user + organization=self.request.organization, created_by=user ) # if user: # action.send( @@ -751,7 +721,7 @@ class SubscriptionViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): def get_serializer_context(self): context = super(SubscriptionViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -766,7 +736,7 @@ def get_serializer_class(self): return SubscriptionSerializer def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization # need for: list, update_plans, cancel_plans if self.action == "list": args = [] @@ -813,7 +783,7 @@ def get_queryset(self): args.append(Q(customer=serializer.validated_data["customer"])) if serializer.validated_data.get("plan_id"): args.append(Q(billing_plan__plan=serializer.validated_data["plan"])) - organization = parse_organization(self.request) + organization = self.request.organization args.append(Q(organization=organization)) qs = ( SubscriptionRecord.objects.filter(*args) @@ -869,7 +839,7 @@ def update(self, request, *args, **kwargs): @action(detail=False, methods=["post"]) def plans(self, request, *args, **kwargs): # run checks to make sure it's valid - organization = parse_organization(self.request) + organization = self.request.organization serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) plan_name = serializer.validated_data["billing_plan"].plan.plan_name @@ -955,7 +925,7 @@ def plans(self, request, *args, **kwargs): ) def cancel_plans(self, request, *args, **kwargs): qs = self.get_queryset() - organization = parse_organization(self.request) + organization = self.request.organization serializer = self.get_serializer(data=self.request.query_params) serializer.is_valid(raise_exception=True) flat_fee_behavior = serializer.validated_data["flat_fee_behavior"] @@ -1000,7 +970,7 @@ def cancel_plans(self, request, *args, **kwargs): ) def update_plans(self, request, *args, **kwargs): qs = self.get_queryset() - organization = parse_organization(self.request) + organization = self.request.organization original_qs = list(copy.copy(qs).values_list("pk", flat=True)) if qs.count() == 0: raise NotFoundException("Subscription matching the given filters not found") @@ -1090,7 +1060,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -1127,7 +1097,7 @@ class InvoiceViewSet(PermissionPolicyMixin, viewsets.ModelViewSet): def get_queryset(self): args = [ ~Q(payment_status=INVOICE_STATUS.DRAFT), - Q(organization=parse_organization(self.request)), + Q(organization=self.request.organization), ] if self.action == "list": args = [] @@ -1150,7 +1120,7 @@ def get_serializer_class(self): def get_serializer_context(self): context = super(InvoiceViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1161,7 +1131,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -1205,11 +1175,11 @@ def get_serializer_class(self): return BacktestCreateSerializer def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return Backtest.objects.filter(organization=organization) def perform_create(self, serializer): - backtest_obj = serializer.save(organization=parse_organization(self.request)) + backtest_obj = serializer.save(organization=self.request.organization) bt_id = backtest_obj.backtest_id run_backtest.delay(bt_id) @@ -1220,7 +1190,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -1234,7 +1204,7 @@ def dispatch(self, request, *args, **kwargs): def get_serializer_context(self): context = super(BacktestViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1253,11 +1223,11 @@ class ProductViewSet(viewsets.ModelViewSet): ] def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return Product.objects.filter(organization=organization) def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) def dispatch(self, request, *args, **kwargs): response = super().dispatch(request, *args, **kwargs) @@ -1266,7 +1236,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -1280,7 +1250,7 @@ def dispatch(self, request, *args, **kwargs): def get_serializer_context(self): context = super(ProductViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1304,7 +1274,7 @@ class ActionViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): ] def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return ( super() .get_queryset() @@ -1329,7 +1299,7 @@ class ExternalPlanLinkViewSet(viewsets.ModelViewSet): http_method_names = ["post", "head", "delete"] def get_queryset(self): - filter_kwargs = {"organization": parse_organization(self.request)} + filter_kwargs = {"organization": self.request.organization} source = self.request.query_params.get("source") if source: filter_kwargs["source"] = source @@ -1342,7 +1312,7 @@ def dispatch(self, request, *args, **kwargs): username = self.request.user.username except: username = None - organization = parse_organization(self.request) + organization = self.request.organization posthog.capture( POSTHOG_PERSON if POSTHOG_PERSON @@ -1355,11 +1325,11 @@ def dispatch(self, request, *args, **kwargs): return response def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) def get_serializer_context(self): context = super(ExternalPlanLinkViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1388,7 +1358,7 @@ class OrganizationSettingViewSet(viewsets.ModelViewSet): lookup_field = "setting_id" def get_queryset(self): - filter_kwargs = {"organization": parse_organization(self.request)} + filter_kwargs = {"organization": self.request.organization} setting_name = self.request.query_params.get("setting_name") if setting_name: filter_kwargs["setting_name"] = setting_name @@ -1410,15 +1380,15 @@ class PricingUnitViewSet( http_method_names = ["get", "post", "head"] def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return PricingUnit.objects.filter(organization=organization) def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) def get_serializer_context(self): context = super(PricingUnitViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1442,7 +1412,7 @@ class OrganizationViewSet( lookup_field = "organization_id" def get_queryset(self): - organization = parse_organization(self.request) + organization = self.request.organization return Organization.objects.filter(pk=organization.pk) def get_object(self): @@ -1457,7 +1427,7 @@ def get_serializer_class(self): def get_serializer_context(self): context = super(OrganizationViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context @@ -1484,7 +1454,7 @@ class CustomerBalanceAdjustmentViewSet( lookup_field = "adjustment_id" def get_queryset(self): - filter_kwargs = {"organization": parse_organization(self.request)} + filter_kwargs = {"organization": self.request.organization} customer_id = self.request.query_params.get("customer_id") if customer_id: filter_kwargs["customer__customer_id"] = customer_id @@ -1492,12 +1462,12 @@ def get_queryset(self): def get_serializer_context(self): context = super(CustomerBalanceAdjustmentViewSet, self).get_serializer_context() - organization = parse_organization(self.request) + organization = self.request.organization context.update({"organization": organization}) return context def perform_create(self, serializer): - serializer.save(organization=parse_organization(self.request)) + serializer.save(organization=self.request.organization) @extend_schema( parameters=[ diff --git a/backend/metering_billing/views/organization_views.py b/backend/metering_billing/views/organization_views.py index 912c0484a..599551d7c 100644 --- a/backend/metering_billing/views/organization_views.py +++ b/backend/metering_billing/views/organization_views.py @@ -14,21 +14,6 @@ POSTHOG_PERSON = settings.POSTHOG_PERSON DEFAULT_FROM_EMAIL = settings.DEFAULT_FROM_EMAIL -# class OrganizationView(APIView): -# permission_classes = [IsAuthenticated] - -# def get(self, request, format=None): -# """ -# Get the current settings for the organization. -# """ -# organization = parse_organization(request) -# OrganizationSerializer(organization).data -# team_members = organization.org_users.all().values_list("email", flat=True) -# return Response( -# {"organization": organization.company_name, "team_members": team_members}, -# status=status.HTTP_200_OK, -# ) - class InviteView(APIView): permission_classes = [IsAuthenticated & ValidOrganization] @@ -36,7 +21,7 @@ class InviteView(APIView): def post(self, request, *args, **kwargs): email = request.data.get("email", None) user = request.user - organization = parse_organization(request) + organization = request.organization token_object, created = OrganizationInviteToken.objects.get_or_create( organization=organization, email=email, defaults={"user": user} diff --git a/backend/metering_billing/views/payment_provider_views.py b/backend/metering_billing/views/payment_provider_views.py index eba86100f..2f5aa806d 100644 --- a/backend/metering_billing/views/payment_provider_views.py +++ b/backend/metering_billing/views/payment_provider_views.py @@ -20,7 +20,7 @@ class PaymentProviderView(APIView): responses={200: SinglePaymentProviderSerializer(many=True)}, ) def get(self, request, format=None): - organization = parse_organization(request) + organization = request.organization response = [] for payment_processor_name, pp_obj in PAYMENT_PROVIDER_MAP.items(): pp_response = { @@ -38,7 +38,7 @@ def get(self, request, format=None): responses={200: PaymentProviderPostResponseSerializer}, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization # parse outer level request serializer = PaymentProviderPostRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) diff --git a/backend/metering_billing/views/views.py b/backend/metering_billing/views/views.py index ff5d9d671..349a747d1 100644 --- a/backend/metering_billing/views/views.py +++ b/backend/metering_billing/views/views.py @@ -53,7 +53,7 @@ def get(self, request, format=None): """ Returns the revenue for an organization in a given time period. """ - organization = parse_organization(request) + organization = request.organization serializer = PeriodComparisonRequestSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) p1_start, p1_end, p2_start, p2_end = [ @@ -135,7 +135,8 @@ def get(self, request, format=None): """ Returns the revenue for an organization in a given time period. """ - organization = parse_organization(request) + organization = request.organization + organization = request.organization serializer = CostAnalysisRequestSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) start_date, end_date, customer_id = [ @@ -244,7 +245,7 @@ class PeriodSubscriptionsView(APIView): responses={200: PeriodSubscriptionsResponseSerializer}, ) def get(self, request, format=None): - organization = parse_organization(request) + organization = request.organization serializer = PeriodComparisonRequestSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) p1_start, p1_end, p2_start, p2_end = [ @@ -297,7 +298,7 @@ def get(self, request, format=None): """ Return current usage for a customer during a given billing period. """ - organization = parse_organization(request) + organization = request.organization serializer = PeriodMetricUsageRequestSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) q_start, q_end, top_n = [ @@ -404,7 +405,7 @@ def get(self, request, format=None): """ Revokes the current API key and returns a new one. """ - organization = parse_organization(request) + organization = request.organization tk = APIToken.objects.filter(organization=organization).first() if tk: cache.delete(tk.prefix) @@ -433,7 +434,7 @@ def get(self, request, format=None): """ Get the current settings for the organization. """ - organization = parse_organization(request) + organization = request.organization return Response( {"organization": organization.company_name}, status=status.HTTP_200_OK ) @@ -449,7 +450,7 @@ def get(self, request, format=None): """ Get the current settings for the organization. """ - organization = parse_organization(request) + organization = request.organization customers = Customer.objects.filter(organization=organization).prefetch_related( Prefetch( "subscription_records", @@ -479,17 +480,15 @@ def get(self, request, format=None): """ Return current usage for a customer during a given billing period. """ - organization = parse_organization(request) + organization = request.organization customers = Customer.objects.filter(organization=organization) cust = [] for customer in customers: total_amount_due = customer.get_outstanding_revenue() - next_amount_due = customer.get_active_sub_drafts_revenue() serializer = CustomerWithRevenueSerializer( customer, context={ "total_amount_due": total_amount_due, - "next_amount_due": next_amount_due, }, ) cust.append(serializer.data) @@ -513,7 +512,7 @@ def get(self, request, format=None): """ Pagination-enabled endpoint for retrieving an organization's event stream. """ - organization = parse_organization(request) + organization = request.organization serializer = DraftInvoiceRequestSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) try: @@ -807,7 +806,7 @@ class ImportCustomersView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization source = request.data["source"] if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]: raise ExternalConnectionInvalid(f"Invalid source: {source}") @@ -853,7 +852,7 @@ class ImportPaymentObjectsView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization source = request.data["source"] if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]: raise ExternalConnectionInvalid(f"Invalid source: {source}") @@ -901,7 +900,7 @@ class TransferSubscriptionsView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization source = request.data["source"] if source not in [choice[0] for choice in PAYMENT_PROVIDERS.choices]: raise ExternalConnectionInvalid(f"Invalid source: {source}") @@ -943,7 +942,7 @@ class ExperimentalToActiveView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization serializer = ExperimentalToActiveRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) billing_plan = serializer.validated_data["version_id"] @@ -994,7 +993,7 @@ class PlansByNumCustomersView(APIView): }, ) def get(self, request, format=None): - organization = parse_organization(request) + organization = request.organization plans = ( SubscriptionRecord.objects.filter( organization=organization, status=SUBSCRIPTION_STATUS.ACTIVE @@ -1048,7 +1047,7 @@ class CustomerBatchCreateView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization serializer = CustomerSerializer( data=request.data["customers"], many=True, @@ -1140,7 +1139,7 @@ class ConfirmIdemsReceivedView(APIView): }, ) def post(self, request, format=None): - organization = parse_organization(request) + organization = request.organization if request.data.get("idempotency_ids") is None: return Response( { diff --git a/frontend/src/components/Customers/CustomerInfo.tsx b/frontend/src/components/Customers/CustomerInfo.tsx index f81dd4278..c58cb8b9c 100644 --- a/frontend/src/components/Customers/CustomerInfo.tsx +++ b/frontend/src/components/Customers/CustomerInfo.tsx @@ -1,7 +1,9 @@ // @ts-ignore import React, { FC, useEffect } from "react"; import { Column } from "@ant-design/plots"; +import { useQueryClient } from "react-query"; import { Select, Tag } from "antd"; +import { DraftInvoiceType, LineItem } from "../../types/invoice-type"; // @ts-ignore import dayjs from "dayjs"; import LoadingSpinner from "../LoadingSpinner"; @@ -14,6 +16,12 @@ const CustomerInfoView: FC = ({ data, cost_data, onDateChange }) => { [] ); const [isEdit, setIsEdit] = React.useState(false); + const queryClient = useQueryClient(); + + let invoiceData: DraftInvoiceType | undefined = queryClient.getQueryData([ + "draft_invoice", + data.customer_id, + ]); const updateCustomer = useMutation( (obj: { customer_id: string; default_currency_code: string }) => @@ -155,7 +163,7 @@ const CustomerInfoView: FC = ({ data, cost_data, onDateChange }) => {

Amount Due On Next Invoice: {data?.default_currency?.symbol} - {data.next_amount_due.toFixed(2)} + {invoiceData?.invoices[0].cost_due.toFixed(2)}