From 8958858e748259b5dc685e79c561dd04653d1842 Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Wed, 18 Oct 2023 12:12:41 -0400 Subject: [PATCH] begin work adding authz to discovery endpoints + asyncifying --- chord_metadata_service/authz/counts.py | 61 ++- chord_metadata_service/authz/queries.py | 58 +++ chord_metadata_service/patients/api_views.py | 96 +++-- chord_metadata_service/restapi/api_views.py | 426 ++++++++++++------- chord_metadata_service/restapi/utils.py | 155 ++++--- 5 files changed, 520 insertions(+), 276 deletions(-) create mode 100644 chord_metadata_service/authz/queries.py diff --git a/chord_metadata_service/authz/counts.py b/chord_metadata_service/authz/counts.py index 5b97db681..e856824f8 100644 --- a/chord_metadata_service/authz/counts.py +++ b/chord_metadata_service/authz/counts.py @@ -1,11 +1,8 @@ from django.http import HttpRequest +from typing import overload -from .constants import ( - PERMISSION_QUERY_DATA, - PERMISSION_QUERY_PROJECT_LEVEL_COUNTS, - PERMISSION_QUERY_DATASET_LEVEL_COUNTS, -) -from .middleware import authz_middleware +from .constants import PERMISSION_QUERY_PROJECT_LEVEL_COUNTS, PERMISSION_QUERY_DATASET_LEVEL_COUNTS +from .queries import query_permission, can_query_data from .utils import create_resource @@ -22,26 +19,48 @@ def get_counts_permission(dataset_level: bool) -> str: return PERMISSION_QUERY_PROJECT_LEVEL_COUNTS # We don't have a node-level counts permission -async def can_see_counts(request: HttpRequest, resource: dict) -> bool: - return await authz_middleware.async_authz_post(request, "/policy/evaluate", { - "requested_resource": resource, - "required_permissions": [get_counts_permission(resource.get("dataset") is not None)], - })["result"] or ( - # If we don't have a count permission, we may still have a query:data permission (no cascade) - await authz_middleware.async_authz_post(request, "/policy/evaluate", { - "requested_resource": resource, - "required_permissions": [PERMISSION_QUERY_DATA], - })["result"] +@overload +async def can_see_counts(request: HttpRequest, resource: dict, dataset_level: bool) -> bool: + ... + + +@overload +async def can_see_counts(request: HttpRequest, resource: list[dict], dataset_level: bool) -> tuple[bool, ...]: + ... + + +async def can_see_counts( + request: HttpRequest, resource: dict | list[dict], dataset_level: bool +) -> bool | tuple[bool, ...]: + # First, check if we have counts permission on either the project or dataset level, depending on the resource. + # If we don't have a count permission, we may still have a query:data permission (no cascade) which gives us these + # for free. + + return ( + await query_permission(request, resource, get_counts_permission(dataset_level)) + or await can_query_data(request, resource) # or-shortcut means this only runs if it needs to be checked. ) async def has_counts_permission_for_data_types( - request: HttpRequest, project: str, dataset: str, data_types: list[str] + request: HttpRequest, project: str | None, dataset: str | None, data_types: list[str] ) -> list[bool]: - has_permission: bool = await can_see_counts(request, create_resource(project, dataset, None)) + dataset_level: bool = dataset is not None + + has_permission: bool = await can_see_counts( + request, create_resource(project, dataset, None), dataset_level) return [ - # Either we have permission for all (saves many calls) or we have for a specific data type - has_permission or (await can_see_counts(request, create_resource(project, dataset, dt_id))) - for dt_id in data_types + # Either we have permission for all (saves many calls via or-shortcutting) or we have for a specific data type: + has_permission or await can_see_counts(request, create_resource(project, dataset, dt_id), dataset_level) + for dt_id, can_see_counts_for_dt in data_types ] + + +async def has_counts_permission_for_data_types_bulk_resources( + request: HttpRequest, + resource_tuples: tuple[tuple[str | None, str | None], ...], + data_types: list[str], + dataset_level: bool, +): + pass # TODO diff --git a/chord_metadata_service/authz/queries.py b/chord_metadata_service/authz/queries.py new file mode 100644 index 000000000..89bbd1d07 --- /dev/null +++ b/chord_metadata_service/authz/queries.py @@ -0,0 +1,58 @@ +from django.http import HttpRequest +from typing import overload + +from .constants import PERMISSION_QUERY_DATA +from .middleware import authz_middleware +from .utils import create_resource + +__all__ = [ + "query_permission", + "can_query_data", + "has_query_data_permission_for_data_types", +] + + +@overload +async def query_permission(request: HttpRequest, resource: dict, permission: str) -> bool: + ... + + +@overload +async def query_permission(request: HttpRequest, resource: list[dict], permission: str) -> tuple[bool, ...]: + ... + + +async def query_permission( + request: HttpRequest, resource: dict | list[dict], permission: str +) -> bool | tuple[bool, ...]: + return tuple( + await authz_middleware.async_authz_post(request, "/policy/evaluate", { + "requested_resource": resource, + "required_permissions": [permission], + })["result"] + ) + + +@overload +async def can_query_data(request: HttpRequest, resource: dict) -> bool: + ... + + +@overload +async def can_query_data(request: HttpRequest, resource: list[dict]) -> tuple[bool, ...]: + ... + + +async def can_query_data(request: HttpRequest, resource: dict | list[dict]) -> bool | tuple[bool, ...]: + return await query_permission(request, resource, PERMISSION_QUERY_DATA) + + +async def has_query_data_permission_for_data_types( + request: HttpRequest, project: str | None, dataset: str | None, data_types: list[str] +) -> list[bool]: + has_permission: bool = await can_query_data(request, create_resource(project, dataset, None)) + return [ + # Either we have permission for all (saves many calls) or we have for a specific data type + has_permission or (await can_query_data(request, create_resource(project, dataset, dt_id))) + for dt_id in data_types + ] diff --git a/chord_metadata_service/patients/api_views.py b/chord_metadata_service/patients/api_views.py index 200f8066d..e02483a2b 100644 --- a/chord_metadata_service/patients/api_views.py +++ b/chord_metadata_service/patients/api_views.py @@ -1,8 +1,8 @@ import re +from asgiref.sync import async_to_sync from datetime import datetime - -from rest_framework import viewsets, filters, mixins, serializers +from rest_framework import filters, mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.settings import api_settings @@ -22,6 +22,7 @@ from .serializers import IndividualSerializer from .models import Individual from .filters import IndividualFilter +from chord_metadata_service.authz.middleware import authz_middleware from chord_metadata_service.logger import logger from chord_metadata_service.phenopackets.api_views import BIOSAMPLE_PREFETCH, PHENOPACKET_PREFETCH from chord_metadata_service.phenopackets.models import Phenopacket @@ -35,10 +36,11 @@ ) from chord_metadata_service.restapi.pagination import LargeResultsSetPagination, BatchResultsSetPagination from chord_metadata_service.restapi.utils import ( + get_threshold, get_field_options, filter_queryset_field_value, biosample_tissue_stats, - experiment_type_stats + experiment_type_stats, ) from chord_metadata_service.restapi.negociation import FormatInPostContentNegotiation @@ -168,14 +170,16 @@ class PublicListIndividuals(APIView): View to return only count of all individuals after filtering. """ - def filter_queryset(self, queryset): + async def filter_queryset(self, queryset, can_query_data: bool): # Check query parameters validity qp = self.request.query_params - if len(qp) > settings.CONFIG_PUBLIC["rules"]["max_query_parameters"]: + config_public = settings.CONFIG_PUBLIC + + if not can_query_data and len(qp) > config_public["rules"]["max_query_parameters"]: raise ValidationError(f"Wrong number of fields: {len(qp)}") - search_conf = settings.CONFIG_PUBLIC["search"] - field_conf = settings.CONFIG_PUBLIC["fields"] + search_conf = config_public["search"] + field_conf = config_public["fields"] queryable_fields = { f"{f}": field_conf[f] for section in search_conf for f in section["fields"] } @@ -185,7 +189,7 @@ def filter_queryset(self, queryset): raise ValidationError(f"Unsupported field used in query: {field}") field_props = queryable_fields[field] - options = get_field_options(field_props) + options = await get_field_options(field_props, low_counts_censored=not can_query_data) if value not in options \ and not ( # case-insensitive search on categories @@ -204,38 +208,46 @@ def filter_queryset(self, queryset): return queryset - def get(self, request, *args, **kwargs): + # TODO: should be project-scoped + + @async_to_sync + async def get(self, request, *_args, **_kwargs): if not settings.CONFIG_PUBLIC: - return Response(settings.NO_PUBLIC_DATA_AVAILABLE) + authz_middleware.mark_authz_done(request) + return Response(settings.NO_PUBLIC_DATA_AVAILABLE, status=status.HTTP_404_NOT_FOUND) + + # TODO: permissions base_qs = Individual.objects.all() try: - filtered_qs = self.filter_queryset(base_qs) + filtered_qs = await self.filter_queryset(base_qs) except ValidationError as e: - return Response(errors.bad_request_error( - *(e.error_list if hasattr(e, "error_list") else e.error_dict.items()), - )) + return Response( + errors.bad_request_error(*(e.error_list if hasattr(e, "error_list") else e.error_dict.items())), + status=status.HTTP_400_BAD_REQUEST, + ) - qct = filtered_qs.count() + qct = await filtered_qs.count() - if qct <= (threshold := settings.CONFIG_PUBLIC["rules"]["count_threshold"]): + if qct <= (threshold := get_threshold()): # TODO: permissions + authz_middleware.mark_authz_done(request) logger.info( f"Public individuals endpoint recieved query params {request.query_params} which resulted in " f"sub-threshold count: {qct} <= {threshold}") return Response(settings.INSUFFICIENT_DATA_AVAILABLE) - tissues_count, sampled_tissues = biosample_tissue_stats(filtered_qs) - experiments_count, experiment_types = experiment_type_stats(filtered_qs) + tissues_count, sampled_tissues = await biosample_tissue_stats(filtered_qs) + experiments_count, experiment_types = await experiment_type_stats(filtered_qs) return Response({ "count": qct, "biosamples": { "count": tissues_count, - "sampled_tissue": sampled_tissues + "sampled_tissue": sampled_tissues, }, "experiments": { "count": experiments_count, - "experiment_type": experiment_types + "experiment_type": experiment_types, } }) @@ -245,14 +257,16 @@ class BeaconListIndividuals(APIView): View to return lists of individuals filtered using search terms from katsu's config.json. Uncensored equivalent of PublicListIndividuals. """ - def filter_queryset(self, queryset): + async def filter_queryset(self, queryset, can_query_data: bool): # Check query parameters validity qp = self.request.query_params - if len(qp) > settings.CONFIG_PUBLIC["rules"]["max_query_parameters"]: + config_public = settings.CONFIG_PUBLIC + + if not can_query_data and len(qp) > config_public["rules"]["max_query_parameters"]: raise ValidationError(f"Wrong number of fields: {len(qp)}") - search_conf = settings.CONFIG_PUBLIC["search"] - field_conf = settings.CONFIG_PUBLIC["fields"] + search_conf = config_public["search"] + field_conf = config_public["fields"] queryable_fields = { f: field_conf[f] for section in search_conf for f in section["fields"] } @@ -262,7 +276,7 @@ def filter_queryset(self, queryset): raise ValidationError(f"Unsupported field used in query: {field}") field_props = queryable_fields[field] - options = get_field_options(field_props) + options = await get_field_options(field_props, low_counts_censored=not can_query_data) if value not in options \ and not ( # case-insensitive search on categories @@ -281,28 +295,40 @@ def filter_queryset(self, queryset): return queryset - def get(self, request, *args, **kwargs): + @async_to_sync + async def get(self, request, *_args, **_kwargs): if not settings.CONFIG_PUBLIC: - return Response(settings.NO_PUBLIC_DATA_AVAILABLE, status=404) + authz_middleware.mark_authz_done(request) + return Response(settings.NO_PUBLIC_DATA_AVAILABLE, status=status.HTTP_404_NOT_FOUND) + + # Steps for permissions + # - Obtain all datasets + # - Do a bulk request to authz for permissions to see counts for the data types for each... base_qs = Individual.objects.all() + + # TODO: permissions + try: - filtered_qs = self.filter_queryset(base_qs) + filtered_qs = await self.filter_queryset(base_qs) except ValidationError as e: - return Response(errors.bad_request_error( - *(e.error_list if hasattr(e, "error_list") else e.error_dict.items())), status=400) + authz_middleware.mark_authz_done(request) + return Response( + errors.bad_request_error(*(e.error_list if hasattr(e, "error_list") else e.error_dict.items())), + status=status.HTTP_400_BAD_REQUEST, + ) - tissues_count, sampled_tissues = biosample_tissue_stats(filtered_qs) - experiments_count, experiment_types = experiment_type_stats(filtered_qs) + tissues_count, sampled_tissues = await biosample_tissue_stats(filtered_qs) + experiments_count, experiment_types = await experiment_type_stats(filtered_qs) return Response({ - "matches": filtered_qs.values_list("id", flat=True), + "matches": await filtered_qs.values_list("id", flat=True), "biosamples": { "count": tissues_count, - "sampled_tissue": sampled_tissues + "sampled_tissue": sampled_tissues, }, "experiments": { "count": experiments_count, - "experiment_type": experiment_types + "experiment_type": experiment_types, } }) diff --git a/chord_metadata_service/restapi/api_views.py b/chord_metadata_service/restapi/api_views.py index ee5a4ff0a..d43891b72 100644 --- a/chord_metadata_service/restapi/api_views.py +++ b/chord_metadata_service/restapi/api_views.py @@ -1,45 +1,51 @@ +import asyncio import json import logging from collections import Counter +from bento_lib.responses import errors from django.conf import settings +from django.http import HttpRequest from django.views.decorators.cache import cache_page -from rest_framework.permissions import AllowAny -from rest_framework.response import Response +from drf_spectacular.utils import extend_schema, inline_serializer +from rest_framework import serializers, status from rest_framework.decorators import api_view, permission_classes +from rest_framework.response import Response +from typing import TypedDict +from chord_metadata_service.authz.counts import has_counts_permission_for_data_types +from chord_metadata_service.authz.middleware import authz_middleware +from chord_metadata_service.authz.permissions import OverrideOrSuperUserOnly, BentoAllowAny +from chord_metadata_service.authz.queries import has_query_data_permission_for_data_types +from chord_metadata_service.chord import models as chord_models +from chord_metadata_service.chord.data_types import DATA_TYPE_PHENOPACKET, DATA_TYPE_EXPERIMENT +from chord_metadata_service.experiments import models as experiments_models +from chord_metadata_service.logger import logger +from chord_metadata_service.mcode import models as mcode_models +from chord_metadata_service.mcode.api_views import MCODEPACKET_PREFETCH, MCODEPACKET_SELECT +from chord_metadata_service.metadata.service_info import SERVICE_INFO +from chord_metadata_service.patients import models as patients_models +from chord_metadata_service.phenopackets import models as pheno_models +from chord_metadata_service.restapi.models import SchemaType from chord_metadata_service.restapi.utils import ( get_age_numeric_binned, + get_public_model_name_and_field_path, get_field_options, stats_for_field, queryset_stats_for_field, get_categorical_stats, get_date_stats, - get_range_stats + get_range_stats, + PUBLIC_MODEL_NAMES_TO_MODEL, + PUBLIC_MODEL_NAMES_TO_DATA_TYPE, BinWithValue, ) -from chord_metadata_service.authz.permissions import OverrideOrSuperUserOnly -from chord_metadata_service.metadata.service_info import SERVICE_INFO -from chord_metadata_service.chord import models as chord_models -from chord_metadata_service.phenopackets import models as pheno_models -from chord_metadata_service.mcode import models as mcode_models -from chord_metadata_service.patients import models as patients_models -from chord_metadata_service.experiments import models as experiments_models -from chord_metadata_service.mcode.api_views import MCODEPACKET_PREFETCH, MCODEPACKET_SELECT -from chord_metadata_service.restapi.models import SchemaType -from drf_spectacular.utils import extend_schema, inline_serializer -from rest_framework import serializers - -from chord_metadata_service.chord import data_types as dt - -logger = logging.getLogger("restapi_api_views") -logger.setLevel(logging.INFO) OVERVIEW_AGE_BIN_SIZE = 10 @api_view() -@permission_classes([AllowAny]) +@permission_classes([BentoAllowAny]) def service_info(_request): """ get: @@ -63,36 +69,38 @@ def service_info(_request): ) @api_view(["GET"]) @permission_classes([OverrideOrSuperUserOnly]) -def overview(_request): +async def overview(_request): """ get: Overview of all Phenopackets in the database """ - phenopackets_count = pheno_models.Phenopacket.objects.all().count() - biosamples_count = pheno_models.Biosample.objects.all().count() - individuals_count = patients_models.Individual.objects.all().count() - experiments_count = experiments_models.Experiment.objects.all().count() - experiment_results_count = experiments_models.ExperimentResult.objects.all().count() - instruments_count = experiments_models.Instrument.objects.all().count() - phenotypic_features_count = pheno_models.PhenotypicFeature.objects.all().distinct('pftype').count() + + # TODO: parallel + phenopackets_count = await pheno_models.Phenopacket.objects.all().count() + biosamples_count = await pheno_models.Biosample.objects.all().count() + individuals_count = await patients_models.Individual.objects.all().count() + experiments_count = await experiments_models.Experiment.objects.all().count() + experiment_results_count = await experiments_models.ExperimentResult.objects.all().count() + instruments_count = await experiments_models.Instrument.objects.all().count() + phenotypic_features_count = await pheno_models.PhenotypicFeature.objects.all().distinct('pftype').count() # Sex related fields stats are precomputed here and post processed later # to include missing values inferred from the schema - individuals_sex = stats_for_field(patients_models.Individual, "sex") - individuals_k_sex = stats_for_field(patients_models.Individual, "karyotypic_sex") + individuals_sex = await stats_for_field(patients_models.Individual, "sex") + individuals_k_sex = await stats_for_field(patients_models.Individual, "karyotypic_sex") - diseases_stats = stats_for_field(pheno_models.Phenopacket, "diseases__term__label") + diseases_stats = await stats_for_field(pheno_models.Phenopacket, "diseases__term__label") diseases_count = len(diseases_stats) - individuals_age = get_age_numeric_binned(patients_models.Individual.objects.all(), OVERVIEW_AGE_BIN_SIZE) + individuals_age = await get_age_numeric_binned(patients_models.Individual.objects.all(), OVERVIEW_AGE_BIN_SIZE) - r = { + return Response({ "phenopackets": phenopackets_count, "data_type_specific": { "biosamples": { "count": biosamples_count, - "taxonomy": stats_for_field(pheno_models.Biosample, "taxonomy__label"), - "sampled_tissue": stats_for_field(pheno_models.Biosample, "sampled_tissue__label"), + "taxonomy": await stats_for_field(pheno_models.Biosample, "taxonomy__label"), + "sampled_tissue": await stats_for_field(pheno_models.Biosample, "sampled_tissue__label"), }, "diseases": { # count is a number of unique disease terms (not all diseases in the database) @@ -105,41 +113,39 @@ def overview(_request): "karyotypic_sex": { k: individuals_k_sex.get(k, 0) for k in (s[0] for s in pheno_models.Individual.KARYOTYPIC_SEX) }, - "taxonomy": stats_for_field(patients_models.Individual, "taxonomy__label"), + "taxonomy": await stats_for_field(patients_models.Individual, "taxonomy__label"), "age": individuals_age, - "ethnicity": stats_for_field(patients_models.Individual, "ethnicity"), + "ethnicity": await stats_for_field(patients_models.Individual, "ethnicity"), }, "phenotypic_features": { # count is a number of unique phenotypic feature types (not all pfs in the database) "count": phenotypic_features_count, - "type": stats_for_field(pheno_models.PhenotypicFeature, "pftype__label") + "type": await stats_for_field(pheno_models.PhenotypicFeature, "pftype__label") }, "experiments": { "count": experiments_count, - "study_type": stats_for_field(experiments_models.Experiment, "study_type"), - "experiment_type": stats_for_field(experiments_models.Experiment, "experiment_type"), - "molecule": stats_for_field(experiments_models.Experiment, "molecule"), - "library_strategy": stats_for_field(experiments_models.Experiment, "library_strategy"), - "library_source": stats_for_field(experiments_models.Experiment, "library_source"), - "library_selection": stats_for_field(experiments_models.Experiment, "library_selection"), - "library_layout": stats_for_field(experiments_models.Experiment, "library_layout"), - "extraction_protocol": stats_for_field(experiments_models.Experiment, "extraction_protocol"), + "study_type": await stats_for_field(experiments_models.Experiment, "study_type"), + "experiment_type": await stats_for_field(experiments_models.Experiment, "experiment_type"), + "molecule": await stats_for_field(experiments_models.Experiment, "molecule"), + "library_strategy": await stats_for_field(experiments_models.Experiment, "library_strategy"), + "library_source": await stats_for_field(experiments_models.Experiment, "library_source"), + "library_selection": await stats_for_field(experiments_models.Experiment, "library_selection"), + "library_layout": await stats_for_field(experiments_models.Experiment, "library_layout"), + "extraction_protocol": await stats_for_field(experiments_models.Experiment, "extraction_protocol"), }, "experiment_results": { "count": experiment_results_count, - "file_format": stats_for_field(experiments_models.ExperimentResult, "file_format"), - "data_output_type": stats_for_field(experiments_models.ExperimentResult, "data_output_type"), - "usage": stats_for_field(experiments_models.ExperimentResult, "usage") + "file_format": await stats_for_field(experiments_models.ExperimentResult, "file_format"), + "data_output_type": await stats_for_field(experiments_models.ExperimentResult, "data_output_type"), + "usage": await stats_for_field(experiments_models.ExperimentResult, "usage") }, "instruments": { "count": instruments_count, - "platform": stats_for_field(experiments_models.Experiment, "instrument__platform"), - "model": stats_for_field(experiments_models.Experiment, "instrument__model") + "platform": await stats_for_field(experiments_models.Experiment, "instrument__platform"), + "model": await stats_for_field(experiments_models.Experiment, "instrument__model") }, } - } - - return Response(r) + }) @api_view(["GET"]) @@ -155,7 +161,7 @@ def extra_properties_schema_types(_request): @api_view(["GET", "POST"]) @permission_classes([OverrideOrSuperUserOnly]) -def search_overview(request): +async def search_overview(request): """ get+post: Overview statistics of a list of patients (associated with a search result) @@ -165,44 +171,52 @@ def search_overview(request): individual_id = request.GET.getlist("id") if request.method == "GET" else request.data.get("id", []) queryset = patients_models.Individual.objects.all().filter(id__in=individual_id) + # TODO: filter to only individuals where we have project/dataset-level access? or can we at least pass a dataset + # in too to make this less annoying... + individuals_count = len(individual_id) - biosamples_count = queryset.values("phenopackets__biosamples__id").exclude( - phenopackets__biosamples__id__isnull=True).count() + biosamples_count = await ( + queryset + .values("phenopackets__biosamples__id") + .exclude(phenopackets__biosamples__id__isnull=True) + .count() + ) # Sex related fields stats are precomputed here and post processed later # to include missing values inferred from the schema - individuals_sex = queryset_stats_for_field(queryset, "sex") + individuals_sex = await queryset_stats_for_field(queryset, "sex") # several obvious approaches to experiment counts give incorrect answers - experiment_types = queryset_stats_for_field(queryset, "phenopackets__biosamples__experiment__experiment_type") + experiment_types = await queryset_stats_for_field( + queryset, "phenopackets__biosamples__experiment__experiment_type") experiments_count = sum(experiment_types.values()) - r = { + return Response({ "biosamples": { "count": biosamples_count, - "sampled_tissue": queryset_stats_for_field(queryset, "phenopackets__biosamples__sampled_tissue__label"), - "histological_diagnosis": queryset_stats_for_field( + "sampled_tissue": await queryset_stats_for_field( + queryset, "phenopackets__biosamples__sampled_tissue__label"), + "histological_diagnosis": await queryset_stats_for_field( queryset, "phenopackets__biosamples__histological_diagnosis__label" ), }, "diseases": { - "term": queryset_stats_for_field(queryset, "phenopackets__diseases__term__label"), + "term": await queryset_stats_for_field(queryset, "phenopackets__diseases__term__label"), }, "individuals": { "count": individuals_count, "sex": {k: individuals_sex.get(k, 0) for k in (s[0] for s in pheno_models.Individual.SEX)}, - "age": get_age_numeric_binned(queryset, OVERVIEW_AGE_BIN_SIZE), + "age": await get_age_numeric_binned(queryset, OVERVIEW_AGE_BIN_SIZE), }, "phenotypic_features": { - "type": queryset_stats_for_field(queryset, "phenopackets__phenotypic_features__pftype__label") + "type": await queryset_stats_for_field(queryset, "phenopackets__phenotypic_features__pftype__label"), }, "experiments": { "count": experiments_count, "experiment_type": experiment_types, }, - } - return Response(r) + }) @extend_schema( @@ -214,7 +228,7 @@ def search_overview(request): 'mcodepackets': serializers.IntegerField(), 'data_type_specific': serializers.JSONField(), } - ) + ), } ) # Cache page for the requested url for 2 hours @@ -297,156 +311,252 @@ def mcode_overview(_request): }) +class DiscoveryPermissionsDict(TypedDict): + counts: bool + data: bool + + +DataTypeDiscoveryPermissions = dict[str, DiscoveryPermissionsDict] + + +async def get_data_type_discovery_permissions( + request: HttpRequest, data_types: list[str] +) -> DataTypeDiscoveryPermissions: + # For all of these required data types, figure out if we have: + # a) full-response query:data permissions, and + # b) count-level permissions (at the project level) - will also re-check the query:data permissions currently :( + + query_data_perms, counts_perms = await asyncio.gather( + has_query_data_permission_for_data_types(request, None, None, data_types), + has_counts_permission_for_data_types(request, None, None, data_types), + ) + + # Collect these permissions, organized by data type, in a dictionary, so we can query them later: + return { + dt: { + "counts": c_perm, + "data": qd_perm, + } + for dt, qd_perm, c_perm in zip( + data_types, # List of data type IDs + query_data_perms, # query:data permissions for each data type + counts_perms, # query:project_level_counts permissions for each data type + ) + } + + +async def get_public_data_type_permissions(request: HttpRequest) -> DataTypeDiscoveryPermissions: + return await get_data_type_discovery_permissions( + request, + + # Collect all data types that we need permissions for to give various parts of the public overview response. + # - individuals & biosamples are in the 'phenopacket' data type, experiments are in the 'experiment' data type + list(set(PUBLIC_MODEL_NAMES_TO_DATA_TYPE.values())) + ) + + @extend_schema( description="Public search fields with their configuration", responses={ - 200: inline_serializer( + status.HTTP_200_OK: inline_serializer( name='public_search_fields_response', - fields={ - 'sections': serializers.JSONField(), - } - ) + fields={'sections': serializers.JSONField()} + ), + status.HTTP_404_NOT_FOUND: inline_serializer( + name='public_search_fields_not_configured', + fields={'message': serializers.CharField()}, + ), } ) @api_view(["GET"]) -@permission_classes([AllowAny]) -def public_search_fields(_request): +async def public_search_fields(request: HttpRequest): """ get: Return public search fields with their configuration """ - if not settings.CONFIG_PUBLIC: - return Response(settings.NO_PUBLIC_FIELDS_CONFIGURED) - search_conf = settings.CONFIG_PUBLIC["search"] - field_conf = settings.CONFIG_PUBLIC["fields"] + # TODO: should be project-scoped + + config_public = settings.CONFIG_PUBLIC + + if not config_public: + authz_middleware.mark_authz_done(request) + return Response(settings.NO_PUBLIC_FIELDS_CONFIGURED, status=status.HTTP_404_NOT_FOUND) + + # Access (counts/data) permissions by Bento data type + dt_permissions = await get_public_data_type_permissions(request) + + field_conf = config_public["fields"] + # Note: the array is wrapped in a dictionary structure to help with JSON # processing by some services. - r = { - "sections": [ - { - **section, - "fields": [ - { - **field_conf[f], - "id": f, - "options": get_field_options(field_conf[f]) - } for f in section["fields"] - ] - } for section in search_conf - ] - } - return Response(r) + + async def _get_field_response(field) -> dict: + field_props = field_conf[field] + field_perms = get_count_and_query_data_permissions_for_field(dt_permissions, field_props) + + return { + **field_props, + "id": field, + "options": await get_field_options(field_props, low_counts_censored=not field_perms["data"]), + } + + async def _get_section_response(section) -> dict: + return { + **section, + "fields": await asyncio.gather(*map(_get_field_response, section["fields"])), + } + + return Response({ + "sections": await asyncio.gather(*map(_get_section_response, config_public["search"])), + }) + + +def get_count_and_query_data_permissions_for_field( + dt_permissions: DataTypeDiscoveryPermissions, field_props: dict +) -> DiscoveryPermissionsDict: + public_model_name, _ = get_public_model_name_and_field_path(field_props["mapping"]) + field_bento_data_type = PUBLIC_MODEL_NAMES_TO_DATA_TYPE[public_model_name] + return dt_permissions[field_bento_data_type] @extend_schema( description="Overview of all public data in the database", responses={ - 200: inline_serializer( + status.HTTP_200_OK: inline_serializer( name='public_overview_response', - fields={ - 'datasets': serializers.CharField(), - } - ) + fields={'datasets': serializers.CharField()} + ), + status.HTTP_404_NOT_FOUND: inline_serializer( + name='public_overview_not_available', + fields={'message': serializers.CharField()}, + ), } ) -@api_view(["GET"]) -@permission_classes([AllowAny]) -def public_overview(_request): +@api_view(["GET"]) # Don't use BentoAllowAny, we want to be more careful of cases here. +async def public_overview(request: HttpRequest): """ get: Overview of all public data in the database """ - if not settings.CONFIG_PUBLIC: - return Response(settings.NO_PUBLIC_DATA_AVAILABLE) + config_public = settings.CONFIG_PUBLIC - # Predefined counts - individuals_count = patients_models.Individual.objects.all().count() - biosamples_count = pheno_models.Biosample.objects.all().count() - experiments_count = experiments_models.Experiment.objects.all().count() + if not config_public: + authz_middleware.mark_authz_done(request) + return Response(settings.NO_PUBLIC_DATA_AVAILABLE, status=status.HTTP_404_NOT_FOUND) + + # TODO: public overviews SHOULD be project-scoped at least. + + # Access (counts/data) permissions by Bento data type + dt_permissions = await get_public_data_type_permissions(request) - # Early return when there is not enough data - if individuals_count < settings.CONFIG_PUBLIC["rules"]["count_threshold"]: - return Response(settings.INSUFFICIENT_DATA_AVAILABLE) + # If we don't have AT LEAST one count permission, assume we're not supposed to see this page and return forbidden. + if not any(dpd["counts"] for dpd in dt_permissions.values()): + authz_middleware.mark_authz_done(request) + return Response(errors.forbidden_error, status=status.HTTP_403_FORBIDDEN) + + # Predefined counts + async def _counts_for_model_name(mn: str) -> tuple[str, int]: + return mn, await PUBLIC_MODEL_NAMES_TO_MODEL[mn].objects.all().count() + counts = dict(await asyncio.gather(*map(_counts_for_model_name, PUBLIC_MODEL_NAMES_TO_MODEL))) # Get the rules config rules_config = settings.CONFIG_PUBLIC["rules"] + count_threshold = rules_config["count_threshold"] + + # Set counts to 0 if they're under the count threshold, and we don't have full data access permissions for the + # data type corresponding to the model. + for public_model_name in counts: + data_type = PUBLIC_MODEL_NAMES_TO_DATA_TYPE[public_model_name] + if counts[public_model_name] < count_threshold and not dt_permissions[data_type]["data"]: + logger.info(f"Public overview: {public_model_name} count is below count threshold") + counts[public_model_name] = 0 response = { - "layout": settings.CONFIG_PUBLIC["overview"], + "layout": config_public["overview"], "fields": {}, "counts": { - "individuals": individuals_count, - "biosamples": biosamples_count, - "experiments": experiments_count + **({ + "individuals": counts["individual"], + "biosamples": counts["biosample"], + } if dt_permissions[DATA_TYPE_PHENOPACKET]["counts"] else {}), + **({ + "experiments": counts["experiment"], + } if dt_permissions[DATA_TYPE_EXPERIMENT]["counts"] else {}), }, "max_query_parameters": rules_config["max_query_parameters"], - "count_threshold": rules_config["count_threshold"], + "count_threshold": count_threshold, } - # Parse the public config to gather data for each field defined in the - # overview - fields = [chart["field"] for section in settings.CONFIG_PUBLIC["overview"] for chart in section["charts"]] + # Parse the public config to gather data for each field defined in the overview - for field in fields: - field_props = settings.CONFIG_PUBLIC["fields"][field] - if field_props["datatype"] == "string": - stats = get_categorical_stats(field_props) + fields = [chart["field"] for section in config_public["overview"] for chart in section["charts"]] + field_conf = config_public["fields"] + + async def _get_field_response(field_props: dict) -> dict: + field_perms = get_count_and_query_data_permissions_for_field(dt_permissions, field_props) + + # Permissions incorporation: only censor small cell counts when we don't have query:data access + stats: list[BinWithValue] | None + if not field_perms["counts"]: + stats = None + elif field_props["datatype"] == "string": + stats = await get_categorical_stats(field_props, low_counts_censored=not field_perms["data"]) elif field_props["datatype"] == "number": - stats = get_range_stats(field_props) + stats = await get_range_stats(field_props, low_counts_censored=not field_perms["data"]) elif field_props["datatype"] == "date": - stats = get_date_stats(field_props) + stats = await get_date_stats(field_props, low_counts_censored=not field_perms["data"]) else: raise NotImplementedError() - response["fields"][field] = { + return { **field_props, "id": field, - "data": stats + **({"data": stats} if stats is not None else {}), } + # Parallel async collection of field responses for public overview + field_responses = await asyncio.gather(*(_get_field_response(field_conf[field]) for field in fields)) + + for field, field_res in zip(fields, field_responses): + response["fields"][field] = field_res + return Response(response) @api_view(["GET"]) -@permission_classes([AllowAny]) -def public_dataset(_request): +@permission_classes([BentoAllowAny]) +async def public_dataset(_request): """ get: Properties of the datasets """ - if not settings.CONFIG_PUBLIC: - return Response(settings.NO_PUBLIC_DATA_AVAILABLE) - - # Datasets provenance metadata - datasets = chord_models.Dataset.objects.values( - "title", "description", "contact_info", - "dates", "stored_in", "spatial_coverage", - "types", "privacy", "distributions", - "dimensions", "primary_publications", "citations", - "produced_by", "creators", "licenses", - "acknowledges", "keywords", "version", "dats_file", - "extra_properties", "identifier" - ) + # For now, we don't have any permissions checks for this. + # In the future, we could introduce a view:dataset permission or something. - # convert dats_file json content to dict - datasets = [ - { - **d, - "dats_file": json.loads(d["dats_file"]) if d["dats_file"] else None - } for d in datasets] + if not settings.CONFIG_PUBLIC: + return Response(settings.NO_PUBLIC_DATA_AVAILABLE, status=status.HTTP_404_NOT_FOUND) return Response({ - "datasets": datasets + "datasets": [ + { + **d, + # convert dats_file json content to dict + "dats_file": json.loads(d["dats_file"]) if d["dats_file"] else None, + } + async for d in ( + # Datasets provenance metadata: + chord_models.Dataset.objects.values( + "title", "description", "contact_info", + "dates", "stored_in", "spatial_coverage", + "types", "privacy", "distributions", + "dimensions", "primary_publications", "citations", + "produced_by", "creators", "licenses", + "acknowledges", "keywords", "version", "dats_file", + "extra_properties", "identifier" + ) + ) + ] }) - - -DT_QUERYSETS = { - dt.DATA_TYPE_EXPERIMENT: experiments_models.Experiment.objects.all(), - dt.DATA_TYPE_EXPERIMENT_RESULT: experiments_models.ExperimentResult.objects.all(), - dt.DATA_TYPE_MCODEPACKET: mcode_models.MCodePacket.objects.all(), - dt.DATA_TYPE_PHENOPACKET: pheno_models.Phenopacket.objects.all(), - # dt.DATA_TYPE_READSET: None, -} diff --git a/chord_metadata_service/restapi/utils.py b/chord_metadata_service/restapi/utils.py index 28be837a9..81011f2cc 100644 --- a/chord_metadata_service/restapi/utils.py +++ b/chord_metadata_service/restapi/utils.py @@ -6,12 +6,14 @@ from collections import defaultdict, Counter from calendar import month_abbr from decimal import Decimal, ROUND_HALF_EVEN -from typing import Any, Optional, Type, TypedDict, Mapping, Generator +from typing import Any, Type, TypedDict, Mapping, Generator from django.db.models import Count, F, Func, IntegerField, CharField, Case, Model, When, Value from django.db.models.functions import Cast from django.conf import settings +from chord_metadata_service.chord.data_types import DATA_TYPE_PHENOPACKET, DATA_TYPE_EXPERIMENT +from chord_metadata_service.patients import models as patient_models from chord_metadata_service.phenopackets import models as pheno_models from chord_metadata_service.experiments import models as experiments_models from chord_metadata_service.logger import logger @@ -19,10 +21,16 @@ LENGTH_Y_M = 4 + 1 + 2 # dates stored as yyyy-mm-dd -MODEL_NAMES_TO_MODEL: dict[str, Type[Model]] = { - "individual": pheno_models.Individual, - "experiment": experiments_models.Experiment, +PUBLIC_MODEL_NAMES_TO_MODEL: dict[str, Type[Model]] = { + "individual": patient_models.Individual, "biosample": pheno_models.Biosample, + "experiment": experiments_models.Experiment, +} + +PUBLIC_MODEL_NAMES_TO_DATA_TYPE = { + "individual": DATA_TYPE_PHENOPACKET, + "biosample": DATA_TYPE_PHENOPACKET, + "experiment": DATA_TYPE_EXPERIMENT, } @@ -31,12 +39,12 @@ class BinWithValue(TypedDict): value: int -def get_threshold() -> int: +def get_threshold(low_counts_censored: bool) -> int: """ Gets the maximum count threshold for hiding censored data (i.e., rounding to 0). This is a function to prevent settings errors if not running/importing this file in a Django context. """ - return settings.CONFIG_PUBLIC["rules"]["count_threshold"] + return settings.CONFIG_PUBLIC["rules"]["count_threshold"] if low_counts_censored else 0 def camel_case_field_names(string) -> str: @@ -112,9 +120,9 @@ def _round_decimal_two_places(d: float) -> Decimal: return Decimal(d).quantize(Decimal("0.01"), rounding=ROUND_HALF_EVEN) -def time_element_to_years(time_element: dict, unit: str = "years") -> tuple[Optional[Decimal], Optional[str]]: - time_value: Optional[Decimal] = None - time_unit: Optional[str] = None +def time_element_to_years(time_element: dict, unit: str = "years") -> tuple[Decimal | None, str | None]: + time_value: Decimal | None = None + time_unit: str | None = None if "age" in time_element: return iso_duration_to_years(time_element["age"], unit=unit) elif "age_range" in time_element: @@ -125,7 +133,7 @@ def time_element_to_years(time_element: dict, unit: str = "years") -> tuple[Opti return time_value, time_unit -def iso_duration_to_years(iso_age_duration: str | dict, unit: str = "years") -> tuple[Optional[Decimal], Optional[str]]: +def iso_duration_to_years(iso_age_duration: str | dict, unit: str = "years") -> tuple[Decimal | None, str | None]: """ This function takes ISO8601 Duration string in the format e.g 'P20Y6M4D' and converts it to years. """ @@ -184,8 +192,8 @@ def custom_binning_generator(field_props: dict) -> Generator[tuple[int, int, str """ c = field_props["config"] - minimum: Optional[int] = int(c["minimum"]) if "minimum" in c else None - maximum: Optional[int] = int(c["maximum"]) if "maximum" in c else None + minimum: int | None = int(c["minimum"]) if "minimum" in c else None + maximum: int | None = int(c["maximum"]) if "maximum" in c else None bins: list[int] = [int(value) for value in c["bins"]] # check prerequisites @@ -275,6 +283,11 @@ def monthly_generator(start: str, end: str) -> tuple[int, int]: yield year, month +def get_public_model_name_and_field_path(field_id: str) -> tuple[str, tuple[str, ...]]: + model_name, *field_path = field_id.split("/") + return model_name, tuple(field_path) + + def get_model_and_field(field_id: str) -> tuple[any, str]: """ Parses a path-like string representing an ORM such as "individual/extra_properties/date_of_consent" @@ -284,9 +297,9 @@ def get_model_and_field(field_id: str) -> tuple[any, str]: field for this object. """ - model_name, *field_path = field_id.split("/") + model_name, field_path = get_public_model_name_and_field_path(field_id) - model: Optional[Type[Model]] = MODEL_NAMES_TO_MODEL.get(model_name) + model: Type[Model] | None = PUBLIC_MODEL_NAMES_TO_MODEL.get(model_name) if model is None: msg = f"Accessing field on model {model_name} not implemented" raise NotImplementedError(msg) @@ -295,16 +308,15 @@ def get_model_and_field(field_id: str) -> tuple[any, str]: return model, field_name -def stats_for_field(model, field: str, add_missing=False) -> Mapping[str, int]: +async def stats_for_field(model, field: str, add_missing=False) -> Mapping[str, int]: """ Computes counts of distinct values for a given field. Mainly applicable to char fields representing categories """ - queryset = model.objects.all() - return queryset_stats_for_field(queryset, field, add_missing) + return await queryset_stats_for_field(model.objects.all(), field, add_missing) -def queryset_stats_for_field(queryset, field: str, add_missing=False) -> Mapping[str, int]: +async def queryset_stats_for_field(queryset, field: str, add_missing=False) -> Mapping[str, int]: """ Computes counts of distinct values for a queryset. """ @@ -318,7 +330,7 @@ def queryset_stats_for_field(queryset, field: str, add_missing=False) -> Mapping stats: dict[str, int] = {} - for item in annotated_queryset: + async for item in annotated_queryset: key = item[field] if key is None: num_missing = item["total"] @@ -335,7 +347,7 @@ def queryset_stats_for_field(queryset, field: str, add_missing=False) -> Mapping return stats -def get_field_bins(query_set, field, bin_size): +async def get_field_bins(query_set, field, bin_size): # computes a new column "binned" by substracting the modulo by bin size to # the value which requires binning (e.g. 28 => 28 - 28 % 10 = 20) # cast to integer to avoid numbers such as 60.00 if that was a decimal, @@ -346,11 +358,11 @@ def get_field_bins(query_set, field, bin_size): IntegerField() ) ).values("binned").annotate(total=Count("binned")) - stats = {item["binned"]: item["total"] for item in query_set} + stats = {item["binned"]: item["total"] async for item in query_set} return stats -def compute_binned_ages(individual_queryset, bin_size: int) -> list[int]: +async def compute_binned_ages(individual_queryset, bin_size: int) -> list[int]: """ When age_numeric field is not available, use this function to process the age field in its various formats. @@ -363,7 +375,7 @@ def compute_binned_ages(individual_queryset, bin_size: int) -> list[int]: a = individual_queryset.filter(age_numeric__isnull=True).values('time_at_last_encounter') binned_ages = [] - for r in a.iterator(): # reduce memory footprint (no caching) + async for r in a.iterator(): # reduce memory footprint (no caching) if r["time_at_last_encounter"] is None: continue age = parse_individual_age(r["time_at_last_encounter"]) @@ -372,33 +384,36 @@ def compute_binned_ages(individual_queryset, bin_size: int) -> list[int]: return binned_ages -def get_age_numeric_binned(individual_queryset, bin_size: int) -> dict: +async def get_age_numeric_binned(individual_queryset, bin_size: int) -> dict: """ age_numeric is computed at ingestion time of phenopackets. On some instances it might be unavailable and as a fallback must be computed from the age JSON field which has two alternate formats (hence more complex and slower to process) """ - individuals_age = get_field_bins(individual_queryset, "age_numeric", bin_size) + individuals_age = await get_field_bins(individual_queryset, "age_numeric", bin_size) if None not in individuals_age: return individuals_age del individuals_age[None] individuals_age = Counter(individuals_age) individuals_age.update( - compute_binned_ages(individual_queryset, bin_size) # single update instead of creating iterables in a loop + # single update instead of creating iterables in a loop + await compute_binned_ages(individual_queryset, bin_size) ) return individuals_age -def get_categorical_stats(field_props: dict) -> list[BinWithValue]: +async def get_categorical_stats(field_props: dict, low_counts_censored: bool) -> list[BinWithValue]: """ Fetches statistics for a given categorical field and apply privacy policies """ + model, field_name = get_model_and_field(field_props["mapping"]) - stats = stats_for_field(model, field_name, add_missing=True) + + stats: Mapping[str, int] = await stats_for_field(model, field_name, add_missing=True) # Enforce values order from config and apply policies - labels: Optional[list[str]] = field_props["config"].get("enum") + labels: list[str] | None = field_props["config"].get("enum") derived_labels: bool = labels is None # Special case: for some fields, values are based on what's present in the @@ -414,7 +429,7 @@ def get_categorical_stats(field_props: dict) -> list[BinWithValue]: key=lambda x: x.lower() ) - threshold = get_threshold() + threshold = get_threshold(low_counts_censored) bins: list[BinWithValue] = [] for category in labels: @@ -438,7 +453,7 @@ def get_categorical_stats(field_props: dict) -> list[BinWithValue]: return bins -def get_date_stats(field_props: dict) -> list[BinWithValue]: +async def get_date_stats(field_props: dict, low_counts_censored: bool = True) -> list[BinWithValue]: """ Fetches statistics for a given date field, fill the gaps in the date range and apply privacy policies. @@ -459,7 +474,7 @@ def get_date_stats(field_props: dict) -> list[BinWithValue]: raise NotImplementedError(msg) # Note: lexical sort works on ISO dates - query_set = ( + query_set = await ( model.objects.all() .values(field_name) .order_by(field_name) @@ -467,8 +482,8 @@ def get_date_stats(field_props: dict) -> list[BinWithValue]: ) stats = defaultdict(int) - start: Optional[str] = None - end: Optional[str] = None + start: str | None = None + end: str | None = None # Key the counts on yyyy-mm combination (aggregate same month counts) for item in query_set: key = "missing" if item[field_name] is None else item[field_name][:LENGTH_Y_M] @@ -484,7 +499,7 @@ def get_date_stats(field_props: dict) -> list[BinWithValue]: start = key # All the bins between start and end date must be represented - threshold = get_threshold() + threshold = get_threshold(low_counts_censored) bins: list[BinWithValue] = [] if start: # at least one month for year, month in monthly_generator(start, end or start): @@ -503,7 +518,7 @@ def get_date_stats(field_props: dict) -> list[BinWithValue]: return bins -def get_month_date_range(field_props: dict) -> tuple[Optional[str], Optional[str]]: +def get_month_date_range(field_props: dict) -> tuple[str | None, str | None]: """ Get start date and end date from the database Note that dates within a JSON are stored as strings, not instances of datetime. @@ -541,16 +556,19 @@ def get_month_date_range(field_props: dict) -> tuple[Optional[str], Optional[str return start, end -def get_range_stats(field_props: dict) -> list[BinWithValue]: +async def get_range_stats(field_props: dict, low_counts_censored: bool = True) -> list[BinWithValue]: model, field = get_model_and_field(field_props["mapping"]) # Generate a list of When conditions that return a label for the given bin. # This is equivalent to an SQL CASE statement. - whens = [When( - **{f"{field}__gte": floor} if floor is not None else {}, - **{f"{field}__lt": ceil} if ceil is not None else {}, - then=Value(label) - ) for floor, ceil, label in labelled_range_generator(field_props)] + whens = [ + When( + **{f"{field}__gte": floor} if floor is not None else {}, + **{f"{field}__lt": ceil} if ceil is not None else {}, + then=Value(label), + ) + for floor, ceil, label in labelled_range_generator(field_props) + ] query_set = ( model.objects @@ -558,9 +576,10 @@ def get_range_stats(field_props: dict) -> list[BinWithValue]: .annotate(total=Count("label")) ) - threshold = get_threshold() # Maximum number of entries needed to round a count down to 0 (censored discovery) + # Maximum number of entries needed to round a count from its true value down to 0 (censored discovery) + threshold = get_threshold(low_counts_censored) stats: dict[str, int] = dict() - for item in query_set: + async for item in query_set: key = item["label"] stats[key] = item["total"] if item["total"] > threshold else 0 @@ -575,7 +594,7 @@ def get_range_stats(field_props: dict) -> list[BinWithValue]: return bins -def get_field_options(field_props: dict) -> list[Any]: +async def get_field_options(field_props: dict, low_counts_censored: bool) -> list[Any]: """ Given properties for a public field, return the list of authorized options for querying this field. @@ -587,7 +606,7 @@ def get_field_options(field_props: dict) -> list[Any]: # We must be careful here not to leak 'small cell' values as options # - e.g., if there are three individuals with sex=UNKNOWN_SEX, this # should be treated as if the field isn't in the database at all. - options = get_distinct_field_values(field_props) + options = await get_distinct_field_values(field_props, low_counts_censored) elif field_props["datatype"] == "number": options = [label for floor, ceil, label in labelled_range_generator(field_props)] elif field_props["datatype"] == "date": @@ -601,16 +620,23 @@ def get_field_options(field_props: dict) -> list[Any]: return options -def get_distinct_field_values(field_props: dict) -> list[Any]: +async def get_distinct_field_values(field_props: dict, low_counts_censored: bool) -> list[Any]: # We must be careful here not to leak 'small cell' values as options # - e.g., if there are three individuals with sex=UNKNOWN_SEX, this # should be treated as if the field isn't in the database at all. model, field = get_model_and_field(field_props["mapping"]) - threshold = get_threshold() - - values_with_counts = model.objects.values_list(field).annotate(count=Count(field)) - return [val for val, count in values_with_counts if count > threshold] + threshold = get_threshold(low_counts_censored) + + return [ + val + async for val, count in ( + model.objects + .values_list(field) + .annotate(count=Count(field)) + ) + if count > threshold + ] def filter_queryset_field_value(qs, field_props, value: str): @@ -661,29 +687,34 @@ def filter_queryset_field_value(qs, field_props, value: str): return qs.filter(**condition) -def experiment_type_stats(queryset): +async def experiment_type_stats(queryset): """ returns count and bento_public format list of stats for experiment type note that queryset_stats_for_field() does not count "missing" correctly when the field has multiple foreign keys """ - e_types = queryset.values(label=F("phenopackets__biosamples__experiment__experiment_type")).annotate( - value=Count("phenopackets__biosamples__experiment", distinct=True)) - return bento_public_format_count_and_stats_list(e_types) + return await bento_public_format_count_and_stats_list( + queryset + .values(label=F("phenopackets__biosamples__experiment__experiment_type")) + .annotate(value=Count("phenopackets__biosamples__experiment", distinct=True)) + ) -def biosample_tissue_stats(queryset): +async def biosample_tissue_stats(queryset): """ returns count and bento_public format list of stats for biosample sampled_tissue """ - b_tissue = queryset.values(label=F("phenopackets__biosamples__sampled_tissue__label")).annotate( - value=Count("phenopackets__biosamples", distinct=True)) - return bento_public_format_count_and_stats_list(b_tissue) + return await bento_public_format_count_and_stats_list( + queryset + .values(label=F("phenopackets__biosamples__sampled_tissue__label")) + .annotate(value=Count("phenopackets__biosamples", distinct=True)) + ) -def bento_public_format_count_and_stats_list(annotated_queryset) -> tuple[int, list[BinWithValue]]: +async def bento_public_format_count_and_stats_list(annotated_queryset) -> tuple[int, list[BinWithValue]]: stats_list: list[BinWithValue] = [] - total = 0 - for q in annotated_queryset: + total: int = 0 + + async for q in annotated_queryset: label = q["label"] value = int(q["value"]) total += value