Skip to content

Commit

Permalink
refactor all tests to async (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 authored Dec 18, 2024
1 parent 55027c1 commit b82ad62
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 134 deletions.
Binary file removed nua/e2e/assets/financial-new-kb.arrow
Binary file not shown.
Binary file added nua/e2e/assets/financial-news-kb.arrow
Binary file not shown.
Binary file modified nua/e2e/assets/legal-text-kb.arrow
Binary file not shown.
3 changes: 1 addition & 2 deletions nua/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Tokens:
"aws-us-east-2-1.nuclia.cloud",
],
)
def nua_config(request):
async def nua_config(request):
if (
os.environ.get("TEST_ENV") == "stage" and "stashify.cloud" not in request.param # noqa
): # noqa
Expand All @@ -71,6 +71,5 @@ def nua_config(request):
client_id = nuclia_auth.nua(token.nua_key)
assert client_id
nuclia_auth._config.set_default_nua(client_id)

yield request.param
reset_config_file()
117 changes: 67 additions & 50 deletions nua/e2e/regional/test_da_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from conftest import TOKENS
from regional.utils import define_path
from typing import Callable, AsyncGenerator
from httpx import AsyncClient
import asyncio
import aiofiles
from nucliadb_protos.writer_pb2 import BrokerMessage
from dataclasses import dataclass
import base64
from typing import Optional
import aiohttp


@dataclass
Expand All @@ -38,79 +38,89 @@ class TestInput:
validate_output: Callable[[BrokerMessage], None]


@pytest.fixture
def httpx_client() -> Callable[[str, str], AsyncGenerator[AsyncClient, None]]:
async def create_httpx_client(
@pytest.fixture(scope="session")
def aiohttp_client() -> (
Callable[[str, str], AsyncGenerator[aiohttp.ClientSession, None]]
):
async def create_aiohttp_client(
base_url: str,
nua_key: Optional[str] = None,
pat_key: Optional[str] = None,
timeout: int = 5,
) -> AsyncGenerator[AsyncClient, None]:
client = AsyncClient()
async with AsyncClient(
base_url=base_url,
headers={"X-NUCLIA-NUAKEY": f"Bearer {nua_key}"}
timeout: int = 30,
) -> AsyncGenerator[aiohttp.ClientSession, None]:
headers = (
{"X-NUCLIA-NUAKEY": f"Bearer {nua_key}"}
if nua_key
else {"Authorization": f"Bearer {pat_key}"},
timeout=timeout,
) as client:
yield client
else {"Authorization": f"Bearer {pat_key}"}
)
timeout_config = aiohttp.ClientTimeout(total=timeout)

async with aiohttp.ClientSession(
base_url=base_url,
headers=headers,
timeout=timeout_config,
) as session:
yield session

return create_httpx_client
return create_aiohttp_client


async def create_nua_key(
client: AsyncClient, account_id: str, title: str
client: aiohttp.ClientSession, account_id: str, title: str
) -> tuple[str, str]:
body = {
"title": title,
"contact": "temporal key, safe to delete",
}
resp = await client.post(f"/api/v1/account/{account_id}/nua_clients", json=body)
assert resp.status_code == 201, resp.text
nua_response = resp.json()
assert resp.status == 201, await resp.text()
nua_response = await resp.json()
return nua_response["client_id"], nua_response["token"]


async def delete_nua_key(client: AsyncClient, account_id: str, nua_client_id: str):
async def delete_nua_key(
client: aiohttp.ClientSession, account_id: str, nua_client_id: str
):
resp = await client.delete(
f"/api/v1/account/{account_id}/nua_client/{nua_client_id}"
)
assert resp.status_code == 204, resp.text
assert resp.status == 204, await resp.text()


def task_done(task_request: dict) -> bool:
return task_request["failed"] or task_request["completed"]


async def create_dataset(client: AsyncClient) -> str:
async def create_dataset(client: aiohttp.ClientSession) -> str:
dataset_body = {
"name": "e2e-test-dataset",
"filter": {"labels": []},
"type": "FIELD_STREAMING",
}
resp = await client.post("/api/v1/datasets", json=dataset_body)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
assert resp.status == 201, await resp.text()
return (await resp.json())["id"]


async def delete_dataset(client: AsyncClient, dataset_id: str):
async def delete_dataset(client: aiohttp.ClientSession, dataset_id: str):
resp = await client.delete(f"/api/v1/dataset/{dataset_id}")
assert resp.status_code == 204, resp.text
assert resp.status == 204, await resp.text()


async def push_data_to_dataset(client: AsyncClient, dataset_id: str, filename: str):
async def push_data_to_dataset(
client: aiohttp.ClientSession, dataset_id: str, filename: str
):
async with aiofiles.open(define_path(filename), "rb") as f:
content = await f.read()
resp = await client.put(
f"/api/v1/dataset/{dataset_id}/partition/1",
data=content,
)
assert resp.status_code == 204, resp.text
assert resp.status == 204, await resp.text()


async def start_task(
client: AsyncClient,
client: aiohttp.ClientSession,
dataset_id: str,
task_name: str,
parameters: PARAMETERS_TYPING,
Expand All @@ -119,22 +129,22 @@ async def start_task(
f"/api/v1/dataset/{dataset_id}/task/start",
json=TaskStart(name=task_name, parameters=parameters).model_dump(),
)
assert resp.status_code == 200, resp.text
return resp.json()["id"]
assert resp.status == 200, await resp.text()
return (await resp.json())["id"]


async def stop_task(client: AsyncClient, dataset_id: str, task_id: str):
async def stop_task(client: aiohttp.ClientSession, dataset_id: str, task_id: str):
resp = await client.post(f"/api/v1/dataset/{dataset_id}/task/{task_id}/stop")
assert resp.status_code == 200, resp.text
assert resp.status == 200, await resp.text()


async def delete_task(client: AsyncClient, dataset_id: str, task_id: str):
async def delete_task(client: aiohttp.ClientSession, dataset_id: str, task_id: str):
resp = await client.delete(f"/api/v1/dataset/{dataset_id}/task/{task_id}")
assert resp.status_code == 200, resp.text
assert resp.status == 200, await resp.text()


async def wait_for_task_completion(
client: AsyncClient,
client: aiohttp.ClientSession,
dataset_id: str,
task_id: str,
max_duration: int = 2300,
Expand All @@ -150,24 +160,24 @@ async def wait_for_task_completion(
resp = await client.get(
f"/api/v1/dataset/{dataset_id}/task/{task_id}/inspect",
)
assert resp.status_code == 200, resp.text
task_request = resp.json()
assert resp.status == 200, await resp.text()
task_request = await resp.json()
if task_done(task_request):
return task_request

await asyncio.sleep(20)


async def validate_task_output(
client: AsyncClient, validation: Callable[[BrokerMessage], None]
client: aiohttp.ClientSession, validation: Callable[[BrokerMessage], None]
):
max_retries = 5
for _ in range(max_retries):
resp = await client.get(
"/api/v1/processing/pull", params={"from_cursor": 0, "limit": 1}
)
assert resp.status_code == 200, resp.text
pull_response = resp.json()
assert resp.status == 200, await resp.text()
pull_response = await resp.json()
if pull_response["payloads"]:
assert len(pull_response["payloads"]) == 1
message = BrokerMessage()
Expand Down Expand Up @@ -291,7 +301,7 @@ def validate_labeler_output_text_block(msg: BrokerMessage):

DA_TEST_INPUTS: list[TestInput] = [
TestInput(
filename="financial-new-kb.arrow",
filename="financial-news-kb.arrow",
task_name=TaskName.LABELER,
parameters=DataAugmentation(
name="e2e-test-labeler",
Expand Down Expand Up @@ -435,7 +445,7 @@ def validate_labeler_output_text_block(msg: BrokerMessage):
validate_output=validate_synthetic_questions_output,
),
TestInput(
filename="financial-new-kb.arrow",
filename="financial-news-kb.arrow",
task_name=TaskName.LABELER,
parameters=DataAugmentation(
name="e2e-test-labeler-text-block",
Expand Down Expand Up @@ -499,10 +509,12 @@ def validate_labeler_output_text_block(msg: BrokerMessage):


@pytest.fixture
async def nua_key(nua_config: str, httpx_client: AsyncGenerator[AsyncClient, None]):
async def tmp_nua_key(
nua_config: str, aiohttp_client: AsyncGenerator[aiohttp.ClientSession, None]
) -> AsyncGenerator[str, None]:
account_id = TOKENS[nua_config].account_id
pat_client_generator = httpx_client(
base_url=f"https://{nua_config}", pat_key=TOKENS[nua_config].pat_key, timeout=5
pat_client_generator = aiohttp_client(
base_url=f"https://{nua_config}", pat_key=TOKENS[nua_config].pat_key, timeout=30
)
pat_client = await anext(pat_client_generator)
nua_client_id, nua_key = await create_nua_key(
Expand All @@ -524,20 +536,20 @@ async def nua_key(nua_config: str, httpx_client: AsyncGenerator[AsyncClient, Non
)
async def test_da_agent_tasks(
nua_config: str,
httpx_client: AsyncGenerator[AsyncClient, None],
nua_key: str,
aiohttp_client: AsyncGenerator[aiohttp.ClientSession, None],
tmp_nua_key: str,
test_input: TestInput,
):
dataset_id = None
task_id = None
start_time = asyncio.get_event_loop().time()
try:
nua_client_generator = httpx_client(
base_url=f"https://{nua_config}", nua_key=nua_key, timeout=30
nua_client_generator = aiohttp_client(
base_url=f"https://{nua_config}", nua_key=tmp_nua_key, timeout=30
)
nua_client = await anext(nua_client_generator)

dataset_id = await create_dataset(client=nua_client)
print(f"{test_input.parameters.name} dataset_id: {dataset_id}")
await push_data_to_dataset(
client=nua_client, dataset_id=dataset_id, filename=test_input.filename
)
Expand Down Expand Up @@ -569,3 +581,8 @@ async def test_da_agent_tasks(
client=nua_client, dataset_id=dataset_id, task_id=task_id
)
await delete_dataset(client=nua_client, dataset_id=dataset_id)
end_time = asyncio.get_event_loop().time()
elapsed_time = end_time - start_time
print(
f"Test {test_input.parameters.name} completed in {elapsed_time:.2f} seconds."
)
10 changes: 6 additions & 4 deletions nua/e2e/regional/test_llm_chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from nuclia.lib.nua_responses import ChatModel, UserPrompt
from nuclia.sdk.predict import NucliaPredict
from nuclia.sdk.predict import AsyncNucliaPredict

from regional.models import ALL_LLMS
import pytest


def test_llm_chat(nua_config):
@pytest.mark.asyncio_cooperative
async def test_llm_chat(nua_config):
# Validate that other features such as
# * citations
# * custom prompts
# * reranking (TODO once supported by the SDK)
np = NucliaPredict()
np = AsyncNucliaPredict()
chat_model = ChatModel(
question="Which is the CEO of Nuclia?",
retrieval=False,
Expand All @@ -24,7 +26,7 @@ def test_llm_chat(nua_config):
},
citations=True,
)
generated = np.generate(
generated = await np.generate(
text=chat_model,
model=ALL_LLMS[0],
)
Expand Down
12 changes: 0 additions & 12 deletions nua/e2e/regional/test_llm_citation.py

This file was deleted.

15 changes: 8 additions & 7 deletions nua/e2e/regional/test_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import pytest
from nuclia.exceptions import NuaAPIException
from nuclia.lib.nua_responses import LearningConfigurationCreation
from nuclia.sdk.predict import NucliaPredict
from nuclia.sdk.predict import AsyncNucliaPredict


def test_llm_config_nua(nua_config):
np = NucliaPredict()
@pytest.mark.asyncio_cooperative
async def test_llm_config_nua(nua_config):
np = AsyncNucliaPredict()

try:
np.del_config("kbid")
await np.del_config("kbid")
except NuaAPIException:
pass

with pytest.raises(NuaAPIException):
config = np.config("kbid")
config = await np.config("kbid")

lcc = LearningConfigurationCreation()
np.set_config("kbid", lcc)
await np.set_config("kbid", lcc)

config = np.config("kbid")
config = await np.config("kbid")

assert config.resource_labelers_models is None
assert config.ner_model == "multilingual"
Expand Down
9 changes: 5 additions & 4 deletions nua/e2e/regional/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
from nuclia.sdk.predict import NucliaPredict
from nuclia.sdk.predict import AsyncNucliaPredict

from regional.models import ALL_LLMS


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model", ALL_LLMS)
def test_llm_generate(nua_config, model):
np = NucliaPredict()
generated = np.generate("Which is the capital of Catalonia?", model=model)
async def test_llm_generate(nua_config, model):
np = AsyncNucliaPredict()
generated = await np.generate("Which is the capital of Catalonia?", model=model)

Check failure on line 11 in nua/e2e/regional/test_llm_generate.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_llm_generate.test_llm_generate[europe-1.stashify.cloud-chatgpt-vision]

pydantic_core._pydantic_core.ValidationError: 1 validation error for GenerativeChunk chunk Field required [type=missing, input_value={'detail': 'Unknown LLM e... Unknown API exception'}, input_type=dict] For further information visit https://errors.pydantic.dev/2.10/v/missing
Raw output
nua_config = 'europe-1.stashify.cloud', model = 'chatgpt-vision'

    @pytest.mark.asyncio_cooperative
    @pytest.mark.parametrize("model", ALL_LLMS)
    async def test_llm_generate(nua_config, model):
        np = AsyncNucliaPredict()
>       generated = await np.generate("Which is the capital of Catalonia?", model=model)

nua/e2e/regional/test_llm_generate.py:11: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/nuclia/decorators.py:136: in async_wrapper_checkout_nua
    return await func(*args, **kwargs)
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/nuclia/sdk/predict.py:215: in generate
    return await nc.generate(body, model)
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/nuclia/lib/nua.py:523: in generate
    async for chunk in self._stream(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <nuclia.lib.nua.AsyncNuaClient object at 0x7fed082e0ec0>, method = 'POST'
url = 'https://europe-1.stashify.cloud/api/v1/predict/chat?model=chatgpt-vision'
payload = {'chat_history': [], 'citations': False, 'context': [], 'generative_model': None, ...}
timeout = 300

    async def _stream(
        self,
        method: str,
        url: str,
        payload: Optional[dict[Any, Any]] = None,
        timeout: int = 60,
    ) -> AsyncIterator[GenerativeChunk]:
        async with self.stream_client.stream(
            method,
            url,
            json=payload,
            timeout=timeout,
        ) as response:
            async for json_body in response.aiter_lines():
>               yield GenerativeChunk.model_validate_json(json_body)  # type: ignore
E               pydantic_core._pydantic_core.ValidationError: 1 validation error for GenerativeChunk
E               chunk
E                 Field required [type=missing, input_value={'detail': 'Unknown LLM e... Unknown API exception'}, input_type=dict]
E                   For further information visit https://errors.pydantic.dev/2.10/v/missing

/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/nuclia/lib/nua.py:433: ValidationError
assert "Barcelona" in generated.answer
9 changes: 5 additions & 4 deletions nua/e2e/regional/test_llm_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from nuclia.lib.nua_responses import ChatModel, UserPrompt
from nuclia.sdk.predict import NucliaPredict
from nuclia.sdk.predict import AsyncNucliaPredict

from regional.models import LLM_WITH_JSON_OUTPUT_SUPPORT

Expand Down Expand Up @@ -34,10 +34,11 @@
TEXT = """"Many football players have existed. Messi is by far the greatest. Messi was born in Rosario, 24th of June 1987"""


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model_name", LLM_WITH_JSON_OUTPUT_SUPPORT)
def test_llm_json(nua_config, model_name):
np = NucliaPredict()
results = np.generate(
async def test_llm_json(nua_config, model_name):
np = AsyncNucliaPredict()
results = await np.generate(
text=ChatModel(
question="",
retrieval=False,
Expand Down
Loading

0 comments on commit b82ad62

Please sign in to comment.