Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CZID-8458] Endpoint to validate files and mark them successfully uploaded #62

Merged
merged 9 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions entities/.dockerignore → .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
**/__pycache__
**/*.egg-info
**/.cache/
**/.vscode/

.tox
.coverage
.pytest_cache
.mypy_cache
**/.tox
**/.coverage
**/.pytest_cache/
**/.mypy_cache/
**/.ruff_cache/
**/.vscode/
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ clean:
$(MAKE) -C entities local-clean
$(MAKE) -C workflows local-clean
docker compose down
rm .moto_recording
rm -f .moto_recording
4 changes: 2 additions & 2 deletions bin/init_moto.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/bin/bash
#!/usr/bin/bash

# Script to initialize moto server; runs inside the motoserver container

# Launch moto server
moto_server --host 0.0.0.0 --port $MOTO_PORT &

# Initialize data once server is ready
sleep 1 && curl -X POST "http://localhost:${MOTO_PORT}/moto-api/recorder/replay-recording"
sleep 1 && curl -X POST "http://motoserver.czidnet:${MOTO_PORT}/moto-api/recorder/replay-recording"

# Go back to moto server
wait
8 changes: 6 additions & 2 deletions bin/seed_moto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ export AWS_SECRET_ACCESS_KEY=test
export AWS_REGION=us-west-2

# Create dev bucket but don't error if it already exists
bucket=local-bucket
$aws s3api head-bucket --bucket $bucket 2>/dev/null || $aws s3 mb s3://$bucket
bucket_1=local-bucket
bucket_2=remote-bucket
$aws s3api head-bucket --bucket $bucket_1 2>/dev/null || $aws s3 mb s3://$bucket_1
$aws s3api head-bucket --bucket $bucket_2 2>/dev/null || $aws s3 mb s3://$bucket_2
$aws s3 cp entities/test_infra/fixtures/test1.fastq s3://$bucket_1/anything/back/among/population.wav
$aws s3 cp entities/test_infra/fixtures/test1.fastq s3://$bucket_2/remember/offer/radio/result.webm
9 changes: 6 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ services:
# To use it from the CLI: aws --endpoint-url=http://localhost:4000 s3 ls
# To reset all services without restarting the container: curl -X POST http://localhost:4000/moto-api/reset
motoserver:
image: motoserver/moto:4.2.2
image: motoserver/moto:4.2.3
ports:
- "4000:4000"
environment:
- MOTO_PORT=4000
- MOTO_ENABLE_RECORDING=True
- S3_IGNORE_SUBDOMAIN_BUCKETNAME=True
- MOTO_S3_CUSTOM_ENDPOINTS=http://motoserver.czidnet:4000
volumes:
- .moto_recording:/moto/moto_recording
- ./bin/init_moto.sh:/moto/init_moto.sh
entrypoint: ["bash", "/moto/init_moto.sh"]
- ./bin:/moto/bin
entrypoint: []
command: ["/moto/bin/init_moto.sh"]

networks:
default:
Expand Down
27 changes: 26 additions & 1 deletion entities/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@
import typing
from typing import Optional

import boto3
import pytest_asyncio
from cerbos.sdk.model import Principal
from platformics.database.connect import AsyncDB
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
from moto import mock_s3
from mypy_boto3_s3.client import S3Client

from platformics.api.core.deps import get_auth_principal, get_db_session, get_engine, require_auth_principal
from platformics.api.core.deps import (
get_auth_principal,
get_db_session,
get_engine,
require_auth_principal,
get_s3_client,
)
from api.main import get_app


Expand Down Expand Up @@ -43,6 +52,21 @@ async def query(
return result.json()


@pytest_asyncio.fixture()
async def moto_client() -> typing.AsyncGenerator[S3Client, None]:
mocks3 = mock_s3()
mocks3.start()
res = boto3.resource("s3")
res.create_bucket(Bucket="local-bucket")
res.create_bucket(Bucket="remote-bucket")
yield boto3.client("s3")
mocks3.stop()


async def patched_s3_client() -> typing.AsyncGenerator[S3Client, None]:
yield boto3.client("s3")


@pytest_asyncio.fixture()
async def gql_client(http_client: AsyncClient) -> GQLTestClient:
client = GQLTestClient(http_client)
Expand Down Expand Up @@ -81,6 +105,7 @@ async def patched_session() -> typing.AsyncGenerator[AsyncSession, None]:
api.dependency_overrides[get_db_session] = patched_session
api.dependency_overrides[require_auth_principal] = patched_authprincipal
api.dependency_overrides[get_auth_principal] = patched_authprincipal
api.dependency_overrides[get_s3_client] = patched_s3_client
return api


Expand Down
36 changes: 35 additions & 1 deletion entities/api/files.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import typing
import database.models as db
import strawberry
import uuid
from fastapi import Depends
from mypy_boto3_s3.client import S3Client
from platformics.api.core.deps import get_s3_client
from platformics.api.core.strawberry_extensions import DependencyExtension
from api.strawberry import strawberry_sqlalchemy_mapper

from cerbos.sdk.client import CerbosClient
from cerbos.sdk.model import Principal
from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal
from sqlalchemy.ext.asyncio import AsyncSession
from platformics.security.authorization import CerbosAction, get_resource_query
from files.format_handlers import get_validator


@strawberry.type
class SignedURL:
Expand All @@ -26,9 +34,35 @@ def download_link(
) -> typing.Optional[SignedURL]:
if not self.path: # type: ignore
return None
key = self.path # type: ignore
key = self.path.lstrip("/") # type: ignore
bucket_name = self.namespace # type: ignore
url = s3_client.generate_presigned_url(
ClientMethod="get_object", Params={"Bucket": bucket_name, "Key": key}, ExpiresIn=expiration
)
return SignedURL(url=url, protocol="https", method="get", expiration=expiration)


@strawberry.mutation(extensions=[DependencyExtension()])
async def mark_upload_complete(
file_id: uuid.UUID,
principal: Principal = Depends(require_auth_principal),
cerbos_client: CerbosClient = Depends(get_cerbos_client),
session: AsyncSession = Depends(get_db_session, use_cache=False),
s3_client: S3Client = Depends(get_s3_client),
) -> db.File:
query = get_resource_query(principal, cerbos_client, CerbosAction.UPDATE, db.File)
query = query.filter(db.File.id == file_id)
file = (await session.execute(query)).scalars().one()
if not file:
raise Exception("NOT FOUND!") # TODO: How do we raise sane errors in our api?

validator = get_validator(file.file_format)
try:
file_size = validator.validate(s3_client, file.namespace, file.path.lstrip("/"))
except: # noqa
file.status = db.FileStatus.FAILED
else:
file.status = db.FileStatus.SUCCESS
file.size = file_size

return file
3 changes: 2 additions & 1 deletion entities/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from platformics.database.connect import AsyncDB
from strawberry.fastapi import GraphQLRouter
from api.strawberry import strawberry_sqlalchemy_mapper
from api.files import File
from api.files import File, mark_upload_complete

######################
# Strawberry-GraphQL #
Expand Down Expand Up @@ -77,6 +77,7 @@ class Mutation:

# file_stuff
get_upload_url: Sample = get_base_updater(db.Sample, Sample) # type: ignore
mark_upload_complete: File = mark_upload_complete


# --------------------
Expand Down
82 changes: 82 additions & 0 deletions entities/api/tests/test_file_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import pytest
from api.conftest import GQLTestClient
from platformics.database.connect import SyncDB
from test_infra import factories as fa
from mypy_boto3_s3.client import S3Client
from database.models import File
import sqlalchemy as sa


# Test that we can mark a file upload as complete
@pytest.mark.asyncio
async def test_file_validation(
sync_db: SyncDB,
gql_client: GQLTestClient,
moto_client: S3Client,
) -> None:
user1_id = 12345
project1_id = 123

# Create mock data
with sync_db.session() as session:
fa.SessionStorage.set_session(session)
fa.SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id)
fa.FileFactory.update_file_ids()
session.commit()
file = session.execute(sa.select(File)).scalars().one()

valid_fastq_file = "test_infra/fixtures/test1.fastq"
moto_client.put_object(Bucket=file.namespace, Key=file.path.lstrip("/"), Body=open(valid_fastq_file, "rb"))

# Mark upload complete
query = f"""
mutation MyMutation {{
markUploadComplete(fileId: "{file.id}") {{
id
namespace
size
status
}}
}}
"""
res = await gql_client.query(query, member_projects=[project1_id])
fileinfo = res["data"]["markUploadComplete"]
assert fileinfo["status"] == "SUCCESS"
assert fileinfo["size"] == os.stat(valid_fastq_file).st_size


# Test that invalid fastq's don't work
@pytest.mark.asyncio
async def test_invalid_fastq(
sync_db: SyncDB,
gql_client: GQLTestClient,
moto_client: S3Client,
) -> None:
user1_id = 12345
project1_id = 123

# Create mock data
with sync_db.session() as session:
fa.SessionStorage.set_session(session)
fa.SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id)
fa.FileFactory.update_file_ids()
session.commit()
file = session.execute(sa.select(File)).scalars().one()

moto_client.put_object(Bucket=file.namespace, Key=file.path.lstrip("/"), Body="this is not a fastq file")

# Mark upload complete
query = f"""
mutation MyMutation {{
markUploadComplete(fileId: "{file.id}") {{
id
namespace
size
status
}}
}}
"""
res = await gql_client.query(query, member_projects=[project1_id])
fileinfo = res["data"]["markUploadComplete"]
assert fileinfo["status"] == "FAILED"
4 changes: 2 additions & 2 deletions entities/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ services:
- DEFAULT_UPLOAD_BUCKET=local-bucket
- BOTO_ENDPOINT_URL=http://motoserver.czidnet:4000
- AWS_REGION=us-west-2
- AWS_ACCESS_KEY_ID=ACCESS_ID
- AWS_SECRET_ACCESS_KEY=ACCESS_KEY
- AWS_ACCESS_KEY_ID=test
- AWS_SECRET_ACCESS_KEY=test
# TODO - these are keypairs for testing only! Do not use in prod!!
- JWK_PUBLIC_KEY_FILE=/czid-platformics/entities/test_infra/fixtures/public_key.pem
- JWK_PRIVATE_KEY_FILE=/czid-platformics/entities/test_infra/fixtures/private_key.pem
Expand Down
40 changes: 40 additions & 0 deletions entities/files/format_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import abstractmethod
from mypy_boto3_s3.client import S3Client
from Bio import SeqIO
from io import StringIO
from typing import Protocol


class FileFormatHandler(Protocol):
@classmethod
@abstractmethod
def validate(cls, client: S3Client, bucket: str, file_path: str) -> int:
raise NotImplementedError

@classmethod
@abstractmethod
def convert_to(cls, client: S3Client, bucket: str, file_path: str, format: dict) -> str:
raise NotImplementedError


class FastqHandler(FileFormatHandler):
@classmethod
def validate(cls, client: S3Client, bucket: str, file_path: str) -> int:
# Overly simplistic validator for fastq filees checks whether the first 1mb of a file are a valid fastq
data = client.get_object(Bucket=bucket, Key=file_path, Range="bytes=0-1000000")["Body"].read()
records = 0
for _ in SeqIO.parse(StringIO(data.decode("ascii")), "fastq"):
records += 1
assert records > 0
return client.head_object(Bucket=bucket, Key=file_path)["ContentLength"]

@classmethod
def convert_to(cls, client: S3Client, bucket: str, file_path: str, format: dict) -> str:
return ""


def get_validator(format: str) -> type[FileFormatHandler]:
if format == "fastq":
return FastqHandler
else:
raise Exception("Unknown file format")
Loading