Skip to content

Commit

Permalink
[CZID-8454] Add tests for file queries (#40)
Browse files Browse the repository at this point in the history
* Rearrange test layout.

* Add test coverage for file queries.
  • Loading branch information
jgadling authored Sep 1, 2023
1 parent d7dee2f commit 9472537
Show file tree
Hide file tree
Showing 18 changed files with 656 additions and 187 deletions.
89 changes: 89 additions & 0 deletions entities/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
import typing
from typing import Optional

import pytest_asyncio
from cerbos.sdk.model import Principal
from database.connect import AsyncDB
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request

from api.core.deps import get_auth_principal, get_db_session, get_engine, require_auth_principal
from api.main import get_app


class GQLTestClient:
def __init__(self, http_client: AsyncClient):
self.http_client = http_client

# Utility function for making GQL HTTP queries
async def query(
self,
query: str,
user_id: Optional[int] = None,
member_projects: Optional[list[int]] = None,
admin_projects: Optional[list[int]] = None,
):
if not user_id:
user_id = 111
if not admin_projects:
admin_projects = []
if not member_projects:
member_projects = []
gql_headers = {
"content-type": "application/json",
"accept": "application/json",
"user_id": str(user_id),
"member_projects": json.dumps(member_projects),
"admin_projects": json.dumps(admin_projects),
}
result = await self.http_client.post("/graphql", json={"query": query}, headers=gql_headers)
return result.json()


@pytest_asyncio.fixture()
async def gql_client(http_client: AsyncClient) -> GQLTestClient:
client = GQLTestClient(http_client)
return client


async def patched_authprincipal(request: Request) -> Principal:
user_id = request.headers.get("user_id")
if not user_id:
raise Exception("user_id not found in request headers")
principal = Principal(
user_id,
roles=["user"],
attr={
"user_id": int(user_id),
"member_projects": json.loads(request.headers.get("member_projects", "[]")),
"admin_projects": json.loads(request.headers.get("admin_projects", "[]")),
},
)
return principal


@pytest_asyncio.fixture()
async def api(
async_db: AsyncDB,
) -> FastAPI:
async def patched_session() -> typing.AsyncGenerator[AsyncSession, None]:
session = async_db.session()
try:
yield session
finally:
await session.close()

api = get_app()
api.dependency_overrides[get_engine] = lambda: async_db
api.dependency_overrides[get_db_session] = patched_session
api.dependency_overrides[require_auth_principal] = patched_authprincipal
api.dependency_overrides[get_auth_principal] = patched_authprincipal
return api


@pytest_asyncio.fixture()
async def http_client(api: FastAPI) -> AsyncClient:
return AsyncClient(app=api, base_url="http://test")
2 changes: 1 addition & 1 deletion entities/api/core/gql_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def generate_strawberry_arguments(action, sql_model, gql_type):
continue

# Get GQL field
field = gql_type._type_definition.get_field(sql_column)
field = gql_type.__strawberry_definition__.get_field(sql_column)
if field:
# When updating an entity, only entity ID is required
is_optional_field = action == CERBOS_ACTION_UPDATE and field.name != "entity_id"
Expand Down
14 changes: 4 additions & 10 deletions entities/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,18 @@ class EntityInterface:
pass


@strawberry_sqlalchemy_mapper.type(db.Entity)
class Entity:
@strawberry_sqlalchemy_mapper.type(db.File)
class File:
pass


@strawberry_sqlalchemy_mapper.type(db.Sample)
class Sample:
class Sample(EntityInterface):
pass


@strawberry_sqlalchemy_mapper.type(db.SequencingRead)
class SequencingRead:
pass


@strawberry_sqlalchemy_mapper.type(db.File)
class File:
class SequencingRead(EntityInterface):
pass


Expand All @@ -59,7 +54,6 @@ class File:

@strawberry.type
class Query:
entity: typing.List[Sample] = get_base_loader(db.Entity, EntityInterface)
samples: typing.List[Sample] = get_base_loader(db.Sample, Sample)
sequencing_reads: typing.List[SequencingRead] = get_base_loader(db.SequencingRead, SequencingRead)
files: typing.List[File] = get_file_loader(db.File, File)
Expand Down
32 changes: 29 additions & 3 deletions entities/api/schema.graphql
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
type Entity implements EntityInterface {
id: UUID!
type: String!
producingRunId: Int
ownerUserId: Int!
collectionId: Int!
}

interface EntityInterface {
id: UUID!
type: String!
Expand All @@ -6,17 +14,33 @@ interface EntityInterface {
collectionId: Int!
}

type File {
id: UUID!
entityId: UUID
entityFieldName: String
status: String!
protocol: String!
namespace: String!
path: String!
fileFormat: String!
compressionType: String!
size: Int!
entity: EntityInterface
}

type Mutation {
createSample(name: String!, location: String!, collectionId: Int!): Sample!
createSequencingRead(nucleotide: String!, sequence: String!, protocol: String!, sampleId: UUID!, collectionId: Int!): SequencingRead!
createSequencingRead(nucleotide: String!, sequence: String!, protocol: String!, sequenceFileId: UUID, sampleId: UUID!, collectionId: Int!): SequencingRead!
updateSample(entityId: UUID!, name: String!, location: String!): Sample!
}

type Query {
samples(id: UUID = null): [Sample!]!
sequencingReads(id: UUID = null): [SequencingRead!]!
files(id: UUID = null): [File!]!
}

type Sample {
type Sample implements EntityInterface {
id: UUID!
type: String!
producingRunId: Int
Expand All @@ -28,7 +52,7 @@ type Sample {
sequencingReads: SequencingReadConnection!
}

type SequencingRead {
type SequencingRead implements EntityInterface {
id: UUID!
type: String!
producingRunId: Int
Expand All @@ -38,7 +62,9 @@ type SequencingRead {
nucleotide: String!
sequence: String!
protocol: String!
sequenceFileId: UUID
sampleId: UUID!
sequenceFile: File
sample: Sample!
}

Expand Down
Loading

0 comments on commit 9472537

Please sign in to comment.