From 9d50270005b188cdb1c234c0607c77fbb07aed21 Mon Sep 17 00:00:00 2001 From: SwiftSeal Date: Thu, 16 Jan 2025 00:25:18 +0000 Subject: [PATCH] fix logger --- assets/resistify.ipynb | 2 +- resistify/coconat.py | 8 +++----- resistify/nlrexpress.py | 8 +++----- resistify/tmbed.py | 13 +++---------- resistify/utility.py | 27 ++++++++++++++++++++------- 5 files changed, 30 insertions(+), 28 deletions(-) diff --git a/assets/resistify.ipynb b/assets/resistify.ipynb index 8d09f7f..77ffc36 100644 --- a/assets/resistify.ipynb +++ b/assets/resistify.ipynb @@ -10,7 +10,7 @@ "Google Colab is quite limited in the resources that it gives, so it will be a lot slower than running locally.\n", "\n", "In this example, we'll be downloading proteins from the potato genome and taking a random selection of ~1000.\n", - "If you'd like to use your own sequences, you can use the bar on the left to upload a file of your choice, modify the code to use your respective fasta file.\n", + "If you'd like to use your own sequences, you can use the bar on the left to upload a file of your choice, and modify the code to use your respective fasta file.\n", "\n", "To get started, use the bar at the top to start a session.\n", "You'll want to select a GPU instance so that the PRR pipeline runs quickly!\n", diff --git a/resistify/coconat.py b/resistify/coconat.py index c2107ce..8bcd4c6 100644 --- a/resistify/coconat.py +++ b/resistify/coconat.py @@ -10,7 +10,7 @@ import tempfile import warnings from resistify._loguru import logger -from resistify.utility import log_percentage +from resistify.utility import ProgressLogger logging.getLogger("transformers").setLevel(logging.CRITICAL) @@ -176,8 +176,7 @@ def coconat(sequences, models_path: str): output_file = os.path.join(temp_dir.name, "out") prefix_path = os.path.join(temp_dir.name, "crf") - total_iterations = len(sequences) - iteration = 0 + progress_logger = ProgressLogger(len(sequences)) for sequence in sequences: logger.debug(f"Processing {sequence.id}...") @@ -244,7 +243,6 @@ def coconat(sequences, models_path: str): probability_matrix = np.loadtxt(f"{prefix_path}_0") cc_probability = 1 - probability_matrix[:, 0] sequence.cc_probs = cc_probability - iteration += 1 - log_percentage(iteration, total_iterations) + progress_logger.update() return sequences diff --git a/resistify/nlrexpress.py b/resistify/nlrexpress.py index 5acc054..af0e692 100644 --- a/resistify/nlrexpress.py +++ b/resistify/nlrexpress.py @@ -9,7 +9,7 @@ from threadpoolctl import threadpool_limits import shutil import warnings -from resistify.utility import log_percentage +from resistify.utility import ProgressLogger from resistify._loguru import logger # Version 1.3 of sklearn introduced InconsistentVersionWarning, fall back to UserWarning if not available @@ -149,8 +149,7 @@ def nlrexpress(sequences, search_type, chunk_size, threads): logger.info("Running NLRexpress - this could take a while...") - iterations = 0 - total_iterations = len(batches) + progress_logger = ProgressLogger(len(batches)) results = [] with concurrent.futures.ProcessPoolExecutor( max_workers=-(-threads // 2) @@ -161,8 +160,7 @@ def nlrexpress(sequences, search_type, chunk_size, threads): ] for future in concurrent.futures.as_completed(futures): - iterations += 1 - log_percentage(iterations, total_iterations) + progress_logger.update() for sequence in future.result(): results.append(sequence) diff --git a/resistify/tmbed.py b/resistify/tmbed.py index 4485708..173f9d3 100644 --- a/resistify/tmbed.py +++ b/resistify/tmbed.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from resistify.utility import log_percentage +from resistify.utility import ProgressLogger import logging import warnings from resistify._loguru import logger @@ -452,8 +452,7 @@ def tmbed(sequences, models_path): "o": "outside", } - total_iterations = len(sequences) - iteration = 0 + progress_logger = ProgressLogger(len(sequences)) for sequence in sequences: try: logger.debug(f"Predicting transmembrane domains for {sequence.id}...") @@ -464,11 +463,6 @@ def tmbed(sequences, models_path): ) continue - # CPU alternative, implement fallback? - # encoder.to_cpu() - # torch.cuda.empty_cache() - # embeddings = encoder.embed(sequences) - embedding = embedding.to(device=device) embedding = embedding.to(dtype=torch.float32) @@ -511,7 +505,6 @@ def tmbed(sequences, models_path): len(sequence.seq), ) - iteration += 1 - log_percentage(iteration, total_iterations) + progress_logger.update() return sequences diff --git a/resistify/utility.py b/resistify/utility.py index 66b6ca7..c031550 100644 --- a/resistify/utility.py +++ b/resistify/utility.py @@ -9,13 +9,26 @@ from resistify._loguru import logger -def log_percentage(n, total): - if total < 10: - logger.info(f"{n} of {total} complete") - else: - percent_complete = int((n / total) * 100) - if percent_complete % 10 == 0: - logger.info(f"{percent_complete}% complete") +class ProgressLogger: + def __init__(self, total_count): + self.total_count = total_count + self.current_count = 0 + self.last_reported_percent = -1 # Initialize with an invalid percentage + + def update(self): + self.current_count += 1 + if self.total_count < 10: + # For small totals, report as "n of total" + logger.info(f"{self.current_count} of {self.total_count} complete") + else: + # Calculate percentage + percent_complete = int((self.current_count / self.total_count) * 100) + if ( + percent_complete % 10 == 0 + and percent_complete > self.last_reported_percent + ): + logger.info(f"{percent_complete}% complete") + self.last_reported_percent = percent_complete def create_output_directory(outdir):