Skip to content

Commit

Permalink
created attribute subclasses for used across BIA packages (#294)
Browse files Browse the repository at this point in the history
* created attribute subclasses for used across BIA packages

* removed to_pascal
  • Loading branch information
sherwoodf authored Feb 5, 2025
1 parent b6d7a74 commit fec017b
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 22 deletions.
24 changes: 14 additions & 10 deletions bia-export/bia_export/website_export/studies/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
retrieve_object,
)
from pathlib import Path
from pydantic import ValidationError
from pydantic.alias_generators import to_snake
from .models import StudyCLIContext, CacheUse
from bia_shared_datamodels import semantic_models, bia_data_model
from bia_shared_datamodels import bia_data_model, attribute_models
from bia_integrator_api import models as api_models
import json
from typing import List, Type
Expand Down Expand Up @@ -266,14 +267,17 @@ def retrieve_detail_objects(
}

for attribute in dataset.attribute:
if attribute.name in attribute_name_type_map:
for uuid in attribute.value[attribute.name]:
# retrieve_object handles whether to retrieve from file or from api
api_object = retrieve_object(
uuid, attribute_name_type_map[attribute.name], context
)
detail_fields[attribute_name_type_map[attribute.name]].append(
api_object
)
try:
attribute_models.DatasetAssociatedUUIDAttribute.model_validate(attribute.model_dump())
except ValidationError:
continue
for uuid in attribute.value[attribute.name]:
# retrieve_object handles whether to retrieve from file or from api
api_object = retrieve_object(
uuid, attribute_name_type_map[attribute.name], context
)
detail_fields[attribute_name_type_map[attribute.name]].append(
api_object
)

return detail_fields
5 changes: 3 additions & 2 deletions bia-ingest/bia_ingest/biostudies/generic_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..config import settings, api_client
from ..cli_logging import IngestionResult, log_failed_model_creation
import bia_integrator_api.models as api_models
from bia_shared_datamodels import attribute_models

logger = logging.getLogger("__main__." + __name__)

Expand Down Expand Up @@ -64,9 +65,9 @@ def get_generic_section_as_dict(
key_mapping: List[Tuple[str, str, Union[str, None, List]]],
mapped_object: Optional[BaseModel] = None,
valdiation_error_tracking: Optional[IngestionResult] = None,
) -> Dict[str, Dict[str, str|List[str]] | BaseModel]:
) -> Dict[str, Dict[str, str | List[str]] | BaseModel]:
"""
Map biostudies.Submission objects to dict or an object
Map biostudies.Submission objects to dict or an object
"""
if type(root) is Submission:
sections = find_sections_recursive(root.section, section_name, [])
Expand Down
19 changes: 13 additions & 6 deletions bia-ingest/bia_ingest/biostudies/v4/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional
from uuid import UUID

from bia_ingest.bia_object_creation_utils import dict_to_api_model, dicts_to_api_models
from bia_ingest.bia_object_creation_utils import dicts_to_api_models
from bia_ingest.persistence_strategy import PersistenceStrategy


Expand All @@ -16,7 +16,7 @@
find_sections_recursive,
)

from bia_shared_datamodels import bia_data_model, semantic_models
from bia_shared_datamodels import bia_data_model, semantic_models, attribute_models
from bia_shared_datamodels.uuid_creation import create_dataset_uuid

logger = logging.getLogger("__main__." + __name__)
Expand Down Expand Up @@ -49,6 +49,7 @@ def get_dataset(

return datasets


def get_dataset_dict_from_study_component(
submission: Submission,
study_uuid: UUID,
Expand All @@ -73,7 +74,7 @@ def get_dataset_dict_from_study_component(
attribute_list = []
if len(associations) > 0:
attribute_list.append(
semantic_models.Attribute.model_validate(
attribute_models.DatasetAssociationAttribute.model_validate(
{
"provenance": semantic_models.AttributeProvenance("bia_ingest"),
"name": "associations",
Expand All @@ -99,7 +100,7 @@ def get_dataset_dict_from_study_component(
"correlation_method": correlation_method_list,
"example_image_uri": [],
"version": 0,
"attribute": attribute_list,
"attribute": attribute_list
}

model_dicts.append(model_dict)
Expand Down Expand Up @@ -230,7 +231,10 @@ def get_uuid_attribute_from_associations(
}
)

return [semantic_models.Attribute.model_validate(x) for x in attribute_dicts]
return [
attribute_models.DatasetAssociatedUUIDAttribute.model_validate(x)
for x in attribute_dicts
]


def store_annotation_method_in_attribute(
Expand All @@ -254,7 +258,10 @@ def store_annotation_method_in_attribute(
raise RuntimeError(
"Dataset cannot find Annotation Method that should have been created"
)
return attribute_dicts
return [
attribute_models.DatasetAssociatedUUIDAttribute.model_validate(x)
for x in attribute_dicts
]


def get_image_analysis_method_from_associations(
Expand Down
20 changes: 16 additions & 4 deletions bia-ingest/bia_ingest/cli_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class IngestionResult(CLIResult):
ProcessingVersion: BioStudiesProcessingVersion = (
BioStudiesProcessingVersion.FALLBACK
)

Study_CreationCount: int = Field(default=0)
Study_ValidationErrorCount: int = Field(default=0)
Contributor_CreationCount: int = Field(default=0)
Expand Down Expand Up @@ -64,6 +64,12 @@ class IngestionResult(CLIResult):
AnnotationMethod_CreationCount: int = Field(default=0)
AnnotationMethod_ValidationErrorCount: int = Field(default=0)

DatasetAssociationAttribute_CreationCount: int = Field(default=0)
DatasetAssociationAttribute_ValidationErrorCount: int = Field(default=0)

DatasetAssociatedUUIDAttribute_CreationCount: int = Field(default=0)
DatasetAssociatedUUIDAttribute_ValidationErrorCount: int = Field(default=0)

ImageAnalysisMethod_CreationCount: int = Field(default=0)
ImageAnalysisMethod_ValidationErrorCount: int = Field(default=0)

Expand Down Expand Up @@ -119,11 +125,11 @@ def tabulate_ingestion_errors(

if result.Uncaught_Exception:
error_message += f"Uncaught exception: {result.Uncaught_Exception}"

warning_message = ""
if result.Warning:
warning_message = result.Warning

if (error_message == "") & (warning_message == ""):
status = Text("Success")
status.stylize("green")
Expand All @@ -145,7 +151,13 @@ def tabulate_ingestion_errors(
error_message = Text(error_message)
error_message.stylize("red")

row_info = [accession_id_key, result.ProcessingVersion, status, error_message, warning_message]
row_info = [
accession_id_key,
result.ProcessingVersion,
status,
error_message,
warning_message,
]

if include_object_count:
for header in headers[5:]:
Expand Down
109 changes: 109 additions & 0 deletions bia-shared-datamodels/src/bia_shared_datamodels/attribute_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

from typing import List, Optional, Any
from typing_extensions import Self

from pydantic import BaseModel, Field, model_validator, field_validator, ValidationError
from .semantic_models import Attribute, AttributeProvenance
from uuid import UUID

# For shared models that are placed inside Attribute objects
# The API does not enforce what is created inside an Attribute object's value (beyond it being a dicitionary).
# This allows us to have additional freeform information from submitters.
# Model here are for use by BIA internal code packages that require a shared source of truth.


class SubAttributeMixin(BaseModel):
def __eq__(self, other):
if other.__class__ == Attribute:
attribute_self = Attribute.model_validate(self.model_dump())
return attribute_self.__eq__(other)
else:
return BaseModel.__eq__(self, other)


class DatasetAssociationValue(BaseModel):
# Allows None, but requires fields to be present
image_analysis: Optional[str] = Field()
image_correlation: Optional[str] = Field()
biosample: Optional[str] = Field()
image_acquisition: Optional[str] = Field()
specimen: Optional[str] = Field()


class DatasetAssociationAttribute(Attribute, SubAttributeMixin):
"""
Model for storing user provided Associations from biostudies in an Attribute on a dataset.
"""

@field_validator("provenance", mode="after")
@classmethod
def attribute_provenance(cls, value: AttributeProvenance) -> AttributeProvenance:
if value != AttributeProvenance.bia_ingest:
raise ValueError(
f"Provenance for this type of attribute must be {AttributeProvenance.bia_ingest}"
)

return value

@field_validator("name", mode="after")
@classmethod
def attribute_name(cls, value: str) -> str:
if value != "associations":
raise ValueError(
f"name field for this type of attribute must be 'associations'"
)
return value

@field_validator("value", mode="after")
@classmethod
def attribute_value(cls, value: dict) -> dict:
if len(value.keys()) != 1:
raise ValueError("Value dictionary should have exactly one 1 key")
elif "associations" not in value.keys():
raise ValueError(f'The value dictionary key must be "associations"')
return value

value: dict[str, list[DatasetAssociationValue]] = Field()


class DatasetAssociatedUUIDAttribute(Attribute, SubAttributeMixin):
"""
Model for storing uuid of objects linked to a dataset, for use in code downstream of ingest.
"""

@field_validator("provenance", mode="after")
@classmethod
def attribute_provenance(cls, value: AttributeProvenance) -> AttributeProvenance:
if value != AttributeProvenance.bia_ingest:
raise ValueError(
f"Provenance for this type of attribute must be {AttributeProvenance.bia_ingest}"
)
return value

@field_validator("name", mode="after")
@classmethod
def validate_attribute_name(cls, value) -> Self:
valid_associations = [
"image_acquisition_protocol_uuid",
"specimen_imaging_preparation_protocol_uuid",
"bio_sample_uuid",
"annotation_method_uuid",
"protocol_uuid",
]
if value not in valid_associations:
raise ValueError(
f"Name for this type of attribute must be one of: {valid_associations}"
)

return value

@model_validator(mode="after")
def validate_attribute_value_key(self) -> Self:
if len(self.value.keys()) != 1:
raise ValueError("Value dictionary should have exactly one 1 key")
elif self.name not in self.value.keys():
raise ValueError(f"The key for this type of attribute must be {self.name}")
return self

value: dict[str, list[str]] = Field()
32 changes: 32 additions & 0 deletions bia-shared-datamodels/src/bia_shared_datamodels/mock_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,35 @@ def get_attribute_dict(completeness=Completeness.COMPLETE) -> dict:
"value": {},
}
return attribute


def get_dataset_associated_uuid_attribute(completeness=Completeness.COMPLETE) -> dict:
attribute = {
"provenance": semantic_models.AttributeProvenance.bia_ingest,
"name": "protocol_uuid",
"value": {"protocol_uuid": [str(get_protocol_dict()["uuid"])]},
}
return attribute


def get_dataset_associatation_attribute(completeness=Completeness.COMPLETE) -> dict:
attribute = {
"provenance": semantic_models.AttributeProvenance.bia_ingest,
"name": "associations",
"value": {"associations": []},
}
if completeness == Completeness.COMPLETE:
attribute |= {
"value": {
"associations": [
{
"image_analysis": "image_analysis_title",
"image_correlation": "image_correlation_title",
"biosample": "biosample_title",
"image_acquisition": "biosample_title",
"specimen": "specimen_title",
}
]
}
}
return attribute
59 changes: 59 additions & 0 deletions bia-shared-datamodels/test/test_attribute_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
from pydantic import ValidationError
from bia_shared_datamodels import (
semantic_models,
mock_objects,
attribute_models,
)
from typing import Callable


@pytest.mark.parametrize(
("expected_model_type", "dict_creation_func"),
(
(
attribute_models.DatasetAssociatedUUIDAttribute,
mock_objects.get_dataset_associated_uuid_attribute,
),
(
attribute_models.DatasetAssociationAttribute,
mock_objects.get_dataset_associatation_attribute,
),
),
)
def test_sub_attribute_models(
expected_model_type: semantic_models.Attribute,
dict_creation_func: Callable[[mock_objects.Completeness], dict],
):

model_completeness_list = [
mock_objects.Completeness.COMPLETE,
mock_objects.Completeness.MINIMAL,
]

for model_completion in model_completeness_list:

model_dict = dict_creation_func(model_completion)

attribute_model = semantic_models.Attribute.model_validate(model_dict)
sub_attribute_model = expected_model_type.model_validate(model_dict)

# We have modified the __eq__ function, so it is good to check:
assert sub_attribute_model == sub_attribute_model

assert attribute_model.model_dump() == sub_attribute_model.model_dump()

assert attribute_model == sub_attribute_model

assert attribute_model == semantic_models.Attribute.model_validate(
sub_attribute_model
)

# Check basic attribute doesn't pass validation
try:
expected_model_type.model_validate(
mock_objects.get_attribute_dict(model_completion)
)
assert False
except ValidationError:
assert True

0 comments on commit fec017b

Please sign in to comment.