From df96b5860f4f1a02cec85eaf78a18d51198ff132 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 15 Feb 2024 18:04:03 +0100 Subject: [PATCH 01/12] make sure padded asym_id won't affect permutation steps --- openfold/utils/multi_chain_permutation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index c950862a..4a044c18 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -105,13 +105,12 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates """ entity_2_asym_list = get_entity_2_asym_list(batch) - unique_entity_ids = torch.unique(batch["entity_id"]) + unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding entity_asym_count = {} entity_length = {} for entity_id in unique_entity_ids: asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id]) - # Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id] if not asym_ids_in_pred: @@ -122,10 +121,8 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): # Calculate entity length entity_mask = (batch["entity_id"] == entity_id) entity_length[int(entity_id)] = entity_mask.sum().item() - min_asym_count = min(entity_asym_count.values()) least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] - # If multiple entities have the least asym_id count, return those with the longest length if len(least_asym_entities) > 1: max_length = max([entity_length[entity] for entity in least_asym_entities]) @@ -140,7 +137,6 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities]) anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id] - return anchor_gt_asym_id, anchor_pred_asym_ids @@ -160,6 +156,7 @@ def greedy_align( used = [False for _ in range(len(true_ca_poses))] align = [] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0] + for cur_asym_id in unique_asym_ids: i = int(cur_asym_id - 1) asym_mask = batch["asym_id"] == cur_asym_id @@ -349,6 +346,7 @@ def compute_permutation_alignment(out, features, ground_truth): # First select anchors from predicted structures and ground truths anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth, features['asym_id']) + entity_2_asym_list = get_entity_2_asym_list(ground_truth) labels = split_ground_truth_labels(ground_truth) assert isinstance(labels, list) From d74b09cc142ea26e2c84a275184fc32ab279392f Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 15 Feb 2024 18:20:22 +0100 Subject: [PATCH 02/12] fixed bugs in unittests for multi-chain permutation. now working on extra subtests --- tests/test_permutation.py | 50 ++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/test_permutation.py b/tests/test_permutation.py index d0db977a..fe7a609b 100644 --- a/tests/test_permutation.py +++ b/tests/test_permutation.py @@ -55,7 +55,8 @@ def setUp(self): self.sym_id = self.asym_id self.entity_id = torch.tensor([[1] * (self.chain_a_num_res * 2) + [2] * (self.chain_b_num_res * 3)], device=device) - + + # @unittest.skip("skip for now") def test_1_selecting_anchors(self): batch = { 'asym_id': self.asym_id, @@ -75,6 +76,7 @@ def test_1_selecting_anchors(self): self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym) self.assertEqual(set(), anchor_pred_asym & expected_non_anchors) + # @unittest.skip("skip for now") def test_2_permutation_pentamer(self): batch = { 'asym_id': self.asym_id, @@ -111,26 +113,25 @@ def test_2_permutation_pentamer(self): batch['all_atom_positions'] = true_atom_position batch['all_atom_mask'] = true_atom_mask - aligns, _ = compute_permutation_alignment(out, batch, + aligns, per_asym_residue_index = compute_permutation_alignment(out, batch, batch) - print(f"##### aligns is {aligns}") possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]] wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]] self.assertIn(aligns, possible_outcome) self.assertNotIn(aligns, wrong_outcome) - @unittest.skip("Test needs to be fixed post-refactor") + # @unittest.skip("Test needs to be fixed post-refactor") def test_3_merge_labels(self): nres_pad = 325 - 57 # suppose the cropping size is 325 batch = { - 'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1), - 'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1), - 'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1), - 'aatype': torch.randint(21, size=(1, 325)), + 'asym_id': self.asym_id, + 'sym_id': self.sym_id, + 'entity_id': self.entity_id, + 'aatype': torch.randint(21, size=(1, 57)), 'seq_length': torch.tensor([57]) } - batch['asym_id'] = batch['asym_id'].reshape(1, 325) - batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1) + batch['asym_id'] = batch['asym_id'].reshape(1, 57) + batch["residue_index"] = torch.tensor([self.residue_index]) # create fake ground truth atom positions chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37), dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3) @@ -155,15 +156,30 @@ def test_3_merge_labels(self): torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37))), dim=1) - batch['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1) - batch['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1) - # tensor_to_cuda = lambda t: t.to('cuda') - # ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth) + batch['all_atom_positions'] = true_atom_position + batch['all_atom_mask'] = true_atom_mask + + # Below create a fake_input_features + fake_input_features = { + 'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1), + 'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1), + 'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1), + 'aatype': torch.randint(21, size=(1, 325)), + 'seq_length': torch.tensor([57]) + } + fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape(1, 325) + fake_input_features["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1) + fake_input_features['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1) + fake_input_features['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1) + + # NOTE + # batch: simulates ground_truth features + # fake_input_features: simulates the data that gonna be used as input for model.forward(fake_input_features) + # out: simulates the output of model.forward(fake_input_features) aligns, per_asym_residue_index = compute_permutation_alignment(out, - batch, + fake_input_features, batch) - print(f"##### aligns is {aligns}") labels = split_ground_truth_labels(batch) labels = merge_labels(per_asym_residue_index, labels, aligns, @@ -173,5 +189,5 @@ def test_3_merge_labels(self): expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) - expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1) + # expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1) self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos)) From aa18a56b3ba57cf0074081c37e9ceb9f1698fc97 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 15 Feb 2024 18:24:12 +0100 Subject: [PATCH 03/12] remove unnecessary lines --- tests/test_permutation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_permutation.py b/tests/test_permutation.py index fe7a609b..ea1d2918 100644 --- a/tests/test_permutation.py +++ b/tests/test_permutation.py @@ -189,5 +189,5 @@ def test_3_merge_labels(self): expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) - # expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1) + self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos)) From 2c565664dc56871f28fdc7fbf20e79bb0bd74679 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 15 Feb 2024 18:28:38 +0100 Subject: [PATCH 04/12] restore to the verison on main --- openfold/utils/multi_chain_permutation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index 4a044c18..c950862a 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates """ entity_2_asym_list = get_entity_2_asym_list(batch) - unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding + unique_entity_ids = torch.unique(batch["entity_id"]) entity_asym_count = {} entity_length = {} for entity_id in unique_entity_ids: asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id]) + # Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id] if not asym_ids_in_pred: @@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): # Calculate entity length entity_mask = (batch["entity_id"] == entity_id) entity_length[int(entity_id)] = entity_mask.sum().item() + min_asym_count = min(entity_asym_count.values()) least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] + # If multiple entities have the least asym_id count, return those with the longest length if len(least_asym_entities) > 1: max_length = max([entity_length[entity] for entity in least_asym_entities]) @@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities]) anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id] + return anchor_gt_asym_id, anchor_pred_asym_ids @@ -156,7 +160,6 @@ def greedy_align( used = [False for _ in range(len(true_ca_poses))] align = [] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0] - for cur_asym_id in unique_asym_ids: i = int(cur_asym_id - 1) asym_mask = batch["asym_id"] == cur_asym_id @@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth): # First select anchors from predicted structures and ground truths anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth, features['asym_id']) - entity_2_asym_list = get_entity_2_asym_list(ground_truth) labels = split_ground_truth_labels(ground_truth) assert isinstance(labels, list) From 7df201e57d6c11ec70e568ddee23f8122ed8444c Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Fri, 16 Feb 2024 17:44:49 +0100 Subject: [PATCH 05/12] added typing hints and fixed some comments --- openfold/utils/multi_chain_permutation.py | 184 ++++++++++++++++------ 1 file changed, 136 insertions(+), 48 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index c950862a..3d60380c 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -1,7 +1,7 @@ import logging import random import torch - +from typing import Tuple, List,Dict from openfold.np import residue_constants as rc logger = logging.getLogger(__name__) @@ -13,6 +13,17 @@ def compute_rmsd( atom_mask: torch.Tensor = None, eps: float = 1e-6, ) -> torch.Tensor: + """ + Function to calculate RMSD between predicted and ground truth atom position + + Args: + true_atom_pos: a [nres*3] tensor + pred_atom_pos: a [nres*3] tensor + atom_mask: a [1*nres] tensor + + Return: + RMSD value between true and predicted atom positions + """ sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) if atom_mask is not None: sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device)) @@ -21,7 +32,7 @@ def compute_rmsd( return torch.sqrt(msd + eps) # prevent sqrt 0 -def kabsch_rotation(P, Q): +def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor: """ Calculate the best rotation that minimises the RMSD between P and Q. @@ -29,11 +40,11 @@ def kabsch_rotation(P, Q): https://en.wikipedia.org/wiki/Kabsch_algorithm Args: - P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates - Q: [N * 3] the same dimension as P + P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates + Q: [N * 3] the same dimension as P return: - A 3*3 rotation matrix + one 3*3 rotation matrix """ assert P.shape == torch.Size([Q.shape[0], Q.shape[1]]) @@ -54,11 +65,15 @@ def get_optimal_transform( src_atoms: torch.Tensor, tgt_atoms: torch.Tensor, mask: torch.Tensor = None, -): +) -> Tuple[torch.Tensor, torch.Tensor]: """ - src_atoms: predicted CA positions, shape:[num_res,3] - tgt_atoms: ground-truth CA positions, shape:[num_res,3] - mask: a vector of boolean values, shape:[num_res] + A function that obtain the transformation that optimally align + src_atoms with tgt_atoms + + Args: + src_atoms: predicted CA positions, shape:[num_res,3] + tgt_atoms: ground-truth CA positions, shape:[num_res,3] + mask: a vector of boolean values, shape:[num_res] """ assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape[-1] == 3 @@ -88,7 +103,7 @@ def get_optimal_transform( return r, x -def get_least_asym_entity_or_longest_length(batch, input_asym_id): +def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tuple[torch.Tensor, List[torch.Tensor]]: """ First check how many subunit(s) one sequence has. Select the subunit that is less common, e.g. if the protein was AABBB then select one of the A as anchor @@ -97,12 +112,12 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): then choose one of the corresponding subunits as anchor Args: - batch: in this function batch is the full ground truth features - input_asym_id: A list of asym_ids that are in the cropped input features + batch: in this function batch is the full ground truth features + input_asym_id: A list of asym_ids that are in the cropped input features Return: - anchor_gt_asym_id: Tensor(int) selected ground truth asym_id - anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates + anchor_gt_asym_id: Tensor(int) selected ground truth asym_id + anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates """ entity_2_asym_list = get_entity_2_asym_list(batch) unique_entity_ids = torch.unique(batch["entity_id"]) @@ -145,17 +160,29 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): def greedy_align( - batch, - per_asym_residue_index, - entity_2_asym_list, - pred_ca_pos, - pred_ca_mask, - true_ca_poses, - true_ca_masks, -): + batch:dict, + per_asym_residue_index:dict, + entity_2_asym_list:dict, + pred_ca_pos:torch.Tensor, + pred_ca_mask:torch.Tensor, + true_ca_poses:list, + true_ca_masks:list +) -> List[Tuple[int,int]]: """ Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 + + Args: + batch: a dictionary of ground truth features + per_asym_residue_index: a dictionary recording which residues belong to which aysm_id + entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id + pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward() + pred_ca_mask: a boolean tensor that masks pred_ca_pos + true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5 + true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5 + + Return: + A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated """ used = [False for _ in range(len(true_ca_poses))] align = [] @@ -189,21 +216,38 @@ def greedy_align( return align -def pad_features(feature_tensor, nres_pad, pad_dim): - """Pad input feature tensor""" +def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torch.Tensor: + """ + Pad input feature tensor. Padding values will be 0 and put behind the true feature values + + Args: + feature_tensor: A feature tensor + nres_pad: number of residues to add + pad_dim: along which dimension of the feature_tensor to pad + + Returns: + a padded feature tensor + """ pad_shape = list(feature_tensor.shape) pad_shape[pad_dim] = nres_pad padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device) return torch.concat((feature_tensor, padding_tensor), dim=pad_dim) -def merge_labels(per_asym_residue_index, labels, align, original_nres): +def merge_labels(per_asym_residue_index:Dict[int,List[int]], + labels:dict, align:List[Tuple[int, int]], + original_nres:int) -> Dict[str,torch.Tensor]: """ Merge ground truth labels according to the permutation results - labels: list of original ground truth feats - align: list of tuples, each entry specify the corresponding label of the asym. + Args: + per_asym_residue_index: a dictionary recording which residues belong to which aysm_id + labels: list of original ground truth feats e.g. if there're 5 chains, labels will have a length of 5 + align: list of tuples, each entry specify the corresponding label of the asym. + original_nres: int, corresponding to the number of residues specified by crop_size in config.py + Returns: + A new dictionary of permuated ground truth features modified based on UniFold: https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1 """ @@ -230,13 +274,13 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres): return outs -def split_ground_truth_labels(gt_features): +def split_ground_truth_labels(gt_features:dict) -> List[Dict]: """ Splits ground truth features according to chains Returns: - a list of feature dictionaries with only necessary ground truth features - required to finish multi-chain permutation + a list of feature dictionaries with only necessary ground truth features + required to finish multi-chain permutation """ unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True) n_res = gt_features["asym_id"].shape[-1] @@ -251,7 +295,16 @@ def split_dim(shape): return labels -def get_per_asym_residue_index(features): +def get_per_asym_residue_index(features: dict) -> Dict[int,list]: + """ + A function that retrieve which residues belong to which asym_id + + Args: + features: a dictionary that contains input features after cropping + + Returns: + A dictionary that records which region of the sequence belongs to which asym_id + """ unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0] per_asym_residue_index = {} for cur_asym_id in unique_asym_ids: @@ -261,7 +314,7 @@ def get_per_asym_residue_index(features): return per_asym_residue_index -def get_entity_2_asym_list(batch): +def get_entity_2_asym_list(batch: dict) -> Dict[int,list]: """ Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. @@ -281,14 +334,16 @@ def get_entity_2_asym_list(batch): return entity_2_asym_list -def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, - asym_mask, pred_ca_mask): +def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.Tensor, + anchor_gt_residue:list, + asym_mask:torch.Tensor, pred_ca_mask:torch.Tensor) -> torch.Tensor: """ Calculate an input mask for downstream optimal transformation computation Args: - true_ca_masks (Tensor): ca mask from ground truth. - anchor_gt_idx (Tensor): The index of selected ground truth anchor. + true_ca_masks: list of masks from ground truth chains. + anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor. + anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor. pred_ca_mask (Tensor): ca mask from predicted structure. @@ -303,11 +358,26 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, return input_mask -def calculate_optimal_transform(true_ca_poses, - anchor_gt_idx, anchor_gt_residue, - true_ca_masks, pred_ca_mask, - asym_mask, - pred_ca_pos): +def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], + anchor_gt_idx:int, anchor_gt_residue:list, + true_ca_masks:List[torch.Tensor], pred_ca_mask:torch.Tensor, + asym_mask:torch.Tensor, + pred_ca_pos:torch.Tensor): + + """ + Takes selected anchor ground truth c-alpha positions and + selected predicted anchor c-alpha position then calculate the optimal rotation matrix + to align ground-truth anchor and predicted anchor + + Args: + true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5 + anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor. + anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor + true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure + pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features + asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id + pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions + """ input_mask = calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, @@ -326,13 +396,25 @@ def calculate_optimal_transform(true_ca_poses, return r, x -def compute_permutation_alignment(out, features, ground_truth): +def compute_permutation_alignment(out:Dict[str,torch.Tensor], + features:Dict[str,torch.Tensor], + ground_truth:List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int,int]], Dict[int,List[int]]]: """ - A class method that first permutate chains in ground truth first + Permutates chains in ground truth first before calculating the loss. + Args: + out: a dictionary of output tensors from model.forward() + features: a dictionary of feature tensors that are used as input for model.forward() + ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure + + Returns: + best_align: a list of tuple(int,int) that instructs how ground truth chains should be permutated + per_asym_residue_index: per_asym_residue_index: a dictionary recording which residues belong to which aysm_id Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 + + """ unique_asym_ids = set(torch.unique(features['asym_id']).tolist()) unique_asym_ids.discard(0) # Remove padding asym_id @@ -397,13 +479,19 @@ def compute_permutation_alignment(out, features, ground_truth): return best_align, per_asym_residue_index -def multi_chain_permutation_align(out, features, ground_truth): - """Compute multi-chain permutation alignment. +def multi_chain_permutation_align(out:Dict[str,torch.Tensor], + features:Dict[str,torch.Tensor], + ground_truth:List[Dict[str, torch.Tensor]])->Dict[str,torch.Tensor]: + """ + Compute multi-chain permutation alignment. Args: - out: The output of model.forward() - features: Input features - ground_truth: Ground truth features + out: a dictionary of output tensors from model.forward() + features: a dictionary of feature tensors that are used as input for model.forward() + ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure + + Returns: + features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations. """ labels = split_ground_truth_labels(ground_truth) From 170d9c55605f47088db5df05e868c6baf96bd4d7 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 20 Feb 2024 14:01:41 +0100 Subject: [PATCH 06/12] make sure no padded features are going to be selected as anchors --- openfold/utils/multi_chain_permutation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index 3d60380c..c9cf2d0d 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -120,7 +120,7 @@ def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tup anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates """ entity_2_asym_list = get_entity_2_asym_list(batch) - unique_entity_ids = torch.unique(batch["entity_id"]) + unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding entity_asym_count = {} entity_length = {} From 8dfe77e6db2b845fd215ee70053d9d8b19f77147 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 21 Mar 2024 14:17:24 +0100 Subject: [PATCH 07/12] fixed typing errors; added more comments --- openfold/utils/multi_chain_permutation.py | 94 ++++++++++++++--------- 1 file changed, 59 insertions(+), 35 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index c9cf2d0d..2887daf8 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -1,7 +1,7 @@ import logging import random import torch -from typing import Tuple, List,Dict +from typing import Tuple, List, Dict from openfold.np import residue_constants as rc logger = logging.getLogger(__name__) @@ -74,6 +74,11 @@ def get_optimal_transform( src_atoms: predicted CA positions, shape:[num_res,3] tgt_atoms: ground-truth CA positions, shape:[num_res,3] mask: a vector of boolean values, shape:[num_res] + + Returns: + a rotation matrix that record the optimal rotation + that will best align selected anchor prediction to selected anchor truth + a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions """ assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape[-1] == 3 @@ -103,7 +108,7 @@ def get_optimal_transform( return r, x -def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tuple[torch.Tensor, List[torch.Tensor]]: +def get_least_asym_entity_or_longest_length(batch: dict, input_asym_id: list) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ First check how many subunit(s) one sequence has. Select the subunit that is less common, e.g. if the protein was AABBB then select one of the A as anchor @@ -160,14 +165,14 @@ def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tup def greedy_align( - batch:dict, - per_asym_residue_index:dict, - entity_2_asym_list:dict, - pred_ca_pos:torch.Tensor, - pred_ca_mask:torch.Tensor, - true_ca_poses:list, - true_ca_masks:list -) -> List[Tuple[int,int]]: + batch: dict, + per_asym_residue_index: dict, + entity_2_asym_list: dict, + pred_ca_pos: torch.Tensor, + pred_ca_mask: torch.Tensor, + true_ca_poses: list, + true_ca_masks: list +) -> List[Tuple[int, int]]: """ Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 @@ -216,7 +221,7 @@ def greedy_align( return align -def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torch.Tensor: +def pad_features(feature_tensor: torch.Tensor, nres_pad: int, pad_dim: int) -> torch.Tensor: """ Pad input feature tensor. Padding values will be 0 and put behind the true feature values @@ -234,9 +239,9 @@ def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torc return torch.concat((feature_tensor, padding_tensor), dim=pad_dim) -def merge_labels(per_asym_residue_index:Dict[int,List[int]], - labels:dict, align:List[Tuple[int, int]], - original_nres:int) -> Dict[str,torch.Tensor]: +def merge_labels(per_asym_residue_index: Dict[int,List[int]], + labels: List[Dict], align: List[Tuple[int, int]], + original_nres: int) -> Dict[str, torch.Tensor]: """ Merge ground truth labels according to the permutation results @@ -274,13 +279,20 @@ def merge_labels(per_asym_residue_index:Dict[int,List[int]], return outs -def split_ground_truth_labels(gt_features:dict) -> List[Dict]: +def split_ground_truth_labels(gt_features: dict) -> List[Dict]: """ Splits ground truth features according to chains + Args: + gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method + In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline, + thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id) + 2) split the concatenated tensors back to individual ones that correspond to individual asym_ids + Returns: a list of feature dictionaries with only necessary ground truth features - required to finish multi-chain permutation + required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there + are 5 chains in total. """ unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True) n_res = gt_features["asym_id"].shape[-1] @@ -295,7 +307,7 @@ def split_dim(shape): return labels -def get_per_asym_residue_index(features: dict) -> Dict[int,list]: +def get_per_asym_residue_index(features: dict) -> Dict[int, list]: """ A function that retrieve which residues belong to which asym_id @@ -314,7 +326,7 @@ def get_per_asym_residue_index(features: dict) -> Dict[int,list]: return per_asym_residue_index -def get_entity_2_asym_list(batch: dict) -> Dict[int,list]: +def get_entity_2_asym_list(batch: dict) -> Dict[int, list]: """ Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. @@ -334,9 +346,9 @@ def get_entity_2_asym_list(batch: dict) -> Dict[int,list]: return entity_2_asym_list -def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.Tensor, - anchor_gt_residue:list, - asym_mask:torch.Tensor, pred_ca_mask:torch.Tensor) -> torch.Tensor: +def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor, + anchor_gt_residue: list, + asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor: """ Calculate an input mask for downstream optimal transformation computation @@ -358,11 +370,11 @@ def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.T return input_mask -def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], - anchor_gt_idx:int, anchor_gt_residue:list, - true_ca_masks:List[torch.Tensor], pred_ca_mask:torch.Tensor, - asym_mask:torch.Tensor, - pred_ca_pos:torch.Tensor): +def calculate_optimal_transform(true_ca_poses: List[torch.Tensor], + anchor_gt_idx: int, anchor_gt_residue: list, + true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor, + asym_mask: torch.Tensor, + pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Takes selected anchor ground truth c-alpha positions and @@ -377,6 +389,18 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions + + Process: + 1) select an achor chain from ground truth, denoted by anchor_gt_idx, and + an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence + 2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue + 3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt, + done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm + + Returns: + a rotation matrix that record the optimal rotation + that will best align selected anchor prediction to selected anchor truth + a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions """ input_mask = calculate_input_mask(true_ca_masks, anchor_gt_idx, @@ -396,11 +420,11 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], return r, x -def compute_permutation_alignment(out:Dict[str,torch.Tensor], - features:Dict[str,torch.Tensor], - ground_truth:List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int,int]], Dict[int,List[int]]]: +def compute_permutation_alignment(out: Dict[str,torch.Tensor], + features: Dict[str,torch.Tensor], + ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]: """ - Permutates chains in ground truth first + A method that permutes chains in ground truth before calculating the loss. Args: @@ -409,8 +433,8 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor], ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure Returns: - best_align: a list of tuple(int,int) that instructs how ground truth chains should be permutated - per_asym_residue_index: per_asym_residue_index: a dictionary recording which residues belong to which aysm_id + a list of tuple(int,int) that instructs how ground truth chains should be permutated + a dictionary recording which residues belong to which aysm_id Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 @@ -479,9 +503,9 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor], return best_align, per_asym_residue_index -def multi_chain_permutation_align(out:Dict[str,torch.Tensor], - features:Dict[str,torch.Tensor], - ground_truth:List[Dict[str, torch.Tensor]])->Dict[str,torch.Tensor]: +def multi_chain_permutation_align(out: Dict[str, torch.Tensor], + features: Dict[str, torch.Tensor], + ground_truth: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """ Compute multi-chain permutation alignment. From 2dbc8c0eeb66addb0fb17e92d6b209382c6de713 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Thu, 21 Mar 2024 16:53:39 +0100 Subject: [PATCH 08/12] added comments --- openfold/utils/multi_chain_permutation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index 2887daf8..7984618d 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -424,8 +424,11 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor], features: Dict[str,torch.Tensor], ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]: """ - A method that permutes chains in ground truth - before calculating the loss. + A method that permutes chains in ground truth before calculating the loss + because the mapping between the predicted and ground-truth will become arbitrary. + The model cannot be assumed to predict chains in the same order as the ground truth. + Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth, + by minimising the RMSD. Args: out: a dictionary of output tensors from model.forward() From 61191bff81243f5246b6c20cbbe56bc35a1a983b Mon Sep 17 00:00:00 2001 From: Dingquan Yu Date: Fri, 10 May 2024 16:00:30 +0200 Subject: [PATCH 09/12] update comments;fixed typos --- openfold/utils/multi_chain_permutation.py | 34 +++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index 7984618d..c92c8a7a 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -32,7 +32,7 @@ def compute_rmsd( return torch.sqrt(msd + eps) # prevent sqrt 0 -def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor: +def kabsch_rotation(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor: """ Calculate the best rotation that minimises the RMSD between P and Q. @@ -44,7 +44,7 @@ def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor: Q: [N * 3] the same dimension as P return: - one 3*3 rotation matrix + one 3*3 rotation matrix that best aligns the sorce and target atoms """ assert P.shape == torch.Size([Q.shape[0], Q.shape[1]]) @@ -187,9 +187,16 @@ def greedy_align( true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5 Return: - A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated + A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated + e.g. if 3 chains in the imput model have the same sequences, an example return would be: + [(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth, + and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth. + + Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing + is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices. + Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0. """ - used = [False for _ in range(len(true_ca_poses))] + used = [False for _ in range(len(true_ca_poses))] # a list the keeps recording whether a ground truth chain has been used or not align = [] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0] for cur_asym_id in unique_asym_ids: @@ -326,22 +333,22 @@ def get_per_asym_residue_index(features: dict) -> Dict[int, list]: return per_asym_residue_index -def get_entity_2_asym_list(batch: dict) -> Dict[int, list]: +def get_entity_2_asym_list(features: dict) -> Dict[int, list]: """ Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. Args: - batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors. + features (dict): A dictionary containing data features, including "entity_id" and "asym_id" tensors. Returns: entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs associated with each entity. """ entity_2_asym_list = {} - unique_entity_ids = torch.unique(batch["entity_id"]) + unique_entity_ids = torch.unique(features["entity_id"]) for cur_ent_id in unique_entity_ids: - ent_mask = batch["entity_id"] == cur_ent_id - cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) + ent_mask = features["entity_id"] == cur_ent_id + cur_asym_id = torch.unique(features["asym_id"][ent_mask]) entity_2_asym_list[int(cur_ent_id)] = cur_asym_id return entity_2_asym_list @@ -428,7 +435,10 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor], because the mapping between the predicted and ground-truth will become arbitrary. The model cannot be assumed to predict chains in the same order as the ground truth. Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth, - by minimising the RMSD. + by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation + + Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: + https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 Args: out: a dictionary of output tensors from model.forward() @@ -438,10 +448,6 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor], Returns: a list of tuple(int,int) that instructs how ground truth chains should be permutated a dictionary recording which residues belong to which aysm_id - Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: - https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 - - """ unique_asym_ids = set(torch.unique(features['asym_id']).tolist()) unique_asym_ids.discard(0) # Remove padding asym_id From 5f782370fc4c505c291eabf44dc8c7b0ca262293 Mon Sep 17 00:00:00 2001 From: Dingquan Yu Date: Fri, 10 May 2024 17:06:05 +0200 Subject: [PATCH 10/12] Update tests and comments --- tests/test_permutation.py | 137 +++++++++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 38 deletions(-) diff --git a/tests/test_permutation.py b/tests/test_permutation.py index ea1d2918..6f64567f 100644 --- a/tests/test_permutation.py +++ b/tests/test_permutation.py @@ -48,15 +48,15 @@ def setUp(self): self.chain_a_num_res = 9 self.chain_b_num_res = 13 # below create default fake ground truth structures for a hetero-pentamer A2B3 - self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3 + self.residue_index = list( + range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3 self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3 self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [ 3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device) self.sym_id = self.asym_id self.entity_id = torch.tensor([[1] * (self.chain_a_num_res * 2) + [2] * (self.chain_b_num_res * 3)], device=device) - - # @unittest.skip("skip for now") + def test_1_selecting_anchors(self): batch = { 'asym_id': self.asym_id, @@ -64,20 +64,44 @@ def test_1_selecting_anchors(self): 'entity_id': self.entity_id, 'seq_length': torch.tensor([57]) } - anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) + anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length( + batch, batch['asym_id']) anchor_gt_asym = int(anchor_gt_asym) anchor_pred_asym = {int(i) for i in anchor_pred_asym} expected_anchors = {1, 2} expected_non_anchors = {3, 4, 5} - self.assertIn(anchor_gt_asym, expected_anchors) + self.assertIn(anchor_gt_asym, expected_anchors) self.assertNotIn(anchor_gt_asym, expected_non_anchors) # Check that predicted anchors are within expected anchor set self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym) - self.assertEqual(set(), anchor_pred_asym & expected_non_anchors) + self.assertEqual(set(), anchor_pred_asym & expected_non_anchors) - # @unittest.skip("skip for now") def test_2_permutation_pentamer(self): + """ + Test the permutation results on a pentamer A2B3, in which protein A has 9 residues + and protein B has 13 residues. + + Expected outputs: + Only protein A should be selected as an anchor thus, in the output list, either [(0,1), (1,0)] or [(0,0), (1,1)] are allowed + The 3 chains from protein B should ALWAYS be aligned in a way that predicted b1 to be aligned with ground truth b1, pred b2 to ground truth b2 + as shown below: + + predicted structure: a2 - a1 - b2 - b3 - b1 + indexes in the predicted list: 0 1 2 3 4 + + ground truth structure: a1 - a2 - b1 - b2 - b3 + indexes in the ground truth list: 0 1 2 3 4 + + then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid. + + However, the 3 protein B chains should be strictly aligned in the following order: + [(2,3), (3,4), (4,1)], regardless of how protein A chains are aligned. + + Therefore, the only 2 correct permutations are : + [(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and + [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)] + """ batch = { 'asym_id': self.asym_id, 'sym_id': self.sym_id, @@ -87,7 +111,7 @@ def test_2_permutation_pentamer(self): } batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res) batch["residue_index"] = torch.tensor([self.residue_index]) - # create fake ground truth atom positions + # create fake ground truth atom positions chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37), dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3) chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10 @@ -95,16 +119,22 @@ def test_2_permutation_pentamer(self): chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37), dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3) chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10 - chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 - # Below permutate predicted chain positions - pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) + chain_b3_pos = torch.matmul(torch.matmul( + chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 + # Below permutate predicted chain positions + # here the b2 chain from the ground truth is deliberately put in b1 chain's position, and predicted b3 chain to b2's position + # and predicted b1 chain to b3's position + pred_atom_position = torch.cat( + (chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) + pred_atom_mask = torch.ones((1, self.num_res, 37)) out = { 'final_atom_positions': pred_atom_position, 'final_atom_mask': pred_atom_mask } - true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) + true_atom_position = torch.cat( + (chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)), @@ -114,13 +144,34 @@ def test_2_permutation_pentamer(self): batch['all_atom_mask'] = true_atom_mask aligns, per_asym_residue_index = compute_permutation_alignment(out, batch, - batch) - possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]] - wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]] - self.assertIn(aligns, possible_outcome) - self.assertNotIn(aligns, wrong_outcome) + batch) + + expected_asym_residue_index = { + 1: torch.tensor(list(range(self.chain_a_num_res))), + 2: torch.tensor(list(range(self.chain_a_num_res))), + 3: torch.tensor(list(range(self.chain_b_num_res))), + 4: torch.tensor(list(range(self.chain_b_num_res))), + 5: torch.tensor(list(range(self.chain_b_num_res))) + } + chain_a_permutated_chain_b_permutated = [ + (0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] + chain_a_not_permutated_chain_b_permutated = [ + (0, 0), (1, 1), (2, 3), (3, 4), (4, 2)] + chain_a_permutated_chain_b_not_permuated = [ + (0, 1), (1, 0), (2, 2), (3, 3), (4, 4)] + chain_a_not_permutated_chain_b_not_permuated = [ + (0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + # test on the permutation alignments + self.assertIn(aligns, [chain_a_permutated_chain_b_permutated, + chain_a_not_permutated_chain_b_permutated]) + self.assertNotIn(aligns, [chain_a_permutated_chain_b_not_permuated, + chain_a_not_permutated_chain_b_not_permuated]) + + # test on the per_aysm_residue_index + for k, v in expected_asym_residue_index.items(): + self.assertTrue(torch.equal(v, per_asym_residue_index[k])) - # @unittest.skip("Test needs to be fixed post-refactor") def test_3_merge_labels(self): nres_pad = 325 - 57 # suppose the cropping size is 325 batch = { @@ -132,7 +183,7 @@ def test_3_merge_labels(self): } batch['asym_id'] = batch['asym_id'].reshape(1, 57) batch["residue_index"] = torch.tensor([self.residue_index]) - # create fake ground truth atom positions + # create fake ground truth atom positions chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37), dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3) chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10 @@ -140,42 +191,50 @@ def test_3_merge_labels(self): chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37), dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3) chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10 - chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 - # Below permutate predicted chain positions - pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) + chain_b3_pos = torch.matmul(torch.matmul( + chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30 + # Below permutate predicted chain positions + pred_atom_position = torch.cat( + (chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) pred_atom_mask = torch.ones((1, self.num_res, 37)) - pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1) + pred_atom_position = pad_features( + pred_atom_position, nres_pad, pad_dim=1) pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1) out = { 'final_atom_positions': pred_atom_position, 'final_atom_mask': pred_atom_mask } - true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) + true_atom_position = torch.cat( + (chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1) true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_a_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37)), torch.ones((1, self.chain_b_num_res, 37))), dim=1) - batch['all_atom_positions'] = true_atom_position - batch['all_atom_mask'] = true_atom_mask + batch['all_atom_positions'] = true_atom_position + batch['all_atom_mask'] = true_atom_mask - # Below create a fake_input_features - fake_input_features = { + # Below create a fake_input_features + fake_input_features = { 'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1), 'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1), 'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1), 'aatype': torch.randint(21, size=(1, 325)), 'seq_length': torch.tensor([57]) } - fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape(1, 325) - fake_input_features["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1) - fake_input_features['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1) - fake_input_features['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1) - + fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape( + 1, 325) + fake_input_features["residue_index"] = pad_features( + torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1) + fake_input_features['all_atom_positions'] = pad_features( + true_atom_position, nres_pad, pad_dim=1) + fake_input_features['all_atom_mask'] = pad_features( + true_atom_mask, nres_pad=nres_pad, pad_dim=1) + # NOTE - # batch: simulates ground_truth features - # fake_input_features: simulates the data that gonna be used as input for model.forward(fake_input_features) + # batch: simulates ground_truth features + # fake_input_features: simulates the data that are going be used as input for model.forward(fake_input_features) # out: simulates the output of model.forward(fake_input_features) aligns, per_asym_residue_index = compute_permutation_alignment(out, fake_input_features, @@ -185,9 +244,11 @@ def test_3_merge_labels(self): labels = merge_labels(per_asym_residue_index, labels, aligns, original_nres=batch['aatype'].shape[-1]) - self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index'])) + self.assertTrue(torch.equal( + labels['residue_index'], batch['residue_index'])) expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1) - - self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos)) + + self.assertTrue(torch.equal( + labels['all_atom_positions'], expected_permutated_gt_pos)) From 15113dcb251b781ca4888432161668ff0c7b4bf6 Mon Sep 17 00:00:00 2001 From: Dingquan Yu Date: Fri, 10 May 2024 17:18:49 +0200 Subject: [PATCH 11/12] fixed typing error of anchor_gt_residue --- openfold/utils/multi_chain_permutation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openfold/utils/multi_chain_permutation.py b/openfold/utils/multi_chain_permutation.py index c92c8a7a..5f0b62ce 100644 --- a/openfold/utils/multi_chain_permutation.py +++ b/openfold/utils/multi_chain_permutation.py @@ -314,7 +314,7 @@ def split_dim(shape): return labels -def get_per_asym_residue_index(features: dict) -> Dict[int, list]: +def get_per_asym_residue_index(features: dict) -> Dict[int, torch.Tensor]: """ A function that retrieve which residues belong to which asym_id @@ -354,7 +354,7 @@ def get_entity_2_asym_list(features: dict) -> Dict[int, list]: def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor, - anchor_gt_residue: list, + anchor_gt_residue: torch.Tensor, asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor: """ Calculate an input mask for downstream optimal transformation computation @@ -362,7 +362,7 @@ def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch Args: true_ca_masks: list of masks from ground truth chains. anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor. - anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor + anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor. pred_ca_mask (Tensor): ca mask from predicted structure. @@ -378,7 +378,7 @@ def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch def calculate_optimal_transform(true_ca_poses: List[torch.Tensor], - anchor_gt_idx: int, anchor_gt_residue: list, + anchor_gt_idx: int, anchor_gt_residue: torch.Tensor, true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor, asym_mask: torch.Tensor, pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -391,7 +391,7 @@ def calculate_optimal_transform(true_ca_poses: List[torch.Tensor], Args: true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5 anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor. - anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor + anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id From 55c293ca0bcd882d57379f25214a26aca169f2fc Mon Sep 17 00:00:00 2001 From: Jennifer Wei <97625454+jnwei@users.noreply.github.com> Date: Sat, 11 May 2024 15:37:12 +0700 Subject: [PATCH 12/12] Update test_permutation.py Fixed a small typo in permutation unit test docstring --- tests/test_permutation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_permutation.py b/tests/test_permutation.py index 6f64567f..bf3fbf37 100644 --- a/tests/test_permutation.py +++ b/tests/test_permutation.py @@ -96,7 +96,7 @@ def test_2_permutation_pentamer(self): then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid. However, the 3 protein B chains should be strictly aligned in the following order: - [(2,3), (3,4), (4,1)], regardless of how protein A chains are aligned. + [(2,3), (3,4), (4,2)], regardless of how protein A chains are aligned. Therefore, the only 2 correct permutations are : [(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and