Skip to content

Commit

Permalink
remove langchain dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
geemi725 committed Dec 7, 2023
1 parent 89d7626 commit 2e7a69a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
25 changes: 15 additions & 10 deletions exmol/exmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
from rdkit.Chem.Draw import MolToImage as mol2img, DrawMorganBit # type: ignore
from rdkit.Chem import rdchem # type: ignore
from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity, TanimotoSimilarity # type: ignore
import langchain.llms as llms
import langchain.prompts as prompts

import openai
from . import stoned
from .plot_utils import _mol_images, _image_scatter, _bit2atoms
from .data import *
Expand Down Expand Up @@ -392,6 +391,7 @@ def _check_alphabet_consistency(
alphabet_symbols = _alphabet_to_elements(set(alphabet_symbols))
# find all elements in smiles (Upper alpha or upper alpha followed by lower alpha)
smiles_symbols = set(re.findall(r"[A-Z][a-z]?", smiles))

if check and not smiles_symbols.issubset(alphabet_symbols):
# show which symbols are not in alphabet
raise ValueError(
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def merge_text_explains(
def text_explain_generate(
text_explanations: List[Tuple[str, float]],
property_name: str,
llm: Optional[llms.BaseLLM] = None,
llm_model: str = 'gpt-4',
single: bool = True,
) -> str:
"""Insert text explanations into template, and generate explanation.
Expand All @@ -1430,14 +1430,19 @@ def text_explain_generate(
for x in text_explanations
]
)
prompt_template = prompts.PromptTemplate(
input_variables=["property", "text"],
template=_single_prompt if single else _multi_prompt,
)

prompt_template = _single_prompt if single else _multi_prompt
prompt = prompt_template.format(property=property_name, text=text)
if llm is None:
llm = llms.OpenAI(temperature=0.05)
return llm(prompt)

messages = [{"role": "system", "content": "Your goal is to explain which molecular features are important to its properties based on the given text."},
{"role": "user", "content": prompt}]
response = openai.ChatCompletion.create(
model=llm_model,
messages=messages,
temperature=0.05,
)

return response.choices[0].message["content"]


def text_explain(
Expand Down
11 changes: 7 additions & 4 deletions paper2_LIME/RF-lime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
"import numpy as np\n",
"import mordred, mordred.descriptors\n",
"from mordred import HydrogenBond, Polarizability\n",
"from mordred import SLogP, AcidBase, BertzCT, Aromatic, BondCount, AtomCount\n",
"from mordred import SLogP, AcidBase, Aromatic, BondCount, AtomCount\n",
"from mordred import Calculator\n",
"\n",
"import exmol as exmol\n",
"from rdkit.Chem.Draw import rdDepictor\n",
"import os\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import roc_auc_score, plot_roc_curve\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"rdDepictor.SetPreferCoordGen(True)\n",
Expand All @@ -50,6 +49,9 @@
"soldata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv\"\n",
")\n",
"#drop smile with containing 'P'\n",
"soldata = soldata[soldata[\"SMILES\"].str.contains(\"P\") == False]\n",
"\n",
"features_start_at = list(soldata.columns).index(\"MolWt\")"
]
},
Expand Down Expand Up @@ -97,7 +99,8 @@
"outputs": [],
"source": [
"raw_features = np.array(raw_features)\n",
"labels = soldata[\"Solubility\"]"
"labels = soldata[\"Solubility\"]\n",
"print(len(labels)==len(molecules))"
]
},
{
Expand Down Expand Up @@ -197,7 +200,7 @@
"metadata": {},
"outputs": [],
"source": [
"smi = soldata.SMILES[1500]\n",
"smi = soldata.SMILES[150]\n",
"stoned_kwargs = {\n",
" \"num_samples\": 2000,\n",
" \"alphabet\": exmol.get_basic_alphabet(),\n",
Expand Down
9 changes: 2 additions & 7 deletions paper2_LIME/Solubility-RNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,15 @@
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle, FancyBboxPatch\n",
"from matplotlib.offsetbox import AnnotationBbox\n",
"import seaborn as sns\n",
"import skunk\n",
"import matplotlib as mpl\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import selfies as sf\n",
"import exmol\n",
"from dataclasses import dataclass\n",
"from rdkit.Chem.Draw import rdDepictor, MolsToGridImage\n",
"from rdkit.Chem import MolFromSmiles, MACCSkeys\n",
"from rdkit.Chem import MolFromSmiles\n",
"import random\n",
"\n",
"\n",
"rdDepictor.SetPreferCoordGen(True)\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.font_manager as font_manager\n",
Expand Down Expand Up @@ -66,6 +60,7 @@
"soldata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv\"\n",
")\n",
"\n",
"features_start_at = list(soldata.columns).index(\"MolWt\")\n",
"np.random.seed(0)\n",
"random.seed(0)"
Expand Down
1 change: 1 addition & 0 deletions paper2_LIME/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ jupyter
seaborn
pandas
tensorflow>=2.4

0 comments on commit 2e7a69a

Please sign in to comment.