Skip to content

Commit

Permalink
Merge branch 'main' into edge_weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub authored Sep 20, 2023
2 parents 990bb09 + 9060969 commit e34b842
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
1 change: 1 addition & 0 deletions playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# t_chain=t_chain)
# mt_seq == chain_mt.getSequence()


#%%
from src.data_processing.downloaders import Downloader

Expand Down
3 changes: 1 addition & 2 deletions src/data_processing/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str,
"please create subset before initialization."
self.subset = subset

print(save_root)
super(BaseDataset, self).__init__(save_root, *args, **kwargs)
self.load()

Expand Down Expand Up @@ -760,7 +759,7 @@ def pre_process(self):
pdb_wt = row['mut.wt_pdb']
pdb_mt = row['mut.mt_pdb']
t_chain = row['affin.chain']

# Getting sequence from pdb file:
missing_wt = pdb_wt == 'NO'
pdb = pdb_mt if missing_wt else pdb_wt
Expand Down
14 changes: 6 additions & 8 deletions src/utils/residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def t_chain(self, chain_ID:str):
@property
def sequence(self) -> str:
return self.getSequence()

def getSequence(self) -> str:
"""
camelCase to mimic ProDy Chain class
Expand Down Expand Up @@ -197,7 +197,7 @@ def getCoords(self) -> np.array:
coords.append(res["CB"])
self._coords = np.array(coords)
return self._coords

@staticmethod
def align_coords(c1:np.array,c2:np.array) -> tuple[np.array, np.array]:
"""Aligns the given two 3D coordinate sets"""
Expand Down Expand Up @@ -234,7 +234,6 @@ def TM_score(self, template:'Chain'):
# compute the distance for each pair of atoms
di = np.sum((c1 - c2) ** 2, 1) # sum along first axis
return np.sum(1 / (1 + (di / d0) ** 2)) / L


def get_mutated_seq(self, muts:list[str], reversed:bool=False) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -324,7 +323,6 @@ def _pdb_get_chains(pdb_file: str, model:int=1) -> OrderedDict:
alt_loc = line[16] # some can have multiple locations for each protein confirmation.
res_name = line[17:20].strip()
if res_name == 'UNK': continue # WARNING: unkown residues are skipped

if atm_type not in ['CA', 'CB']: continue
icode = line[26].strip() # dumb icode because residues will sometimes share the same res num
# (https://www.wwpdb.org/documentation/file-format-content/format33/sect9.html)
Expand All @@ -339,7 +337,7 @@ def _pdb_get_chains(pdb_file: str, model:int=1) -> OrderedDict:
# Only keep first alt_loc
if atm_type in chains[curr_chain].get(res_key, {}) and bool(alt_loc.strip()):
continue

assert atm_type not in chains[curr_chain].get(res_key, {}), \
f"Duplicate {atm_type} for residue {res_key} in {pdb_file}"

Expand All @@ -348,6 +346,7 @@ def _pdb_get_chains(pdb_file: str, model:int=1) -> OrderedDict:
chains[curr_chain].setdefault(res_key, OrderedDict())[atm_type] = np.array([x,y,z])

# Saving residue name
res_name = line[17:20].strip()
assert ("name" not in chains[curr_chain].get(res_key, {})) or \
(chains[curr_chain][res_key]["name"] == res_name), \
f"Inconsistent residue name for residue {res_key} in {pdb_file}"
Expand Down Expand Up @@ -394,7 +393,7 @@ def buildHessian(self, cutoff:int=15., g:float=1.0):
kirchhoff[i, i] = kirchhoff[i, i] + g
kirchhoff[j, j] = kirchhoff[j, j] + g
return hessian

def get_contact_map(self, display=False, title="Residue Contact Map") -> np.array:
"""
Returns the residue contact map for that structure.
Expand Down Expand Up @@ -432,5 +431,4 @@ def get_contact_map(self, display=False, title="Residue Contact Map") -> np.arra
plt.title(title)
plt.show()

return pairwise_distances

return pairwise_distances

0 comments on commit e34b842

Please sign in to comment.