Skip to content

Commit

Permalink
stop using sdk config file for nua tests (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 authored Dec 20, 2024
1 parent 126dadb commit 9ccd52c
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 67 deletions.
30 changes: 12 additions & 18 deletions nua/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import nuclia
import pytest
from nuclia.config import reset_config_file, set_config_file
from nuclia.sdk.auth import NucliaAuth
from nuclia.lib.nua import AsyncNuaClient
from dataclasses import dataclass
from typing import AsyncGenerator

logger = logging.getLogger("e2e")

Expand All @@ -22,17 +23,17 @@ class Tokens:
PROD_ACCOUNT_ID = "5cec111b-ea23-4b0c-a82a-d1a666dd1fd2"

TOKENS: dict[str, Tokens] = {
"europe-1.stashify.cloud": Tokens(
"https://europe-1.stashify.cloud": Tokens(
nua_key=os.environ.get("TEST_EUROPE1_STASHIFY_NUA"),
pat_key=os.environ.get("STAGE_PERMAMENT_ACCOUNT_OWNER_PAT_TOKEN"),
account_id=os.environ.get("TEST_EUROPE1_STASHIFY_ACCOUNT", STAGE_ACCOUNT_ID),
),
"europe-1.nuclia.cloud": Tokens(
"https://europe-1.nuclia.cloud": Tokens(
nua_key=os.environ.get("TEST_EUROPE1_NUCLIA_NUA"),
pat_key=os.environ.get("PROD_PERMAMENT_ACCOUNT_OWNER_PAT_TOKEN"),
account_id=PROD_ACCOUNT_ID,
),
"aws-us-east-2-1.nuclia.cloud": Tokens(
"https://aws-us-east-2-1.nuclia.cloud": Tokens(
nua_key=os.environ.get("TEST_AWS_US_EAST_2_1_NUCLIA_NUA"),
pat_key=os.environ.get("PROD_PERMAMENT_ACCOUNT_OWNER_PAT_TOKEN"),
account_id=PROD_ACCOUNT_ID,
Expand All @@ -43,12 +44,12 @@ class Tokens:
@pytest.fixture(
scope="function",
params=[
"europe-1.stashify.cloud",
"europe-1.nuclia.cloud",
"aws-us-east-2-1.nuclia.cloud",
"https://europe-1.stashify.cloud",
"https://europe-1.nuclia.cloud",
"https://aws-us-east-2-1.nuclia.cloud",
],
)
async def nua_config(request):
async def nua_config(request) -> AsyncGenerator[AsyncNuaClient, None]:
if (
os.environ.get("TEST_ENV") == "stage" and "stashify.cloud" not in request.param # noqa
): # noqa
Expand All @@ -63,13 +64,6 @@ async def nua_config(request):
assert token.pat_key
assert token.account_id

with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b"{}")
temp_file.flush()
set_config_file(temp_file.name)
nuclia_auth = NucliaAuth()
client_id = nuclia_auth.nua(token.nua_key)
assert client_id
nuclia_auth._config.set_default_nua(client_id)
yield request.param
reset_config_file()
yield AsyncNuaClient(
region=request.param, account=token.account_id, token=token.nua_key
)
16 changes: 10 additions & 6 deletions nua/e2e/regional/test_da_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import base64
from typing import Optional
import aiohttp
from nuclia.lib.nua import AsyncNuaClient


@dataclass
Expand Down Expand Up @@ -510,17 +511,20 @@ def validate_labeler_output_text_block(msg: BrokerMessage):

@pytest.fixture
async def tmp_nua_key(
nua_config: str, aiohttp_client: AsyncGenerator[aiohttp.ClientSession, None]
nua_config: AsyncNuaClient,
aiohttp_client: AsyncGenerator[aiohttp.ClientSession, None],
) -> AsyncGenerator[str, None]:
account_id = TOKENS[nua_config].account_id
account_id = TOKENS[nua_config.region].account_id
pat_client_generator = aiohttp_client(
base_url=f"https://{nua_config}", pat_key=TOKENS[nua_config].pat_key, timeout=30
base_url=nua_config.url,
pat_key=TOKENS[nua_config.url].pat_key,
timeout=30,
)
pat_client = await anext(pat_client_generator)
nua_client_id, nua_key = await create_nua_key(
client=pat_client,
account_id=account_id,
title=f"E2E DA AGENTS - {nua_config}",
title=f"E2E DA AGENTS - {nua_config.region}",
)
try:
yield nua_key
Expand All @@ -535,7 +539,7 @@ async def tmp_nua_key(
"test_input", DA_TEST_INPUTS, ids=lambda test_input: test_input.parameters.name
)
async def test_da_agent_tasks(
nua_config: str,
nua_config: AsyncNuaClient,
aiohttp_client: AsyncGenerator[aiohttp.ClientSession, None],
tmp_nua_key: str,
test_input: TestInput,
Expand All @@ -545,7 +549,7 @@ async def test_da_agent_tasks(
start_time = asyncio.get_event_loop().time()
try:
nua_client_generator = aiohttp_client(
base_url=f"https://{nua_config}", nua_key=tmp_nua_key, timeout=30
base_url=nua_config.url, nua_key=tmp_nua_key, timeout=30
)
nua_client = await anext(nua_client_generator)

Expand Down
5 changes: 3 additions & 2 deletions nua/e2e/regional/test_llm_chat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from nuclia.lib.nua_responses import ChatModel, UserPrompt
from nuclia.sdk.predict import AsyncNucliaPredict

from nuclia.lib.nua import AsyncNuaClient
from regional.models import ALL_LLMS
import pytest


@pytest.mark.asyncio_cooperative
async def test_llm_chat(nua_config):
async def test_llm_chat(nua_config: AsyncNuaClient):
# Validate that other features such as
# * citations
# * custom prompts
Expand All @@ -29,6 +29,7 @@ async def test_llm_chat(nua_config):
generated = await np.generate(
text=chat_model,
model=ALL_LLMS[0],
nc=nua_config,
)
# Check that system + user prompt worked
assert generated.answer.startswith("ITALIAN")
Expand Down
11 changes: 6 additions & 5 deletions nua/e2e/regional/test_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import pytest
from nuclia.exceptions import NuaAPIException
from nuclia.lib.nua_responses import LearningConfigurationCreation
from nuclia.lib.nua import AsyncNuaClient
from nuclia.sdk.predict import AsyncNucliaPredict


@pytest.mark.asyncio_cooperative
async def test_llm_config_nua(nua_config):
async def test_llm_config_nua(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()

try:
await np.del_config("kbid")
await np.del_config("kbid", nc=nua_config)
except NuaAPIException:
pass

with pytest.raises(NuaAPIException):
config = await np.config("kbid")
config = await np.config("kbid", nc=nua_config)

lcc = LearningConfigurationCreation()
await np.set_config("kbid", lcc)
await np.set_config("kbid", lcc, nc=nua_config)

config = await np.config("kbid")
config = await np.config("kbid", nc=nua_config)

assert config.resource_labelers_models is None
assert config.ner_model == "multilingual"
Expand Down
8 changes: 5 additions & 3 deletions nua/e2e/regional/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from nuclia.sdk.predict import AsyncNucliaPredict

from nuclia.lib.nua import AsyncNuaClient
from regional.models import ALL_LLMS


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model", ALL_LLMS)
async def test_llm_generate(nua_config, model):
async def test_llm_generate(nua_config: AsyncNuaClient, model):
np = AsyncNucliaPredict()
generated = await np.generate("Which is the capital of Catalonia?", model=model)
generated = await np.generate(
"Which is the capital of Catalonia?", model=model, nc=nua_config
)
assert "Barcelona" in generated.answer
5 changes: 3 additions & 2 deletions nua/e2e/regional/test_llm_json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from nuclia.lib.nua_responses import ChatModel, UserPrompt
from nuclia.sdk.predict import AsyncNucliaPredict

from nuclia.lib.nua import AsyncNuaClient
from regional.models import LLM_WITH_JSON_OUTPUT_SUPPORT

SCHEMA = {
Expand Down Expand Up @@ -36,7 +36,7 @@

@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model_name", LLM_WITH_JSON_OUTPUT_SUPPORT)
async def test_llm_json(nua_config, model_name):
async def test_llm_json(nua_config: AsyncNuaClient, model_name):
np = AsyncNucliaPredict()
results = await np.generate(
text=ChatModel(
Expand All @@ -47,5 +47,6 @@ async def test_llm_json(nua_config, model_name):
json_schema=SCHEMA,
),
model=model_name,
nc=nua_config,
)
assert "SPORTS" in results.object["document_type"]
5 changes: 3 additions & 2 deletions nua/e2e/regional/test_llm_rag.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
from nuclia.sdk.predict import AsyncNucliaPredict

from nuclia.lib.nua import AsyncNuaClient
from regional.models import ALL_LLMS


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model", ALL_LLMS)
async def test_llm_rag(nua_config, model):
async def test_llm_rag(nua_config: AsyncNuaClient, model):
np = AsyncNucliaPredict()
generated = await np.rag(
question="Which is the CEO of Nuclia?",
Expand All @@ -15,6 +15,7 @@ async def test_llm_rag(nua_config, model):
"Eudald Camprubí is CEO at the same company as Ramon Navarro",
],
model=model,
nc=nua_config,
)

assert "Eudald" in generated.answer
9 changes: 5 additions & 4 deletions nua/e2e/regional/test_llm_schema.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from nuclia.sdk.predict import AsyncNucliaPredict
import pytest
from nuclia.lib.nua import AsyncNuaClient


@pytest.mark.asyncio_cooperative
async def test_llm_schema_nua(nua_config):
async def test_llm_schema_nua(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
config = await np.schema()
config = await np.schema(nc=nua_config)

assert len(config.ner_model.options) == 1
assert len(config.generative_model.options) >= 5


@pytest.mark.asyncio_cooperative
async def test_llm_schema_kbid(nua_config):
async def test_llm_schema_kbid(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
config = await np.schema("fake_kbid")
config = await np.schema("fake_kbid", nc=nua_config)
assert len(config.ner_model.options) == 1
assert len(config.generative_model.options) >= 5
13 changes: 7 additions & 6 deletions nua/e2e/regional/test_llm_summarize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from nuclia.sdk.predict import AsyncNucliaPredict
import pytest
from nuclia.lib.nua import AsyncNuaClient

DATA = {
"barcelona": "Barcelona (pronunciat en català central, [bərsəˈlonə]) és una ciutat i metròpoli a la costa mediterrània de la península Ibèrica. És la capital de Catalunya,[1] així com de la comarca del Barcelonès i de la província de Barcelona, i la segona ciutat en població i pes econòmic de la península Ibèrica,[2][3] després de Madrid. El municipi creix sobre una plana encaixada entre la serralada Litoral, el mar Mediterrani, el riu Besòs i la muntanya de Montjuïc. La ciutat acull les seus de les institucions d'autogovern més importants de la Generalitat de Catalunya: el Parlament de Catalunya, el President i el Govern de la Generalitat. Pel fet d'haver estat capital del Comtat de Barcelona, rep sovint el sobrenom de Ciutat Comtal. També, com que ha estat la ciutat més important del Principat de Catalunya des d'època medieval, rep sovint el sobrenom o títol de cap i casal.[4]", # noqa
Expand Down Expand Up @@ -27,25 +28,25 @@


@pytest.mark.asyncio_cooperative
async def test_summarize_chatgpt(nua_config):
async def test_summarize_chatgpt(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
embed = await np.summarize(DATA, model="chatgpt4o")
embed = await np.summarize(DATA, model="chatgpt4o", nc=nua_config)
assert "Manresa" in embed.summary
assert "Barcelona" in embed.summary


@pytest.mark.asyncio_cooperative
async def test_summarize_azure_chatgpt(nua_config):
async def test_summarize_azure_chatgpt(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
embed = await np.summarize(DATA, model="chatgpt-azure-4o")
embed = await np.summarize(DATA, model="chatgpt-azure-4o", nc=nua_config)
assert "Manresa" in embed.summary
assert "Barcelona" in embed.summary


@pytest.mark.asyncio_cooperative
async def test_summarize_claude(nua_config):
async def test_summarize_claude(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
embed = await np.summarize(DATA_COFFEE, model="claude-3-fast")
embed = await np.summarize(DATA_COFFEE, model="claude-3-fast", nc=nua_config)
# changed to partial summaries since anthropic is not consistent in the global summary at all
assert "flat white" in embed.resources["Flat white"].summary.lower()
assert "macchiato" in embed.resources["Macchiato"].summary.lower()
24 changes: 14 additions & 10 deletions nua/e2e/regional/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from regional.models import ALL_ENCODERS, ALL_LLMS
from nuclia_models.predict.remi import RemiRequest
from nuclia.lib.nua import AsyncNuaClient


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model", ALL_ENCODERS.keys())
async def test_predict_sentence(nua_config, model):
async def test_predict_sentence(nua_config: AsyncNuaClient, model):
np = AsyncNucliaPredict()
embed = await np.sentence(text="This is my text", model=model)
embed = await np.sentence(text="This is my text", model=model, nc=nua_config)
assert embed.time > 0
# Deprecated field (data)
assert len(embed.data) == ALL_ENCODERS[model]
Expand All @@ -20,9 +21,9 @@ async def test_predict_sentence(nua_config, model):


@pytest.mark.asyncio_cooperative
async def test_predict_query(nua_config):
async def test_predict_query(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
embed = await np.query(text="I love Barcelona")
embed = await np.query(text="I love Barcelona", nc=nua_config)
# Semantic
assert embed.semantic_threshold > 0
assert len(embed.sentence.data) > 128
Expand All @@ -39,9 +40,9 @@ async def test_predict_query(nua_config):


@pytest.mark.asyncio_cooperative
async def test_predict_tokens(nua_config):
async def test_predict_tokens(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
embed = await np.tokens(text="I love Barcelona")
embed = await np.tokens(text="I love Barcelona", nc=nua_config)
assert embed.tokens[0].text == "Barcelona"
assert embed.tokens[0].start == 7
assert embed.tokens[0].end == 16
Expand All @@ -50,16 +51,18 @@ async def test_predict_tokens(nua_config):

@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("model", ALL_LLMS)
async def test_predict_rephrase(nua_config, model):
async def test_predict_rephrase(nua_config: AsyncNuaClient, model):
# Check that rephrase is working for all models
np = AsyncNucliaPredict()
# TODO: Test that custom rephrase prompt works once SDK supports it
rephrased = await np.rephrase(question="Barcelona best coffe", model=model)
rephrased = await np.rephrase(
question="Barcelona best coffe", model=model, nc=nua_config
)
assert rephrased != "Barcelona best coffe" and rephrased != ""


@pytest.mark.asyncio_cooperative
async def test_predict_remi(nua_config):
async def test_predict_remi(nua_config: AsyncNuaClient):
np = AsyncNucliaPredict()
results = await np.remi(
RemiRequest(
Expand All @@ -70,7 +73,8 @@ async def test_predict_remi(nua_config):
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
],
)
),
nc=nua_config,
)
assert results.answer_relevance.score >= 4

Expand Down
Loading

0 comments on commit 9ccd52c

Please sign in to comment.