diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py index 08cd72ce19037..33a68604c7ecf 100644 --- a/ee/billing/billing_manager.py +++ b/ee/billing/billing_manager.py @@ -1,3 +1,5 @@ +from django.conf import settings +from django.db.models import F from datetime import datetime, timedelta from enum import Enum from typing import Any, Optional, cast @@ -5,19 +7,17 @@ import jwt import requests import structlog -from django.conf import settings -from django.db.models import F from django.utils import timezone -from requests import JSONDecodeError -from rest_framework.exceptions import NotAuthenticated from sentry_sdk import capture_message +from requests import JSONDecodeError # type: ignore[attr-defined] +from rest_framework.exceptions import NotAuthenticated +from posthog.exceptions_capture import capture_exception from ee.billing.billing_types import BillingStatus from ee.billing.quota_limiting import set_org_usage_summary, update_org_billing_quotas from ee.models import License from ee.settings import BILLING_SERVICE_URL from posthog.cloud_utils import get_cached_instance_license -from posthog.exceptions_capture import capture_exception from posthog.models import Organization from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo from posthog.models.user import User diff --git a/posthog/api/test/test_exports.py b/posthog/api/test/test_exports.py index 4651ee2037082..bba55310e5397 100644 --- a/posthog/api/test/test_exports.py +++ b/posthog/api/test/test_exports.py @@ -1,7 +1,6 @@ -from datetime import datetime, timedelta from typing import Optional from unittest.mock import patch - +from datetime import datetime, timedelta import celery import requests.exceptions from boto3 import resource @@ -398,7 +397,7 @@ def requests_side_effect(*args, **kwargs): def raise_for_status(): if 400 <= response.status_code < 600: - raise requests.exceptions.HTTPError(response=response) # type: ignore[arg-type] + raise requests.exceptions.HTTPError(response=response) response.raise_for_status = raise_for_status # type: ignore[attr-defined] return response @@ -502,7 +501,7 @@ def requests_side_effect(*args, **kwargs): def raise_for_status(): if 400 <= response.status_code < 600: - raise requests.exceptions.HTTPError(response=response) # type: ignore[arg-type] + raise requests.exceptions.HTTPError(response=response) response.raise_for_status = raise_for_status # type: ignore[attr-defined] return response diff --git a/posthog/api/utils.py b/posthog/api/utils.py index 7ac30f8eb82d3..52e4cf925cc48 100644 --- a/posthog/api/utils.py +++ b/posthog/api/utils.py @@ -1,28 +1,30 @@ import json +from django.http import HttpRequest +from rest_framework.decorators import action as drf_action +from functools import wraps +from posthog.api.documentation import extend_schema import re import socket import urllib.parse from enum import Enum, auto -from functools import wraps from ipaddress import ip_address -from typing import Any, Literal, Optional, Union from urllib.parse import urlparse + +from requests.adapters import HTTPAdapter +from typing import Literal, Optional, Union, Any + +from rest_framework.fields import Field +from urllib3 import HTTPSConnectionPool, HTTPConnectionPool, PoolManager from uuid import UUID import structlog from django.core.exceptions import RequestDataTooBig from django.db.models import QuerySet -from django.http import HttpRequest from prometheus_client import Counter -from requests.adapters import HTTPAdapter -from rest_framework import request, serializers, status -from rest_framework.decorators import action as drf_action +from rest_framework import request, status, serializers from rest_framework.exceptions import ValidationError -from rest_framework.fields import Field from statshog.defaults.django import statsd -from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, PoolManager -from posthog.api.documentation import extend_schema from posthog.constants import EventDefinitionType from posthog.exceptions import ( RequestParsingError, @@ -363,13 +365,13 @@ def raise_if_connected_to_private_ip(conn): class PublicIPOnlyHTTPConnectionPool(HTTPConnectionPool): def _validate_conn(self, conn): raise_if_connected_to_private_ip(conn) - super()._validate_conn(conn) # type: ignore[misc] + super()._validate_conn(conn) class PublicIPOnlyHTTPSConnectionPool(HTTPSConnectionPool): def _validate_conn(self, conn): raise_if_connected_to_private_ip(conn) - super()._validate_conn(conn) # type: ignore[misc] + super()._validate_conn(conn) class PublicIPOnlyHttpAdapter(HTTPAdapter): @@ -388,7 +390,7 @@ def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): block=block, **pool_kwargs, ) - self.poolmanager.pool_classes_by_scheme = { # type: ignore[attr-defined] + self.poolmanager.pool_classes_by_scheme = { "http": PublicIPOnlyHTTPConnectionPool, "https": PublicIPOnlyHTTPSConnectionPool, } diff --git a/posthog/tasks/exports/test/test_csv_exporter.py b/posthog/tasks/exports/test/test_csv_exporter.py index caa4f8689b583..8d3cf5f182f27 100644 --- a/posthog/tasks/exports/test/test_csv_exporter.py +++ b/posthog/tasks/exports/test/test_csv_exporter.py @@ -1,20 +1,19 @@ from datetime import datetime -from io import BytesIO from typing import Any, Optional from unittest import mock -from unittest.mock import ANY, MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, patch, ANY +from dateutil.relativedelta import relativedelta +from freezegun import freeze_time +from openpyxl import load_workbook +from io import BytesIO import pytest from boto3 import resource from botocore.client import Config -from dateutil.relativedelta import relativedelta from django.test import override_settings from django.utils.timezone import now -from freezegun import freeze_time -from openpyxl import load_workbook from requests.exceptions import HTTPError -from posthog.hogql.constants import CSV_EXPORT_BREAKDOWN_LIMIT_INITIAL from posthog.models import ExportedAsset from posthog.models.utils import UUIDT from posthog.settings import ( @@ -28,10 +27,11 @@ from posthog.tasks.exports import csv_exporter from posthog.tasks.exports.csv_exporter import ( UnexpectedEmptyJsonResponse, - _convert_response_to_csv_data, add_query_params, + _convert_response_to_csv_data, ) -from posthog.test.base import APIBaseTest, _create_event, _create_person, flush_persons_and_events +from posthog.hogql.constants import CSV_EXPORT_BREAKDOWN_LIMIT_INITIAL +from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events, _create_person from posthog.test.test_journeys import journeys_for from posthog.utils import absolute_uri @@ -330,7 +330,7 @@ def test_failing_export_api_is_reported(self, _mock_logger: MagicMock) -> None: def test_failing_export_api_is_reported_query_size_exceeded(self, _mock_logger: MagicMock) -> None: with patch("posthog.tasks.exports.csv_exporter.make_api_call") as patched_make_api_call: exported_asset = self._create_asset() - mock_error = HTTPError("Query size exceeded") # type: ignore[call-arg] + mock_error = HTTPError("Query size exceeded") mock_error.response = Mock() mock_error.response.text = "Query size exceeded" patched_make_api_call.side_effect = mock_error diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 4344afdf632e3..b9da1d015c02f 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -3,33 +3,32 @@ import json import re -import posthoganalytics from django.db import close_old_connections +import posthoganalytics from temporalio import activity, exceptions, workflow from temporalio.common import RetryPolicy + # TODO: remove dependency from posthog.temporal.common.base import PostHogWorkflow -from posthog.temporal.common.logger import bind_temporal_worker_logger_sync from posthog.temporal.data_imports.workflow_activities.check_billing_limits import ( CheckBillingLimitsActivityInputs, check_billing_limits_activity, ) -from posthog.temporal.data_imports.workflow_activities.create_job_model import ( - CreateExternalDataJobModelActivityInputs, - create_external_data_job_model_activity, -) -from posthog.temporal.data_imports.workflow_activities.import_data_sync import ( - ImportDataActivityInputs, - import_data_activity_sync, -) +from posthog.temporal.data_imports.workflow_activities.import_data_sync import import_data_activity_sync from posthog.temporal.data_imports.workflow_activities.sync_new_schemas import ( SyncNewSchemasActivityInputs, sync_new_schemas_activity, ) from posthog.temporal.utils import ExternalDataWorkflowInputs +from posthog.temporal.data_imports.workflow_activities.create_job_model import ( + CreateExternalDataJobModelActivityInputs, + create_external_data_job_model_activity, +) +from posthog.temporal.data_imports.workflow_activities.import_data_sync import ImportDataActivityInputs from posthog.utils import get_machine_id from posthog.warehouse.data_load.source_templates import create_warehouse_templates_for_source + from posthog.warehouse.external_data_source.jobs import ( update_external_job_status, ) @@ -37,6 +36,7 @@ ExternalDataJob, ExternalDataSource, ) +from posthog.temporal.common.logger import bind_temporal_worker_logger_sync from posthog.warehouse.models.external_data_schema import update_should_sync Any_Source_Errors: list[str] = ["Could not establish session to SSH gateway"] @@ -68,10 +68,6 @@ "No primary key defined for table", "Access denied for user", ], - ExternalDataSource.Type.SALESFORCE: [ - "400 Client Error: Bad Request for url", - "403 Client Error: Forbidden for url", - ], ExternalDataSource.Type.SNOWFLAKE: [ "This account has been marked for decommission", "404 Not Found", diff --git a/posthog/temporal/data_imports/pipelines/pipeline/test/test_pipeline_utils.py b/posthog/temporal/data_imports/pipelines/pipeline/test/test_pipeline_utils.py index f9ab6a1807f63..ba93086583936 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/test/test_pipeline_utils.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/test/test_pipeline_utils.py @@ -1,12 +1,9 @@ +from ipaddress import IPv4Address, IPv6Address +from dateutil import parser import decimal import uuid -from ipaddress import IPv4Address, IPv6Address - import pyarrow as pa -import pytest -from dateutil import parser - -from posthog.temporal.data_imports.pipelines.pipeline.utils import _get_max_decimal_type, table_from_py_list +from posthog.temporal.data_imports.pipelines.pipeline.utils import table_from_py_list def test_table_from_py_list_uuid(): @@ -225,26 +222,6 @@ def test_table_from_py_list_with_schema_and_too_small_decimal_type(): assert table.schema.equals(expected_schema) -@pytest.mark.parametrize( - "decimals,expected", - [ - ([decimal.Decimal("1")], pa.decimal128(2, 1)), - ([decimal.Decimal("1.001112")], pa.decimal128(7, 6)), - ([decimal.Decimal("0.001112")], pa.decimal128(6, 6)), - ([decimal.Decimal("1.0100000")], pa.decimal128(8, 7)), - # That is 1 followed by 37 zeroes to go over the pa.Decimal128 precision limit of 38. - ([decimal.Decimal("10000000000000000000000000000000000000.1")], pa.decimal256(39, 1)), - ], -) -def test_get_max_decimal_type_returns_correct_decimal_type( - decimals: list[decimal.Decimal], - expected: pa.Decimal128Type | pa.Decimal256Type, -): - """Test whether expected PyArrow decimal type variant is returned.""" - result = _get_max_decimal_type(decimals) - assert result == expected - - def test_table_from_py_list_with_ipv4_address(): table = table_from_py_list([{"column": IPv4Address("127.0.0.1")}]) diff --git a/posthog/temporal/data_imports/pipelines/pipeline/utils.py b/posthog/temporal/data_imports/pipelines/pipeline/utils.py index 435fe51c42bc6..3dcb8ae153e38 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/utils.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/utils.py @@ -1,25 +1,25 @@ import decimal from ipaddress import IPv4Address, IPv6Address import json +from collections.abc import Sequence import math -import uuid -from collections.abc import Hashable, Iterator, Sequence from typing import Any, Optional - -import deltalake as deltalake -import numpy as np +from collections.abc import Hashable +from collections.abc import Iterator +from dateutil import parser +import uuid import orjson +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc -from dateutil import parser -from django.db.models import F -from dlt.common.data_types.typing import TDataType from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_schema -from dlt.common.normalizers.naming.snake_case import NamingConvention from dlt.sources import DltResource - +import deltalake as deltalake +from django.db.models import F from posthog.temporal.common.logger import FilteringBoundLogger +from dlt.common.data_types.typing import TDataType +from dlt.common.normalizers.naming.snake_case import NamingConvention from posthog.temporal.data_imports.pipelines.pipeline.typings import SourceResponse from posthog.warehouse.models import ExternalDataJob, ExternalDataSchema @@ -329,29 +329,15 @@ def build_pyarrow_decimal_type(precision: int, scale: int) -> pa.Decimal128Type def _get_max_decimal_type(values: list[decimal.Decimal]) -> pa.Decimal128Type | pa.Decimal256Type: - """Determine maximum precision and scale from all `decimal.Decimal` values. - - Returns: - A `pa.Decimal128Type` or `pa.Decimal256Type` with enough precision and - scale to hold all `values`. - """ max_precision = 1 max_scale = 0 for value in values: - _, digits, exponent = value.as_tuple() + sign, digits, exponent = value.as_tuple() if not isinstance(exponent, int): continue - - # This implementation accounts for leading zeroes being excluded from digits - # It is based on Arrow, see: - # https://github.com/apache/arrow/blob/main/python/pyarrow/src/arrow/python/decimal.cc#L75 - if exponent < 0: - precision = max(len(digits), -exponent) - scale = -exponent - else: - precision = len(digits) + exponent - scale = 0 + precision = len(digits) + scale = -exponent if exponent < 0 else 0 max_precision = max(precision, max_precision) max_scale = max(scale, max_scale) diff --git a/posthog/temporal/data_imports/pipelines/salesforce/__init__.py b/posthog/temporal/data_imports/pipelines/salesforce/__init__.py index 1d5518e3dd3fd..65bfdc928a5ba 100644 --- a/posthog/temporal/data_imports/pipelines/salesforce/__init__.py +++ b/posthog/temporal/data_imports/pipelines/salesforce/__init__.py @@ -1,16 +1,13 @@ -import re from datetime import datetime from typing import Any, Optional -from urllib.parse import urlencode - import dlt -from dlt.sources.helpers.requests import Request, Response +from urllib.parse import urlencode from dlt.sources.helpers.rest_client.paginators import BasePaginator - -from posthog.temporal.common.logger import get_internal_logger +from dlt.sources.helpers.requests import Response, Request from posthog.temporal.data_imports.pipelines.rest_source import RESTAPIConfig, rest_api_resources from posthog.temporal.data_imports.pipelines.rest_source.typing import EndpointResource from posthog.temporal.data_imports.pipelines.salesforce.auth import SalseforceAuth +import re # Note: When pulling all fields, salesforce requires a 200 limit. We circumvent the pagination by using Id ordering. @@ -313,36 +310,19 @@ def __init__(self, instance_url, is_incremental: bool): super().__init__() self.instance_url = instance_url self.is_incremental = is_incremental - self.logger = get_internal_logger() - - def __repr__(self): - pairs = ( - f"{attr}={repr(getattr(self, attr))}" - for attr in ("is_incremental", "_has_next_page", "_model_name", "_last_record_id") - ) - return f"" def update_state(self, response: Response, data: Optional[list[Any]] = None) -> None: res = response.json() + self._next_page = None + if not res or not res["records"]: self._has_next_page = False - self.logger.debug( - "No more Salesforce pages", instance_url=self.instance_url, is_incremental=self.is_incremental - ) return last_record = res["records"][-1] model_name = res["records"][0]["attributes"]["type"] - self.logger.debug( - "More Salesforce pages required", - instance_url=self.instance_url, - is_incremental=self.is_incremental, - model_name=model_name, - last_record_id=last_record["Id"], - ) - self._has_next_page = True self._last_record_id = last_record["Id"] self._model_name = model_name @@ -352,29 +332,12 @@ def update_request(self, request: Request) -> None: # Cludge: Need to get initial value for date filter query = request.params.get("q", "") date_match = re.search(r"SystemModstamp >= (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.+?)\s", query) - - self.logger.debug( - "Constructing incremental query", - instance_url=self.instance_url, - is_incremental=self.is_incremental, - model_name=self._model_name, - last_record_id=self._last_record_id, - ) - if date_match: date_filter = date_match.group(1) query = f"SELECT FIELDS(ALL) FROM {self._model_name} WHERE Id > '{self._last_record_id}' AND SystemModstamp >= {date_filter} ORDER BY Id ASC LIMIT 200" else: raise ValueError("No date filter found in initial query. Incremental loading requires a date filter.") else: - self.logger.debug( - "Constructing non-incremental query", - instance_url=self.instance_url, - is_incremental=self.is_incremental, - model_name=self._model_name, - last_record_id=self._last_record_id, - ) - query = f"SELECT FIELDS(ALL) FROM {self._model_name} WHERE Id > '{self._last_record_id}' ORDER BY Id ASC LIMIT 200" _next_page = f"/services/data/v61.0/query" + "?" + urlencode({"q": query}) diff --git a/posthog/temporal/data_imports/pipelines/salesforce/auth.py b/posthog/temporal/data_imports/pipelines/salesforce/auth.py index 873e1952f4a3b..95f267516b26c 100644 --- a/posthog/temporal/data_imports/pipelines/salesforce/auth.py +++ b/posthog/temporal/data_imports/pipelines/salesforce/auth.py @@ -25,40 +25,6 @@ def obtain_token(self) -> None: self.token_expiry = pendulum.now().add(hours=1) -class SalesforceAuthRequestError(Exception): - """Exception to capture errors when an auth request fails.""" - - def __init__(self, error_message: str, response: requests.Response): - self.response = response - super().__init__(error_message) - - @classmethod - def raise_from_response(cls, response: requests.Response) -> None: - """Raise a `SalesforceAuthRequestError` from a failed response. - - If the response did not fail, nothing is raised or returned. - """ - if 400 <= response.status_code < 500: - error_message = f"{response.status_code} Client Error: {response.reason}: " - - elif 500 <= response.status_code < 600: - error_message = f"{response.status_code} Server Error: {response.reason}: " - else: - return - - try: - error_description = response.json()["error_description"] - except requests.exceptions.JSONDecodeError: - if response.text: - error_message += response.text - else: - error_message += "No additional error details" - else: - error_message += error_description - - raise cls(error_message, response=response) - - def salesforce_refresh_access_token(refresh_token: str) -> str: res = requests.post( "https://login.salesforce.com/services/oauth2/token", @@ -70,7 +36,9 @@ def salesforce_refresh_access_token(refresh_token: str) -> str: }, ) - SalesforceAuthRequestError.raise_from_response(res) + if res.status_code != 200: + err_message = res.json()["error_description"] + raise Exception(err_message) return res.json()["access_token"] @@ -87,7 +55,9 @@ def get_salesforce_access_token_from_code(code: str, redirect_uri: str) -> tuple }, ) - SalesforceAuthRequestError.raise_from_response(res) + if res.status_code != 200: + err_message = res.json()["error_description"] + raise Exception(err_message) payload = res.json() diff --git a/posthog/temporal/data_imports/pipelines/salesforce/test/test_auth.py b/posthog/temporal/data_imports/pipelines/salesforce/test/test_auth.py deleted file mode 100644 index 55263657e93c7..0000000000000 --- a/posthog/temporal/data_imports/pipelines/salesforce/test/test_auth.py +++ /dev/null @@ -1,94 +0,0 @@ -import unittest.mock -import requests -import pytest -import json - -from posthog.temporal.data_imports.pipelines.salesforce import auth - - -def test_salesforce_refresh_access_token_raises_on_client_failure(): - """Test whether an exception is raised when failing with a client error.""" - status_code = 400 - error_description = "Bad client!" - - response = requests.Response() - response.status_code = status_code - response._content = json.dumps({"error_description": error_description}).encode("utf-8") - - with ( - unittest.mock.patch( - "posthog.temporal.data_imports.pipelines.salesforce.auth.requests.post", return_value=response - ), - pytest.raises(auth.SalesforceAuthRequestError) as exc, - ): - _ = auth.salesforce_refresh_access_token("something") - - assert exc.value.response == response - assert "Client Error" in str(exc.value) - assert error_description in str(exc.value) - - -def test_salesforce_refresh_access_token_raises_on_server_failure(): - """Test whether an exception is raised when failing with a server error.""" - status_code = 500 - response_body = "something went terribly wrong" - - response = requests.Response() - response.status_code = status_code - response._content = response_body.encode("utf-8") - - with ( - unittest.mock.patch( - "posthog.temporal.data_imports.pipelines.salesforce.auth.requests.post", return_value=response - ), - pytest.raises(auth.SalesforceAuthRequestError) as exc, - ): - _ = auth.salesforce_refresh_access_token("something") - - assert exc.value.response == response - assert "Server Error" in str(exc.value) - assert response_body in str(exc.value) - - -def test_get_salesforce_access_token_from_code_raises_on_client_failure(): - """Test whether an exception is raised when failing with a client error.""" - status_code = 400 - error_description = "Bad client!" - - response = requests.Response() - response.status_code = status_code - response._content = json.dumps({"error_description": error_description}).encode("utf-8") - - with ( - unittest.mock.patch( - "posthog.temporal.data_imports.pipelines.salesforce.auth.requests.post", return_value=response - ), - pytest.raises(auth.SalesforceAuthRequestError) as exc, - ): - _ = auth.get_salesforce_access_token_from_code("something", "something") - - assert exc.value.response == response - assert "Client Error" in str(exc.value) - assert error_description in str(exc.value) - - -def test_get_salesforce_access_token_from_code_raises_on_server_failure(): - """Test whether an exception is raised when failing with a server error.""" - status_code = 500 - response_body = "something went terribly wrong" - - response = requests.Response() - response.status_code = status_code - response._content = response_body.encode("utf-8") - - with ( - unittest.mock.patch( - "posthog.temporal.data_imports.pipelines.salesforce.auth.requests.post", return_value=response - ), - pytest.raises(auth.SalesforceAuthRequestError) as exc, - ): - _ = auth.get_salesforce_access_token_from_code("something", "something") - - assert exc.value.response == response - assert "Server Error" in str(exc.value) - assert response_body in str(exc.value) diff --git a/requirements-dev.in b/requirements-dev.in index 56f40f8051b9b..20e5f1a31644b 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -37,7 +37,7 @@ types-python-dateutil>=2.8.3 types-pytz==2023.3 types-redis==4.3.20 types-retry==0.9.9.4 -types-requests==2.31.0.6 # >= 2.31.0.7 versions require urllib>=2, which is incompatible with our dependencies +types-requests==2.26.1 types-tzlocal~=5.1.0.1 parameterized==0.9.0 pyarrow==18.1.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 6c1eaa84c834a..d61cd881eb4bd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -177,10 +177,6 @@ googleapis-common-protos==1.60.0 # via # -c requirements.txt # opentelemetry-exporter-otlp-proto-grpc -greenlet==3.1.1 - # via - # -c requirements.txt - # sqlalchemy grpcio==1.63.2 # via # -c requirements.txt @@ -690,7 +686,7 @@ types-pyyaml==6.0.1 # responses types-redis==4.3.20 # via -r requirements-dev.in -types-requests==2.31.0.6 +types-requests==2.26.1 # via # -r requirements-dev.in # djangorestframework-stubs @@ -702,8 +698,6 @@ types-toml==0.10.8.20240310 # via inline-snapshot types-tzlocal==5.1.0.1 # via -r requirements-dev.in -types-urllib3==1.26.25.14 - # via types-requests typing-extensions==4.12.2 # via # -c requirements.txt diff --git a/requirements.txt b/requirements.txt index 5939af842606e..c8ddb79036113 100644 --- a/requirements.txt +++ b/requirements.txt @@ -332,8 +332,6 @@ graphql-core==3.2.5 # graphql-relay graphql-relay==3.2.0 # via graphene -greenlet==3.1.1 - # via sqlalchemy grpcio==1.63.2 # via # -r requirements.in