Skip to content

Commit

Permalink
parse_a3m_fast (dptech-corp#28)
Browse files Browse the repository at this point in the history
* parse_a3m_fast

* fix typo

* rewrite

* advance

* accel make msa feats

* change default fast

* fix

Co-authored-by: ziyao <[email protected]>
Co-authored-by: Ziyao Li <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2022
1 parent 1cfcbfa commit 55a68f9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
40 changes: 39 additions & 1 deletion unifold/msa/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def parse_stockholm(stockholm_string: str) -> Msa:
)


def parse_a3m(a3m_string: str) -> Msa:
def parse_a3m(a3m_string: str, fast: bool = True) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
Expand All @@ -174,6 +174,8 @@ def parse_a3m(a3m_string: str) -> Msa:
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
if fast:
return parse_a3m_fast(a3m_string)
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
Expand All @@ -196,6 +198,42 @@ def parse_a3m(a3m_string: str) -> Msa:
descriptions=descriptions,
)

src_table = string.ascii_uppercase + "-"
dest_table = "".join([','] * len(src_table))

map_trans = str.maketrans(src_table, dest_table)
deletion_table = str.maketrans("", "", string.ascii_lowercase)

def parse_a3m_fast(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
sequences, descriptions = parse_fasta(a3m_string)

aligned_sequences = [s.translate(deletion_table) for s in sequences]

comma_seqs = [s.translate(map_trans) for s in sequences]
deletion_matrix = [
[len(x) for x in s.split(",")] for s in comma_seqs
]

return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions,
)

def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str
Expand Down
19 changes: 13 additions & 6 deletions unifold/msa/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def make_sequence_features(
features["sequence"] = np.array([sequence.encode("utf-8")], dtype=np.object_)
return features

hhblits_map_trans = str.maketrans(
''.join(residue_constants.HHBLITS_AA_TO_ID.keys()),
''.join(chr(v) for v in residue_constants.HHBLITS_AA_TO_ID.values())
)

def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
Expand All @@ -67,7 +71,9 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
np.frombuffer(
(bytes(sequence.translate(hhblits_map_trans), "ascii")), dtype=np.uint8
)
)
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
Expand All @@ -77,11 +83,12 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:

num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array([num_alignments] * num_res, dtype=np.int32)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
features = {
"deletion_matrix_int": np.array(deletion_matrix, dtype=np.int32),
"msa": np.stack(int_msa),
"num_alignments": np.array([num_alignments] * num_res, dtype=np.int32),
"msa_species_identifiers": np.array(species_ids, dtype=np.object_),
}
return features


Expand Down

0 comments on commit 55a68f9

Please sign in to comment.