From df47c5abf7c4d58f0656f4b6dbf613ada2f29bbe Mon Sep 17 00:00:00 2001 From: SwiftSeal Date: Mon, 13 Jan 2025 23:35:55 +0000 Subject: [PATCH] switch to concurrent futures --- resistify/nlrexpress.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/resistify/nlrexpress.py b/resistify/nlrexpress.py index 462727d..ebe70eb 100644 --- a/resistify/nlrexpress.py +++ b/resistify/nlrexpress.py @@ -4,7 +4,8 @@ import pickle import os import tempfile -from multiprocessing import cpu_count, get_context +from multiprocessing import cpu_count +import concurrent.futures from threadpoolctl import threadpool_limits import shutil import warnings @@ -128,7 +129,6 @@ def parse_jackhmmer(file, iteration=False): def nlrexpress(sequences, search_type, chunk_size, threads): - total_sequences = len(sequences) if threads is None: try: threads = len(os.sched_getaffinity(0)) @@ -147,28 +147,28 @@ def nlrexpress(sequences, search_type, chunk_size, threads): sequences[i : i + chunk_size] for i in range(0, len(sequences), chunk_size) ] - args = [(batch, jackhmmer_db.name, models) for batch in batches] - results = [] logger.info("Running NLRexpress - this could take a while...") - total_iterations = len(args) - iteration = 0 + iterations = 0 + total_iterations = len(batches) + results = [] # Need to use spawn otherwise the subprocesses will hang - with get_context("spawn").Pool(-(-threads // 2)) as pool: - for result in pool.imap(nlrexpress_subprocess, args): - results.append(result) - iteration += 1 - log_percentage(iteration, total_iterations) - - sequences = [seq for batch in results for seq in batch] - - if len(sequences) != total_sequences: + with concurrent.futures.ProcessPoolExecutor(max_workers = -(-threads // 2)) as executor: + futures = [executor.submit(nlrexpress_subprocess, (batch, jackhmmer_db.name, models)) for batch in batches] + + for future in concurrent.futures.as_completed(futures): + iterations += 1 + log_percentage(iterations, total_iterations) + for sequence in future.result(): + results.append(sequence) + + if len(results) != len(sequences): logger.critical( "Sequences dropped during NLRexpress - this should not happen and must be reported" ) - return sequences + return results def load_models(search_type):