-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve colabfold MSAs to include unpaired MSA hits (#213)
* Typo in logging message * Call colabfold mmseqs server once for paired MSAs, and once for unpaired MSAs * Update comment * Fix issue with monomers and constructing source databases * Avoid using string literals * Address PR feedback
- Loading branch information
Showing
2 changed files
with
94 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,9 @@ | |
from tqdm import tqdm | ||
|
||
from chai_lab import __version__ | ||
from chai_lab.data.parsing.fasta import read_fasta | ||
from chai_lab.data.parsing.fasta import Fasta, read_fasta | ||
from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence | ||
from chai_lab.data.parsing.msas.data_source import MSADataSource | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -26,7 +27,7 @@ | |
) | ||
|
||
|
||
# N.B. this code is copied from https://github.com/sokrypton/ColabFold | ||
# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold | ||
# and follows the license in that repository | ||
@typing.no_type_check # Original ColabFold code was not well typed | ||
def _run_mmseqs2( | ||
|
@@ -41,6 +42,7 @@ def _run_mmseqs2( | |
host_url="https://api.colabfold.com", | ||
user_agent: str = "", | ||
) -> list[str] | tuple[list[str], list[str]]: | ||
"""Return a block of a3m lines for each of the input sequences in x.""" | ||
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" | ||
|
||
headers = {} | ||
|
@@ -342,19 +344,26 @@ def download(ID, path): | |
return (a3m_lines, template_paths) if use_templates else a3m_lines | ||
|
||
|
||
def _is_padding_msa_row(sequence: str) -> bool: | ||
"""Check if the given MSA sequence is a a padding sequence.""" | ||
seq_chars = set(sequence) | ||
return len(seq_chars) == 1 and seq_chars.pop() == "-" | ||
|
||
|
||
def generate_colabfold_msas( | ||
protein_seqs: list[str], | ||
msa_dir: Path, | ||
msa_server_url: str, | ||
write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging | ||
): | ||
""" | ||
Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold) | ||
server. No-op if no protein sequences are given. | ||
N.B. the MSAs in our technical report were generated using jackhmmer, not | ||
N.B.: | ||
- the MSAs in our technical report were generated using jackhmmer, not | ||
ColabFold, so we would expect some difference in results. | ||
This implementation also relies on ColabFold's chain pairing algorithm | ||
- this implementation relies on ColabFold's chain pairing algorithm | ||
rather than using Chai-1's own algorithm, which could also lead to | ||
differences in results. | ||
|
@@ -369,52 +378,108 @@ def generate_colabfold_msas( | |
with tempfile.TemporaryDirectory() as tmp_dir_path: | ||
tmp_dir = Path(tmp_dir_path) | ||
|
||
mmseqs_paired_dir = tmp_dir / "mmseqs_paired" | ||
mmseqs_paired_dir.mkdir() | ||
|
||
mmseqs_dir = tmp_dir / "mmseqs" | ||
mmseqs_dir.mkdir() | ||
|
||
a3ms_dir = tmp_dir / "a3ms" | ||
a3ms_dir = (tmp_dir if not write_a3m_to_msa_dir else msa_dir) / "a3ms" | ||
a3ms_dir.mkdir() | ||
|
||
# Generate MSAs for each protein chain | ||
logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences") | ||
msas = _run_mmseqs2( | ||
|
||
# In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows | ||
# and each row is already paired to have the same species. As such, we insert pairing key | ||
# as the i-th index of the sequence so long as it isn't a padding sequence (all -) | ||
paired_msas: list[str] | ||
if len(protein_seqs) > 1: | ||
paired_msas = _run_mmseqs2( | ||
protein_seqs, | ||
mmseqs_paired_dir, | ||
use_pairing=True, | ||
host_url=msa_server_url, | ||
user_agent=f"chai-lab/{__version__} [email protected]", | ||
) | ||
else: | ||
# If we only have a single protein chain, there are no paired MSAs by definition | ||
paired_msas = [""] * len(protein_seqs) | ||
|
||
# MSAs without pairing logic attached; may include sequences not contained in the paired MSA | ||
# Needs a second call as the colabfold server returns either paired or unpaired, not both | ||
per_chain_msas = _run_mmseqs2( | ||
protein_seqs, | ||
mmseqs_dir, | ||
# N.B. we can set this to False to disable pairing | ||
use_pairing=len(protein_seqs) > 1, | ||
use_pairing=False, | ||
host_url=msa_server_url, | ||
user_agent=f"chai-lab/{__version__} [email protected]", | ||
) | ||
assert isinstance(msas, list) | ||
|
||
# Process the MSAs into our internal format | ||
for protein_seq, msa in zip(protein_seqs, msas, strict=True): | ||
# Write out an A3M file | ||
a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m" | ||
a3m_path.write_text(msa) | ||
|
||
# Convert the A3M file into aligned parquet files | ||
msa_fasta = read_fasta(a3m_path) | ||
headers, msa_seqs = zip(*msa_fasta) | ||
for protein_seq, pair_msa, single_msa in zip( | ||
protein_seqs, paired_msas, per_chain_msas, strict=True | ||
): | ||
# Write out an A3M file for both | ||
hkey = hash_sequence(protein_seq.upper()) | ||
pair_a3m_path = a3ms_dir / f"{hkey}.pair.a3m" | ||
pair_a3m_path.write_text(pair_msa) | ||
single_a3m_path = a3ms_dir / f"{hkey}.single.a3m" | ||
single_a3m_path.write_text(single_msa) | ||
|
||
## Convert the A3M file into aligned parquet files | ||
# Set the pairing key as the ith-index in the sequences, skip over sequences that have | ||
# been inserted as padding as our internal pairing logic will match on pairing key. | ||
paired_fasta: list[tuple[str, str, str]] = [ | ||
(str(pairkey), record.header, record.sequence) | ||
for pairkey, record in enumerate(read_fasta(pair_a3m_path)) | ||
if not _is_padding_msa_row(record.sequence) | ||
] | ||
pairing_key, paired_headers, paired_msa_seqs = ( | ||
zip(*paired_fasta) if paired_fasta else ((), (), ()) | ||
) | ||
unique_paired_msa_seqs = set(paired_msa_seqs) | ||
|
||
# Non-paired MSA sequences that weren't already covered in the paired MSA; skip header | ||
single_fasta: list[Fasta] = [ | ||
record | ||
for i, record in enumerate(read_fasta(single_a3m_path)) | ||
if ( | ||
i > 0 | ||
and not _is_padding_msa_row(record.sequence) | ||
and record.sequence not in unique_paired_msa_seqs | ||
) | ||
] | ||
single_headers = [record.header for record in single_fasta] | ||
single_msa_seqs = [record.sequence for record in single_fasta] | ||
# Create null pairing keys for each of the entries in the single MSA seq | ||
single_null_pair_keys = [""] * len(single_msa_seqs) | ||
|
||
# This shouldn't have much of an effect on the model, but we make | ||
# a best effort to synthesize a source database anyway | ||
# NOTE we already dropped the query row from the single MSAs so no need to slice | ||
source_databases = ["query"] + [ | ||
"uniref90" if h.startswith("UniRef") else "bfd_uniclust" | ||
for h in headers[1:] | ||
( | ||
MSADataSource.UNIREF90.value | ||
if h.startswith("UniRef") | ||
else MSADataSource.BFD_UNICLUST.value | ||
) | ||
for h in (list(paired_headers) + single_headers)[1:] | ||
] | ||
|
||
# Combine information across paired and single hits | ||
all_sequences = list(paired_msa_seqs) + single_msa_seqs | ||
all_pairing_keys = list(pairing_key) + single_null_pair_keys | ||
assert ( | ||
len(all_sequences) == len(all_pairing_keys) == len(source_databases) | ||
), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}" | ||
|
||
# Map the MSAs to our internal format | ||
aligned_df = pd.DataFrame( | ||
data=dict( | ||
sequence=msa_seqs, | ||
sequence=all_sequences, | ||
source_database=source_databases, | ||
# ColabFold does not return taxonomies from its API, so we | ||
# can't rely on our internal chain pairing logic. As an | ||
# alternative, we could disable ColabFold pairing and rely | ||
# on a mapping from sequence ~> taxonomy, which would allow | ||
# us to use our internal pairing logic. | ||
pairing_key="", | ||
pairing_key=all_pairing_keys, | ||
comment="", | ||
), | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters