Skip to content

Commit

Permalink
chore: move cohort calculation to longrunning queue (PostHog#28116)
Browse files Browse the repository at this point in the history
  • Loading branch information
aspicer authored Feb 1, 2025
1 parent 1253d96 commit e66bcab
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
8 changes: 4 additions & 4 deletions posthog/api/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
calculate_cohort_from_list,
insert_cohort_from_feature_flag,
insert_cohort_from_insight_filter,
update_cohort,
increment_version_and_enqueue_calculate_cohort,
insert_cohort_from_query,
)
from posthog.utils import format_query_params_absolute_url
Expand Down Expand Up @@ -161,7 +161,7 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort:
elif cohort.query is not None:
raise ValidationError("Cannot create a dynamic cohort with a query. Set is_static to true.")
else:
update_cohort(cohort, initiating_user=request.user)
increment_version_and_enqueue_calculate_cohort(cohort, initiating_user=request.user)

report_user_action(request.user, "cohort created", cohort.get_analytics_metadata())
return cohort
Expand Down Expand Up @@ -274,9 +274,9 @@ def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any
if request.FILES.get("csv"):
self._calculate_static_by_csv(request.FILES["csv"], cohort)
else:
update_cohort(cohort, initiating_user=request.user)
increment_version_and_enqueue_calculate_cohort(cohort, initiating_user=request.user)
else:
update_cohort(cohort, initiating_user=request.user)
increment_version_and_enqueue_calculate_cohort(cohort, initiating_user=request.user)

report_user_action(
request.user,
Expand Down
9 changes: 5 additions & 4 deletions posthog/tasks/calculate_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from posthog.models.cohort import get_and_update_pending_version
from posthog.models.cohort.util import clear_stale_cohortpeople, get_static_cohort_size
from posthog.models.user import User
from posthog.tasks.utils import CeleryQueue

COHORT_RECALCULATIONS_BACKLOG_GAUGE = Gauge(
"cohort_recalculations_backlog",
Expand All @@ -35,7 +36,7 @@
MAX_AGE_MINUTES = 15


def calculate_cohorts(parallel_count: int) -> None:
def enqueue_cohorts_to_calculate(parallel_count: int) -> None:
"""
Calculates maximum N cohorts in parallel.
Expand Down Expand Up @@ -68,7 +69,7 @@ def calculate_cohorts(parallel_count: int) -> None:
.order_by(F("last_calculation").asc(nulls_first=True))[0:parallel_count]
):
cohort = Cohort.objects.filter(pk=cohort.pk).get()
update_cohort(cohort, initiating_user=None)
increment_version_and_enqueue_calculate_cohort(cohort, initiating_user=None)

# update gauge
backlog = (
Expand All @@ -84,7 +85,7 @@ def calculate_cohorts(parallel_count: int) -> None:
COHORT_RECALCULATIONS_BACKLOG_GAUGE.set(backlog)


def update_cohort(cohort: Cohort, *, initiating_user: Optional[User]) -> None:
def increment_version_and_enqueue_calculate_cohort(cohort: Cohort, *, initiating_user: Optional[User]) -> None:
pending_version = get_and_update_pending_version(cohort)
calculate_cohort_ch.delay(cohort.id, pending_version, initiating_user.id if initiating_user else None)

Expand All @@ -95,7 +96,7 @@ def clear_stale_cohort(cohort_id: int, before_version: int) -> None:
clear_stale_cohortpeople(cohort, before_version)


@shared_task(ignore_result=True, max_retries=2)
@shared_task(ignore_result=True, max_retries=2, queue=CeleryQueue.LONG_RUNNING.value)
def calculate_cohort_ch(cohort_id: int, pending_version: int, initiating_user_id: Optional[int] = None) -> None:
cohort: Cohort = Cohort.objects.get(pk=cohort_id)

Expand Down
6 changes: 3 additions & 3 deletions posthog/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,11 @@ def monitoring_check_clickhouse_schema_drift() -> None:
check_clickhouse_schema_drift()


@shared_task(ignore_result=True, queue=CeleryQueue.LONG_RUNNING.value)
@shared_task(ignore_result=True)
def calculate_cohort(parallel_count: int) -> None:
from posthog.tasks.calculate_cohort import calculate_cohorts
from posthog.tasks.calculate_cohort import enqueue_cohorts_to_calculate

calculate_cohorts(parallel_count)
enqueue_cohorts_to_calculate(parallel_count)


class Polling:
Expand Down
10 changes: 5 additions & 5 deletions posthog/tasks/test/test_calculate_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from posthog.models.cohort import Cohort
from posthog.models.person import Person
from posthog.tasks.calculate_cohort import calculate_cohort_from_list, calculate_cohorts, MAX_AGE_MINUTES
from posthog.tasks.calculate_cohort import calculate_cohort_from_list, enqueue_cohorts_to_calculate, MAX_AGE_MINUTES
from posthog.test.base import APIBaseTest


Expand Down Expand Up @@ -67,8 +67,8 @@ def test_create_trends_cohort(self, _calculate_cohort_from_list: MagicMock) -> N
people = Person.objects.filter(cohort__id=cohort.pk)
self.assertEqual(people.count(), 1)

@patch("posthog.tasks.calculate_cohort.update_cohort")
def test_exponential_backoff(self, patch_update_cohort: MagicMock) -> None:
@patch("posthog.tasks.calculate_cohort.increment_version_and_enqueue_calculate_cohort")
def test_exponential_backoff(self, patch_increment_version_and_enqueue_calculate_cohort: MagicMock) -> None:
# Exponential backoff
Cohort.objects.create(
last_calculation=timezone.now() - relativedelta(minutes=MAX_AGE_MINUTES + 1),
Expand All @@ -88,7 +88,7 @@ def test_exponential_backoff(self, patch_update_cohort: MagicMock) -> None:
errors_calculating=1,
team_id=self.team.pk,
)
calculate_cohorts(5)
self.assertEqual(patch_update_cohort.call_count, 2)
enqueue_cohorts_to_calculate(5)
self.assertEqual(patch_increment_version_and_enqueue_calculate_cohort.call_count, 2)

return TestCalculateCohort

0 comments on commit e66bcab

Please sign in to comment.