Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/data augmentation #52

Draft
wants to merge 36 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7cb9cb3
Testing rdkit functions
vidvath7 Aug 15, 2024
1f0daa4
Added notebook For Testing
vidvath7 Aug 15, 2024
a9f5765
Added changes for new smiles variant generation
vidvath7 Aug 15, 2024
d84efad
changes-tqdm,randomized rootedAtAtom
vidvath7 Aug 27, 2024
cba0562
Changed directory and added yml config
vidvath7 Sep 5, 2024
165606d
Code optimization & Batch processing changes
vidvath7 Sep 6, 2024
d422948
Reverted changes to original
vidvath7 Sep 13, 2024
db4d16e
Changes in rootedAtAtom,added smiles variation config
vidvath7 Sep 26, 2024
5895d23
Changed the directory for augmented data files
vidvath7 Sep 28, 2024
f303dc7
Changed the file names for augmented data
vidvath7 Oct 3, 2024
32966e1
Added new config file for data augmentation
vidvath7 Oct 18, 2024
72bb33f
Created new class for data augmentation
vidvath7 Oct 18, 2024
5916420
Add new folder
vidvath7 Oct 24, 2024
5c48e4d
Added directory for augmented directory for splitting
vidvath7 Nov 3, 2024
99d3696
Removed changes made for testing with subset
vidvath7 Nov 3, 2024
4cf83d1
removed import
vidvath7 Nov 4, 2024
43360bc
changes for lightning error
vidvath7 Nov 12, 2024
cef2ea2
Changes in yml to ChEBI100
vidvath7 Nov 15, 2024
dff800b
Removed whitespaces
vidvath7 Nov 16, 2024
131ea90
Changed augmented path to ChEBI100
vidvath7 Nov 16, 2024
7056187
Changes for creating splits.csv,added lines for debugging
vidvath7 Nov 28, 2024
419d603
Added new file for Evaluation
vidvath7 Dec 5, 2024
c04a595
Corrected directory
vidvath7 Dec 5, 2024
a52c635
Changes for evaluation-set splits_file_path
vidvath7 Dec 5, 2024
6ba3921
Corrected splits_file_path
vidvath7 Dec 5, 2024
6e9e098
Update eval.py
vidvath7 Dec 6, 2024
b38cce7
printing batch size
vidvath7 Dec 6, 2024
bd4dced
Merge branch 'feature/data-augmentation' of https://github.com/ChEB-A…
vidvath7 Dec 6, 2024
58ef43e
changing batch size
vidvath7 Dec 6, 2024
ae43441
Changes for Augmentation after splitting
vidvath7 Dec 7, 2024
ecaeef7
Changes for evaluation
vidvath7 Dec 11, 2024
f4a34d2
Changes for evaluation- splits file changed
vidvath7 Dec 12, 2024
8850f3a
Generating SMILES based on no. of variations
vidvath7 Dec 12, 2024
83d3db8
Removed print statement
vidvath7 Dec 12, 2024
c38157e
Changes for testing- changed checkpoint
vidvath7 Dec 18, 2024
5c4f7b3
Checkpoint name correction
vidvath7 Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
super().__init__()
self.criterion = criterion
self.save_hyperparameters(
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"]
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics","_class_path"]
)
self.out_dim = out_dim
if optimizer_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions chebai/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Molecule:
max_number_of_parents = 7

def __init__(
self, smile: str, logp: Optional[float] = None, contract_rings: bool = False
self, smile: str, logp: Optional[float] = None, contract_rings: bool = False
):
"""
Initializes a Molecule object.
Expand Down Expand Up @@ -400,8 +400,8 @@ def num_of_features() -> int:
int: Total number of features.
"""
return (
Molecule.max_number_of_parents * Molecule.num_bond_features()
+ Molecule.num_atom_features()
Molecule.max_number_of_parents * Molecule.num_bond_features()
+ Molecule.num_atom_features()
)

@staticmethod
Expand Down
400 changes: 349 additions & 51 deletions chebai/preprocessing/datasets/chebi.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def evaluate_model(
Returns:
Tensors with predictions and labels.
"""
print("Start of evaluate_model")
batch_size=5
print("batch_size: ", batch_size)
model.eval()
collate = data_module.reader.COLLATOR()

Expand Down Expand Up @@ -157,6 +160,7 @@ def evaluate_model(
torch.cat(labels_list),
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
)
print("End of evaluate_model")


def load_results_from_buffer(
Expand All @@ -172,6 +176,7 @@ def load_results_from_buffer(
Returns:
Tensors with predictions and labels.
"""
print("Start of load_results_from_buffer")
preds_list = []
labels_list = []

Expand Down Expand Up @@ -208,6 +213,7 @@ def load_results_from_buffer(
else:
test_labels = None

print("End of load_results_from_buffer")
return test_preds, test_labels


Expand Down
4 changes: 4 additions & 0 deletions configs/data/chebi50.yml
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50
init_args:
aug_data: True
augment_data_batch_size: 5000
num_smiles_variations: 5
5 changes: 5 additions & 0 deletions configs/data/chebi_augmentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100
init_args:
aug_data: True
augment_data_batch_size: 5000
num_smiles_variations: 5
74 changes: 74 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pandas as pd

from chebai.result.utils import (
evaluate_model,
load_results_from_buffer,
)
from chebai.result.classification import print_metrics
from chebai.models.electra import Electra
from chebai.preprocessing.datasets.chebi import ChEBIOver50, ChEBIOver100
import os
import tqdm
import torch
import pickle

DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(DEVICE)


# Specify paths and parameters
checkpoint_name = "best_epoch=31_val_loss=0.0204_val_macro-f1=0.7655_val_micro-f1=0.9246"
print("checkpoint_name : ",checkpoint_name)
checkpoint_path = os.path.join("logs/wandb/run-20241212_003611-8yohluv6/files/checkpoints", f"{checkpoint_name}.ckpt")
print("checkpoint_path : ",checkpoint_path)
kind = "test" # Change to "train" or "validation" as needed
buffer_dir = os.path.join("results_buffer", checkpoint_name, kind)
print("buffer_dir : ",buffer_dir)
batch_size = 10 # Set batch size

# Load data module
data_module = ChEBIOver100(chebi_version=231)

data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/augmented_splits.csv"
model_class = Electra

# evaluates model, stores results in buffer_dir
model = model_class.load_from_checkpoint(checkpoint_path)
if buffer_dir is None:
preds, labels = evaluate_model(
model,
data_module,
buffer_dir=buffer_dir,
# No need to provide this parameter for Chebi dataset, "kind" parameter should be provided
# filename=data_module.processed_file_names_dict[kind],
batch_size=10,
kind=kind,
)
else:
evaluate_model(
model,
data_module,
buffer_dir=buffer_dir,
# No need to provide this parameter for Chebi dataset, "kind" parameter should be provided
# filename=data_module.processed_file_names_dict[kind],
batch_size=10,
kind=kind,
)
# load data from buffer_dir
preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)


# Load classes from the classes.txt
with open(os.path.join(data_module.processed_dir_main, "classes.txt"), "r") as f:
classes = [line.strip() for line in f.readlines()]


# output relevant metrics
print_metrics(
preds,
labels.to(torch.int),
DEVICE,
classes=classes,
markdown_output=False,
top_k=10,
)
Empty file added logs/.gitkeep
Empty file.
Loading