Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Guard against NoneType in mutation-sequence logic #73

Open
wants to merge 1 commit into
base: enhanced_hmm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 105 additions & 20 deletions advntr/hmm_alignment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
from advntr.models import load_unique_vntrs_data
from advntr.models import load_unique_vntrs_data
from advntr.hmm_utils import get_repeating_unit_state_count

from collections import defaultdict, OrderedDict
from glob import glob


def get_first_match_state_position(visited_states):
"""
Return the numeric index portion of the first match state ('M') from visited_states.

:param visited_states: List of HMM states visited during alignment
:return: The integer index from the first match state, or None if not found
"""
for state in visited_states:
if state.startswith("M"):
return int(state.split("_")[0][1:])


def get_aligned_read(ru_index, ru_sequence, seq, visited_states, mutation):
"""
Return an aligned read (as a string) for the given repeat unit (RU) index, RU sequence,
raw read sequence, visited states, and the associated mutation.

:param ru_index: The index of the repeat unit to align
:param ru_sequence: The reference repeat unit sequence
:param seq: The full read sequence
:param visited_states: List of HMM states visited during alignment
:param mutation: String describing the mutation
:return: A string representing the aligned region of the read
"""
seq_index = 0
aligned_seq = ""

Expand Down Expand Up @@ -56,14 +72,37 @@ def get_aligned_read(ru_index, ru_sequence, seq, visited_states, mutation):


def is_multiple_mutation(mutation):
"""
Check if the mutation string includes multiple events, delimited by '&'.

:param mutation: The mutation string (e.g., 'D12_2' or 'D12_2&I12_2_A_LEN1')
:return: True if multiple mutations exist, otherwise False
"""
return True if "&" in mutation else False


def is_matching_state(state):
"""
Check if a given state is a matching or insertion state.

:param state: The HMM state string
:return: True if it's 'M' or 'I', otherwise False
"""
return True if state.startswith("M") or state.startswith("I") else False


def get_emitted_basepair_from_visited_states(state, visited_states, sequence):
"""
Return the base that a state emits. For 'I' or 'M' states, this function
finds where the state appears among the visited states, counting only
those that emit bases, and uses that count to index into the sequence.

:param state: The HMM state (must start with 'I' or 'M')
:param visited_states: List of visited states in the alignment path
:param sequence: The read sequence
:return: The emitted base
:raises ValueError: If called for a 'D' state
"""
if state.startswith("D"):
raise ValueError("Deletion state doesn't emit base")
base_pair_idx = 0
Expand All @@ -76,6 +115,14 @@ def get_emitted_basepair_from_visited_states(state, visited_states, sequence):


def get_modified_base_count_for_reference(detected_mutation):
"""
Return the integer number of inserted bases (as taken from the mutation string).
The current logic simply extracts the final digit from the mutation string,
which is assumed to be the insertion length.

:param detected_mutation: The mutation string (e.g., 'I12_2_A_LEN3')
:return: The integer count of inserted bases
"""
insertion_count = int(detected_mutation[-1]) # IX_X_LENX
return insertion_count
# deletion_count = detected_mutation.count("D")
Expand All @@ -85,6 +132,17 @@ def get_modified_base_count_for_reference(detected_mutation):


def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference_vntr_db=None, ref_vntr_dict=None):
"""
Parse an advntr logfile, find specified mutations (or all if none given),
and generate an alignment file (.aln) showing references and reads for each mutation.

:param advntr_logfile: The path to the .log file produced by advntr
:param output_mutations: A set of mutations to focus on (or None for all)
:param out_folder: Where to place the .aln output
:param reference_vntr_db: Path to a SQLite DB containing reference VNTR data (optional if ref_vntr_dict is given)
:param ref_vntr_dict: A dictionary mapping VNTR IDs to VNTR objects (optional if reference_vntr_db is given)
:raises ValueError: If no reference input is given, or if there's no intersection between target mutations
"""
if reference_vntr_db is None and ref_vntr_dict is None:
raise ValueError("Either reference DB or dictionary should be given")
if reference_vntr_db is not None:
Expand Down Expand Up @@ -113,22 +171,31 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference
out.write("Target mutations {}\n".format(target_mutations))

# Find related sequences
from collections import defaultdict
vid_read_length = defaultdict(int)
vid_to_aln_info = defaultdict(lambda: defaultdict(list))

vid = None # <-- ensure 'vid' is initialized

with open(advntr_logfile, "r") as f:
for line in f:
if "INFO:Using read length" in line: # INFO:Using read length [read_length]
read_length = int(line.split(" ")[-1])
vid_read_length[vid] = read_length
if vid is not None:
vid_read_length[vid] = read_length

if "DEBUG:finding" in line: # DEBUG:finding repeat count from alignment file for [vid]
vid = int(line.split(" ")[-1])
reference_vntr = ref_vntr_dict[vid]
patterns = reference_vntr.get_repeat_segments()
pattern_clusters = [[pattern] * patterns.count(pattern) for pattern in sorted(list(set(patterns)))]

if "DEBUG:ReadName:" in line: # DEBUG:ReadName:[str]
read_name = line[30 + 9:].strip()

if "DEBUG:Read:" in line: # DEBUG:Read:[sequence]
sequence = line[30 + 5:].strip()

if "DEBUG:VisitedStates:" in line: # DEBUG:VisitedStates:['state1', ...]
visited = line[line.index('[') + 1:-2]
split = visited.split(', ')
Expand Down Expand Up @@ -193,24 +260,22 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference
continue
if ru_state_count['partial_start']['I'] != ru_state_count['partial_start']['D']:
if current_state.startswith('I'):
current_state += '_' + get_emitted_basepair_from_visited_states(current_state,
visited_states,
sequence)
current_state += '_' + get_emitted_basepair_from_visited_states(
current_state, visited_states, sequence)
mutation_count_temp[current_state] = mutation_count_temp.get(current_state, 0) + 1
continue

# Reads ending with a partially observed repeat unit
if current_repeat >= fully_observed_ru_count:
if current_repeat is not None and current_repeat >= fully_observed_ru_count:
if 'partial_end' in ru_state_count:
if ru_state_count['partial_end']['M'] < 5:
continue
if ru_state_count['partial_end']['S'] >= 4:
continue
if ru_state_count['partial_end']['I'] != ru_state_count['partial_end']['D']:
if current_state.startswith('I'):
current_state += '_' + get_emitted_basepair_from_visited_states(current_state,
visited_states,
sequence)
current_state += '_' + get_emitted_basepair_from_visited_states(
current_state, visited_states, sequence)
mutation_count_temp[current_state] = mutation_count_temp.get(current_state, 0) + 1
continue

Expand Down Expand Up @@ -240,8 +305,8 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference

# TODO If there are run of insertions, the sequence should be different
if current_state.startswith('I'):
current_state += '_' + get_emitted_basepair_from_visited_states(current_state, visited_states,
sequence)
current_state += '_' + get_emitted_basepair_from_visited_states(
current_state, visited_states, sequence)

mutation_count_temp[current_state] = mutation_count_temp.get(current_state, 0) + 1

Expand All @@ -258,7 +323,7 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference
if temp_mutation in target_mutations:
vid_to_aln_info[vid][temp_mutation].append((sequence, visited_states, read_name))
else:
sorted_temp_mutations = mutation_count_temp.items()
sorted_temp_mutations = list(mutation_count_temp.items())
prev_mutation = sorted_temp_mutations[0][0]
mutation_sequence = prev_mutation
if prev_mutation.startswith("I"):
Expand All @@ -282,8 +347,13 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference
if temp_mutation.startswith("D"):
# Case 1: D(i-1), D(i),
# In this case, the deletion is connected to the previous mutation sequence and skip
if prev_mutation_index + 1 == current_mutation_index and prev_hmm_index == current_hmm_index: # Only possible with D(i-1)
mutation_sequence += '&' + temp_mutation
if (prev_mutation_index + 1 == current_mutation_index
and prev_hmm_index == current_hmm_index):
# <-- FIX: handle None mutation_sequence
if mutation_sequence is None:
mutation_sequence = temp_mutation
else:
mutation_sequence += '&' + temp_mutation
# Case 2: I/D(j), D(i), j < i-1
# In this case, they are not connected (This should be rare, two separated deletions in a RU)
else:
Expand All @@ -296,7 +366,8 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference

if temp_mutation.startswith("I"):
# Case 3: D(i), I(i)
if prev_mutation_index == current_mutation_index and prev_hmm_index == current_hmm_index: # Only possible with D(i)
if (prev_mutation_index == current_mutation_index
and prev_hmm_index == current_hmm_index):
# Add the insertion and done
mutation_sequence += "&{}_LEN{}".format(temp_mutation,
mutation_count_temp[temp_mutation])
Expand Down Expand Up @@ -356,23 +427,37 @@ def generate_aln(advntr_logfile, output_mutations=None, out_folder="", reference
if mutation.startswith("I"):
mutation_position = int(mutation.split("_")[0][1:])
break
ru_sequence = ru_sequence[:mutation_position] + "-" * modified_base_count + ru_sequence[mutation_position:]
ru_sequence = (ru_sequence[:mutation_position]
+ "-" * modified_base_count
+ ru_sequence[mutation_position:])
else:
if detected_mutation.startswith("I"):
mutation_position = int(detected_mutation.split("_")[0][1:])
modified_base_count = get_modified_base_count_for_reference(detected_mutation)
ru_sequence = ru_sequence[:mutation_position] + "-" * modified_base_count + ru_sequence[mutation_position:]
ru_sequence = (ru_sequence[:mutation_position]
+ "-" * modified_base_count
+ ru_sequence[mutation_position:])

out.write("Ref RU: {}{}".format(ru_sequence, "\n"))
out.write("Ref RU: {}{}\n".format(ru_sequence, ""))

seq_state_readname_list = vid_to_aln_info[vid][detected_mutation]
for i, seq_state_readname in enumerate(seq_state_readname_list):
aligned_read = get_aligned_read(ru_index, ru_sequence, seq_state_readname[0], seq_state_readname[1], detected_mutation)
aligned_read = get_aligned_read(ru_index,
ru_sequence,
seq_state_readname[0],
seq_state_readname[1],
detected_mutation)
read_name = seq_state_readname[-1]
out.write("Read{:>2}: {}\t{}\n".format(i, aligned_read, read_name))


def get_samples_having_muc1_insertion(input_files):
"""
Return a list of samples (input files) containing a specific MUC1 insertion mutation.

:param input_files: List of advntr .log files
:return: A list of file names where the insertion 'I22_2_G_LEN1' was detected
"""
selected_samples = []
for input_file in input_files:
with open(input_file, "r") as f:
Expand Down
Loading