Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvements to otu update and otu batch-update #133

Merged
merged 9 commits into from
Mar 26, 2025
199 changes: 98 additions & 101 deletions ref_builder/otu/update.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Collection, Iterable, Iterator
from pathlib import Path
Expand Down Expand Up @@ -41,6 +42,15 @@
"""Assists in reading and writing the fetched accessions by taxid index from file."""


class BaseBatchRecordGetter(ABC):
"""An abstract class with a .get_records() method."""

@abstractmethod
def get_records(self, taxid: int) -> list[NCBIGenbank]:
"""Return Genbank records corresponding to the given Taxonomy ID."""
return NotImplemented


def auto_update_otu(
repo: Repo,
otu: RepoOTU,
Expand Down Expand Up @@ -89,6 +99,48 @@ def auto_update_otu(
return repo.get_otu(otu.id)


class PrecachedRecordStore(BaseBatchRecordGetter):
"""Retrieves records from an indexed dictionary of records and
a batch fetch index set at initialization.
"""

def __init__(
self,
batch_fetch_index: dict[int, set[str]],
record_index: dict[str, NCBIGenbank],
):
self.batch_fetch_index = batch_fetch_index
self.record_index = record_index

def get_records(self, taxid: int) -> list[NCBIGenbank]:
accessions = self.batch_fetch_index.get(taxid, [])

otu_records = [
record
for accession in accessions
if (record := self.record_index.get(accession)) is not None
]

return otu_records


class RecordFetcher(BaseBatchRecordGetter):
"""Retrieves records from NCBI Nucleotide based on a batch fetch index
set at initialization.
"""

def __init__(
self, batch_fetch_index: dict[int, set[str]], ignore_cache: bool = False
):
self.batch_fetch_index = batch_fetch_index
self.ncbi = NCBIClient(ignore_cache)

def get_records(self, taxid: int) -> list[NCBIGenbank]:
accessions = self.batch_fetch_index.get(taxid, [])

return self.ncbi.fetch_genbank_records(accessions)


def batch_update_repo(
repo: Repo,
start_date: datetime.date | None = None,
Expand Down Expand Up @@ -126,14 +178,14 @@ def batch_update_repo(
else:
otu_iterator = repo.iter_otus()

taxid_new_accession_index = batch_fetch_new_accessions(
batch_fetch_index = batch_fetch_new_accessions(
otu_iterator,
modification_date_start=start_date,
ignore_cache=ignore_cache,
)

fetch_index_cache_path = _cache_fetch_index(
taxid_new_accession_index, repo.path / ".cache"
batch_fetch_index, repo.path / ".cache"
)

repo_logger.info("Fetch index cached", fetch_index_path=fetch_index_cache_path)
Expand All @@ -143,109 +195,75 @@ def batch_update_repo(
"Loading fetch index...", fetch_index_path=str(fetch_index_path)
)

taxid_new_accession_index = _load_fetch_index(fetch_index_path)
batch_fetch_index = _load_fetch_index(fetch_index_path)

if not taxid_new_accession_index:
if not batch_fetch_index:
logger.info("OTUs are up to date.")

return updated_otu_ids

fetch_set = {
accession
for otu_accessions in taxid_new_accession_index.values()
for accession in otu_accessions
}

logger.info(
"Batch fetch index contains potential new accessions.",
otu_count=len(taxid_new_accession_index),
otu_count=len(batch_fetch_index),
)

if precache_records:
fetch_set = {
accession
for otu_accessions in batch_fetch_index.values()
for accession in otu_accessions
}

logger.info("Precaching records...", accession_count=len(fetch_set))

indexed_records = batch_fetch_new_records(
record_index_by_accession = batch_fetch_new_records(
fetch_set,
chunk_size=chunk_size,
ignore_cache=ignore_cache,
)

if not indexed_records:
if not record_index_by_accession:
logger.info("No valid accessions found.")
return updated_otu_ids

logger.info(
"Checking new records against OTUs.",
otu_count=len(taxid_new_accession_index),
record_getter = PrecachedRecordStore(
batch_fetch_index, record_index_by_accession
)

for taxid, accessions in taxid_new_accession_index.items():
if (otu_id := repo.get_otu_id_by_taxid(taxid)) is None:
logger.debug("No corresponding OTU found in this repo", taxid=taxid)
continue

if skip_recently_updated and not _otu_is_cooled(
repo,
otu_id,
timestamp_current=operation_run_timestamp,
):
logger.info(
"This OTU was updated recently. Skipping...",
cooldown=UPDATE_COOLDOWN_INTERVAL_IN_DAYS,
otu_id=str(otu_id),
taxid=str(taxid),
)
continue

otu_records = [
record
for accession in accessions
if (record := indexed_records.get(accession)) is not None
]
if otu_records:
isolate_ids = _process_records_into_otu(
repo, repo.get_otu(otu_id), otu_records
)

if isolate_ids:
updated_otu_ids.add(otu_id)
else:
record_getter = RecordFetcher(batch_fetch_index)

repo.write_otu_update_history_entry(otu_id)
for taxid, accessions in batch_fetch_index.items():
if (otu_id := repo.get_otu_id_by_taxid(taxid)) is None:
logger.debug("No corresponding OTU found in this repo", taxid=taxid)
continue

else:
ncbi = NCBIClient(ignore_cache)

for taxid, accessions in taxid_new_accession_index.items():
if (otu_id := repo.get_otu_id_by_taxid(taxid)) is None:
logger.debug("No corresponding OTU found in this repo", taxid=taxid)
continue

if skip_recently_updated and not _otu_is_cooled(
repo,
otu_id,
timestamp_current=operation_run_timestamp,
):
logger.info(
"This OTU was updated recently. Skipping...",
cooldown=UPDATE_COOLDOWN_INTERVAL_IN_DAYS,
otu_id=str(otu_id),
taxid=str(taxid),
)
continue
if skip_recently_updated and not _otu_is_cooled(
repo,
otu_id,
timestamp_current=operation_run_timestamp,
):
logger.info(
"This OTU was updated recently. Skipping...",
cooldown=UPDATE_COOLDOWN_INTERVAL_IN_DAYS,
otu_id=str(otu_id),
taxid=str(taxid),
)
continue

otu_records = ncbi.fetch_genbank_records(accessions)
otu_records = record_getter.get_records(taxid)

if otu_records:
isolate_ids = _process_records_into_otu(
repo, repo.get_otu(otu_id), otu_records
)
if otu_records:
isolate_ids = promote_and_update_otu_from_records(
repo, repo.get_otu(otu_id), otu_records
)

if isolate_ids:
updated_otu_ids.add(otu_id)
if isolate_ids:
updated_otu_ids.add(otu_id)

repo.write_otu_update_history_entry(otu_id)
repo.write_otu_update_history_entry(otu_id)

repo_logger.info("Batch update complete.")
repo_logger.info("Batch update complete.", new_isolate_count=len(updated_otu_ids))

return updated_otu_ids

Expand Down Expand Up @@ -418,20 +436,18 @@ def update_isolate_from_records(
return isolate


def _process_records_into_otu(
def promote_and_update_otu_from_records(
repo: Repo,
otu: RepoOTU,
records: list[NCBIGenbank],
):
"""Promote new RefSeq accessions and add new isolates."""
genbank_records, refseq_records = [], []

for record in records:
if record.refseq:
refseq_records.append(record)

else:
genbank_records.append(record)

if promote_otu_accessions_from_records(
repo,
otu=repo.get_otu(otu.id),
Expand All @@ -442,7 +458,7 @@ def _process_records_into_otu(
new_isolate_ids = update_otu_with_records(
repo,
otu=otu,
records=genbank_records,
records=records,
)

repo.get_otu(otu.id)
Expand Down Expand Up @@ -472,7 +488,7 @@ def update_otu_with_accessions(
records = ncbi.fetch_genbank_records(accessions)

if records:
return _process_records_into_otu(repo, otu, records)
return promote_and_update_otu_from_records(repo, otu, records)


def update_otu_with_records(
Expand Down Expand Up @@ -637,25 +653,6 @@ def iter_fetch_list(
yield fetch_list[iterator * page_size : (iterator + 1) * page_size]


def _bin_refseq_records(
records: list[NCBIGenbank],
) -> tuple[list[NCBIGenbank], list[NCBIGenbank]]:
"""Return a list of GenBank records as two lists, RefSeq and non-RefSeq."""
refseq_records = []
non_refseq_records = []

for record in records:
if record.refseq:
refseq_records.append(record)
else:
non_refseq_records.append(record)

if len(refseq_records) + len(non_refseq_records) != len(records):
raise ValueError("Invalid total number of records")

return refseq_records, non_refseq_records


def _generate_datestamp_filename():
"""Get the current UTC date and return as a a filename_safe string."""
timestamp = arrow.utcnow().naive
Expand Down
Loading