diff --git a/entities/Makefile b/entities/Makefile index 0af95f62..a8a3cde1 100644 --- a/entities/Makefile +++ b/entities/Makefile @@ -90,7 +90,7 @@ local-gqlschema: ## Export this app's GQL schema. .PHONY: local-token local-token: ## Copy an auth token for this local dev env to the system clipboard - TOKEN=$$($(docker_compose) run entities ./cli/gqlcli.py auth generate-token 111 --project 444:admin --expiration 99999); echo $$TOKEN | pbcopy; echo $$TOKEN + TOKEN=$$($(docker_compose) run entities ./cli/gqlcli.py auth generate-token 111 --project 444:admin --expiration 99999); echo '{"Authorization":"Bearer '$$TOKEN'"}' | tee >(pbcopy) .PHONY: fix-poetry-lock fix-poetry-lock: ## Fix poetry lockfile after merge conflict & repairing pyproject.toml diff --git a/entities/api/core/gql_loaders.py b/entities/api/core/gql_loaders.py index 0889648f..5630f424 100644 --- a/entities/api/core/gql_loaders.py +++ b/entities/api/core/gql_loaders.py @@ -1,30 +1,20 @@ +import uuid +import typing +import strawberry +import database.models as db from collections import defaultdict from typing import Any, Mapping, Tuple, Optional - -from database.models import Base -from sqlalchemy import tuple_ +from sqlalchemy import ColumnElement, ColumnExpressionArgument, tuple_ from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.ext.asyncio import AsyncSession from strawberry.dataloader import DataLoader from cerbos.sdk.client import CerbosClient -from cerbos.sdk.model import Principal, ResourceDesc -from thirdparty.cerbos_sqlalchemy.query import get_query - +from cerbos.sdk.model import Principal, Resource, ResourceDesc from fastapi import Depends - -import typing - -import database.models as db -import strawberry -from sqlalchemy.ext.asyncio import AsyncSession -import uuid - -from api.core.deps import ( - require_auth_principal, - get_cerbos_client, - get_db_session, -) +from database.models import Base +from thirdparty.cerbos_sqlalchemy.query import get_query +from api.core.deps import require_auth_principal, get_cerbos_client, get_db_session from api.core.strawberry_extensions import DependencyExtension -from sqlalchemy import ColumnExpressionArgument, ColumnElement async def get_entities( @@ -32,8 +22,8 @@ async def get_entities( session: AsyncSession, cerbos_client: CerbosClient, principal: Principal, - filters: Optional[list[ColumnExpressionArgument]], - order_by: Optional[list[tuple[ColumnElement[Any], ...]]], + filters: Optional[list[ColumnExpressionArgument]] = [], + order_by: Optional[list[tuple[ColumnElement[Any], ...]]] = [], ): rd = ResourceDesc(model.__tablename__) plan = cerbos_client.plan_resources("view", principal, rd) @@ -54,6 +44,40 @@ async def get_entities( return result.scalars().all() +# Returns function that helps create entities +async def create_entity( + principal: Principal = Depends(require_auth_principal), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + session: AsyncSession = Depends(get_db_session, use_cache=False), +): + async def create(entity_model, gql_type, params): + # Validate that user can create entity in this collection + attr = {"collection_id": params.get("collection_id")} + resource = Resource(id="NEW_ID", kind=entity_model.__tablename__, attr=attr) + if not cerbos_client.is_allowed("create", principal, resource): + raise Exception("Unauthorized") + + # TODO: User must have permissions to the sample + + # Save to DB + params["owner_user_id"] = int(principal.id) + new_entity = entity_model(**params) + session.add(new_entity) + await session.commit() + + # Return GQL object to client (FIXME: is there a better way to convert `new_entity` to `gql_type`?) + params = { + **params, + "id": new_entity.entity_id, + "type": new_entity.type, + "producing_run_id": new_entity.producing_run_id, + "entity_id": new_entity.entity_id, + } + return gql_type(**params) + + return create + + class EntityLoader: """ Creates DataLoader instances on-the-fly for SQLAlchemy relationships diff --git a/entities/api/main.py b/entities/api/main.py index 65c19196..bdc7d858 100644 --- a/entities/api/main.py +++ b/entities/api/main.py @@ -1,18 +1,18 @@ +import uuid import typing - -import database.models as db import strawberry import uvicorn +import database.models as db from cerbos.sdk.client import CerbosClient from cerbos.sdk.model import Principal -from database.connect import AsyncDB from fastapi import Depends, FastAPI from strawberry.fastapi import GraphQLRouter +from database.connect import AsyncDB from thirdparty.strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper - +from api.core.gql_loaders import EntityLoader, get_base_loader, create_entity from api.core.deps import get_auth_principal, get_cerbos_client, get_engine -from api.core.gql_loaders import EntityLoader, get_base_loader from api.core.settings import APISettings +from api.core.strawberry_extensions import DependencyExtension ###################### # Strawberry-GraphQL # @@ -36,12 +36,63 @@ class SequencingRead: pass +# -------------------- +# Queries +# -------------------- + + @strawberry.type class Query: samples: typing.List[Sample] = get_base_loader(db.Sample, Sample) sequencing_reads: typing.List[SequencingRead] = get_base_loader(db.SequencingRead, SequencingRead) +# -------------------- +# Mutations +# -------------------- + + +@strawberry.type +class Mutation: + @strawberry.mutation(extensions=[DependencyExtension()]) + async def create_sample( + self, + name: str, + location: str, + collection_id: int, + create_entity: typing.Callable = Depends(create_entity), + ) -> Sample: + if not name or not location: + raise Exception("Fields cannot be empty") + params = dict(name=name, location=location, collection_id=collection_id) + return await create_entity(entity_model=db.Sample, gql_type=Sample, params=params) + + # FIXME: add auth in gql_loaders.py + @strawberry.mutation(extensions=[DependencyExtension()]) + async def create_sequencing_read( + self, + nucleotide: str, + sequence: str, + protocol: str, + sample_id: uuid.UUID, + collection_id: int, + create_entity: typing.Callable = Depends(create_entity), + ) -> SequencingRead: + params = dict( + nucleotide=nucleotide, + sequence=sequence, + protocol=protocol, + sample_id=sample_id, + collection_id=collection_id, + ) + return await create_entity(entity_model=db.SequencingRead, gql_type=SequencingRead, params=params) + + +# -------------------- +# Initialize app +# -------------------- + + def get_context( engine: AsyncDB = Depends(get_engine), cerbos_client: CerbosClient = Depends(get_cerbos_client), @@ -62,8 +113,7 @@ def get_context( # start server with strawberry server app schema = strawberry.Schema( query=Query, - # mutation=Mutation, - # extensions=extensions, + mutation=Mutation, types=additional_types, ) diff --git a/entities/api/schema.graphql b/entities/api/schema.graphql index d215f80f..bfa67b2f 100644 --- a/entities/api/schema.graphql +++ b/entities/api/schema.graphql @@ -1,30 +1,53 @@ +interface EntityInterface { + id: UUID! + type: String! + producingRunId: Int + ownerUserId: Int! + collectionId: Int! +} + +type Mutation { + createSample(name: String!, location: String!, collectionId: Int!): Sample! + createSequencingRead(nucleotide: String!, sequence: String!, protocol: String!, sampleId: UUID!, collectionId: Int!): SequencingRead! +} + type Query { - getSample(id: ID!): Sample! - getAllSamples: [Sample!]! - getSequencingRead(id: ID!): SequencingRead! - getAllSequencingReads: [SequencingRead!]! + samples(id: UUID = null): [Sample!]! + sequencingReads(id: UUID = null): [SequencingRead!]! } type Sample { - id: Int! + id: UUID! type: String! producingRunId: Int - ownerUserId: Int - entityId: Int! + ownerUserId: Int! + collectionId: Int! + entityId: UUID! name: String! location: String! - sequencingReads: [SequencingRead!]! + sequencingReads: SequencingReadConnection! } type SequencingRead { - id: Int! + id: UUID! type: String! producingRunId: Int - ownerUserId: Int - entityId: Int! + ownerUserId: Int! + collectionId: Int! + entityId: UUID! nucleotide: String! sequence: String! protocol: String! - sampleId: Int! + sampleId: UUID! sample: Sample! } + +type SequencingReadConnection { + edges: [SequencingReadEdge!]! +} + +type SequencingReadEdge { + node: SequencingRead! +} + +scalar UUID diff --git a/entities/api/schema.json b/entities/api/schema.json index 591bae1b..e3adc714 100644 --- a/entities/api/schema.json +++ b/entities/api/schema.json @@ -90,7 +90,9 @@ "name": "specifiedBy" } ], - "mutationType": null, + "mutationType": { + "name": "Mutation" + }, "queryType": { "name": "Query" }, @@ -592,6 +594,148 @@ "name": "Query", "possibleTypes": null }, + { + "enumValues": null, + "fields": [ + { + "args": [ + { + "defaultValue": null, + "name": "name", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "location", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "collectionId", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "Int", + "ofType": null + } + } + } + ], + "name": "createSample", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "OBJECT", + "name": "Sample", + "ofType": null + } + } + }, + { + "args": [ + { + "defaultValue": null, + "name": "nucleotide", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "sequence", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "protocol", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "sampleId", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "UUID", + "ofType": null + } + } + }, + { + "defaultValue": null, + "name": "collectionId", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "Int", + "ofType": null + } + } + } + ], + "name": "createSequencingRead", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "OBJECT", + "name": "SequencingRead", + "ofType": null + } + } + } + ], + "inputFields": null, + "interfaces": [], + "kind": "OBJECT", + "name": "Mutation", + "possibleTypes": null + }, { "enumValues": null, "fields": null, diff --git a/entities/cerbos/policies/sample.yaml b/entities/cerbos/policies/sample.yaml index 84f6ebca..230403dc 100644 --- a/entities/cerbos/policies/sample.yaml +++ b/entities/cerbos/policies/sample.yaml @@ -6,7 +6,7 @@ resourcePolicy: - common_roles resource: "sample" rules: - - actions: ['*'] + - actions: ['view', 'create'] effect: EFFECT_ALLOW derivedRoles: - project_member diff --git a/entities/cli/gql_schema.py b/entities/cli/gql_schema.py index 5c6ce057..6e8a2601 100644 --- a/entities/cli/gql_schema.py +++ b/entities/cli/gql_schema.py @@ -28,13 +28,7 @@ class UUID(sgqlc.types.Scalar): ######################################################################## class EntityInterface(sgqlc.types.Interface): __schema__ = gql_schema - __field_names__ = ( - "id", - "type", - "producing_run_id", - "owner_user_id", - "collection_id", - ) + __field_names__ = ("id", "type", "producing_run_id", "owner_user_id", "collection_id") id = sgqlc.types.Field(sgqlc.types.non_null(UUID), graphql_name="id") type = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name="type") producing_run_id = sgqlc.types.Field(Int, graphql_name="producingRunId") @@ -42,6 +36,41 @@ class EntityInterface(sgqlc.types.Interface): collection_id = sgqlc.types.Field(sgqlc.types.non_null(Int), graphql_name="collectionId") +class Mutation(sgqlc.types.Type): + __schema__ = gql_schema + __field_names__ = ("create_sample", "create_sequencing_read") + create_sample = sgqlc.types.Field( + sgqlc.types.non_null("Sample"), + graphql_name="createSample", + args=sgqlc.types.ArgDict( + ( + ("name", sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name="name", default=None)), + ("location", sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name="location", default=None)), + ( + "collection_id", + sgqlc.types.Arg(sgqlc.types.non_null(Int), graphql_name="collectionId", default=None), + ), + ) + ), + ) + create_sequencing_read = sgqlc.types.Field( + sgqlc.types.non_null("SequencingRead"), + graphql_name="createSequencingRead", + args=sgqlc.types.ArgDict( + ( + ("nucleotide", sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name="nucleotide", default=None)), + ("sequence", sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name="sequence", default=None)), + ("protocol", sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name="protocol", default=None)), + ("sample_id", sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name="sampleId", default=None)), + ( + "collection_id", + sgqlc.types.Arg(sgqlc.types.non_null(Int), graphql_name="collectionId", default=None), + ), + ) + ), + ) + + class Query(sgqlc.types.Type): __schema__ = gql_schema __field_names__ = ("samples", "sequencing_reads") @@ -115,8 +144,7 @@ class SequencingReadConnection(sgqlc.types.Type): __schema__ = gql_schema __field_names__ = ("edges",) edges = sgqlc.types.Field( - sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null("SequencingReadEdge"))), - graphql_name="edges", + sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null("SequencingReadEdge"))), graphql_name="edges" ) @@ -134,5 +162,5 @@ class SequencingReadEdge(sgqlc.types.Type): # Schema Entry Points ######################################################################## gql_schema.query_type = Query -gql_schema.mutation_type = None +gql_schema.mutation_type = Mutation gql_schema.subscription_type = None diff --git a/entities/database/models/base.py b/entities/database/models/base.py index 2fbe8232..774b9f36 100644 --- a/entities/database/models/base.py +++ b/entities/database/models/base.py @@ -29,7 +29,7 @@ class Entity(Base): # sequencing_read, etc) type: Mapped[str] - # Example attributes for every entity (TODO: revisit nullable columns later) + # Attributes for each entity producing_run_id = Column(Integer, nullable=True) owner_user_id = Column(Integer, nullable=False) collection_id = Column(Integer, nullable=False) diff --git a/entities/tests/test_examples_gql.py b/entities/tests/test_examples_gql.py index c035f5b1..cf165de0 100644 --- a/entities/tests/test_examples_gql.py +++ b/entities/tests/test_examples_gql.py @@ -22,24 +22,9 @@ async def test_graphql_query( # Create mock data with sync_db.session() as session: fa.SessionStorage.set_session(session) - fa.SampleFactory.create_batch( - 2, - location="San Francisco, CA", - owner_user_id=user_id, - collection_id=project_id, - ) - fa.SampleFactory.create_batch( - 6, - location="Mountain View, CA", - owner_user_id=user_id, - collection_id=project_id, - ) - fa.SampleFactory.create_batch( - 4, - location="Phoenix, AZ", - owner_user_id=secondary_user_id, - collection_id=9999, - ) + fa.SampleFactory.create_batch(2, location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id) + fa.SampleFactory.create_batch(6, location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id) + fa.SampleFactory.create_batch(4, location="Phoenix, AZ", owner_user_id=secondary_user_id, collection_id=9999) # Fetch all samples query = """ @@ -57,12 +42,41 @@ async def test_graphql_query( "member_projects": json.dumps([project_id]), "user_id": str(user_id), } - result = await http_client.post( - "/graphql", - json=request, - headers=headers, - ) + result = await http_client.post("/graphql", json=request, headers=headers) output = result.json() assert output["data"]["samples"][0]["location"] == "San Francisco, CA" assert output["data"]["samples"][-1]["location"] == "Mountain View, CA" assert len(output["data"]["samples"]) == 8 + + +# Validate that can only create samples in collections the user has access to +@pytest.mark.asyncio +async def test_graphql_create_sample( + http_client: AsyncClient, +): + project_id_allowed = 123 + project_id_not_allowed = 456 + query = """ + mutation CreateASample { + createSample(name: "Test Sample", location: "San Francisco, CA", collectionId: 123) { + id, + location + } + } + """ + request = {"operationName": "CreateASample", "query": query} + + for project_id in [project_id_allowed, project_id_not_allowed]: + headers = { + "content-type": "application/json", + "accept": "application/json", + "member_projects": json.dumps([project_id]), + "user_id": "111", + } + result = await http_client.post("/graphql", json=request, headers=headers) + output = result.json() + if project_id == project_id_allowed: + assert output["data"]["createSample"]["location"] == "San Francisco, CA" + else: + assert output["data"] is None + assert output["errors"][0]["message"] == "Unauthorized"