diff --git a/src/argilla/server/apis/v0/handlers/users.py b/src/argilla/server/apis/v0/handlers/users.py index 2c7db9af9d..e27585c6ee 100644 --- a/src/argilla/server/apis/v0/handlers/users.py +++ b/src/argilla/server/apis/v0/handlers/users.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, status from sqlalchemy.ext.asyncio import AsyncSession -from argilla.server import models +from argilla.server import models, telemetry from argilla.server.contexts import accounts from argilla.server.database import get_async_db from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError @@ -27,7 +27,6 @@ from argilla.server.pydantic_v1 import parse_obj_as from argilla.server.security import auth from argilla.server.security.model import User, UserCreate -from argilla.utils import telemetry router = APIRouter(tags=["users"]) diff --git a/src/argilla/server/apis/v1/handlers/datasets/datasets.py b/src/argilla/server/apis/v1/handlers/datasets/datasets.py index c76640dd55..09eabc8ee9 100644 --- a/src/argilla/server/apis/v1/handlers/datasets/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets/datasets.py @@ -48,7 +48,7 @@ get_search_engine, ) from argilla.server.security import auth -from argilla.utils.telemetry import TelemetryClient, get_telemetry_client +from argilla.server.telemetry import TelemetryClient, get_telemetry_client CREATE_DATASET_VECTOR_SETTINGS_MAX_COUNT = 5 diff --git a/src/argilla/server/apis/v1/handlers/datasets/records.py b/src/argilla/server/apis/v1/handlers/datasets/records.py index d3eed8d33f..3a3d318515 100644 --- a/src/argilla/server/apis/v1/handlers/datasets/records.py +++ b/src/argilla/server/apis/v1/handlers/datasets/records.py @@ -96,8 +96,8 @@ TermsFilter as SearchEngineTermsFilter, ) from argilla.server.security import auth +from argilla.server.telemetry import TelemetryClient, get_telemetry_client from argilla.server.utils import parse_query_param, parse_uuids -from argilla.utils.telemetry import TelemetryClient, get_telemetry_client LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50 LIST_DATASET_RECORDS_LIMIT_LE = 1000 @@ -632,7 +632,6 @@ async def search_dataset_records( *, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, metadata: MetadataQueryParams = Depends(), diff --git a/src/argilla/server/errors/api_errors.py b/src/argilla/server/errors/api_errors.py index c967972015..52aff076ee 100644 --- a/src/argilla/server/errors/api_errors.py +++ b/src/argilla/server/errors/api_errors.py @@ -17,6 +17,7 @@ from fastapi import HTTPException, Request from fastapi.exception_handlers import http_exception_handler +from argilla.server import telemetry from argilla.server.errors.adapter import exception_to_argilla_error from argilla.server.errors.base_errors import ( EntityAlreadyExistsError, @@ -25,7 +26,6 @@ ServerError, ) from argilla.server.pydantic_v1 import BaseModel -from argilla.utils import telemetry class ErrorDetail(BaseModel): diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index ebd29709d2..4f08e4f86e 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -17,6 +17,7 @@ from fastapi import Depends +from argilla.server import telemetry from argilla.server.commons.config import TasksFactory from argilla.server.commons.models import TaskStatus from argilla.server.daos.backend.base import WrongLogDataError @@ -27,7 +28,6 @@ from argilla.server.services.datasets import ServiceDataset from argilla.server.services.search.model import ServiceBaseRecordsQuery from argilla.server.services.tasks.commons import ServiceRecord -from argilla.utils import telemetry @dataclasses.dataclass diff --git a/src/argilla/server/telemetry.py b/src/argilla/server/telemetry.py new file mode 100644 index 0000000000..92a44432bd --- /dev/null +++ b/src/argilla/server/telemetry.py @@ -0,0 +1,106 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import logging +import platform +import uuid +from typing import Any, Dict, Optional + +from fastapi import Request + +from argilla.server.commons.models import TaskType +from argilla.server.settings import settings + +try: + from analytics import Client # This import works only for version 2.2.0 +except (ImportError, ModuleNotFoundError): + # TODO: show some warning info + settings.enable_telemetry = False + Client = None + + +_LOGGER = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TelemetryClient: + enable_telemetry: dataclasses.InitVar[bool] = settings.enable_telemetry + disable_send: dataclasses.InitVar[bool] = False + api_key: dataclasses.InitVar[str] = settings.telemetry_key + host: dataclasses.InitVar[str] = "https://api.segment.io" + + _server_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None) + + @property + def server_id(self) -> uuid.UUID: + return self._server_id + + def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str, host: str): + from argilla import __version__ + + self.client = None + if enable_telemetry: + try: + self.client = Client(write_key=api_key, gzip=True, host=host, send=not disable_send, max_retries=10) + except Exception as err: + _LOGGER.warning(f"Cannot initialize telemetry. Error: {err}. Disabling...") + + self._server_id = uuid.UUID(int=uuid.getnode()) + self._system_info = { + "system": platform.system(), + "machine": platform.machine(), + "platform": platform.platform(), + "python_version": platform.python_version(), + "sys_version": platform.version(), + "version": __version__, + } + + def track_data(self, action: str, data: Dict[str, Any], include_system_info: bool = True): + if not self.client: + return + + event_data = data.copy() + self.client.track( + user_id=str(self._server_id), + event=action, + properties=event_data, + context=self._system_info if include_system_info else {}, + ) + + +_CLIENT = TelemetryClient() + + +def _process_request_info(request: Request): + return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]} + + +async def track_bulk(task: TaskType, records: int): + _CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records}) + + +async def track_login(request: Request, username: str): + _CLIENT.track_data( + action="UserInfoRequested", + data={ + "is_default_user": username == "argilla", + "user_hash": str(uuid.uuid5(namespace=_CLIENT.server_id, name=username)), + **_process_request_info(request), + }, + ) + + +def get_telemetry_client() -> TelemetryClient: + return _CLIENT diff --git a/src/argilla/utils/telemetry.py b/src/argilla/utils/telemetry.py index 7e7a228ace..bf27a9fc58 100644 --- a/src/argilla/utils/telemetry.py +++ b/src/argilla/utils/telemetry.py @@ -20,12 +20,6 @@ from argilla.pydantic_v1 import BaseSettings -if TYPE_CHECKING: - from fastapi import Request - - from argilla.server.commons.models import TaskType - - _DEFAULT_TELEMETRY_KEY = "C6FkcaoCbt78rACAgvyBxGBcMB3dM3nn" @@ -64,11 +58,11 @@ class TelemetryClient: api_key: dataclasses.InitVar[str] = telemetry_settings.telemetry_key host: dataclasses.InitVar[str] = "https://api.segment.io" - _server_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None) + _machine_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None) @property - def server_id(self) -> uuid.UUID: - return self._server_id + def machine_id(self) -> uuid.UUID: + return self._machine_id def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str, host: str): from argilla import __version__ @@ -80,7 +74,7 @@ def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str except Exception as err: _LOGGER.warning(f"Cannot initialize telemetry. Error: {err}. Disabling...") - self._server_id = uuid.UUID(int=uuid.getnode()) + self._machine_id = uuid.UUID(int=uuid.getnode()) self._system_info = { "system": platform.system(), "machine": platform.machine(), @@ -96,7 +90,7 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo event_data = data.copy() self.client.track( - user_id=str(self._server_id), + user_id=str(self._machine_id), event=action, properties=event_data, context=self._system_info if include_system_info else {}, @@ -106,25 +100,6 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo _CLIENT = TelemetryClient() -def _process_request_info(request: "Request"): - return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]} - - -async def track_bulk(task: "TaskType", records: int): - _CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records}) - - -async def track_login(request: "Request", username: str): - _CLIENT.track_data( - action="UserInfoRequested", - data={ - "is_default_user": username == "argilla", - "user_hash": str(uuid.uuid5(namespace=_CLIENT.server_id, name=username)), - **_process_request_info(request), - }, - ) - - def get_current_filename() -> Optional[str]: """Returns the filename of the current file. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 03fdf7f463..82e391ed4d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -14,6 +14,7 @@ import asyncio import contextlib import tempfile +import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, Generator import httpx @@ -28,12 +29,12 @@ from argilla.client.sdk.users import api as users_api from argilla.client.singleton import ArgillaSingleton from argilla.datasets import configure_dataset +from argilla.server import telemetry as server_telemetry from argilla.server.cli.database.migrate import migrate_db from argilla.server.database import get_async_db from argilla.server.models import User, UserRole, Workspace from argilla.server.settings import settings -from argilla.utils import telemetry -from argilla.utils.telemetry import TelemetryClient +from argilla.utils import telemetry as client_telemetry from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -199,15 +200,28 @@ def argilla_auth_header(argilla_user: User) -> Dict[str, str]: @pytest.fixture(autouse=True) -def test_telemetry(mocker: "MockerFixture") -> "MagicMock": - telemetry._CLIENT = TelemetryClient(disable_send=True) +def server_telemetry_client(mocker: "MockerFixture") -> "MagicMock": + mock_telemetry = mocker.Mock(server_telemetry.TelemetryClient) + mock_telemetry.server_id = uuid.uuid4() - return mocker.spy(telemetry._CLIENT, "track_data") + server_telemetry._CLIENT = mock_telemetry + return server_telemetry._CLIENT + + +@pytest.fixture(autouse=True) +def client_telemetry_client(mocker: "MockerFixture") -> "MagicMock": + mock_telemetry = mocker.Mock(client_telemetry.TelemetryClient) + mock_telemetry.machine_id = uuid.uuid4() + + client_telemetry._CLIENT = mock_telemetry + return client_telemetry._CLIENT @pytest.mark.parametrize("client", [True], indirect=True) @pytest.fixture(autouse=True) -def using_test_client_from_argilla_python_client(monkeypatch, test_telemetry: "MagicMock", client: "TestClient"): +def using_test_client_from_argilla_python_client( + monkeypatch, server_telemetry_client, client_telemetry_client, client: "TestClient" +): real_whoami = users_api.whoami def whoami_mocked(*args, **kwargs): diff --git a/tests/unit/server/api/v0/test_text_classification.py b/tests/unit/server/api/v0/test_text_classification.py index 2694ba6bd4..63f0be281d 100644 --- a/tests/unit/server/api/v0/test_text_classification.py +++ b/tests/unit/server/api/v0/test_text_classification.py @@ -224,7 +224,7 @@ async def test_create_records_for_text_classification(async_client: "AsyncClient "words": {"data": 1}, } - test_telemetry.assert_called() + test_telemetry.track_data.assert_called() @pytest.mark.asyncio diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index e6afb432bc..c836cec73d 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -2212,7 +2212,7 @@ async def test_create_dataset_records( records = (await db.execute(select(Record))).scalars().all() mock_search_engine.index_records.assert_called_once_with(dataset, records) - test_telemetry.assert_called_once_with( + test_telemetry.track_data.assert_called_once_with( action="DatasetRecordsCreated", data={"records": len(records_json["items"])} ) @@ -2927,7 +2927,7 @@ async def test_create_dataset_records_as_admin( records = (await db.execute(select(Record))).scalars().all() mock_search_engine.index_records.assert_called_once_with(dataset, records) - test_telemetry.assert_called_once_with( + test_telemetry.track_data.assert_called_once_with( action="DatasetRecordsCreated", data={"records": len(records_json["items"])} ) @@ -4929,7 +4929,7 @@ async def test_publish_dataset( response_body = response.json() assert response_body["status"] == "ready" - test_telemetry.assert_called_once_with(action="PublishedDataset", data={"questions": ["rating"]}) + test_telemetry.track_data.assert_called_once_with(action="PublishedDataset", data={"questions": ["rating"]}) mock_search_engine.create_index.assert_called_once_with(dataset) async def test_publish_dataset_with_error_on_index_creation( diff --git a/tests/unit/server/commons/test_telemetry.py b/tests/unit/server/commons/test_telemetry.py index 5d41cfb2e2..97c0587975 100644 --- a/tests/unit/server/commons/test_telemetry.py +++ b/tests/unit/server/commons/test_telemetry.py @@ -16,9 +16,9 @@ from unittest.mock import MagicMock import pytest +from argilla.server import telemetry from argilla.server.commons.models import TaskType -from argilla.utils import telemetry -from argilla.utils.telemetry import TelemetryClient, get_telemetry_client +from argilla.server.telemetry import TelemetryClient, get_telemetry_client from fastapi import Request mock_request = Request(scope={"type": "http", "headers": {}}) @@ -41,7 +41,7 @@ async def test_track_login(test_telemetry: MagicMock): "user-agent": None, "user_hash": str(uuid.uuid5(current_server_id, name="argilla")), } - test_telemetry.assert_called_once_with("UserInfoRequested", expected_event_data) + test_telemetry.track_data.assert_called_once_with(action="UserInfoRequested", data=expected_event_data) @pytest.mark.asyncio @@ -49,4 +49,6 @@ async def test_track_bulk(test_telemetry): task, records = TaskType.token_classification, 100 await telemetry.track_bulk(task=task, records=records) - test_telemetry.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records}) + test_telemetry.track_data.assert_called_once_with( + action="LogRecordsRequested", data={"task": task, "records": records} + ) diff --git a/tests/unit/server/conftest.py b/tests/unit/server/conftest.py index 8537008721..41132c4599 100644 --- a/tests/unit/server/conftest.py +++ b/tests/unit/server/conftest.py @@ -13,10 +13,12 @@ # limitations under the License. import contextlib +import uuid from typing import TYPE_CHECKING, Dict, Generator import pytest import pytest_asyncio +from argilla.server import telemetry from argilla.server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.datasets import DatasetsDAO @@ -24,10 +26,8 @@ from argilla.server.database import get_async_db from argilla.server.models import User, UserRole, Workspace from argilla.server.search_engine import SearchEngine, get_search_engine -from argilla.server.services.datasets import DatasetsService from argilla.server.settings import settings -from argilla.utils import telemetry -from argilla.utils.telemetry import TelemetryClient +from argilla.server.telemetry import TelemetryClient from httpx import AsyncClient from opensearchpy import OpenSearch @@ -105,9 +105,11 @@ async def override_get_search_engine(): @pytest.fixture(autouse=True) def test_telemetry(mocker: "MockerFixture") -> "MagicMock": - telemetry._CLIENT = TelemetryClient(disable_send=True) + mock_telemetry = mocker.Mock(TelemetryClient) + mock_telemetry.server_id = uuid.uuid4() - return mocker.spy(telemetry._CLIENT, "track_data") + telemetry._CLIENT = mock_telemetry + return telemetry._CLIENT @pytest.fixture(scope="session") diff --git a/tests/unit/server/errors/test_api_errors.py b/tests/unit/server/errors/test_api_errors.py index a5cf91ffe0..a709bdb62c 100644 --- a/tests/unit/server/errors/test_api_errors.py +++ b/tests/unit/server/errors/test_api_errors.py @@ -72,4 +72,4 @@ class TestAPIErrorHandler: async def test_track_error(self, test_telemetry, error, expected_event): await APIErrorHandler.track_error(error, request=mock_request) - test_telemetry.assert_called_once_with("ServerErrorFound", expected_event) + test_telemetry.track_data.assert_called_once_with(action="ServerErrorFound", data=expected_event)