Skip to content

Commit

Permalink
Cleanup cif code; properly write ligands as HETATM (#189)
Browse files Browse the repository at this point in the history
* Simplifiy PDBAtom code

* Better ligand handling

* Fix how ligand atoms are determined

* Remove deprecated code, add types
  • Loading branch information
wukevin authored Nov 28, 2024
1 parent 55c732e commit 2527c38
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 38 deletions.
12 changes: 7 additions & 5 deletions chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import gemmi
import modelcif
import torch
from ihm import ChemComp, DNAChemComp, LPeptideChemComp, RNAChemComp
from ihm import ChemComp, DNAChemComp, LPeptideChemComp, NonPolymerChemComp, RNAChemComp
from modelcif import Assembly, AsymUnit, Entity, dumper, model
from torch import Tensor

from chai_lab.data.io.pdb_utils import (
PDBAtom,
PDBContext,
entity_to_pdb_atoms,
get_pdb_chain_name,
Expand Down Expand Up @@ -91,7 +92,7 @@ def _to_chem_component(res_name_3: str, entity_type: int):
match entity_type:
case EntityType.LIGAND.value:
code = res_name_3
return ChemComp(res_name_3, code, code_canonical=code)
return NonPolymerChemComp(res_name_3)
case EntityType.PROTEIN.value:
code = restype_3to1.get(res_name_3, res_name_3)
one_letter_code = gemmi.find_tabulated_residue(res_name_3).one_letter_code
Expand Down Expand Up @@ -142,7 +143,7 @@ def _make_chain(record: dict) -> AsymUnit:

chains_map = {r["asym_id"]: _make_chain(r) for r in records}

pdb_atoms: list[list] = entity_to_pdb_atoms(context)
pdb_atoms: list[list[PDBAtom]] = entity_to_pdb_atoms(context)

_assembly = Assembly(chains_map.values(), name="Assembly 1")

Expand All @@ -158,7 +159,7 @@ def get_atoms(self):
x=a.pos[0],
y=a.pos[1],
z=a.pos[2],
het=False,
het=a.het,
biso=a.b_factor,
occupancy=1.00,
)
Expand Down Expand Up @@ -188,7 +189,8 @@ def get_atoms(self):
model_group = model.ModelGroup([_model], name="pred")
system.model_groups.append(model_group)

dumper.write(open(outpath, "w"), systems=[system])
with open(outpath, "w") as sink:
dumper.write(sink, systems=[system])


def outputs_to_cif(
Expand Down
45 changes: 12 additions & 33 deletions chai_lab/data/io/pdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import string
from collections import defaultdict
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, replace
from functools import cached_property
from pathlib import Path

Expand Down Expand Up @@ -36,7 +36,7 @@ def get_pdb_chain_name(asym_id: int) -> str:

@dataclass(frozen=True)
class PDBAtom:
record_type: str
het: bool
atom_index: int
atom_name: str
alt_loc: str
Expand All @@ -55,8 +55,9 @@ def __str__(
self,
):
# currently this works only for single-char chain tags
record_type = "HETATM" if self.het else "ATOM"
atom_line = (
f"{self.record_type:<6}{self.atom_index:>5} {self.atom_name:<4}{self.alt_loc:>1}"
f"{record_type:<6}{self.atom_index:>5} {self.atom_name:<4}{self.alt_loc:>1}"
f"{self.res_name_3:>3} {self.chain_tag:>1}"
f"{self.residue_index:>4}{self.insertion_code:>1} "
f"{self.pos[0]:>8.3f}{self.pos[1]:>8.3f}{self.pos[2]:>8.3f}"
Expand All @@ -65,24 +66,6 @@ def __str__(
)
return atom_line

def rename(self, atom_name: str) -> "PDBAtom":
return PDBAtom(
self.record_type,
self.atom_index,
atom_name,
self.alt_loc,
self.res_name_3,
self.chain_tag,
self.asym_id,
self.residue_index,
self.insertion_code,
self.pos,
self.occupancy,
self.b_factor,
self.element,
self.charge,
)


def write_pdb(chain_atoms: list[list[PDBAtom]], out_path: str):
with open(out_path, "w") as f:
Expand Down Expand Up @@ -118,21 +101,14 @@ class PDBContext:
def token_res_names_to_string(self) -> list[str]:
return [tensorcode_to_string(x) for x in self.token_residue_names.cpu()]

@property
def is_ligand(self) -> bool:
return self.is_entity(EntityType.LIGAND)

def is_entity(self, ety: EntityType) -> bool:
return self.token_entity_type[0].item() == ety.value

def get_chain_entity_type(self, asym_id: int) -> int:
mask = self.token_asym_id == asym_id
assert mask.sum() > 0
e_type = self.token_entity_type[mask][0].item()
assert isinstance(e_type, int)
return e_type

def get_pdb_atoms(self):
def get_pdb_atoms(self) -> list[PDBAtom]:
# warning: calling this on cuda tensors is extremely slow
atom_asym_id = self.token_asym_id[self.atom_token_index]
# atom level attributes
Expand All @@ -148,15 +124,18 @@ def get_pdb_atoms(self):
_atomic_num_to_element(int(x.item())) for x in self.atom_ref_element
]

pdb_atoms = []
pdb_atoms: list[PDBAtom] = []
num_atoms = self.atom_coords.shape[0]
for atom_index in range(num_atoms):
if not self.atom_exists_mask[atom_index].item():
# skip missing atoms
continue

token_index = self.atom_token_index[atom_index]
atom = PDBAtom(
record_type="ATOM" if not self.is_ligand else "HETATM",
het=(
self.token_entity_type[token_index].item()
== EntityType.LIGAND.value
),
atom_index=atom_index,
atom_name=atom_names[atom_index],
alt_loc="",
Expand Down Expand Up @@ -206,7 +185,7 @@ def rename_ligand_atoms(atoms: list[PDBAtom]) -> list[PDBAtom]:
idx = atom_type_counter.get(atom.element, 1)
atom_type_counter[atom.element] = idx + 1
base_name = atom.atom_name
renumbered_atoms.append(atom.rename(f"{base_name}_{idx}"))
renumbered_atoms.append(replace(atom, atom_name=f"{base_name}_{idx}"))
return renumbered_atoms


Expand Down

0 comments on commit 2527c38

Please sign in to comment.