Skip to content

Commit

Permalink
Merge pull request #85 from xiaoruiDong/minor_fixes
Browse files Browse the repository at this point in the history
A few updates and fixes
  • Loading branch information
xiaoruiDong committed Mar 15, 2024
2 parents c0c7242 + 9dd1869 commit 603849f
Show file tree
Hide file tree
Showing 9 changed files with 713 additions and 390 deletions.
561 changes: 561 additions & 0 deletions ipython/Conformer Generation Workflow.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions ipython/Generate Atommapped SMILES.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@
"source": [
"new_rxn = Reaction(r_complex, p_complex)\n",
"display(new_rxn)\n",
"print(rxn.to_smiles())"
"print(new_rxn.to_smiles())"
]
},
{
Expand Down Expand Up @@ -870,7 +870,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.9.18"
},
"vscode": {
"interpreter": {
Expand Down
327 changes: 0 additions & 327 deletions ipython/stochastic_conf_pipeline_LP.ipynb

This file was deleted.

99 changes: 60 additions & 39 deletions rdmc/conformer_generation/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@
Modules for providing initial guess geometries
"""

from pathlib import Path

from rdmc import RDKitMol
import os.path as osp
import yaml

import numpy as np
from time import time

try:
import torch
from torch_geometric.data import Batch
except ImportError:
pass

from .utils import *

# GeoMol relevant imports
try:
import torch
from rdmc.external.GeoMol.model.model import GeoMol
from rdmc.external.GeoMol.model.featurization import featurize_mol_from_smiles
from rdmc.external.GeoMol.model.inference import construct_conformers
except ImportError:
from geomol.model import GeoMol
from geomol.featurization import featurize_mol_from_smiles, from_data_list
from geomol.inference import construct_conformers
from geomol.utils import model_path as geomol_model_path
import yaml # only used to load GeoMol parameters
except ImportError as e:
GeoMol = None
print(e)
print("No GeoMol installation detected. Skipping import...")
print("Please install the GeoMol fork at https://github.com/xiaoruiDong/GeoMol")


class ConfGenEmbedder:
Expand Down Expand Up @@ -58,8 +58,10 @@ def update_mol(self, smiles: str):
# Copy the graph but remove conformers
self.mol = self.mol.Copy(quickCopy=True)

def embed_conformers(self,
n_conformers: int):
def embed_conformers(
self,
n_conformers: int
):
"""
Embed conformers according to the molecule graph.
Expand All @@ -71,10 +73,11 @@ def embed_conformers(self,
"""
raise NotImplementedError

def update_stats(self,
n_trials: int,
time: float = 0.
) -> dict:
def update_stats(
self,
n_trials: int,
time: float = 0.
) -> dict:
"""
Update the statistics of the conformer generation.
Expand All @@ -88,10 +91,12 @@ def update_stats(self,
n_success = self.mol.GetNumConformers()
self.n_success = n_success
self.percent_success = n_success / n_trials * 100
stats = {"iter": self.iter,
"time": time,
"n_success": self.n_success,
"percent_success": self.percent_success}
stats = {
"iter": self.iter,
"time": time,
"n_success": self.n_success,
"percent_success": self.percent_success
}
self.stats.append(stats)
return stats

Expand All @@ -104,9 +109,11 @@ def write_mol_data(self):
"""
return mol_to_dict(self.mol, copy=False, iter=self.iter)

def __call__(self,
smiles: str,
n_conformers: int):
def __call__(
self,
smiles: str,
n_conformers: int
):
"""
Embed conformers according to the molecule graph.
Expand Down Expand Up @@ -137,37 +144,51 @@ class GeoMolEmbedder(ConfGenEmbedder):
Embed conformers using GeoMol.
Args:
trained_model_dir (str): Directory of the trained model.
trained_model_dir (str, optional): Directory of the trained model. If not provided, the models distributed with the package will be used.
dataset (str, optional): Dataset used for training. Defaults to ``"drugs"``.
temp_schedule (str, optional): Temperature schedule. Defaults to ``"linear"``.
track_stats (bool, optional): Whether to track the statistics of the conformer generation. Defaults to ``False``.
"""

def __init__(self,
trained_model_dir: str,
dataset: str = "drugs",
temp_schedule: str = "linear",
track_stats: bool = False):
def __init__(
self,
trained_model_dir: str = None,
dataset: str = "drugs",
temp_schedule: str = "linear",
track_stats: bool = False,
device: str = 'cpu',
):
if GeoMol is None:
raise ImportError("No GeoMol installation detected. Please install the GeoMol fork at https://github.com/xiaoruiDong/GeoMol.")
super(GeoMolEmbedder, self).__init__(track_stats)

# TODO: add option of pre-pruning geometries using alpha values
# TODO: inverstigate option of changing "temperature" each iteration to sample diverse geometries
# TODO: investigate option of changing "temperature" each iteration to sample diverse geometries
self.device = device

with open(osp.join(trained_model_dir, "model_parameters.yml")) as f:
trained_model_dir = geomol_model_path / dataset if trained_model_dir is None else Path(trained_model_dir)
with open(trained_model_dir / "model_parameters.yml") as f:
model_parameters = yaml.full_load(f)
model = GeoMol(**model_parameters)

state_dict = torch.load(osp.join(trained_model_dir, "best_model.pt"), map_location=torch.device('cpu'))
state_dict = torch.load(trained_model_dir / "best_model.pt", map_location=torch.device(device))
model.load_state_dict(state_dict, strict=True)
model.to(self.device)
model.eval()
self.model = model
self.tg_data = None
self.std = model_parameters["hyperparams"]["random_vec_std"]
self.temp_schedule = temp_schedule
self.dataset = dataset

def embed_conformers(self,
n_conformers: int):
def to(self, device: str):
self.device = device
self.model.to(device)

def embed_conformers(
self,
n_conformers: int
):
"""
Embed conformers according to the molecule graph.
Expand All @@ -186,11 +207,11 @@ def embed_conformers(self,
# featurize data and run GeoMol
if self.tg_data is None:
self.tg_data = featurize_mol_from_smiles(self.smiles, dataset=self.dataset)
data = Batch.from_data_list([self.tg_data]) # need to run this bc of dumb internal GeoMol processing
data = from_data_list([self.tg_data]).to(self.device) # need to run this bc of dumb internal GeoMol processing
self.model(data, inference=True, n_model_confs=n_conformers)

# process predictions
model_coords = construct_conformers(data, self.model).double().cpu().detach().numpy()
model_coords = construct_conformers(data, self.model, self.device).double().cpu().detach().numpy()
split_model_coords = np.split(model_coords, n_conformers, axis=1)

# package in mol and return
Expand Down
2 changes: 1 addition & 1 deletion rdmc/conformer_generation/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def optimize_conformers(self,
positions = opt_mol.GetPositions()
conf = new_mol.GetConformer(id=c_id)
conf.SetPositions(positions)
energy = float(opt_mol.GetProp('total energy / Eh')) # * HARTREE_TO_KCAL_MOL # kcal/mol (TODO: check)
energy = props['total energy']
mol_data[c_id].update({"positions": positions, # issues if not all opts succeeded?
"conf": conf,
"energy": energy})
Expand Down
23 changes: 22 additions & 1 deletion rdmc/external/logparser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
try:
from ipywidgets import interact, IntSlider, Dropdown, FloatLogSlider
except ImportError:
pass
interact = None


class BaseLog(object):
Expand Down Expand Up @@ -270,6 +270,12 @@ def get_scf_energies(self,
# sub2 stores the energies for subsequent jobs e.g., multiple sps
if 'opt' in self.job_type or 'scan' in self.job_type:
sub1 = scf_energies[:num_opt_geoms][self.get_converged_geom_idx()]
elif 'irc' in self.job_type:
# If taking corrector steps and job failed due to corrector fails
# There is one more energy value compared to the number of geometries
sub1 = scf_energies[: len(self.cclib_results.optstatus)][
self.get_converged_geom_idx()
]
else:
sub1 = scf_energies[self.get_converged_geom_idx()]
if 'scan' not in self.job_type:
Expand Down Expand Up @@ -700,6 +706,9 @@ def interact_opt(self,
Returns:
interact
"""
if interact is None:
raise ImportError('interact is not installed. Please install it by `pip install ipywidgets`.')

mol = self.get_mol(converged=False, sanitize=sanitize, backend=backend)
xyzs = self.get_xyzs(converged=False)
sdfs = [mol.ToMolBlock(confId=i) for i in range(mol.GetNumConformers())]
Expand Down Expand Up @@ -838,6 +847,9 @@ def view_freq(self,
Returns:
interact
"""
if interact is None:
raise ImportError('interact is not installed. Please install it by `pip install ipywidgets`.')

xyz = self.get_xyzs(converged=True)[0]
lines = xyz.splitlines()
vib_xyz_list = lines[0:2]
Expand All @@ -851,6 +863,9 @@ def interact_freq(self):
"""
Create a IPython interactive widget to investigate the frequency calculation.
"""
if interact is None:
raise ImportError('interact is not installed. Please install it by `pip install ipywidgets`.')

dropdown = Dropdown(
options=self.freqs,
value=self.freqs[0],
Expand Down Expand Up @@ -1018,6 +1033,9 @@ def interact_irc(self,
Returns:
interact
"""
if interact is None:
raise ImportError('interact is not installed. Please install it by `pip install ipywidgets`.')

mol = self._process_irc_mol(sanitize=sanitize, converged=converged, backend=backend, bothway=bothway)
sdfs = [mol.ToMolBlock(confId=i) for i in range(mol.GetNumConformers())]
xyzs = self.get_xyzs(converged=converged)
Expand Down Expand Up @@ -1188,6 +1206,9 @@ def interact_scan(self,
Returns:
interact
"""
if interact is None:
raise ImportError('interact is not installed. Please install it by `pip install ipywidgets`.')

mol = self._process_scan_mol(align_scan=align_scan,
align_frag_idx=align_frag_idx,
sanitize=sanitize,
Expand Down
2 changes: 1 addition & 1 deletion rdmc/external/logparser/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _update_schemes(self):
scheme_str = ''.join(line.strip('\n')[1:] for line in scheme_lines[1:])

try:
self._schemes = scheme_to_dict(scheme_str)
self._schemes = scheme_to_dict(scheme_str.lower())
except Exception as e:
print(f'Calculation scheme parser encounters a problem. \nGot: {e}\n'
f'Feel free to raise an issue about this error at RDMC\'s Github Repo.')
Expand Down
Loading

0 comments on commit 603849f

Please sign in to comment.