Skip to content

Commit

Permalink
Merge pull request #39 from SwiftSeal/logging
Browse files Browse the repository at this point in the history
Update logging system
  • Loading branch information
SwiftSeal authored Jan 8, 2025
2 parents 91257fd + b750b02 commit 3e4514a
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 398 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ build-backend = "hatchling.build"

[project]
name = "resistify"
version = "1.0.1"
version = "1.1.0"
dependencies = [
"scikit-learn>=0.24.2",
"numpy",
"rich",
"rich-argparse",
"loguru",
"torch",
"fair-esm",
"transformers",
Expand Down
2 changes: 1 addition & 1 deletion resistify/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.1"
__version__ = "1.1.0"
31 changes: 31 additions & 0 deletions resistify/_loguru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import sys
from loguru import logger as _logger


class LoggerDelegator:
def __init__(self, logger):
self._logger = logger
self._configure_logger()

def _configure_logger(self):
# Remove default handler
self._logger.remove()

# Add a new handler with a default format and dynamic level
log_level = os.getenv("LOG_LEVEL", "INFO")
self._logger.add(
sink=sys.stdout,
format="[{time:HH:mm:ss}] <level>{level: <8}</level> {message}",
level=log_level,
)

def update_level(self, level):
os.environ["LOG_LEVEL"] = level # Set globally for spawned processes
self._configure_logger() # Reconfigure logger in the main process

def __getattr__(self, attr):
return getattr(self._logger, attr)


logger = LoggerDelegator(_logger)
45 changes: 22 additions & 23 deletions resistify/annotations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging

log = logging.getLogger(__name__)
from resistify._loguru import logger

nlr_classifications = ["RNL", "CNL", "TNL", "RN", "CN", "TN", "NL", "N"]

Expand Down Expand Up @@ -190,9 +188,9 @@ def add_annotation(self, domain, source, start, end, evalue=None, score=None):
start = int(start)
end = int(end)
if start > end:
log.error(f"Invalid annotation coordinates for {self.id}")
logger.error(f"Invalid annotation coordinates for {self.id}")
return
log.debug(f"Adding annotation {domain} to {self.id} from {start} to {end}")
logger.debug(f"Adding annotation {domain} to {self.id} from {start} to {end}")
self.annotations.append(Annotation(domain, start, end, evalue, score, source))
self.annotations.sort(key=lambda x: x.start)

Expand Down Expand Up @@ -223,13 +221,11 @@ def identify_cc_domains(self):
else:
# If we were in a dipping region and now the condition is false, record the region
if start is not None:
log.debug(f"Adding CC domain in {self.id} from {start} to {end}")
self.add_annotation("CC", "coconat", start, end)
start = None # Reset start for the next region

# If we ended in a dip region, capture the final one
if start is not None:
log.debug(f"Adding CC domain in {self.id} from {start} to {end}")
self.add_annotation("CC", "coconat", start, end)

def identify_lrr_domains(self, lrr_gap, lrr_length):
Expand Down Expand Up @@ -272,7 +268,9 @@ def classify_nlr(self):
collapsed_domain_string.append(domain)
collapsed_domain_string = "".join(collapsed_domain_string)

log.debug(f"Collapsed domain string for {self.id}: {collapsed_domain_string}")
logger.debug(
f"Collapsed domain string for {self.id}: {collapsed_domain_string}"
)

# Absolutely mawkit, but catch RC collapsed string which will occur when coconat is applied to rpw8
collapsed_domain_string = collapsed_domain_string.replace("RC", "R")
Expand All @@ -281,6 +279,7 @@ def classify_nlr(self):
for classification in nlr_classifications:
if classification in collapsed_domain_string:
self.classification = classification
self.type = "NLR"
break

# scavenge for missed classifications with motifs
Expand All @@ -291,25 +290,28 @@ def classify_nlr(self):
if annotation.domain == "NB-ARC":
nbarc_start = annotation.start
break
for motif in self.motifs["extEDVID"]:
if motif.position < nbarc_start:
self.add_annotation(
"CC",
"nlrexpress",
motif.position,
motif.position + 1,
)
self.classification = "C" + self.classification
continue

CC_motifs = [
motif
for motif in self.motifs["extEDVID"]
if motif.position < nbarc_start
]
TIR_motifs = [
item
for motif in TIR_MOTIFS
for item in self.motifs[motif]
if item.position < nbarc_start
]
# TIR motifs are pretty conserved, seems okay to take 1 as sufficient evidence
if len(TIR_motifs) > 0:

if len(CC_motifs) > 0:
self.add_annotation(
"CC",
"nlrexpress",
CC_motifs[0].position,
CC_motifs[-1].position,
)
self.classification = "C" + self.classification
elif len(TIR_motifs) > 0:
TIR_motifs.sort(key=lambda x: x.position)
self.add_annotation(
"TIR",
Expand All @@ -319,9 +321,6 @@ def classify_nlr(self):
)
self.classification = "T" + self.classification

if self.classification in nlr_classifications:
self.type = "NLR"

def is_rlp(self, extracellular_length=50):
tm_detected = False
n_terminal_length = 0
Expand Down
123 changes: 54 additions & 69 deletions resistify/coconat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import os
import tempfile
import warnings
from resistify._loguru import logger
from resistify.utility import log_percentage

log = logging.getLogger(__name__)
logging.getLogger("transformers").setLevel(logging.CRITICAL)

warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down Expand Up @@ -72,9 +73,9 @@ class EmbeddingProcessor:
def __init__(self, models_path):
# Initialize devices and models only once
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device == "cpu":
log.warning(
"GPU not available or detected, running on CPU. This will be slow..."
if self.device == torch.device("cpu"):
logger.warning(
"GPU not available or detected, running on CPU. This will be slower..."
)

if models_path is not None:
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self, models_path):
self.esm_model.eval()
self.batch_converter = self.esm_alphabet.get_batch_converter()

def process_prot_t5_embedding(self, sequences):
def process_prot_t5_embedding(self, sequence, length):
"""
Compute ProtT5 embeddings for given sequences.
Expand All @@ -115,13 +116,10 @@ def process_prot_t5_embedding(self, sequences):
Returns:
list: List of numpy embeddings
"""
lengths = [len(sequence) for sequence in sequences]
sequences = [
" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences
]
sequence = [" ".join(list(re.sub(r"[UZOB]", "X", sequence)))]

ids = self.prot_t5_tokenizer.batch_encode_plus(
sequences, add_special_tokens=True, padding="longest"
sequence, add_special_tokens=True, padding="longest"
)
input_ids = torch.tensor(ids["input_ids"]).to(self.device)
attention_mask = torch.tensor(ids["attention_mask"]).to(self.device)
Expand All @@ -132,15 +130,16 @@ def process_prot_t5_embedding(self, sequences):
)

embeddings = [
embedding_repr.last_hidden_state[i, : lengths[i]].detach().cpu().numpy()
for i in range(len(sequences))
embedding_repr.last_hidden_state[0, :length].detach().cpu().numpy()
]

return embeddings

def process_esm_embedding(self, chunk_ids, chunk_seqs):
def process_esm_embedding(self, seq_id, seq):
batch_labels, batch_strs, batch_tokens = self.batch_converter(
list(zip(chunk_ids, chunk_seqs))
[
(seq_id, seq),
]
)

batch_lens = (batch_tokens != self.esm_alphabet.padding_idx).sum(1)
Expand All @@ -153,8 +152,7 @@ def process_esm_embedding(self, chunk_ids, chunk_seqs):
token_representations = results["representations"][33]

embeddings = [
token_representations[i, 1 : tokens_len - 1].detach().cpu().numpy()
for i, tokens_len in enumerate(batch_lens)
token_representations[0, 1 : batch_lens[0] - 1].detach().cpu().numpy()
]

return embeddings
Expand All @@ -165,64 +163,60 @@ def coconat(sequences, models_path: str):
biocrf_path = os.path.join(os.path.dirname(__file__), "bin", "biocrf-static")
crf_model = os.path.join(os.path.dirname(__file__), "data", "crfModel")

log.debug("Loading embedder...")
logger.debug("Loading embedder...")
embedding_processor = EmbeddingProcessor(models_path)
log.debug("Loading registers model...")
logger.debug("Loading registers model...")
checkpoint = torch.load(registers_model)
register_model = MMModelLSTM()
register_model.load_state_dict(checkpoint["state_dict"])
register_model.eval()

for chunk_start in range(0, len(sequences), 5):
log.debug("Processing batch...")

chunk_ids, chunk_seqs, chunk_lengths = [], [], []
for sequence in sequences[chunk_start : chunk_start + 5]:
n_terminal_seq = sequence.nterminal_sequence
if n_terminal_seq is None:
log.debug(f"{sequence.id} has no N-terminus, skipping...")
continue
elif len(n_terminal_seq) < 5:
log.debug(f"{sequence.id} N-terminus too short for CoCoNat")
continue
elif len(n_terminal_seq) >= 1022:
log.warning(
f"{sequence.id} N-terminus quite long, errors might occur..."
)

chunk_ids.append(sequence.id)
chunk_seqs.append(n_terminal_seq)
chunk_lengths.append(len(n_terminal_seq))

if len(chunk_ids) == 0:
temp_dir = tempfile.TemporaryDirectory()
prediction_file = os.path.join(temp_dir.name, "predictions")
output_file = os.path.join(temp_dir.name, "out")
prefix_path = os.path.join(temp_dir.name, "crf")

total_iterations = len(sequences)
iteration = 0
for sequence in sequences:
logger.debug(f"Processing {sequence.id}...")

nterminal_seq = sequence.nterminal_sequence
nterminal_len = len(nterminal_seq)
if nterminal_seq is None:
logger.debug(f"{sequence.id} has no N-terminus, skipping...")
continue
elif nterminal_len < 5:
logger.debug(f"{sequence.id} N-terminus too short for CoCoNat")
continue
elif nterminal_len >= 1022:
logger.warning(
f"{sequence.id} N-terminus quite long, errors might occur..."
)

prot_t5_embeddings = embedding_processor.process_prot_t5_embedding(chunk_seqs)
prot_t5_embeddings = embedding_processor.process_prot_t5_embedding(
nterminal_seq, nterminal_len
)
esm_embeddings = embedding_processor.process_esm_embedding(
chunk_ids, chunk_seqs
sequence.id, nterminal_seq
)

logger.debug("Merging embeddings")
merged = [
torch.from_numpy(np.hstack((prot_t5_embeddings[i], esm_embeddings[i])))
for i in range(len(chunk_ids))
torch.from_numpy(np.hstack((prot_t5_embeddings[0], esm_embeddings[0])))
]

merged = torch.nn.utils.rnn.pad_sequence(merged, batch_first=True)

prediction = register_model(merged, chunk_lengths).detach().cpu().numpy()

temp_dir = tempfile.TemporaryDirectory()
prediction_file = os.path.join(temp_dir.name, "predictions")
output_file = os.path.join(temp_dir.name, "out")
prefix_path = os.path.join(temp_dir.name, "crf")
prediction = register_model(merged, [nterminal_len]).detach().cpu().numpy()

with open(prediction_file, "w") as outfile:
for i in range(prediction.shape[0]):
for j in range(chunk_lengths[i]):
prediction_values = " ".join([str(x) for x in prediction[i, j]])
outfile.write(f"{prediction_values} i\n")
outfile.write("\n")
for i in range(nterminal_len):
prediction_values = " ".join([str(x) for x in prediction[0, i]])
outfile.write(f"{prediction_values} i\n")
outfile.write("\n")

logger.debug("Running biocrf")
subprocess.run(
[
f"{biocrf_path}",
Expand All @@ -244,19 +238,10 @@ def coconat(sequences, models_path: str):
stderr=subprocess.PIPE,
)

cc_probabilities = {}

for i, sequence_id in enumerate(chunk_ids):
log.debug(
f"Loading crf probabilities for {sequence_id} in {prefix_path}_{i}"
)
probability_matrix = np.loadtxt(f"{prefix_path}_{i}")
# extract first column
cc_probability = 1 - probability_matrix[:, 0]
cc_probabilities[sequence_id] = cc_probability

for sequence in sequences:
if sequence.id in cc_probabilities.keys():
sequence.cc_probs = cc_probabilities[sequence.id]
probability_matrix = np.loadtxt(f"{prefix_path}_0")
cc_probability = 1 - probability_matrix[:, 0]
sequence.cc_probs = cc_probability
iteration += 1
log_percentage(iteration, total_iterations)

return sequences
Loading

0 comments on commit 3e4514a

Please sign in to comment.