-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update multi-chain permutation and permutation unittest #406
Update multi-chain permutation and permutation unittest #406
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for expanding the docstrings in multi_chain_permutation.py and for fixing the tests.
There is a lot of non-trivial code associated with determining the best permutations, and the docstrings will go a long way towards making the code more approachable.
I have some suggestions to help further improve the clarity of the code, but overall, really nice work!
@@ -88,7 +103,7 @@ def get_optimal_transform( | |||
return r, x | |||
|
|||
|
|||
def get_least_asym_entity_or_longest_length(batch, input_asym_id): | |||
def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tuple[torch.Tensor, List[torch.Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please add 1 space between the argument and the type.
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward() | ||
pred_ca_mask: a boolean tensor that masks pred_ca_pos | ||
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5 | ||
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what relationship (if any) there is between true_ca_masks and pred_ca_mask? Is this an indication of which residues between chains are expected to align.
If you think this is sufficiently defined elsewhere in the multimer codebase, then maybe a simple addition will suffice here.
@jnwei Many thanks for your suggestions and reviews :D I've updated the PR in the new commit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the docstrings Dingquan! Just a few more minor comments
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, | ||
asym_mask, pred_ca_mask): | ||
def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor, | ||
anchor_gt_residue: list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The docstring the type is a Tensor, is this a list or a Tensor?
asym_mask, | ||
pred_ca_pos): | ||
def calculate_optimal_transform(true_ca_poses: List[torch.Tensor], | ||
anchor_gt_idx: int, anchor_gt_residue: list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing here, is anchor_gt_residue a list or a tensor?
tests/test_permutation.py
Outdated
fake_input_features['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1) | ||
|
||
# NOTE | ||
# batch: simulates ground_truth features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: replace gonna with going to
tests/test_permutation.py
Outdated
batch) | ||
print(f"##### aligns is {aligns}") | ||
possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder here, a comment explaining why you expect the given possible outcome, and why the wrong_outcome is bad would be very helpful.
To help explain the examples, you could even break up the examples into different variables with acceptable / not acceptable cases. For example:
chain_a_permuted = [(0, 1), (1, 0), (2, 2), (3, 3), (4, 4)]
chain_b_permuted = [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
chains_a_and_b_permuted = [(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)]
no_permutation = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
possible_outcome = [chain_a_permuted, chain_b_permuted]
wrong_outcome = [chain_a_and_b_permuted, no_permutation]
Although in this example, I still don't understand why chain_a_and_b_permuted
would be under wrong_outcome
Fixed a small typo in permutation unit test docstring
Thanks for the additions to the docstring Dingquan! The explanation for the tests are much clearer now. |
Hi @christinaflo and @jnwei
Sorry I forgot to update the unittest for multi-chain permutations after the major updates on the functions. Here I have added necessary steps to prepare fake test data for testing these functions. Now all the 3 tests can run successfully. Hope it helps.
BTW I'm now adding typing to the functions in multi_chain_permutation.py and fixing some comments in that file as well. These can however go to another PR if you prefer?
Dingquan