Skip to content

Commit

Permalink
refactor: task item review the telemetry client for both client and s…
Browse files Browse the repository at this point in the history
…erver side (#4480)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR separates telemetry clients for server and client modules.

Closes #4479

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [X] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

**Checklist**

- [ ] I added relevant documentation
- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Jan 12, 2024
1 parent 0b91c8f commit 4290369
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 57 deletions.
3 changes: 1 addition & 2 deletions src/argilla/server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

from argilla.server import models
from argilla.server import models, telemetry
from argilla.server.contexts import accounts
from argilla.server.database import get_async_db
from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla.server.policies import UserPolicy, authorize
from argilla.server.pydantic_v1 import parse_obj_as
from argilla.server.security import auth
from argilla.server.security.model import User, UserCreate
from argilla.utils import telemetry

router = APIRouter(tags=["users"])

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v1/handlers/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
get_search_engine,
)
from argilla.server.security import auth
from argilla.utils.telemetry import TelemetryClient, get_telemetry_client
from argilla.server.telemetry import TelemetryClient, get_telemetry_client

CREATE_DATASET_VECTOR_SETTINGS_MAX_COUNT = 5

Expand Down
3 changes: 1 addition & 2 deletions src/argilla/server/apis/v1/handlers/datasets/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@
TermsFilter as SearchEngineTermsFilter,
)
from argilla.server.security import auth
from argilla.server.telemetry import TelemetryClient, get_telemetry_client
from argilla.server.utils import parse_query_param, parse_uuids
from argilla.utils.telemetry import TelemetryClient, get_telemetry_client

LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50
LIST_DATASET_RECORDS_LIMIT_LE = 1000
Expand Down Expand Up @@ -632,7 +632,6 @@ async def search_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
body: SearchRecordsQuery,
metadata: MetadataQueryParams = Depends(),
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/errors/api_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fastapi import HTTPException, Request
from fastapi.exception_handlers import http_exception_handler

from argilla.server import telemetry
from argilla.server.errors.adapter import exception_to_argilla_error
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
Expand All @@ -25,7 +26,6 @@
ServerError,
)
from argilla.server.pydantic_v1 import BaseModel
from argilla.utils import telemetry


class ErrorDetail(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/services/storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from fastapi import Depends

from argilla.server import telemetry
from argilla.server.commons.config import TasksFactory
from argilla.server.commons.models import TaskStatus
from argilla.server.daos.backend.base import WrongLogDataError
Expand All @@ -27,7 +28,6 @@
from argilla.server.services.datasets import ServiceDataset
from argilla.server.services.search.model import ServiceBaseRecordsQuery
from argilla.server.services.tasks.commons import ServiceRecord
from argilla.utils import telemetry


@dataclasses.dataclass
Expand Down
106 changes: 106 additions & 0 deletions src/argilla/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import logging
import platform
import uuid
from typing import Any, Dict, Optional

from fastapi import Request

from argilla.server.commons.models import TaskType
from argilla.server.settings import settings

try:
from analytics import Client # This import works only for version 2.2.0
except (ImportError, ModuleNotFoundError):
# TODO: show some warning info
settings.enable_telemetry = False
Client = None


_LOGGER = logging.getLogger(__name__)


@dataclasses.dataclass
class TelemetryClient:
enable_telemetry: dataclasses.InitVar[bool] = settings.enable_telemetry
disable_send: dataclasses.InitVar[bool] = False
api_key: dataclasses.InitVar[str] = settings.telemetry_key
host: dataclasses.InitVar[str] = "https://api.segment.io"

_server_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None)

@property
def server_id(self) -> uuid.UUID:
return self._server_id

def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str, host: str):
from argilla import __version__

self.client = None
if enable_telemetry:
try:
self.client = Client(write_key=api_key, gzip=True, host=host, send=not disable_send, max_retries=10)
except Exception as err:
_LOGGER.warning(f"Cannot initialize telemetry. Error: {err}. Disabling...")

self._server_id = uuid.UUID(int=uuid.getnode())
self._system_info = {
"system": platform.system(),
"machine": platform.machine(),
"platform": platform.platform(),
"python_version": platform.python_version(),
"sys_version": platform.version(),
"version": __version__,
}

def track_data(self, action: str, data: Dict[str, Any], include_system_info: bool = True):
if not self.client:
return

event_data = data.copy()
self.client.track(
user_id=str(self._server_id),
event=action,
properties=event_data,
context=self._system_info if include_system_info else {},
)


_CLIENT = TelemetryClient()


def _process_request_info(request: Request):
return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]}


async def track_bulk(task: TaskType, records: int):
_CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records})


async def track_login(request: Request, username: str):
_CLIENT.track_data(
action="UserInfoRequested",
data={
"is_default_user": username == "argilla",
"user_hash": str(uuid.uuid5(namespace=_CLIENT.server_id, name=username)),
**_process_request_info(request),
},
)


def get_telemetry_client() -> TelemetryClient:
return _CLIENT
35 changes: 5 additions & 30 deletions src/argilla/utils/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@

from argilla.pydantic_v1 import BaseSettings

if TYPE_CHECKING:
from fastapi import Request

from argilla.server.commons.models import TaskType


_DEFAULT_TELEMETRY_KEY = "C6FkcaoCbt78rACAgvyBxGBcMB3dM3nn"


Expand Down Expand Up @@ -64,11 +58,11 @@ class TelemetryClient:
api_key: dataclasses.InitVar[str] = telemetry_settings.telemetry_key
host: dataclasses.InitVar[str] = "https://api.segment.io"

_server_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None)
_machine_id: Optional[uuid.UUID] = dataclasses.field(init=False, default=None)

@property
def server_id(self) -> uuid.UUID:
return self._server_id
def machine_id(self) -> uuid.UUID:
return self._machine_id

def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str, host: str):
from argilla import __version__
Expand All @@ -80,7 +74,7 @@ def __post_init__(self, enable_telemetry: bool, disable_send: bool, api_key: str
except Exception as err:
_LOGGER.warning(f"Cannot initialize telemetry. Error: {err}. Disabling...")

self._server_id = uuid.UUID(int=uuid.getnode())
self._machine_id = uuid.UUID(int=uuid.getnode())
self._system_info = {
"system": platform.system(),
"machine": platform.machine(),
Expand All @@ -96,7 +90,7 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo

event_data = data.copy()
self.client.track(
user_id=str(self._server_id),
user_id=str(self._machine_id),
event=action,
properties=event_data,
context=self._system_info if include_system_info else {},
Expand All @@ -106,25 +100,6 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo
_CLIENT = TelemetryClient()


def _process_request_info(request: "Request"):
return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]}


async def track_bulk(task: "TaskType", records: int):
_CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records})


async def track_login(request: "Request", username: str):
_CLIENT.track_data(
action="UserInfoRequested",
data={
"is_default_user": username == "argilla",
"user_hash": str(uuid.uuid5(namespace=_CLIENT.server_id, name=username)),
**_process_request_info(request),
},
)


def get_current_filename() -> Optional[str]:
"""Returns the filename of the current file.
Expand Down
26 changes: 20 additions & 6 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
import contextlib
import tempfile
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Generator

import httpx
Expand All @@ -28,12 +29,12 @@
from argilla.client.sdk.users import api as users_api
from argilla.client.singleton import ArgillaSingleton
from argilla.datasets import configure_dataset
from argilla.server import telemetry as server_telemetry
from argilla.server.cli.database.migrate import migrate_db
from argilla.server.database import get_async_db
from argilla.server.models import User, UserRole, Workspace
from argilla.server.settings import settings
from argilla.utils import telemetry
from argilla.utils.telemetry import TelemetryClient
from argilla.utils import telemetry as client_telemetry
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
Expand Down Expand Up @@ -199,15 +200,28 @@ def argilla_auth_header(argilla_user: User) -> Dict[str, str]:


@pytest.fixture(autouse=True)
def test_telemetry(mocker: "MockerFixture") -> "MagicMock":
telemetry._CLIENT = TelemetryClient(disable_send=True)
def server_telemetry_client(mocker: "MockerFixture") -> "MagicMock":
mock_telemetry = mocker.Mock(server_telemetry.TelemetryClient)
mock_telemetry.server_id = uuid.uuid4()

return mocker.spy(telemetry._CLIENT, "track_data")
server_telemetry._CLIENT = mock_telemetry
return server_telemetry._CLIENT


@pytest.fixture(autouse=True)
def client_telemetry_client(mocker: "MockerFixture") -> "MagicMock":
mock_telemetry = mocker.Mock(client_telemetry.TelemetryClient)
mock_telemetry.machine_id = uuid.uuid4()

client_telemetry._CLIENT = mock_telemetry
return client_telemetry._CLIENT


@pytest.mark.parametrize("client", [True], indirect=True)
@pytest.fixture(autouse=True)
def using_test_client_from_argilla_python_client(monkeypatch, test_telemetry: "MagicMock", client: "TestClient"):
def using_test_client_from_argilla_python_client(
monkeypatch, server_telemetry_client, client_telemetry_client, client: "TestClient"
):
real_whoami = users_api.whoami

def whoami_mocked(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/server/api/v0/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def test_create_records_for_text_classification(async_client: "AsyncClient
"words": {"data": 1},
}

test_telemetry.assert_called()
test_telemetry.track_data.assert_called()


@pytest.mark.asyncio
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2212,7 +2212,7 @@ async def test_create_dataset_records(
records = (await db.execute(select(Record))).scalars().all()
mock_search_engine.index_records.assert_called_once_with(dataset, records)

test_telemetry.assert_called_once_with(
test_telemetry.track_data.assert_called_once_with(
action="DatasetRecordsCreated", data={"records": len(records_json["items"])}
)

Expand Down Expand Up @@ -2927,7 +2927,7 @@ async def test_create_dataset_records_as_admin(
records = (await db.execute(select(Record))).scalars().all()
mock_search_engine.index_records.assert_called_once_with(dataset, records)

test_telemetry.assert_called_once_with(
test_telemetry.track_data.assert_called_once_with(
action="DatasetRecordsCreated", data={"records": len(records_json["items"])}
)

Expand Down Expand Up @@ -4929,7 +4929,7 @@ async def test_publish_dataset(
response_body = response.json()
assert response_body["status"] == "ready"

test_telemetry.assert_called_once_with(action="PublishedDataset", data={"questions": ["rating"]})
test_telemetry.track_data.assert_called_once_with(action="PublishedDataset", data={"questions": ["rating"]})
mock_search_engine.create_index.assert_called_once_with(dataset)

async def test_publish_dataset_with_error_on_index_creation(
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/server/commons/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from unittest.mock import MagicMock

import pytest
from argilla.server import telemetry
from argilla.server.commons.models import TaskType
from argilla.utils import telemetry
from argilla.utils.telemetry import TelemetryClient, get_telemetry_client
from argilla.server.telemetry import TelemetryClient, get_telemetry_client
from fastapi import Request

mock_request = Request(scope={"type": "http", "headers": {}})
Expand All @@ -41,12 +41,14 @@ async def test_track_login(test_telemetry: MagicMock):
"user-agent": None,
"user_hash": str(uuid.uuid5(current_server_id, name="argilla")),
}
test_telemetry.assert_called_once_with("UserInfoRequested", expected_event_data)
test_telemetry.track_data.assert_called_once_with(action="UserInfoRequested", data=expected_event_data)


@pytest.mark.asyncio
async def test_track_bulk(test_telemetry):
task, records = TaskType.token_classification, 100

await telemetry.track_bulk(task=task, records=records)
test_telemetry.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records})
test_telemetry.track_data.assert_called_once_with(
action="LogRecordsRequested", data={"task": task, "records": records}
)
Loading

0 comments on commit 4290369

Please sign in to comment.