Skip to content

Commit

Permalink
Ensuring notebook is up to date and correct (#9)
Browse files Browse the repository at this point in the history
* DEV: rollback to old rdkit

* FIX: update link to the pca file, add cell numbering

* FIX: update pca_fname

* DEV: rollback torch

* DEV: rollback torch (proper version)

* DEV: downgrade numpy, upgrade torch in reqs

* DEV: np version not updated in setup.py

* DEV: update Docking cells

* DOCS: fix arxiv badge

* FIX: load descriptors when loading checkpoints

* FIX: do not overwrite ds descriptors

* DEV: add testing mode to dataset loading

* DEV: specify return for test mode

* DEV: allow forcing chars/max_len

* DEV: add translation, fix tokenization for AL

* DEV: move scope of regex definition

* FIX: finish updating the notebook, add more detailed instructions in readme
  • Loading branch information
anmorgunov authored Jul 23, 2024
1 parent ecf19ec commit 9076472
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 90 deletions.
138 changes: 82 additions & 56 deletions ChemSpaceAL.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@
"source": [
"# @title Specify (base) path for storing results (Cell 2)\n",
"# @markdown make sure your path ends with a \"/\"\n",
"base_path = \"/content/drive/MyDrive/ChemSpaceAL-runs/\" # @param {type:\"string\"}\n",
"\n",
"from google.colab import drive\n",
"\n",
"base_path = \"/content/drive/MyDrive/ChemSpaceAL-runs/\" # @param {type:\"string\"}\n",
"drive.mount('/content/drive', force_remount=True)"
]
},
Expand Down Expand Up @@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
Expand All @@ -168,53 +168,37 @@
"id": "8vVlLnpXxzAM",
"outputId": "e4c615a1-52b0-4873-b9e5-5cd6f457583f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 85.3M 100 85.3M 0 0 33.1M 0 0:00:02 0:00:02 --:--:-- 33.1M\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 4589k 100 4589k 0 0 10.1M 0 --:--:-- --:--:-- --:--:-- 10.1M\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 1357 100 1357 0 0 6374 0 --:--:-- --:--:-- --:--:-- 6370\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 24.9M 100 24.9M 0 0 29.5M 0 --:--:-- --:--:-- --:--:-- 29.5M\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 208k 100 208k 0 0 777k 0 --:--:-- --:--:-- --:--:-- 779k\n"
]
}
],
"outputs": [],
"source": [
"#@title Download (if you want) dataset/weights (Cell 4)\n",
"#@markdown note these files will be placed into appropriate folders created above\n",
"downloadDataset = True # @param {type:\"boolean\"}\n",
"downloadModelWeights = True # @param {type:\"boolean\"}\n",
"downloadPCAweights = True # @param {type:\"boolean\"}\n",
"script = '''#!/bin/bash\n",
"'''\n",
"# @title Download (if you want) dataset/weights (Cell 4)\n",
"# @markdown note these files will be placed into appropriate folders created above\n",
"downloadDataset = True # @param {type:\"boolean\"}\n",
"downloadModelWeights = True # @param {type:\"boolean\"}\n",
"downloadPCAweights = True # @param {type:\"boolean\"}\n",
"downloadTargets = True # @param {type:\"boolean\"}\n",
"script = \"\"\"#!/bin/bash\n",
"\"\"\"\n",
"remote_source = \"https://files.ischemist.com/ChemSpaceAL/publication_runs/\"\n",
"if downloadDataset:\n",
" f1 = \"1_Pretraining/datasets/combined_train.csv.gz\"\n",
" f2 = \"1_Pretraining/datasets/combined_valid.csv.gz\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" script += f\"curl -o {base_path}{f2} {remote_source}{f2}\\n\"\n",
" f1 = \"1_Pretraining/datasets/combined_train.csv.gz\"\n",
" f2 = \"1_Pretraining/datasets/combined_valid.csv.gz\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" script += f\"curl -o {base_path}{f2} {remote_source}{f2}\\n\"\n",
"if downloadModelWeights:\n",
" f1 = \"1_Pretraining/datasets_descriptors/combined_train.yaml\"\n",
" f2 = \"1_Pretraining/model_weights/model7_al0_ch1.pt\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" script += f\"curl -o {base_path}{f2} {remote_source}{f2}\\n\"\n",
" f1 = \"1_Pretraining/datasets_descriptors/combined_train.yaml\"\n",
" f2 = \"1_Pretraining/model_weights/model7_al0_ch1.pt\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" script += f\"curl -o {base_path}{f2} {remote_source}{f2}\\n\"\n",
"if downloadPCAweights:\n",
" f1 = \"3_Sampling/pca_weights/scaler_pca_combined_n120_v2.pkl\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" f1 = \"3_Sampling/pca_weights/scaler_pca_combined_n120_v2.pkl\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
"if downloadTargets:\n",
" f1 = \"4_Scoring/binding_targets/HNH_processed.pdb\"\n",
" f2 = \"4_Scoring/binding_targets/1iep_processed.pdb\"\n",
" script += f\"curl -o {base_path}{f1} {remote_source}{f1}\\n\"\n",
" script += f\"curl -o {base_path}{f2} {remote_source}{f2}\\n\"\n",
"with open(\"fetch.bash\", \"w\") as f:\n",
" f.write(script)\n",
" f.write(script)\n",
"!bash fetch.bash"
]
},
Expand Down Expand Up @@ -277,7 +261,7 @@
" base_path=base_path,\n",
" cycle_prefix=\"model0\",\n",
" cycle_suffix=\"ch1\",\n",
" al_iteration=1, # use 0 for pretraining\n",
" al_iteration=0, # use 0 for pretraining\n",
" training_fname=\"combined_train.csv.gz\",\n",
" validation_fname=\"combined_valid.csv.gz\",\n",
" slice_data=None,\n",
Expand Down Expand Up @@ -527,6 +511,7 @@
"source": [
"# Cell 13\n",
"mols = Sampling.project_into_pca_space(config)\n",
"# n_iter below specifies how many times K-Means is run with different starting points\n",
"Sampling.cluster_and_sample(mols=mols, config=config, n_iter=1)"
]
},
Expand All @@ -551,15 +536,18 @@
"%%capture\n",
"#@title Install Docking Software (DiffDock) (Cell 14)\n",
"#@markdown diffdock is pretty heavy and has a lot of dependencies, so we only install it when we need it (and we don't during pretraining, for example)\n",
"\n",
"import torch\n",
"\n",
"print(torch.__version__)\n",
"!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster --y\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
"!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
"!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
"!pip install git+https://github.com/pyg-team/pytorch_geometric.git --quiet\n",
"\n",
"try:\n",
" import torch_geometric\n",
"except ModuleNotFoundError:\n",
" !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster --y\n",
" !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet\n",
" !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet\n",
" !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet\n",
" !pip install git+https://github.com/pyg-team/pytorch_geometric.git --quiet \n",
"\n",
"try:\n",
" import biopandas\n",
Expand All @@ -569,7 +557,6 @@
" !pip install scipy==1.7.3 --quiet\n",
" !pip install networkx==2.6.3 --quiet\n",
" !pip install biopython==1.79 --quiet\n",
" !pip install rdkit-pypi==2022.03.5 --quiet\n",
" !pip install e3nn==0.5.0 --quiet\n",
" !pip install spyrmsd==0.5.2 --quiet\n",
" !pip install pandas==1.5.3 --quiet\n",
Expand All @@ -596,7 +583,45 @@
"import os\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from ChemSpaceAL.Docking import get_top_poses"
"\n",
"def get_top_poses(ligands_csv: str, protein_pdb_path: str, save_pose_path: str):\n",
" data = pd.read_csv(ligands_csv)\n",
" os.makedirs(save_pose_path, exist_ok=True)\n",
"\n",
" os.environ['HOME'] = 'esm/model_weights'\n",
" os.environ['PYTHONPATH'] = f'{os.environ.get(\"PYTHONPATH\", \"\")}:/content/DiffDock/esm'\n",
" pbar = tqdm(range(len(data)), total=len(data))\n",
" for i in pbar: # change 1 to len(data) for processing all ligands\n",
" # print(str((i / len(data)) * 100)[:5], ' %')\n",
" smiles = data['smiles'][i]\n",
" rdkit_mol = Chem.MolFromSmiles(smiles)\n",
"\n",
" if rdkit_mol is not None:\n",
" with open('/content/input_protein_ligand.csv', 'w') as out:\n",
" out.write('protein_path,ligand\\n')\n",
" out.write(f'{protein_pdb_path},{smiles}\\n')\n",
"\n",
" # Clear out old results if running multiple times\n",
" shutil.rmtree('/content/DiffDock/results', ignore_errors=True)\n",
"\n",
" # ESM Embedding Preparation\n",
" os.chdir('/content/DiffDock')\n",
" !python /content/DiffDock/datasets/esm_embedding_preparation.py --protein_ligand_csv /content/input_protein_ligand.csv --out_file /content/DiffDock/data/prepared_for_esm.fasta\n",
"\n",
" # ESM Extraction\n",
" !python /content/DiffDock/esm/scripts/extract.py esm2_t33_650M_UR50D /content/DiffDock/data/prepared_for_esm.fasta /content/DiffDock/data/esm2_output --repr_layers 33 --include per_tok --truncation_seq_length 30000\n",
"\n",
" # Inference\n",
" !python /content/DiffDock/inference.py --protein_ligand_csv /content/input_protein_ligand.csv --out_dir /content/DiffDock/results/user_predictions_small --inference_steps 20 --samples_per_complex 10 --batch_size 6\n",
"\n",
" # Move results\n",
" for root, dirs, files in os.walk('/content/DiffDock/results/user_predictions_small'):\n",
" for file in files:\n",
" if file.startswith('rank1_confidence'):\n",
" shutil.move(\n",
" os.path.join(root, file),\n",
" os.path.join(save_pose_path, f\"complex{i}.sdf\"),\n",
" )"
]
},
{
Expand Down Expand Up @@ -757,7 +782,7 @@
"source": [
"# Cell 19\n",
"config.set_active_learning_parameters(\n",
" selection_mode=\"threshold\", probability_mode=\"linear\", threshold=11, training_size=10\n",
" selection_mode=\"threshold\", probability_mode=\"linear\", threshold=11, training_size=10_000\n",
")"
]
},
Expand All @@ -770,7 +795,8 @@
"outputs": [],
"source": [
"# Cell 20\n",
"ALConstruction.construct_al_training_set(config=config, do_sampling=True)"
"ALConstruction.construct_al_training_set(config=config, do_sampling=True)\n",
"ALConstruction.translate_dataset_for_al(config)"
]
},
{
Expand Down Expand Up @@ -825,7 +851,7 @@
],
"source": [
"# Cell 22\n",
"config.set_training_parameters(mode=\"Active Learning\", epochs=1)"
"config.set_training_parameters(mode=\"Active Learning\", epochs=10)"
]
},
{
Expand Down
40 changes: 40 additions & 0 deletions ChemSpaceAL/ALConstruction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
from rdkit import Chem, RDLogger

from ChemSpaceAL.Configuration import Config
from typing import Union, Dict, List, Callable, Optional, cast
RDLogger.DisableLog("rdApp.*")

Number = Union[float, int]

Expand Down Expand Up @@ -226,3 +229,40 @@ def construct_al_training_set(config: Config, do_sampling: bool = True) -> pd.Da
combined = pd.DataFrame(keyToData)
combined.to_csv(config.cycle_temp_params["path_to_al_training_set"])
return combined

def get_mol(smile_string:str):
mol = Chem.MolFromSmiles(smile_string)
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return mol

def fill_translation_table(config):
smile_df = pd.read_csv(config.cycle_temp_params["completions_fname"])
rdkit_to_predicted = {}
pbar = tqdm(smile_df['smiles'], total=len(smile_df))
for completion in pbar:
if completion[0] == '!' and completion[1] == '~':
completion = '!' + completion[2:]
if '~' not in completion: continue
mol_string = completion[1:completion.index('~')]
mol = get_mol(mol_string)
if mol is None: continue
canonic_smile = Chem.MolToSmiles(mol)
rdkit_to_predicted[canonic_smile] = mol_string
return rdkit_to_predicted

def translate_dataset_for_al(config):
rdkit_to_predicted = fill_translation_table(config)
al_path = config.cycle_temp_params["path_to_al_training_set"]
sampled = pd.read_csv(al_path)
translated = []
for mol in sampled['smiles'].values:
mol = rdkit_to_predicted[mol]
translated.append(mol)
transl_path = al_path.split('.csv')[0]+"_translated.csv"
pd.DataFrame({"smiles": translated}).to_csv(transl_path)
return transl_path
4 changes: 4 additions & 0 deletions ChemSpaceAL/Configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,10 @@ def set_active_learning_parameters(
self.al_train_path
+ f"{self.cycle_prefix}_al{self.al_iteration}_{self.cycle_suffix}.csv"
)
self.cycle_temp_params["path_to_al_translated_set"] = (
self.al_train_path
+ f"{self.cycle_prefix}_al{self.al_iteration}_{self.cycle_suffix}_translated.csv"
)
if self.verbose:
message = (
"--- The following AL training set construction parameters were set:"
Expand Down
45 changes: 18 additions & 27 deletions ChemSpaceAL/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
def load_data(
config: Configuration.Config,
mode: str,
forced_block_size: Optional[int] = None,
forced_vocab: Optional[List[str]] = None,
):
"""
Load data based on the provided configuration dictionary.
Expand Down Expand Up @@ -154,45 +152,39 @@ def load_data(
# Handle data loading for 'Active Learning' mode
elif mode == "Active Learning":
assert (
al_path := config.cycle_temp_params["path_to_al_training_set"]
al_path := config.cycle_temp_params["path_to_al_translated_set"]
) is not None, (
f"The name of the AL training set (al_train_fname) was not initialized"
)
cur_iter = f"al{config.al_iteration}"
prev_iter = f"al{config.al_iteration - 1}"
al_path = al_path.replace(cur_iter, prev_iter)
print(f"Will load AL training set from", config.rel_path(al_path))
print("Will load AL training set from", config.rel_path(al_path))
al_data = pd.read_csv(al_path)
smiles_iterators = [al_data[config.smiles_key].values]
# desc_path = config.al_desc_path + al_fname.split(".")[0] + ".yaml"
desc_path = (
config.pretrain_desc_path + config.training_fname.split(".")[0] + ".yaml"
)
else:
raise KeyError(
f"Only 'pretraining' and 'active learning' modes are currently supported"
)

regex = re.compile(config.regex_pattern)
char_set = {"!", "~", "<"} # start, end, padding tokens respectively

max_len = 0
for smiles in smiles_iterators:
for smile in smiles:
chars = regex.findall(smile.strip())
max_len = max(max_len, len(chars))
char_set.update(chars)

chars = sorted(list(char_set))
max_len += 1
if mode == "Pretraining":
char_set = {"!", "~", "<"} # start, end, padding tokens respectively

if forced_block_size:
assert mode == "Active Learning", "Cannot force a block size in pretraining"
max_len = forced_block_size
max_len = 0
for smiles in smiles_iterators:
for smile in smiles:
chars = regex.findall(smile.strip())
max_len = max(max_len, len(chars))
char_set.update(chars)

if forced_vocab:
assert mode == "Active Learning", "Cannot force a vocabulary in pretraining"
chars = sorted(list(forced_vocab))
chars = sorted(list(char_set))
max_len += 1
elif mode == "Active Learning":
with open(config.model_config.generation_params["desc_path"], 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
max_len = data["block_size"]
chars = sorted(list(data["stoi"].keys()))

datasets = []
for smiles in smiles_iterators:
Expand All @@ -210,9 +202,8 @@ def load_data(
)
datasets.append(dataset)

datasets[0].export_descriptors(desc_path)

if mode == "Active Learning":
return datasets[0]
else:
datasets[0].export_descriptors(desc_path)
return datasets
4 changes: 2 additions & 2 deletions ChemSpaceAL/Generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import pandas as pd
import numpy as np
import rdkit
from rdkit import Chem
from rdkit import Chem, RDLogger
from rdkit.Chem import Fragments

from ChemSpaceAL.Model import GPT
from ChemSpaceAL.Configuration import Config, AdmetDict
from ChemSpaceAL.Dataset import SMILESDataset

from typing import Set, List, Callable, Union, Optional, Any, Dict

RDLogger.DisableLog("rdApp.*")

@torch.no_grad()
def sample(
Expand Down
Loading

0 comments on commit 9076472

Please sign in to comment.