Skip to content

Commit

Permalink
Iniital support for file metadata. (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgadling authored Aug 29, 2023
1 parent d8f4f76 commit c37ad46
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 34 deletions.
77 changes: 63 additions & 14 deletions entities/api/core/gql_loaders.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import uuid
import typing
import strawberry
import database.models as db
import uuid
from collections import defaultdict
from typing import Any, Mapping, Tuple, Optional
from sqlalchemy import ColumnElement, ColumnExpressionArgument, tuple_
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.type import StrawberryType
from strawberry.dataloader import DataLoader
from strawberry.arguments import StrawberryArgument
from typing import Any, Mapping, Optional, Tuple

import database.models as db
import strawberry
from api.core.deps import get_cerbos_client, get_db_session, require_auth_principal
from api.core.strawberry_extensions import DependencyExtension
from cerbos.sdk.client import CerbosClient
from cerbos.sdk.model import Principal, Resource, ResourceDesc
from fastapi import Depends
from database.models import Base
from fastapi import Depends
from sqlalchemy import ColumnElement, ColumnExpressionArgument, tuple_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import RelationshipProperty
from strawberry.arguments import StrawberryArgument
from strawberry.dataloader import DataLoader
from strawberry.type import StrawberryType
from thirdparty.cerbos_sqlalchemy.query import get_query
from api.core.deps import require_auth_principal, get_cerbos_client, get_db_session
from api.core.strawberry_extensions import DependencyExtension

CERBOS_ACTION_VIEW = "view"
CERBOS_ACTION_CREATE = "create"
Expand Down Expand Up @@ -50,6 +51,33 @@ async def get_entities(
return result.scalars().all()


async def get_files(
model: db.File,
session: AsyncSession,
cerbos_client: CerbosClient,
principal: Principal,
filters: Optional[list[ColumnExpressionArgument]] = [],
order_by: Optional[list[tuple[ColumnElement[Any], ...]]] = [],
):
rd = ResourceDesc(model.__tablename__)
plan = cerbos_client.plan_resources(CERBOS_ACTION_VIEW, principal, rd)
query = get_query(
plan,
model,
{
"request.resource.attr.owner_user_id": db.Entity.owner_user_id,
"request.resource.attr.collection_id": db.Entity.collection_id,
},
[(db.Entity, model.entity_id == db.Entity.id)],
)
if filters:
query = query.filter(*filters)
if order_by:
query = query.order_by(*order_by)
result = await session.execute(query)
return result.scalars().all()


class EntityLoader:
"""
Creates DataLoader instances on-the-fly for SQLAlchemy relationships
Expand All @@ -72,6 +100,11 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
except KeyError:
related_model = relationship.entity.entity

if type(related_model) == db.File:
load_method = get_files
else:
load_method = get_entities

async def load_fn(keys: list[Tuple]) -> list[Any]:
if not relationship.local_remote_pairs:
raise Exception("invalid relationship")
Expand All @@ -80,7 +113,7 @@ async def load_fn(keys: list[Tuple]) -> list[Any]:
if relationship.order_by:
order_by = [relationship.order_by]
db_session = self.engine.session()
rows = await get_entities(
rows = await load_method(
related_model,
db_session,
self.cerbos_client,
Expand Down Expand Up @@ -125,6 +158,22 @@ async def resolve_entity(
return resolve_entity


def get_file_loader(sql_model, gql_type):
@strawberry.field(extensions=[DependencyExtension()])
async def resolve_file(
id: typing.Optional[uuid.UUID] = None,
session: AsyncSession = Depends(get_db_session, use_cache=False),
cerbos_client: CerbosClient = Depends(get_cerbos_client),
principal: Principal = Depends(require_auth_principal),
) -> list[Base]:
filters = []
if id:
filters.append(sql_model.entity_id == id)
return await get_files(sql_model, session, cerbos_client, principal, filters, [])

return resolve_file


def get_base_creator(sql_model, gql_type):
@strawberry.mutation(extensions=[DependencyExtension()])
async def create(
Expand Down
26 changes: 23 additions & 3 deletions entities/api/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import typing

import database.models as db
import strawberry
import uvicorn
import database.models as db
from cerbos.sdk.client import CerbosClient
from cerbos.sdk.model import Principal
from database.connect import AsyncDB
from fastapi import Depends, FastAPI
from strawberry.fastapi import GraphQLRouter
from database.connect import AsyncDB
from thirdparty.strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from api.core.gql_loaders import EntityLoader, get_base_loader, get_base_updater, get_base_creator

from api.core.deps import get_auth_principal, get_cerbos_client, get_engine
from api.core.gql_loaders import (
EntityLoader,
get_base_creator,
get_base_loader,
get_base_updater,
get_file_loader,
)
from api.core.settings import APISettings

######################
Expand All @@ -24,6 +32,11 @@ class EntityInterface:
pass


@strawberry_sqlalchemy_mapper.type(db.Entity)
class Entity:
pass


@strawberry_sqlalchemy_mapper.type(db.Sample)
class Sample:
pass
Expand All @@ -34,15 +47,22 @@ class SequencingRead:
pass


@strawberry_sqlalchemy_mapper.type(db.File)
class File:
pass


# --------------------
# Queries
# --------------------


@strawberry.type
class Query:
entity: typing.List[Sample] = get_base_loader(db.Entity, EntityInterface)
samples: typing.List[Sample] = get_base_loader(db.Sample, Sample)
sequencing_reads: typing.List[SequencingRead] = get_base_loader(db.SequencingRead, SequencingRead)
files: typing.List[File] = get_file_loader(db.File, File)


# --------------------
Expand Down
20 changes: 20 additions & 0 deletions entities/cerbos/policies/file.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json
apiVersion: api.cerbos.dev/v1
resourcePolicy:
version: "default"
importDerivedRoles:
- common_roles
resource: "file"
rules:
- actions: ['*']
effect: EFFECT_ALLOW
derivedRoles:
- project_member

- actions: ['download']
effect: EFFECT_ALLOW
derivedRoles:
- owner
schemas:
principalSchema:
ref: cerbos:///principal.json
1 change: 1 addition & 0 deletions entities/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

from database.models.base import Base, meta, Entity # noqa: F401
from database.models.samples import Sample, SequencingRead # noqa: F401
from database.models.files import File # noqa: F401

# configure_mappers()
19 changes: 13 additions & 6 deletions entities/database/models/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from sqlalchemy.orm import DeclarativeBase, Mapped
from sqlalchemy import MetaData, Column, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy import MetaData, Column, Integer, String
import uuid6
import uuid
from sqlalchemy.dialects.postgresql import UUID
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from database.models.files import File
else:
File = "File"

meta = MetaData(
naming_convention={
Expand All @@ -27,9 +33,10 @@ class Entity(Base):

# The "type" field distinguishes between subclasses (e.g. sample,
# sequencing_read, etc)
type: Mapped[str]
type: Mapped[str] = mapped_column(String, nullable=False)

# Attributes for each entity
producing_run_id = Column(Integer, nullable=True)
owner_user_id = Column(Integer, nullable=False)
collection_id = Column(Integer, nullable=False)
producing_run_id: Mapped[uuid.UUID] = mapped_column(Integer, nullable=True)
owner_user_id: Mapped[int] = mapped_column(Integer, nullable=False)
collection_id: Mapped[int] = mapped_column(Integer, nullable=False)
files: Mapped[list[File]] = relationship(File, back_populates="entity", foreign_keys="File.entity_id")
31 changes: 31 additions & 0 deletions entities/database/models/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import uuid

import uuid6
from database.models.base import Base, Entity
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import mapped_column, Mapped, relationship


class File(Base):
__tablename__ = "file"

id: Column[uuid.UUID] = Column(UUID(as_uuid=True), primary_key=True, default=uuid6.uuid7)

# TODO - the relationship between Entities and Files is currently being
# configured in both directions: entities have {fieldname}_file_id fields,
# *and* files have {entity_id, field_name} fields to map back to
# entities. We'll probably deprecate one side of this relationship in
# the future, but I'm not sure yet which one is going to prove to be
# more useful.
entity_id = mapped_column(ForeignKey("entity.id"))
entity_field_name = Column(String)
entity: Mapped[Entity] = relationship(Entity, back_populates="files", foreign_keys=entity_id)

status = mapped_column(String, nullable=False)
protocol = mapped_column(String, nullable=False)
namespace = Column(String, nullable=False)
path = Column(String, nullable=False)
file_format = Column(String, nullable=False)
compression_type = Column(String, nullable=False)
size = Column(Integer, nullable=False)
27 changes: 18 additions & 9 deletions entities/database/models/samples.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from database.models.base import Entity
from sqlalchemy import Column, ForeignKey, String
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from typing import TYPE_CHECKING
import uuid

if TYPE_CHECKING:
from database.models.files import File
else:
File = "File"


class Sample(Entity):
__tablename__ = "sample"
__mapper_args__ = {"polymorphic_identity": __tablename__}

entity_id = mapped_column(ForeignKey("entity.id"), primary_key=True)
name = Column(String, nullable=False)
location = Column(String, nullable=False)
entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
location: Mapped[str] = mapped_column(String, nullable=False)

sequencing_reads: Mapped[list["SequencingRead"]] = relationship(
"SequencingRead",
Expand All @@ -23,10 +30,12 @@ class SequencingRead(Entity):
__tablename__ = "sequencing_read"
__mapper_args__ = {"polymorphic_identity": __tablename__}

entity_id = mapped_column(ForeignKey("entity.id"), primary_key=True)
nucleotide = Column(String, nullable=False)
sequence = Column(String, nullable=False)
protocol = Column(String, nullable=False)
entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True)
nucleotide: Mapped[str] = mapped_column(String, nullable=False)
sequence: Mapped[str] = mapped_column(String, nullable=False)
protocol: Mapped[str] = mapped_column(String, nullable=False)
sequence_file_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("file.id"), nullable=True)
sequence_file: Mapped[File] = relationship("File", foreign_keys=sequence_file_id)

sample_id = Column(UUID, ForeignKey("sample.entity_id"), nullable=False)
sample_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sample.entity_id"), nullable=False)
sample: Mapped[Sample] = relationship("Sample", back_populates="sequencing_reads", foreign_keys=sample_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""create files table
Create Date: 2023-08-28 20:17:44.359982
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "20230828_131743"
down_revision = "20230809_181634"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"file",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("entity_id", sa.UUID(), nullable=True),
sa.Column("entity_field_name", sa.String(), nullable=True),
sa.Column("protocol", sa.String(), nullable=False),
sa.Column("namespace", sa.String(), nullable=False),
sa.Column("path", sa.String(), nullable=False),
sa.Column("file_format", sa.String(), nullable=False),
sa.Column("compression_type", sa.String(), nullable=False),
sa.Column("size", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["entity_id"], ["entity.id"], name=op.f("fk_file_entity_id_entity")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_file")),
)
op.add_column("sequencing_read", sa.Column("sequence_file_id", sa.UUID(), nullable=True))
op.create_foreign_key(
op.f("fk_sequencing_read_sequence_file_id_file"), "sequencing_read", "file", ["sequence_file_id"], ["id"]
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(op.f("fk_sequencing_read_sequence_file_id_file"), "sequencing_read", type_="foreignkey")
op.drop_column("sequencing_read", "sequence_file_id")
op.drop_table("file")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""add file status
Create Date: 2023-08-28 21:12:41.448938
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "20230828_141240"
down_revision = "20230828_131743"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("file", sa.Column("status", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("file", "status")
# ### end Alembic commands ###
Loading

0 comments on commit c37ad46

Please sign in to comment.