Skip to content

Commit

Permalink
fix: Check for active subscription to determine billing version (#12330)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjackwhite authored Oct 19, 2022
1 parent 8f57182 commit 683a4aa
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 25 deletions.
29 changes: 18 additions & 11 deletions ee/api/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def list(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response:

# If on Cloud and we have the property billing - return 404 as we always use legacy billing it it exists
if hasattr(org, "billing"):
if org.billing: # type: ignore
raise NotFound("Billing V2 is not enabled for this organization")
if org.billing.stripe_subscription_id: # type: ignore
raise NotFound("Billing V1 is active for this organization")

response: Dict[str, Any] = {}

if license:
if org and license and license.is_v2_license:
response["license"] = {"plan": license.plan}
billing_service_token = build_billing_token(license, str(org.id))

Expand Down Expand Up @@ -164,11 +164,12 @@ def list(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Response:
# The default response is used if there is no subscription
if not response.get("products"):
products = self._get_products()
calculated_usage = get_cached_current_usage(org)
calculated_usage = get_cached_current_usage(org) if org else None

for product in products:
if product["type"] in calculated_usage:
product["current_usage"] = calculated_usage[product["type"]]
if calculated_usage is not None:
for product in products:
if product["type"] in calculated_usage:
product["current_usage"] = calculated_usage[product["type"]]
response["products"] = products

return Response(response)
Expand All @@ -178,7 +179,7 @@ def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:
license = License.objects.first_valid()
if not license:
raise Exception("There is no license configured for this instance yet.")
org = self._get_org()
org = self._get_org_required()

billing_service_token = build_billing_token(license, str(org.id))

Expand All @@ -201,7 +202,7 @@ def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:
@action(methods=["GET"], detail=False)
def activation(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = License.objects.first_valid()
organization = self._get_org()
organization = self._get_org_required()

redirect_uri = f"{settings.SITE_URL or request.headers.get('Host')}/organization/billing"
url = f"{BILLING_SERVICE_URL}/activation?redirect_uri={redirect_uri}&organization_name={organization.name}"
Expand All @@ -221,7 +222,8 @@ def license(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
"A valid license key already exists. This must be removed before a new one can be added."
)

organization = self._get_org()
organization = self._get_org_required()

serializer = LicenseKeySerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
Expand All @@ -243,9 +245,14 @@ def license(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
self._update_license_details(license, data["license"])
return Response({"success": True})

def _get_org(self) -> Organization:
def _get_org(self) -> Optional[Organization]:
org = None if self.request.user.is_anonymous else self.request.user.organization

return org

def _get_org_required(self) -> Organization:
org = self._get_org()

if not org:
raise Exception("You cannot setup billing without an organization configured.")

Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_actions_does_not_nplus1(self):
action = Action.objects.create(team=self.team, name=f"action_{i}")
action.tagged_items.create(tag=tag)

# django_session + user + team + organizationmembership + organization + action + taggeditem + actionstep + cloud license check
with self.assertNumQueries(9):
# django_session + user + team + organizationmembership + organization + action + taggeditem + actionstep
with self.assertNumQueries(8):
response = self.client.get(f"/api/projects/{self.team.id}/actions")
self.assertEqual(response.json()["results"][0]["tags"][0], "tag")
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down
4 changes: 4 additions & 0 deletions ee/models/license.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class License(models.Model):
def available_features(self) -> List[AvailableFeature]:
return self.PLANS.get(self.plan, [])

@property
def is_v2_license(self) -> bool:
return self.key and len(self.key.split("::")) == 2

__repr__ = sane_repr("key", "plan", "valid_until")


Expand Down
2 changes: 1 addition & 1 deletion ee/tasks/send_license_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def send_license_usage():
return

# New type of license key for billing-v2
if "::" in license.key:
if license.is_v2_license:
return

try:
Expand Down
22 changes: 15 additions & 7 deletions posthog/api/test/test_preflight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils import timezone
from rest_framework import status

from posthog.cloud_utils import TEST_clear_cloud_cache
from posthog.models.instance_setting import set_instance_setting
from posthog.models.organization import Organization, OrganizationInvite
from posthog.test.base import APIBaseTest
Expand Down Expand Up @@ -60,6 +61,11 @@ def preflight_authenticated_dict(self, options={}):

return self.preflight_dict(preflight)

def settings_with_cloud_cache_reset(self, **kwargs):
TEST_clear_cloud_cache()

return self.settings(**kwargs)

def test_preflight_request_unauthenticated(self):
"""
For security purposes, the information contained in an unauthenticated preflight request is minimal.
Expand All @@ -72,7 +78,7 @@ def test_preflight_request_unauthenticated(self):
self.assertEqual(response.json(), self.preflight_dict())

def test_preflight_request(self):
with self.settings(
with self.settings_with_cloud_cache_reset(
MULTI_TENANCY=False,
INSTANCE_PREFERENCES=self.instance_preferences(debug_queries=True),
OBJECT_STORAGE_ENABLED=False,
Expand All @@ -89,7 +95,7 @@ def test_preflight_request(self):
def test_preflight_request_with_object_storage_available(self, patched_s3_client):
patched_s3_client.head_bucket.return_value = True

with self.settings(
with self.settings_with_cloud_cache_reset(
MULTI_TENANCY=False,
INSTANCE_PREFERENCES=self.instance_preferences(debug_queries=True),
OBJECT_STORAGE_ENABLED=True,
Expand All @@ -109,7 +115,7 @@ def test_cloud_preflight_request_unauthenticated(self):

self.client.logout() # make sure it works anonymously

with self.settings(MULTI_TENANCY=True, OBJECT_STORAGE_ENABLED=False):
with self.settings_with_cloud_cache_reset(MULTI_TENANCY=True, OBJECT_STORAGE_ENABLED=False):
response = self.client.get("/_preflight/")
self.assertEqual(response.status_code, status.HTTP_200_OK)

Expand Down Expand Up @@ -154,7 +160,7 @@ def test_cloud_preflight_request(self):
def test_cloud_preflight_request_with_social_auth_providers(self):
set_instance_setting("EMAIL_HOST", "localhost")

with self.settings(
with self.settings_with_cloud_cache_reset(
SOCIAL_AUTH_GOOGLE_OAUTH2_KEY="test_key",
SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET="test_secret",
MULTI_TENANCY=True,
Expand Down Expand Up @@ -187,7 +193,7 @@ def test_cloud_preflight_request_with_social_auth_providers(self):
def test_demo(self):
self.client.logout() # make sure it works anonymously

with self.settings(DEMO=True, OBJECT_STORAGE_ENABLED=False):
with self.settings_with_cloud_cache_reset(DEMO=True, OBJECT_STORAGE_ENABLED=False):
response = self.client.get("/_preflight/")

self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down Expand Up @@ -223,7 +229,7 @@ def test_can_create_org_in_fresh_instance(self):
@pytest.mark.skip_on_multitenancy
def test_can_create_org_with_multi_org(self):
# First with no license
with self.settings(MULTI_ORG_ENABLED=True):
with self.settings_with_cloud_cache_reset(MULTI_ORG_ENABLED=True):
response = self.client.get("/_preflight/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["can_create_org"], False)
Expand All @@ -236,7 +242,7 @@ def test_can_create_org_with_multi_org(self):
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7)
)
with self.settings(MULTI_ORG_ENABLED=True):
with self.settings_with_cloud_cache_reset(MULTI_ORG_ENABLED=True):
response = self.client.get("/_preflight/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["can_create_org"], True)
Expand All @@ -252,6 +258,8 @@ def test_cloud_preflight_based_on_license(self):
key="key::123", plan="cloud", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7)
)

TEST_clear_cloud_cache()

response = self.client.get("/_preflight/")
assert response.status_code == status.HTTP_200_OK
assert response.json()["realm"] == "cloud"
Expand Down
8 changes: 7 additions & 1 deletion posthog/cloud_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def is_cloud():
global is_cloud_cached

if not settings.TEST and isinstance(is_cloud_cached, bool):
if isinstance(is_cloud_cached, bool):
return is_cloud_cached

try:
Expand All @@ -22,3 +22,9 @@ def is_cloud():
# TRICKY - The license table may not exist if a migration is running
except (ImportError, ProgrammingError):
return False


# NOTE: This is purely for testing purposes
def TEST_clear_cloud_cache():
global is_cloud_cached
is_cloud_cached = None
3 changes: 1 addition & 2 deletions posthog/test/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def test_trust_all_proxies(self):
class TestAutoProjectMiddleware(APIBaseTest):
# How many queries are made in the base app
# On Cloud there's an additional multi_tenancy_organizationbilling query
IS_CLOUD_QUERIES = 6 # Checks to is_cloud hit the DB in TEST mode
BASE_APP_NUM_QUERIES = IS_CLOUD_QUERIES + (39 if not settings.MULTI_TENANCY else 40)
BASE_APP_NUM_QUERIES = 39 if not settings.MULTI_TENANCY else 40

second_team: Team

Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env =
DEBUG=1
TEST=1
DJANGO_SETTINGS_MODULE = posthog.settings
addopts = -p no:warnings -p no:randomly --reuse-db
addopts = -p no:warnings --reuse-db

markers =
ee
Expand Down

0 comments on commit 683a4aa

Please sign in to comment.