diff --git a/entities/api/core/gql_loaders.py b/entities/api/core/gql_loaders.py index 353ec45d..94364468 100644 --- a/entities/api/core/gql_loaders.py +++ b/entities/api/core/gql_loaders.py @@ -1,22 +1,23 @@ -import uuid import typing -import strawberry -import database.models as db +import uuid from collections import defaultdict -from typing import Any, Mapping, Tuple, Optional -from sqlalchemy import ColumnElement, ColumnExpressionArgument, tuple_ -from sqlalchemy.orm import RelationshipProperty -from sqlalchemy.ext.asyncio import AsyncSession -from strawberry.type import StrawberryType -from strawberry.dataloader import DataLoader -from strawberry.arguments import StrawberryArgument +from typing import Any, Mapping, Optional, Tuple + +import database.models as db +import strawberry +from api.core.deps import get_cerbos_client, get_db_session, require_auth_principal +from api.core.strawberry_extensions import DependencyExtension from cerbos.sdk.client import CerbosClient from cerbos.sdk.model import Principal, Resource, ResourceDesc -from fastapi import Depends from database.models import Base +from fastapi import Depends +from sqlalchemy import ColumnElement, ColumnExpressionArgument, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import RelationshipProperty +from strawberry.arguments import StrawberryArgument +from strawberry.dataloader import DataLoader +from strawberry.type import StrawberryType from thirdparty.cerbos_sqlalchemy.query import get_query -from api.core.deps import require_auth_principal, get_cerbos_client, get_db_session -from api.core.strawberry_extensions import DependencyExtension CERBOS_ACTION_VIEW = "view" CERBOS_ACTION_CREATE = "create" @@ -50,6 +51,33 @@ async def get_entities( return result.scalars().all() +async def get_files( + model: db.File, + session: AsyncSession, + cerbos_client: CerbosClient, + principal: Principal, + filters: Optional[list[ColumnExpressionArgument]] = [], + order_by: Optional[list[tuple[ColumnElement[Any], ...]]] = [], +): + rd = ResourceDesc(model.__tablename__) + plan = cerbos_client.plan_resources(CERBOS_ACTION_VIEW, principal, rd) + query = get_query( + plan, + model, + { + "request.resource.attr.owner_user_id": db.Entity.owner_user_id, + "request.resource.attr.collection_id": db.Entity.collection_id, + }, + [(db.Entity, model.entity_id == db.Entity.id)], + ) + if filters: + query = query.filter(*filters) + if order_by: + query = query.order_by(*order_by) + result = await session.execute(query) + return result.scalars().all() + + class EntityLoader: """ Creates DataLoader instances on-the-fly for SQLAlchemy relationships @@ -72,6 +100,11 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader: except KeyError: related_model = relationship.entity.entity + if type(related_model) == db.File: + load_method = get_files + else: + load_method = get_entities + async def load_fn(keys: list[Tuple]) -> list[Any]: if not relationship.local_remote_pairs: raise Exception("invalid relationship") @@ -80,7 +113,7 @@ async def load_fn(keys: list[Tuple]) -> list[Any]: if relationship.order_by: order_by = [relationship.order_by] db_session = self.engine.session() - rows = await get_entities( + rows = await load_method( related_model, db_session, self.cerbos_client, @@ -125,6 +158,22 @@ async def resolve_entity( return resolve_entity +def get_file_loader(sql_model, gql_type): + @strawberry.field(extensions=[DependencyExtension()]) + async def resolve_file( + id: typing.Optional[uuid.UUID] = None, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + ) -> list[Base]: + filters = [] + if id: + filters.append(sql_model.entity_id == id) + return await get_files(sql_model, session, cerbos_client, principal, filters, []) + + return resolve_file + + def get_base_creator(sql_model, gql_type): @strawberry.mutation(extensions=[DependencyExtension()]) async def create( diff --git a/entities/api/main.py b/entities/api/main.py index a0bbe189..f0b30be2 100644 --- a/entities/api/main.py +++ b/entities/api/main.py @@ -1,15 +1,23 @@ import typing + +import database.models as db import strawberry import uvicorn -import database.models as db from cerbos.sdk.client import CerbosClient from cerbos.sdk.model import Principal +from database.connect import AsyncDB from fastapi import Depends, FastAPI from strawberry.fastapi import GraphQLRouter -from database.connect import AsyncDB from thirdparty.strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper -from api.core.gql_loaders import EntityLoader, get_base_loader, get_base_updater, get_base_creator + from api.core.deps import get_auth_principal, get_cerbos_client, get_engine +from api.core.gql_loaders import ( + EntityLoader, + get_base_creator, + get_base_loader, + get_base_updater, + get_file_loader, +) from api.core.settings import APISettings ###################### @@ -24,6 +32,11 @@ class EntityInterface: pass +@strawberry_sqlalchemy_mapper.type(db.Entity) +class Entity: + pass + + @strawberry_sqlalchemy_mapper.type(db.Sample) class Sample: pass @@ -34,6 +47,11 @@ class SequencingRead: pass +@strawberry_sqlalchemy_mapper.type(db.File) +class File: + pass + + # -------------------- # Queries # -------------------- @@ -41,8 +59,10 @@ class SequencingRead: @strawberry.type class Query: + entity: typing.List[Sample] = get_base_loader(db.Entity, EntityInterface) samples: typing.List[Sample] = get_base_loader(db.Sample, Sample) sequencing_reads: typing.List[SequencingRead] = get_base_loader(db.SequencingRead, SequencingRead) + files: typing.List[File] = get_file_loader(db.File, File) # -------------------- diff --git a/entities/cerbos/policies/file.yaml b/entities/cerbos/policies/file.yaml new file mode 100644 index 00000000..bf31abf6 --- /dev/null +++ b/entities/cerbos/policies/file.yaml @@ -0,0 +1,20 @@ +# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json +apiVersion: api.cerbos.dev/v1 +resourcePolicy: + version: "default" + importDerivedRoles: + - common_roles + resource: "file" + rules: + - actions: ['*'] + effect: EFFECT_ALLOW + derivedRoles: + - project_member + + - actions: ['download'] + effect: EFFECT_ALLOW + derivedRoles: + - owner + schemas: + principalSchema: + ref: cerbos:///principal.json diff --git a/entities/database/models/__init__.py b/entities/database/models/__init__.py index ff4263e2..5563f260 100644 --- a/entities/database/models/__init__.py +++ b/entities/database/models/__init__.py @@ -3,5 +3,6 @@ from database.models.base import Base, meta, Entity # noqa: F401 from database.models.samples import Sample, SequencingRead # noqa: F401 +from database.models.files import File # noqa: F401 # configure_mappers() diff --git a/entities/database/models/base.py b/entities/database/models/base.py index 774b9f36..0e91c575 100644 --- a/entities/database/models/base.py +++ b/entities/database/models/base.py @@ -1,8 +1,14 @@ -from sqlalchemy.orm import DeclarativeBase, Mapped -from sqlalchemy import MetaData, Column, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy import MetaData, Column, Integer, String import uuid6 import uuid from sqlalchemy.dialects.postgresql import UUID +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from database.models.files import File +else: + File = "File" meta = MetaData( naming_convention={ @@ -27,9 +33,10 @@ class Entity(Base): # The "type" field distinguishes between subclasses (e.g. sample, # sequencing_read, etc) - type: Mapped[str] + type: Mapped[str] = mapped_column(String, nullable=False) # Attributes for each entity - producing_run_id = Column(Integer, nullable=True) - owner_user_id = Column(Integer, nullable=False) - collection_id = Column(Integer, nullable=False) + producing_run_id: Mapped[uuid.UUID] = mapped_column(Integer, nullable=True) + owner_user_id: Mapped[int] = mapped_column(Integer, nullable=False) + collection_id: Mapped[int] = mapped_column(Integer, nullable=False) + files: Mapped[list[File]] = relationship(File, back_populates="entity", foreign_keys="File.entity_id") diff --git a/entities/database/models/files.py b/entities/database/models/files.py new file mode 100644 index 00000000..8f7a0ff4 --- /dev/null +++ b/entities/database/models/files.py @@ -0,0 +1,31 @@ +import uuid + +import uuid6 +from database.models.base import Base, Entity +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import mapped_column, Mapped, relationship + + +class File(Base): + __tablename__ = "file" + + id: Column[uuid.UUID] = Column(UUID(as_uuid=True), primary_key=True, default=uuid6.uuid7) + + # TODO - the relationship between Entities and Files is currently being + # configured in both directions: entities have {fieldname}_file_id fields, + # *and* files have {entity_id, field_name} fields to map back to + # entities. We'll probably deprecate one side of this relationship in + # the future, but I'm not sure yet which one is going to prove to be + # more useful. + entity_id = mapped_column(ForeignKey("entity.id")) + entity_field_name = Column(String) + entity: Mapped[Entity] = relationship(Entity, back_populates="files", foreign_keys=entity_id) + + status = mapped_column(String, nullable=False) + protocol = mapped_column(String, nullable=False) + namespace = Column(String, nullable=False) + path = Column(String, nullable=False) + file_format = Column(String, nullable=False) + compression_type = Column(String, nullable=False) + size = Column(Integer, nullable=False) diff --git a/entities/database/models/samples.py b/entities/database/models/samples.py index b64c6108..5e33ac9e 100644 --- a/entities/database/models/samples.py +++ b/entities/database/models/samples.py @@ -1,16 +1,23 @@ from database.models.base import Entity -from sqlalchemy import Column, ForeignKey, String +from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID +from typing import TYPE_CHECKING +import uuid + +if TYPE_CHECKING: + from database.models.files import File +else: + File = "File" class Sample(Entity): __tablename__ = "sample" __mapper_args__ = {"polymorphic_identity": __tablename__} - entity_id = mapped_column(ForeignKey("entity.id"), primary_key=True) - name = Column(String, nullable=False) - location = Column(String, nullable=False) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + name: Mapped[str] = mapped_column(String, nullable=False) + location: Mapped[str] = mapped_column(String, nullable=False) sequencing_reads: Mapped[list["SequencingRead"]] = relationship( "SequencingRead", @@ -23,10 +30,12 @@ class SequencingRead(Entity): __tablename__ = "sequencing_read" __mapper_args__ = {"polymorphic_identity": __tablename__} - entity_id = mapped_column(ForeignKey("entity.id"), primary_key=True) - nucleotide = Column(String, nullable=False) - sequence = Column(String, nullable=False) - protocol = Column(String, nullable=False) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + nucleotide: Mapped[str] = mapped_column(String, nullable=False) + sequence: Mapped[str] = mapped_column(String, nullable=False) + protocol: Mapped[str] = mapped_column(String, nullable=False) + sequence_file_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("file.id"), nullable=True) + sequence_file: Mapped[File] = relationship("File", foreign_keys=sequence_file_id) - sample_id = Column(UUID, ForeignKey("sample.entity_id"), nullable=False) + sample_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sample.entity_id"), nullable=False) sample: Mapped[Sample] = relationship("Sample", back_populates="sequencing_reads", foreign_keys=sample_id) diff --git a/entities/database_migrations/versions/20230828_131743_create_files_table.py b/entities/database_migrations/versions/20230828_131743_create_files_table.py new file mode 100644 index 00000000..2c8200eb --- /dev/null +++ b/entities/database_migrations/versions/20230828_131743_create_files_table.py @@ -0,0 +1,44 @@ +"""create files table + +Create Date: 2023-08-28 20:17:44.359982 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "20230828_131743" +down_revision = "20230809_181634" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "file", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("entity_id", sa.UUID(), nullable=True), + sa.Column("entity_field_name", sa.String(), nullable=True), + sa.Column("protocol", sa.String(), nullable=False), + sa.Column("namespace", sa.String(), nullable=False), + sa.Column("path", sa.String(), nullable=False), + sa.Column("file_format", sa.String(), nullable=False), + sa.Column("compression_type", sa.String(), nullable=False), + sa.Column("size", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["entity_id"], ["entity.id"], name=op.f("fk_file_entity_id_entity")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_file")), + ) + op.add_column("sequencing_read", sa.Column("sequence_file_id", sa.UUID(), nullable=True)) + op.create_foreign_key( + op.f("fk_sequencing_read_sequence_file_id_file"), "sequencing_read", "file", ["sequence_file_id"], ["id"] + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f("fk_sequencing_read_sequence_file_id_file"), "sequencing_read", type_="foreignkey") + op.drop_column("sequencing_read", "sequence_file_id") + op.drop_table("file") + # ### end Alembic commands ### diff --git a/entities/database_migrations/versions/20230828_141240_add_file_status.py b/entities/database_migrations/versions/20230828_141240_add_file_status.py new file mode 100644 index 00000000..bd2b1483 --- /dev/null +++ b/entities/database_migrations/versions/20230828_141240_add_file_status.py @@ -0,0 +1,25 @@ +"""add file status + +Create Date: 2023-08-28 21:12:41.448938 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "20230828_141240" +down_revision = "20230828_131743" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("file", sa.Column("status", sa.String(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("file", "status") + # ### end Alembic commands ### diff --git a/entities/test_infra/factories.py b/entities/test_infra/factories.py index 5366504f..9d9a105f 100644 --- a/entities/test_infra/factories.py +++ b/entities/test_infra/factories.py @@ -1,10 +1,9 @@ import factory +from database.models import File, Sample, SequencingRead from factory import Faker, fuzzy from faker_biology.bioseq import Bioseq from faker_biology.physiology import Organ -from database.models import Sample, SequencingRead - Faker.add_provider(Bioseq) Faker.add_provider(Organ) @@ -34,6 +33,27 @@ class Meta: sqlalchemy_session = None # workaround for a bug in factoryboy +class FileFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + sqlalchemy_session_factory = SessionStorage.get_session + sqlalchemy_session_persistence = "commit" + sqlalchemy_session = None # workaround for a bug in factoryboy + model = File + # What fields do we try to match to existing db rows to determine whether we + # should create a new row or not? + sqlalchemy_get_or_create = ("namespace", "path") + # exclude = ("sample_organ",) + + status = fuzzy.FuzzyChoice(["awaiting_upload", "error", "success"]) + protocol = fuzzy.FuzzyChoice(["S3", "GCP"]) + namespace = fuzzy.FuzzyChoice(["bucket_1", "bucket_2"]) + # path = factory.LazyAttribute(lambda o: {factory.Faker("file_path", depth=3, extension=o.file_format)}) + path = factory.Faker("file_path", depth=3) + file_format = fuzzy.FuzzyChoice(["fasta", "fastq", "bam"]) + compression_type = fuzzy.FuzzyChoice(["gz", "bz2", "xz"]) + size = fuzzy.FuzzyInteger(1024, 1024 * 1024 * 1024) # Between 1k and 1G + + class SampleFactory(CommonFactory): class Meta: sqlalchemy_session = None # workaround for a bug in factoryboy @@ -73,3 +93,10 @@ class Meta: sequence = fuzzy.FuzzyText(length=100, chars="ACTG") # sequence = factory.Faker('dna', length=100) protocol = fuzzy.FuzzyChoice(["TARGETED", "MNGS", "MSSPE"]) + + sequencing_read_file = factory.RelatedFactory( + FileFactory, + factory_related_name="entity", + entity_field_name="sequence_file", + file_format="fastq", + ) diff --git a/entities/thirdparty/strawberry_sqlalchemy_mapper/mapper.py b/entities/thirdparty/strawberry_sqlalchemy_mapper/mapper.py index 9937508f..361ca469 100644 --- a/entities/thirdparty/strawberry_sqlalchemy_mapper/mapper.py +++ b/entities/thirdparty/strawberry_sqlalchemy_mapper/mapper.py @@ -602,6 +602,7 @@ def convert(type_: Any) -> Any: if make_interface: mapped_type = strawberry.interface(type_) + self.mapped_interfaces[type_.__name__] = mapped_type elif use_federation: mapped_type = strawberry.federation.type(type_) else: