Skip to content

Commit

Permalink
switch to concurrent futures
Browse files Browse the repository at this point in the history
  • Loading branch information
SwiftSeal committed Jan 13, 2025
1 parent 7bb48d8 commit df47c5a
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions resistify/nlrexpress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pickle
import os
import tempfile
from multiprocessing import cpu_count, get_context
from multiprocessing import cpu_count
import concurrent.futures
from threadpoolctl import threadpool_limits
import shutil
import warnings
Expand Down Expand Up @@ -128,7 +129,6 @@ def parse_jackhmmer(file, iteration=False):


def nlrexpress(sequences, search_type, chunk_size, threads):
total_sequences = len(sequences)
if threads is None:
try:
threads = len(os.sched_getaffinity(0))
Expand All @@ -147,28 +147,28 @@ def nlrexpress(sequences, search_type, chunk_size, threads):
sequences[i : i + chunk_size] for i in range(0, len(sequences), chunk_size)
]

args = [(batch, jackhmmer_db.name, models) for batch in batches]

results = []
logger.info("Running NLRexpress - this could take a while...")

total_iterations = len(args)
iteration = 0
iterations = 0
total_iterations = len(batches)
results = []
# Need to use spawn otherwise the subprocesses will hang
with get_context("spawn").Pool(-(-threads // 2)) as pool:
for result in pool.imap(nlrexpress_subprocess, args):
results.append(result)
iteration += 1
log_percentage(iteration, total_iterations)

sequences = [seq for batch in results for seq in batch]

if len(sequences) != total_sequences:
with concurrent.futures.ProcessPoolExecutor(max_workers = -(-threads // 2)) as executor:
futures = [executor.submit(nlrexpress_subprocess, (batch, jackhmmer_db.name, models)) for batch in batches]

for future in concurrent.futures.as_completed(futures):
iterations += 1
log_percentage(iterations, total_iterations)
for sequence in future.result():
results.append(sequence)

if len(results) != len(sequences):
logger.critical(
"Sequences dropped during NLRexpress - this should not happen and must be reported"
)

return sequences
return results


def load_models(search_type):
Expand Down

0 comments on commit df47c5a

Please sign in to comment.