Skip to content

Commit

Permalink
fix logger
Browse files Browse the repository at this point in the history
  • Loading branch information
SwiftSeal committed Jan 16, 2025
1 parent c792058 commit 9d50270
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion assets/resistify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"Google Colab is quite limited in the resources that it gives, so it will be a lot slower than running locally.\n",
"\n",
"In this example, we'll be downloading proteins from the potato genome and taking a random selection of ~1000.\n",
"If you'd like to use your own sequences, you can use the bar on the left to upload a file of your choice, modify the code to use your respective fasta file.\n",
"If you'd like to use your own sequences, you can use the bar on the left to upload a file of your choice, and modify the code to use your respective fasta file.\n",
"\n",
"To get started, use the bar at the top to start a session.\n",
"You'll want to select a GPU instance so that the PRR pipeline runs quickly!\n",
Expand Down
8 changes: 3 additions & 5 deletions resistify/coconat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tempfile
import warnings
from resistify._loguru import logger
from resistify.utility import log_percentage
from resistify.utility import ProgressLogger

logging.getLogger("transformers").setLevel(logging.CRITICAL)

Expand Down Expand Up @@ -176,8 +176,7 @@ def coconat(sequences, models_path: str):
output_file = os.path.join(temp_dir.name, "out")
prefix_path = os.path.join(temp_dir.name, "crf")

total_iterations = len(sequences)
iteration = 0
progress_logger = ProgressLogger(len(sequences))
for sequence in sequences:
logger.debug(f"Processing {sequence.id}...")

Expand Down Expand Up @@ -244,7 +243,6 @@ def coconat(sequences, models_path: str):
probability_matrix = np.loadtxt(f"{prefix_path}_0")
cc_probability = 1 - probability_matrix[:, 0]
sequence.cc_probs = cc_probability
iteration += 1
log_percentage(iteration, total_iterations)
progress_logger.update()

return sequences
8 changes: 3 additions & 5 deletions resistify/nlrexpress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from threadpoolctl import threadpool_limits
import shutil
import warnings
from resistify.utility import log_percentage
from resistify.utility import ProgressLogger
from resistify._loguru import logger

# Version 1.3 of sklearn introduced InconsistentVersionWarning, fall back to UserWarning if not available
Expand Down Expand Up @@ -149,8 +149,7 @@ def nlrexpress(sequences, search_type, chunk_size, threads):

logger.info("Running NLRexpress - this could take a while...")

iterations = 0
total_iterations = len(batches)
progress_logger = ProgressLogger(len(batches))
results = []
with concurrent.futures.ProcessPoolExecutor(
max_workers=-(-threads // 2)
Expand All @@ -161,8 +160,7 @@ def nlrexpress(sequences, search_type, chunk_size, threads):
]

for future in concurrent.futures.as_completed(futures):
iterations += 1
log_percentage(iterations, total_iterations)
progress_logger.update()
for sequence in future.result():
results.append(sequence)

Expand Down
13 changes: 3 additions & 10 deletions resistify/tmbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from resistify.utility import log_percentage
from resistify.utility import ProgressLogger
import logging
import warnings
from resistify._loguru import logger
Expand Down Expand Up @@ -452,8 +452,7 @@ def tmbed(sequences, models_path):
"o": "outside",
}

total_iterations = len(sequences)
iteration = 0
progress_logger = ProgressLogger(len(sequences))
for sequence in sequences:
try:
logger.debug(f"Predicting transmembrane domains for {sequence.id}...")
Expand All @@ -464,11 +463,6 @@ def tmbed(sequences, models_path):
)
continue

# CPU alternative, implement fallback?
# encoder.to_cpu()
# torch.cuda.empty_cache()
# embeddings = encoder.embed(sequences)

embedding = embedding.to(device=device)
embedding = embedding.to(dtype=torch.float32)

Expand Down Expand Up @@ -511,7 +505,6 @@ def tmbed(sequences, models_path):
len(sequence.seq),
)

iteration += 1
log_percentage(iteration, total_iterations)
progress_logger.update()

return sequences
27 changes: 20 additions & 7 deletions resistify/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,26 @@
from resistify._loguru import logger


def log_percentage(n, total):
if total < 10:
logger.info(f"{n} of {total} complete")
else:
percent_complete = int((n / total) * 100)
if percent_complete % 10 == 0:
logger.info(f"{percent_complete}% complete")
class ProgressLogger:
def __init__(self, total_count):
self.total_count = total_count
self.current_count = 0
self.last_reported_percent = -1 # Initialize with an invalid percentage

def update(self):
self.current_count += 1
if self.total_count < 10:
# For small totals, report as "n of total"
logger.info(f"{self.current_count} of {self.total_count} complete")
else:
# Calculate percentage
percent_complete = int((self.current_count / self.total_count) * 100)
if (
percent_complete % 10 == 0
and percent_complete > self.last_reported_percent
):
logger.info(f"{percent_complete}% complete")
self.last_reported_percent = percent_complete


def create_output_directory(outdir):
Expand Down

0 comments on commit 9d50270

Please sign in to comment.