Skip to content

Commit

Permalink
Improve colabfold MSAs to include unpaired MSA hits (#213)
Browse files Browse the repository at this point in the history
* Typo in logging message

* Call colabfold mmseqs server once for paired MSAs, and once for unpaired MSAs

* Update comment

* Fix issue with monomers and constructing source databases

* Avoid using string literals

* Address PR feedback
  • Loading branch information
wukevin authored Dec 5, 2024
1 parent e7b6624 commit 6061e41
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 28 deletions.
119 changes: 92 additions & 27 deletions chai_lab/data/dataset/msas/colabfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from tqdm import tqdm

from chai_lab import __version__
from chai_lab.data.parsing.fasta import read_fasta
from chai_lab.data.parsing.fasta import Fasta, read_fasta
from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence
from chai_lab.data.parsing.msas.data_source import MSADataSource

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +27,7 @@
)


# N.B. this code is copied from https://github.com/sokrypton/ColabFold
# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold
# and follows the license in that repository
@typing.no_type_check # Original ColabFold code was not well typed
def _run_mmseqs2(
Expand All @@ -41,6 +42,7 @@ def _run_mmseqs2(
host_url="https://api.colabfold.com",
user_agent: str = "",
) -> list[str] | tuple[list[str], list[str]]:
"""Return a block of a3m lines for each of the input sequences in x."""
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

headers = {}
Expand Down Expand Up @@ -342,19 +344,26 @@ def download(ID, path):
return (a3m_lines, template_paths) if use_templates else a3m_lines


def _is_padding_msa_row(sequence: str) -> bool:
"""Check if the given MSA sequence is a a padding sequence."""
seq_chars = set(sequence)
return len(seq_chars) == 1 and seq_chars.pop() == "-"


def generate_colabfold_msas(
protein_seqs: list[str],
msa_dir: Path,
msa_server_url: str,
write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging
):
"""
Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold)
server. No-op if no protein sequences are given.
N.B. the MSAs in our technical report were generated using jackhmmer, not
N.B.:
- the MSAs in our technical report were generated using jackhmmer, not
ColabFold, so we would expect some difference in results.
This implementation also relies on ColabFold's chain pairing algorithm
- this implementation relies on ColabFold's chain pairing algorithm
rather than using Chai-1's own algorithm, which could also lead to
differences in results.
Expand All @@ -369,52 +378,108 @@ def generate_colabfold_msas(
with tempfile.TemporaryDirectory() as tmp_dir_path:
tmp_dir = Path(tmp_dir_path)

mmseqs_paired_dir = tmp_dir / "mmseqs_paired"
mmseqs_paired_dir.mkdir()

mmseqs_dir = tmp_dir / "mmseqs"
mmseqs_dir.mkdir()

a3ms_dir = tmp_dir / "a3ms"
a3ms_dir = (tmp_dir if not write_a3m_to_msa_dir else msa_dir) / "a3ms"
a3ms_dir.mkdir()

# Generate MSAs for each protein chain
logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences")
msas = _run_mmseqs2(

# In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows
# and each row is already paired to have the same species. As such, we insert pairing key
# as the i-th index of the sequence so long as it isn't a padding sequence (all -)
paired_msas: list[str]
if len(protein_seqs) > 1:
paired_msas = _run_mmseqs2(
protein_seqs,
mmseqs_paired_dir,
use_pairing=True,
host_url=msa_server_url,
user_agent=f"chai-lab/{__version__} [email protected]",
)
else:
# If we only have a single protein chain, there are no paired MSAs by definition
paired_msas = [""] * len(protein_seqs)

# MSAs without pairing logic attached; may include sequences not contained in the paired MSA
# Needs a second call as the colabfold server returns either paired or unpaired, not both
per_chain_msas = _run_mmseqs2(
protein_seqs,
mmseqs_dir,
# N.B. we can set this to False to disable pairing
use_pairing=len(protein_seqs) > 1,
use_pairing=False,
host_url=msa_server_url,
user_agent=f"chai-lab/{__version__} [email protected]",
)
assert isinstance(msas, list)

# Process the MSAs into our internal format
for protein_seq, msa in zip(protein_seqs, msas, strict=True):
# Write out an A3M file
a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m"
a3m_path.write_text(msa)

# Convert the A3M file into aligned parquet files
msa_fasta = read_fasta(a3m_path)
headers, msa_seqs = zip(*msa_fasta)
for protein_seq, pair_msa, single_msa in zip(
protein_seqs, paired_msas, per_chain_msas, strict=True
):
# Write out an A3M file for both
hkey = hash_sequence(protein_seq.upper())
pair_a3m_path = a3ms_dir / f"{hkey}.pair.a3m"
pair_a3m_path.write_text(pair_msa)
single_a3m_path = a3ms_dir / f"{hkey}.single.a3m"
single_a3m_path.write_text(single_msa)

## Convert the A3M file into aligned parquet files
# Set the pairing key as the ith-index in the sequences, skip over sequences that have
# been inserted as padding as our internal pairing logic will match on pairing key.
paired_fasta: list[tuple[str, str, str]] = [
(str(pairkey), record.header, record.sequence)
for pairkey, record in enumerate(read_fasta(pair_a3m_path))
if not _is_padding_msa_row(record.sequence)
]
pairing_key, paired_headers, paired_msa_seqs = (
zip(*paired_fasta) if paired_fasta else ((), (), ())
)
unique_paired_msa_seqs = set(paired_msa_seqs)

# Non-paired MSA sequences that weren't already covered in the paired MSA; skip header
single_fasta: list[Fasta] = [
record
for i, record in enumerate(read_fasta(single_a3m_path))
if (
i > 0
and not _is_padding_msa_row(record.sequence)
and record.sequence not in unique_paired_msa_seqs
)
]
single_headers = [record.header for record in single_fasta]
single_msa_seqs = [record.sequence for record in single_fasta]
# Create null pairing keys for each of the entries in the single MSA seq
single_null_pair_keys = [""] * len(single_msa_seqs)

# This shouldn't have much of an effect on the model, but we make
# a best effort to synthesize a source database anyway
# NOTE we already dropped the query row from the single MSAs so no need to slice
source_databases = ["query"] + [
"uniref90" if h.startswith("UniRef") else "bfd_uniclust"
for h in headers[1:]
(
MSADataSource.UNIREF90.value
if h.startswith("UniRef")
else MSADataSource.BFD_UNICLUST.value
)
for h in (list(paired_headers) + single_headers)[1:]
]

# Combine information across paired and single hits
all_sequences = list(paired_msa_seqs) + single_msa_seqs
all_pairing_keys = list(pairing_key) + single_null_pair_keys
assert (
len(all_sequences) == len(all_pairing_keys) == len(source_databases)
), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}"

# Map the MSAs to our internal format
aligned_df = pd.DataFrame(
data=dict(
sequence=msa_seqs,
sequence=all_sequences,
source_database=source_databases,
# ColabFold does not return taxonomies from its API, so we
# can't rely on our internal chain pairing logic. As an
# alternative, we could disable ColabFold pairing and rely
# on a mapping from sequence ~> taxonomy, which would allow
# us to use our internal pairing logic.
pairing_key="",
pairing_key=all_pairing_keys,
comment="",
),
)
Expand Down
3 changes: 2 additions & 1 deletion chai_lab/data/dataset/msas/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_UKEY_FOR_QUERY = (-999, -999)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def merge_main_msas_by_chain(msas: list[MSAContext]) -> MSAContext:
Expand Down Expand Up @@ -120,7 +121,7 @@ def pair_and_merge_msas(msas: list[MSAContext]) -> MSAContext:
selected_msa = msa.take_rows_with_padding(all_rowids)

logger.info(
f"Loaded (paired in includes query sequence): "
f"Loaded (paired includes query sequence): "
f"{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} "
)

Expand Down

0 comments on commit 6061e41

Please sign in to comment.