Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
karinazad committed Jan 28, 2025
1 parent 3921d30 commit dfb35cd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
16 changes: 9 additions & 7 deletions src/lobster/model/_lobster_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,37 +187,39 @@ def configure_optimizers(self):
return {"optimizer": optimizer, "lr_scheduler": scheduler}

def predict_fv(self, fv_heavy, fv_light):
linker = 'G' * 25
linker = "G" * 25
homodimer_sequence = fv_heavy + linker + fv_light

tokenized_homodimer = self.tokenizer([homodimer_sequence], return_tensors="pt", add_special_tokens=False)

with torch.no_grad():
position_ids = torch.arange(len(homodimer_sequence), dtype=torch.long)
position_ids[len(fv_heavy) + len(linker):] += 512
tokenized_homodimer['position_ids'] = position_ids.unsqueeze(0)
position_ids[len(fv_heavy) + len(linker) :] += 512
tokenized_homodimer["position_ids"] = position_ids.unsqueeze(0)

tokenized_homodimer = {key: tensor.cuda() for key, tensor in tokenized_homodimer.items()}
with torch.no_grad():
output = self.model(**tokenized_homodimer)

linker_mask = torch.tensor([1] * len(fv_heavy) + [0] * len(linker) + [1] * len(fv_light))[None, :, None]
output['atom37_atom_exists'] = output['atom37_atom_exists'] * linker_mask.to(output['atom37_atom_exists'].device)
output["atom37_atom_exists"] = output["atom37_atom_exists"] * linker_mask.to(
output["atom37_atom_exists"].device
)

pdb_file = self.model.output_to_pdb(output)[0]

# Split the PDB content into lines and modify chain identifiers and residue numbers
pdb_lines = pdb_file.splitlines()
modified_pdb_lines = []
chain_id = 'H' # Start with chain H for fv_heavy
chain_id = "H" # Start with chain H for fv_heavy
current_residue_num_offset = 0
last_residue_num = 0

for line in pdb_lines:
if line.startswith("ATOM") or line.startswith("HETATM"):
res_seq_num = int(line[22:26].strip())
if res_seq_num > len(fv_heavy):
chain_id = 'L' # Switch to chain L for fv_light
chain_id = "L" # Switch to chain L for fv_light
if current_residue_num_offset == 0:
# Calculate offset for light chain to start at 1
current_residue_num_offset = res_seq_num - 1
Expand All @@ -226,7 +228,7 @@ def predict_fv(self, fv_heavy, fv_light):
new_res_seq_num = res_seq_num - current_residue_num_offset
new_res_seq_num_str = f"{new_res_seq_num:>4}"
last_residue_num = new_res_seq_num # Keep track of the last residue number

# Modify the original line with the correct chain and residue number
modified_line = line[:21] + chain_id + new_res_seq_num_str + line[26:]
modified_pdb_lines.append(modified_line)
Expand Down
17 changes: 11 additions & 6 deletions tests/lobster/model/test__lobsterfold.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from io import StringIO

import pytest
import torch
from Bio.PDB import PDBParser, Superimposer
from lobster.data import PDBDataModule
from lobster.extern.openfold_utils import backbone_loss
from lobster.model import LobsterPLMFold
from lobster.transforms import StructureFeaturizer
from torch import Size, Tensor
from Bio.PDB import PDBParser, Superimposer
from io import StringIO


torch.backends.cuda.matmul.allow_tf32 = True

Expand All @@ -22,12 +21,14 @@ def max_length():
@pytest.fixture
def example_fv():
fv_heavy = "VKLLEQSGAEVKKPGASVKVSCKASGYSFTSYGLHWVRQAPGQRLEWMGWISAGTGNTKYSQKFRGRVTFTRDTSATTAYMGLSSLRPEDTAVYYCARDPYGGGKSEFDYWGQGTLVTVSS"
fv_light = "ELVMTQSPSSLSASVGDRVNIACRASQGISSALAWYQQKPGKAPRLLIYDASNLESGVPSRFSGSGSGTDFTLTISSLQPEDFAIYYCQQFNSYPLTFGGGTKVEIKRTV"
fv_light = (
"ELVMTQSPSSLSASVGDRVNIACRASQGISSALAWYQQKPGKAPRLLIYDASNLESGVPSRFSGSGSGTDFTLTISSLQPEDFAIYYCQQFNSYPLTFGGGTKVEIKRTV"
)
return (fv_heavy, fv_light)


@pytest.fixture
def example_fv(scope="session"):
def example_fv_pdb(scope="session"):
return os.path.join(os.path.dirname(__file__), "../../../test_data/fv.pdb")


Expand Down Expand Up @@ -76,7 +77,11 @@ def test_dataloader_tokenizer(self, model):

@pytest.mark.skip(reason="fwd pass too slow")
def test_predict_fv(self, model, example_fv):
pdb_string = model.predict_fv(example_fv[0],example_fv[1])
pdb_string = model.predict_fv(example_fv[0], example_fv[1])

# NOTE from zadorozk: ruff checks were failing because ground_truth_file was not defined
# TODO FIXME
ground_truth_file = None

# Parse the input PDB string
parser = PDBParser(QUIET=True)
Expand Down

0 comments on commit dfb35cd

Please sign in to comment.