Skip to content

Commit

Permalink
🎨 Make source_artifact_of and source_dataframe_of private (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnyosun authored Aug 12, 2024
1 parent 025fdb7 commit b2ef550
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 45 deletions.
121 changes: 83 additions & 38 deletions bionty/core/_add_ontology.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from typing import Iterable, List, Optional, Set, Tuple, Type, Union

import pandas as pd
from lamin_utils import logger
from lnschema_core.models import Record

from bionty.models import BioRecord, Source


def get_all_ancestors(df: pd.DataFrame, ontology_ids: Iterable[str]) -> Set[str]:
ancestors = set()

def get_parents(onto_id: str) -> None:
stack = list(ontology_ids)
while stack:
onto_id = stack.pop()
try:
parents = df.at[onto_id, "parents"]
for parent in parents:
if parent not in ancestors:
ancestors.add(parent)
get_parents(parent)
stack.append(parent)
except KeyError:
print(f"Warning: Ontology ID {onto_id} not found in DataFrame")

for onto_id in ontology_ids:
get_parents(onto_id)

logger.warning(f"ontology ID {onto_id} not found in DataFrame")
return ancestors


Expand All @@ -33,16 +31,15 @@ def prepare_dataframe(df: pd.DataFrame) -> pd.DataFrame:


def get_new_ontology_ids(
registry: Type[BioRecord], ontology_ids: Iterable[str], df_all: pd.DataFrame
registry: Type[BioRecord], ontology_ids: Iterable[str], df: pd.DataFrame
) -> Tuple[Set[str], Set[str]]:
parents_ids = get_all_ancestors(df_all, ontology_ids)
ontology_ids = set(ontology_ids) | parents_ids
all_ontology_ids = set(ontology_ids) | get_all_ancestors(df, ontology_ids)
existing_ontology_ids = set(
registry.filter(ontology_id__in=ontology_ids).values_list(
registry.filter(ontology_id__in=all_ontology_ids).values_list(
"ontology_id", flat=True
)
)
return (ontology_ids - existing_ontology_ids), ontology_ids
return (all_ontology_ids - existing_ontology_ids), all_ontology_ids


def create_records(
Expand Down Expand Up @@ -70,45 +67,75 @@ def create_records(
def create_link_records(
registry: Type[BioRecord], df: pd.DataFrame, records: List[Record]
) -> List[Record]:
"""Create link records.
Args:
registry: The model class of the records.
df: The DataFrame with all ontology IDs and their parents.
records: All records of the ontology.
"""
source = records[0].source
linkorm = registry.parents.through
link_records = []
registry_name_lower = registry.__name__.lower()

# Create a dictionary for quick lookups
record_dict = {r.ontology_id: r for r in records if r.source_id == source.id}

for child_id, parents_ids in df["parents"].items():
if len(parents_ids) == 0:
continue
child_record = next(
(r for r in records if r.ontology_id == child_id and r.source == source),
None,
)
child_record = record_dict.get(child_id)
if not child_record:
continue
for parent_id in parents_ids:
parent_record = next(
(
r
for r in records
if r.ontology_id == parent_id and r.source == source
),
None,
)
parent_record = record_dict.get(parent_id)
if parent_record:
link_records.append(
linkorm(
**{
f"from_{registry.__name__.lower()}": child_record,
f"to_{registry.__name__.lower()}": parent_record,
f"from_{registry_name_lower}": child_record,
f"to_{registry_name_lower}": parent_record,
}
)
)
return link_records


def check_source_in_db(
registry: Type[BioRecord],
source: Source,
update: bool = False,
n_all: int = None,
n_in_db: int = None,
) -> None:
if n_all is None:
n_all = registry.public(source=source).df().shape[0]

if n_in_db is None:
# all records of the source in the database
n_in_db = registry.filter(source=source).count()
if n_in_db >= n_all:
# make sure in_db is set to True if all records are in the database
source.in_db = True
source.save()
if not update:
logger.warning(
f"{registry.__name__} records from source ({source.name}, {source.version}) are already in the database!\n → pass `update=True` to update the records"
)
return
else:
source.in_db = False
source.save()


def add_ontology_from_df(
registry: Type[BioRecord],
ontology_ids: Optional[List[str]] = None,
organism: Union[str, Record, None] = None,
source: Optional[Source] = None,
ignore_conflicts: bool = True,
update: bool = False,
):
import lamindb as ln

Expand All @@ -117,32 +144,48 @@ def add_ontology_from_df(
public = registry.public(organism=organism, source=source)
df = prepare_dataframe(public.df())

# TODO: consider StaticReference
source_record = get_source_record(public) # type:ignore

if ontology_ids is None:
df_new = df
df_all = df
df_new = df_all = df
else:
new_ontology_ids, all_ontology_ids = get_new_ontology_ids(
registry, ontology_ids, df
)
df_new = df[df.index.isin(new_ontology_ids)]
df_all = df[df.index.isin(all_ontology_ids)]

# TODO: consider StaticReference
source_record = get_source_record(public) # type:ignore
# do not create records from obsolete terms
records = [
r
for r in create_records(registry, df_new, source_record)
if not r.name.startswith("obsolete")
]
registry.objects.bulk_create(records, ignore_conflicts=ignore_conflicts)
df_all = df_all[~df_all["name"].str.startswith("obsolete")]

n_all = df_all.shape[0]
if n_all == 0:
raise ValueError("No valid records to add!")

# all records of the source in the database
all_records = registry.filter(source=source_record).all()
n_in_db = all_records.count()

check_source_in_db(
registry=registry,
source=source_record,
update=update,
n_all=n_all,
n_in_db=n_in_db,
)

if source_record.in_db and not update:
return

# do not create records from obsolete terms
records = create_records(registry, df_new, source_record)
registry.objects.bulk_create(records, ignore_conflicts=ignore_conflicts)

link_records = create_link_records(registry, df_all, all_records)
ln.save(link_records, ignore_conflicts=ignore_conflicts)

if ontology_ids is None and len(records) > 0:
if ontology_ids is None:
source_record.in_db = True
source_record.save()

Expand All @@ -152,6 +195,7 @@ def add_ontology(
organism: Union[str, Record, None] = None,
source: Optional[Source] = None,
ignore_conflicts: bool = True,
update: bool = False,
):
registry = records[0]._meta.model
source = source or records[0].source
Expand All @@ -164,4 +208,5 @@ def add_ontology(
organism=organism,
source=source,
ignore_conflicts=ignore_conflicts,
update=update,
)
32 changes: 32 additions & 0 deletions bionty/migrations/0036_alter_source_artifacts_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Generated by Django 5.2 on 2024-08-09 10:13

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("bionty", "0035_alter_protein_gene_symbol"),
("lnschema_core", "0060_alter_artifact__actions"),
]

operations = [
migrations.AlterField(
model_name="source",
name="artifacts",
field=models.ManyToManyField(
related_name="_source_artifact_of", to="lnschema_core.artifact"
),
),
migrations.AlterField(
model_name="source",
name="dataframe_artifact",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.PROTECT,
related_name="_source_dataframe_of",
to="lnschema_core.artifact",
),
),
]
15 changes: 8 additions & 7 deletions bionty/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def import_from_source(
ontology_ids: list[str] | None = None,
organism: str | Record | None = None,
ignore_conflicts: bool = True,
update: bool = False,
):
"""Bulk save records from a dataframe.
Expand All @@ -160,15 +161,16 @@ def import_from_source(
Examples:
>>> bionty.CellType.import_from_source()
"""
if hasattr(cls, "ontology_id"):
from .core._add_ontology import add_ontology_from_df
from .core._add_ontology import add_ontology_from_df, check_source_in_db

if hasattr(cls, "ontology_id"):
add_ontology_from_df(
registry=cls,
ontology_ids=ontology_ids,
organism=organism,
source=source,
ignore_conflicts=ignore_conflicts,
update=update,
)
else:
import lamindb as ln
Expand All @@ -193,9 +195,8 @@ def import_from_source(
)
ln.save(records, ignore_conflicts=ignore_conflicts)

if ontology_ids is None and len(records) > 0:
source_record.in_db = True
source_record.save()
# make sure source.in_db is correctly set based on the DB content
check_source_in_db(registry=cls, source=source_record, update=update)

@classmethod
def public(
Expand Down Expand Up @@ -1413,11 +1414,11 @@ class Meta(BioRecord.Meta, TracksRun.Meta, TracksUpdates.Meta):
source_website: str | None = models.TextField(null=True, default=None)
"""Website of the source."""
dataframe_artifact: Artifact = models.ForeignKey(
Artifact, PROTECT, null=True, default=None, related_name="source_dataframe_of"
Artifact, PROTECT, null=True, default=None, related_name="_source_dataframe_of"
)
"""Dataframe artifact that corresponds to this source."""
artifacts: Artifact = models.ManyToManyField(
Artifact, related_name="source_artifact_of"
Artifact, related_name="_source_artifact_of"
)
"""Additional files that correspond to this source."""

Expand Down

0 comments on commit b2ef550

Please sign in to comment.