Skip to content

Commit

Permalink
tests: Adapt tests to changes
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Dec 22, 2023
1 parent dce3cfb commit be6612a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime
from unittest.mock import call
from uuid import UUID, uuid4

import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.enums import ResponseStatus
from argilla.server.models import Response, User
from argilla.server.search_engine import SearchEngine
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.enums import ResponseStatus
from argilla.server.models import Response, User
from argilla.server.search_engine import SearchEngine
from argilla.server.use_cases.responses.upsert_responses_in_bulk import UpsertResponsesInBulkUseCase
from tests.factories import (
AnnotatorFactory,
DatasetFactory,
Expand Down Expand Up @@ -386,3 +387,58 @@ async def test_too_many_responses(self, async_client: AsyncClient, owner_auth_he
)

assert resp.status_code == 422

@pytest.mark.skipif(reason="Profiling is not active", condition=not bool(os.getenv("TEST_PROFILING", None)))
async def test_create_responses_in_bulk_profiling(self, db: "AsyncSession", elasticsearch_config: dict):
from pyinstrument import Profiler
from tests.factories import OwnerFactory, TextFieldFactory
from argilla.server.search_engine import ElasticSearchEngine
from argilla.server.schemas.v1.responses import DraftResponseUpsert

async def refresh_dataset(dataset):
await dataset.awaitable_attrs.fields
await dataset.awaitable_attrs.questions
await dataset.awaitable_attrs.metadata_properties
await dataset.awaitable_attrs.vectors_settings

async def refresh_records(records):
for record in records:
await record.awaitable_attrs.suggestions
await record.awaitable_attrs.responses
await record.awaitable_attrs.vectors

dataset = await DatasetFactory.create()
user = await OwnerFactory.create()

await RatingQuestionFactory.create(name="prompt-quality", required=True, dataset=dataset)
await TextFieldFactory.create(name="text", required=True, dataset=dataset)
await TextFieldFactory.create(name="sentiment", required=True, dataset=dataset)

records = await RecordFactory.create_batch(dataset=dataset, size=500)

engine = ElasticSearchEngine(config=elasticsearch_config, number_of_replicas=0, number_of_shards=1)

await refresh_dataset(dataset)
await refresh_records(records)

await engine.create_index(dataset)
await engine.index_records(dataset, records)

profiler = Profiler()

responses = [
DraftResponseUpsert.parse_obj(
{
"values": {"prompt-quality": {"value": 10}},
"record_id": record.id,
"status": "draft",
}
)
for record in records
]
use_case = UpsertResponsesInBulkUseCase(db, engine)
with profiler:
bulk_items = await use_case.execute(responses, user)
await use_case.execute([bulk_item.item for bulk_item in bulk_items], user)

profiler.open_in_browser()
8 changes: 0 additions & 8 deletions tests/unit/server/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,6 @@ class TestBaseElasticAndOpenSearchEngine:
"""

# TODO: Use other public method to detect the error
async def test_get_index_or_raise(self, search_engine: BaseElasticAndOpenSearchEngine):
dataset = await DatasetFactory.create()
with pytest.raises(
ValueError, match=f"Cannot access to index for dataset {dataset.id}: the specified index does not exist"
):
await search_engine._get_index_or_raise(dataset)

async def test_create_index_for_dataset(
self, search_engine: BaseElasticAndOpenSearchEngine, db: "AsyncSession", opensearch: OpenSearch
):
Expand Down

0 comments on commit be6612a

Please sign in to comment.