Skip to content

Commit

Permalink
[CZID-8370] tests for nested type queries (#36)
Browse files Browse the repository at this point in the history
* working on tests for nested queries.

* Fix hanging session problems.

* Add unit test for nested queries

* Fix lint.
  • Loading branch information
jgadling authored Aug 25, 2023
1 parent c0bd742 commit 759e4b5
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 10 deletions.
8 changes: 5 additions & 3 deletions entities/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ rm-pycache: ## remove all __pycache__ files (run if encountering issues with pyc
find . -name '__pycache__' | xargs rm -rf

### DOCKER LOCAL DEV #########################################
.PHONY: local-init
local-init:
$(docker_compose) up -d
.PHONY: local-setup
local-setup: ## Load db tables and seed data
while [ -z "$$($(docker_compose) exec -T postgres psql $(LOCAL_DB_CONN_STRING) -c 'select 1')" ]; do echo "waiting for db to start..."; sleep 1; done;
$(docker_compose) run entities alembic upgrade head
$(MAKE) local-seed

.PHONY: local-init
local-init: local-start local-setup ## Setup a working local-dev environment

.PHONY: debugger
debugger: ## Attach to the gql service (useful for pdb)
docker attach $$($(docker_compose) ps | grep entities | cut -d ' ' -f 1 | head -n 1)
Expand Down
4 changes: 3 additions & 1 deletion entities/api/core/gql_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ async def load_fn(keys: list[Tuple]) -> list[Any]:
order_by: list[tuple[ColumnElement[Any], ...]] = []
if relationship.order_by:
order_by = [relationship.order_by]
db_session = self.engine.session()
rows = await get_entities(
related_model,
self.engine.session(),
db_session,
self.cerbos_client,
self.principal,
filters, # type: ignore
order_by,
)
await db_session.close()

def group_by_remote_key(row: Any) -> Tuple:
if not relationship.local_remote_pairs:
Expand Down
10 changes: 4 additions & 6 deletions entities/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pytest
import pytest_asyncio
from api.core.deps import get_db_session, require_auth_principal
from api.main import get_app, get_context
from api.core.deps import get_db_session, require_auth_principal, get_auth_principal, get_engine
from api.main import get_app
from database.connect import AsyncDB, SyncDB, init_async_db, init_sync_db
from database.models.base import Base
from fastapi import FastAPI
Expand Down Expand Up @@ -96,13 +96,11 @@ async def patched_session() -> typing.AsyncGenerator[AsyncSession, None]:
finally:
await session.close()

def patched_context():
return {}

api = get_app()
api.dependency_overrides[get_engine] = lambda: async_db
api.dependency_overrides[get_db_session] = patched_session
api.dependency_overrides[get_context] = patched_context
api.dependency_overrides[require_auth_principal] = patched_authprincipal
api.dependency_overrides[get_auth_principal] = patched_authprincipal
return api


Expand Down
119 changes: 119 additions & 0 deletions entities/tests/test_nested_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Tests for nested queries + authorization
"""

import pytest
from httpx import AsyncClient
from database.connect import SyncDB
import json
from collections import defaultdict
from test_infra import factories as fa


async def get_gql_results(
http_client: AsyncClient, query: str, project_ids: list[int], user_id: int, query_name: str = "MyQuery"
):
request = {"operationName": query_name, "query": query}
headers = {
"content-type": "application/json",
"accept": "application/json",
"member_projects": json.dumps(project_ids),
"user_id": str(user_id),
}
result = await http_client.post(
"/graphql",
json=request,
headers=headers,
)
output = result.json()
return output["data"]


@pytest.mark.asyncio
async def test_nested_query(
sync_db: SyncDB,
http_client: AsyncClient,
):
# For now, use the hardcoded user_id for tests
user1_id = 111
user2_id = 222
user3_id = 222
project1_id = 888
project2_id = 999

# Create mock data
with sync_db.session() as session:
fa.SessionStorage.set_session(session)
# create some samples with multiple SequencingReads
sa1 = fa.SampleFactory(owner_user_id=user1_id, collection_id=project1_id)
sa2 = fa.SampleFactory(owner_user_id=user3_id, collection_id=project1_id)
sa3 = fa.SampleFactory(owner_user_id=user2_id, collection_id=project2_id)

seq1 = fa.SequencingReadFactory.create_batch(
3, sample=sa1, owner_user_id=sa1.owner_user_id, collection_id=sa1.collection_id
)
seq2 = fa.SequencingReadFactory.create_batch(
2,
sample=sa2,
owner_user_id=sa2.owner_user_id,
collection_id=sa2.collection_id,
)
seq3 = fa.SequencingReadFactory.create_batch(
2,
sample=sa3,
owner_user_id=sa3.owner_user_id,
collection_id=sa3.collection_id,
)

# Fetch samples and nested sequencing reads AND nested samples again!
query = """
query MyQuery {
samples {
id
name
ownerUserId
collectionId
sequencingReads {
edges {
node {
collectionId
ownerUserId
sequence
nucleotide
sample {
id
ownerUserId
collectionId
name
}
}
}
}
}
}
"""

# Make sure user1 can only see samples from project1
results = await get_gql_results(http_client, query, [project1_id], user1_id)
expected_samples_by_owner = {
user1_id: 1,
user2_id: 1,
user3_id: 1,
}
expected_sequences_by_owner = {
user1_id: len(seq1),
user2_id: len(seq3),
user3_id: len(seq2),
}
actual_samples_by_owner: dict[int, int] = defaultdict(int)
actual_sequences_by_owner: dict[int, int] = defaultdict(int)
for sample in results["samples"]:
assert sample["collectionId"] == project1_id
actual_samples_by_owner[sample["ownerUserId"]] += 1
actual_sequences_by_owner[sample["ownerUserId"]] = len(sample["sequencingReads"]["edges"])
assert sample["sequencingReads"]["edges"][0]["node"]["sample"]["id"] == sample["id"]

for userid in expected_sequences_by_owner:
assert actual_sequences_by_owner[userid] == expected_sequences_by_owner[userid]
for userid in expected_samples_by_owner:
assert actual_samples_by_owner[userid] == expected_samples_by_owner[userid]

0 comments on commit 759e4b5

Please sign in to comment.