Skip to content

Commit

Permalink
Add code for AlphaFold-Multimer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407076987
  • Loading branch information
Augustin-Zidek committed Nov 2, 2021
1 parent 1d43aaf commit 0be2b30
Show file tree
Hide file tree
Showing 48 changed files with 8,337 additions and 1,193 deletions.
213 changes: 168 additions & 45 deletions README.md

Large diffs are not rendered by default.

37 changes: 25 additions & 12 deletions alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,23 @@ def compute_predicted_aligned_error(
def predicted_tm_score(
logits: np.ndarray,
breaks: np.ndarray,
residue_weights: Optional[np.ndarray] = None) -> np.ndarray:
"""Computes predicted TM alignment score.
residue_weights: Optional[np.ndarray] = None,
asym_id: Optional[np.ndarray] = None,
interface: bool = False) -> np.ndarray:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns:
ptm_score: the predicted TM alignment score.
ptm_score: The predicted TM alignment or the predicted iTM score.
"""

# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
Expand All @@ -132,24 +137,32 @@ def predicted_tm_score(

bin_centers = _calculate_bin_centers(breaks)

num_res = np.sum(residue_weights)
num_res = int(np.sum(residue_weights))
# Clip num_res to avoid negative/undefined d0.
clipped_num_res = max(num_res, 19)

# Compute d_0(num_res) as defined by TM-score, eqn. (5) in
# http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
# Yang & Skolnick "Scoring function for automated
# assessment of protein structure template quality" 2004
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8

# Convert logits to probs
# Convert logits to probs.
probs = scipy.special.softmax(logits, axis=-1)

# TM-Score term for every bin
# TM-Score term for every bin.
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
# E_distances tm(distance)
# E_distances tm(distance).
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)

normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum())
pair_mask = np.ones(shape=(num_res, num_res), dtype=bool)
if interface:
pair_mask *= asym_id[:, None] != asym_id[None, :]

predicted_tm_term *= pair_mask

pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
pair_residue_weights, axis=-1, keepdims=True))
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
155 changes: 102 additions & 53 deletions alphafold/common/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.

# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.


@dataclasses.dataclass(frozen=True)
class Protein:
Expand All @@ -43,11 +47,21 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]

# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]

# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]

def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
'because these cannot be written to PDB format.')


def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
Expand All @@ -57,9 +71,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
Expand All @@ -73,57 +86,63 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]

if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
'Only single chain PDBs are supported when chain_id not specified. '
f'Found {len(chains)} chains.')
else:
chain = chains[0]

atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []

for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)

# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])

return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors))


def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}')


def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Expand All @@ -143,16 +162,33 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors

if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')

# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
chain_ids[i] = PDB_CHAIN_IDS[i]

pdb_lines.append('MODEL 1')
atom_index = 1
chain_id = 'A'
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(_chain_end(
atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
residue_index[i - 1]))
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.

res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
Expand All @@ -168,25 +204,23 @@ def to_pdb(prot: Protein) -> str:
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1

# Close the chain.
chain_end = 'TER'
chain_termination_line = (
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
# Close the final chain.
pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
chain_ids[chain_index[-1]], residue_index[-1]))
pdb_lines.append('ENDMDL')

pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines)

# Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.


def ideal_atom_mask(prot: Protein) -> np.ndarray:
Expand All @@ -205,25 +239,40 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray:
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]


def from_prediction(features: FeatureDict, result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values.
Returns:
A protein instance.
"""
fold_output = result['structure_module']

def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr

if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features['asym_id'])
else:
chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))

if b_factors is None:
b_factors = np.zeros_like(fold_output['final_atom_mask'])

return Protein(
aatype=features['aatype'][0],
aatype=_maybe_remove_leading_dim(features['aatype']),
atom_positions=fold_output['final_atom_positions'],
atom_mask=fold_output['final_atom_mask'],
residue_index=features['residue_index'][0] + 1,
residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
chain_index=chain_index,
b_factors=b_factors)
Loading

1 comment on commit 0be2b30

@chrisroat
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you indicate if pdb_seqres_database_path and uniprot_database_path, needed for multimer, should be on local SSD?

(I know for performance that the bfd_database_path and uniclust30_database_path should be on fast disk)

Please sign in to comment.