Skip to content

Commit

Permalink
Add and use upload_documents to minimize repeated code in repeated …
Browse files Browse the repository at this point in the history
…tasks
  • Loading branch information
smokestacklightnin committed Feb 1, 2025
1 parent 565ccab commit 3f0000f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 62 deletions.
74 changes: 12 additions & 62 deletions tests/deploy/api/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import contextlib

from ragna.deploy import Config
from tests.deploy.api.utils import upload_documents
from tests.deploy.utils import make_api_client


def test_get_documents(tmp_local_root):
config = Config(local_root=tmp_local_root)

needs_more_of = ["reverb", "cowbell"]

document_root = config.local_root / "documents"
Expand All @@ -21,37 +19,16 @@ def test_get_documents(tmp_local_root):
with make_api_client(
config=Config(), ignore_unavailable_components=False
) as client:
documents = (
client.post(
"/api/documents",
json=[{"name": document_path.name} for document_path in document_paths],
)
.raise_for_status()
.json()
)

with contextlib.ExitStack() as stack:
files = [
stack.enter_context(open(document_path, "rb"))
for document_path in document_paths
]
client.put(
"/api/documents",
files=[
("documents", (document["id"], file))
for document, file in zip(documents, files)
],
)

documents = upload_documents(client=client, document_paths=document_paths)
response = client.get("/api/documents").raise_for_status()

# Sort the items in case they are retrieved in different orders
def sorting_key(d):
return d["id"]
# Sort the items in case they are retrieved in different orders
def sorting_key(d):
return d["id"]

assert sorted(documents, key=sorting_key) == sorted(
response.json(), key=sorting_key
)
assert sorted(documents, key=sorting_key) == sorted(
response.json(), key=sorting_key
)


def test_get_document(tmp_local_root):
Expand All @@ -66,24 +43,10 @@ def test_get_document(tmp_local_root):
with make_api_client(
config=Config(), ignore_unavailable_components=False
) as client:
document = (
client.post(
"/api/documents",
json=[{"name": document_path.name}],
)
.raise_for_status()
.json()[0]
)

with open(document_path, "rb") as file:
client.put(
"/api/documents",
files=[("documents", (document["id"], file))],
)

document = upload_documents(client=client, document_paths=[document_path])[0]
response = client.get(f"/api/documents/{document['id']}").raise_for_status()

assert document == response.json()
assert document == response.json()


def test_get_document_content(tmp_local_root):
Expand All @@ -98,24 +61,11 @@ def test_get_document_content(tmp_local_root):
with make_api_client(
config=Config(), ignore_unavailable_components=False
) as client:
document = (
client.post(
"/api/documents",
json=[{"name": document_path.name}],
)
.raise_for_status()
.json()[0]
)

with open(document_path, "rb") as file:
client.put(
"/api/documents",
files=[("documents", (document["id"], file))],
)
document = upload_documents(client=client, document_paths=[document_path])[0]

with client.stream(
"GET", f"/api/documents/{document['id']}/content"
) as response:
received_lines = list(response.iter_lines())

assert received_lines == ["Needs more reverb"]
assert received_lines == ["Needs more reverb"]
27 changes: 27 additions & 0 deletions tests/deploy/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import contextlib


def upload_documents(*, client, document_paths):
documents = (
client.post(
"/api/documents",
json=[{"name": document_path.name} for document_path in document_paths],
)
.raise_for_status()
.json()
)

with contextlib.ExitStack() as stack:
files = [
stack.enter_context(open(document_path, "rb"))
for document_path in document_paths
]
client.put(
"/api/documents",
files=[
("documents", (document["id"], file))
for document, file in zip(documents, files)
],
)

return documents

0 comments on commit 3f0000f

Please sign in to comment.