Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with mol-BBBP Dataset Splitting #487

Open
mxqmxqmxq opened this issue Aug 18, 2024 · 0 comments
Open

Issue with mol-BBBP Dataset Splitting #487

mxqmxqmxq opened this issue Aug 18, 2024 · 0 comments

Comments

@mxqmxqmxq
Copy link

mxqmxqmxq commented Aug 18, 2024

I noticed that you have provided a pre-split BBBP dataset in a ZIP file, using a scaffold split with an 8:1:1 ratio. However, when I tried re-splitting the dataset myself and running my model, I observed that the results from my custom split were significantly better than those from your provided split result. Could you please confirm if the dataset was indeed split using the standard scaffold method? Additionally, is the source code for this splitting process available? I would appreciate your prompt response.

import pandas as pd
import shutil, os
import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.io.read_graph_pyg import read_graph_pyg
from ogb.utils import smiles2graph
import torch
from torch_geometric.data import InMemoryDataset, Data
from rdkit import Chem
from ogb.utils import smiles2graph
from tqdm import tqdm
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
from sklearn.model_selection import train_test_split
import random

class CustomMoleculeDataset(InMemoryDataset):
    def __init__(self, name, seed,root='dataset', transform=None, pre_transform=None, meta_dict=None):
        self.name = name
        self.dir_name = '_'.join(name.split('-'))
        self.original_root = root
        self.root = osp.join(root, self.dir_name)
        self.meta_info = None
        self.seed=seed
        if 'ZINC' in self.name:
            print(f"The dataset name {name} contains 'ZINK'. Assigning specific meta_dict.")
            meta_dict = {
                'data type': 'mol',
                'num tasks': 1,
                'download_name': '',
                'eval metric': 'rocauc',
                'version': 1,
                'add_inverse_edge': 'True',
                'has_edge_attr': 'True',
                'binary': 'False',
                'url': '',
                'additional node files': 'None',
                'additional edge files': 'None',
                'split': 'scaffold',
                'task type': 'binary classification',
                'has_node_attr': 'True',
                'num classes': 2
            }
        
        if meta_dict is None:
            print(osp.join(os.path.dirname(__file__)))
            master = pd.read_csv('./dataset/master.csv', index_col=0, keep_default_na=False)
            if self.name not in master:
                error_mssg = f'Invalid dataset name {self.name}.\n'
                error_mssg += 'Available datasets are as follows:\n'
                error_mssg += '\n'.join(master.keys())
                raise ValueError(error_mssg)
            self.meta_info = master[self.name].to_dict()
        else:
            self.meta_info = meta_dict
        
        self.num_tasks = int(self.meta_info['num tasks'])
        self.eval_metric = self.meta_info['eval metric']
        self.task_type = self.meta_info['task type']
        self.__num_classes__ = int(self.meta_info['num classes'])
        self.binary = self.meta_info['binary'] == 'True'
        self.split_type = self.meta_info['split']
        
        super(CustomMoleculeDataset, self).__init__(self.root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return [f'{self.name}.csv']

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        pass  # Implement download logic if needed

    def process(self):
        data_list = []

        df = pd.read_csv(osp.join(self.raw_dir, f'{self.name}.csv'))
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing molecules"):
            smiles = row['smiles']
            if 'ZINC' in self.name:
                y = None  # 无监督学习任务没有标签
            else:
                y = row['label']
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue

            graph = smiles2graph(smiles)
            x = torch.tensor(graph['node_feat'], dtype=torch.long)
            edge_index = torch.tensor(graph['edge_index'], dtype=torch.long)
            edge_attr = torch.tensor(graph['edge_feat'], dtype=torch.long)
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
            data.smiles = smiles
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def get_idx_split(self, split_type=None):
        if split_type is None:
            split_type = self.meta_info['split']
        
        path = osp.join(self.root, 'split', split_type)

        # short-cut if split_dict.pt exists
        if os.path.isfile(os.path.join(path, 'split_dict.pt')):
            return torch.load(os.path.join(path, 'split_dict.pt'))

        if all(os.path.isfile(osp.join(path, f'{split}.csv.gz')) for split in ['train', 'valid', 'test']):
            train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0]
            valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0]
            test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0]
        else:
            # 自动生成分割
            print("No pre-split files found. Generating new splits...")
            train_idx, valid_idx, test_idx = self._generate_splits()

            # 保存分割
            os.makedirs(path, exist_ok=True)
            pd.DataFrame(train_idx).to_csv(osp.join(path, 'train.csv.gz'), index=False, header=False, compression='gzip')
            pd.DataFrame(valid_idx).to_csv(osp.join(path, 'valid.csv.gz'), index=False, header=False, compression='gzip')
            pd.DataFrame(test_idx).to_csv(osp.join(path, 'test.csv.gz'), index=False, header=False, compression='gzip')

        return {'train': torch.tensor(train_idx, dtype=torch.long), 
                'valid': torch.tensor(valid_idx, dtype=torch.long), 
                'test': torch.tensor(test_idx, dtype=torch.long)}

    def _generate_splits(self, scaffold_split=True, valid_size=0.1, test_size=0.1, balanced=True):
        if scaffold_split:
            print("Using scaffold split...")
            print(self.seed)
            all_smiles = [data.smiles for data in self]
            scaffolds = defaultdict(list)

            for i, smiles in enumerate(all_smiles):
                mol = Chem.MolFromSmiles(smiles)
                scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=False)
                scaffolds[scaffold].append(i)

            scaffold_sets = list(scaffolds.values())

            if balanced:
                big_scaffolds, small_scaffolds = [], []
                for scaffold_set in scaffold_sets:
                    if len(scaffold_set) > valid_size * len(all_smiles) / 2 or len(scaffold_set) > test_size * len(all_smiles) / 2:
                        big_scaffolds.append(scaffold_set)
                    else:
                        small_scaffolds.append(scaffold_set)
                random.seed(self.seed)
                random.shuffle(big_scaffolds)
                random.shuffle(small_scaffolds)
                scaffold_sets = big_scaffolds + small_scaffolds
            else:
                random.shuffle(scaffold_sets)
            train_idx, valid_idx, test_idx = [], [], []
            for scaffold_set in scaffold_sets:
                if len(train_idx) < (1 - valid_size - test_size) * len(all_smiles):
                    train_idx += scaffold_set
                elif len(valid_idx) < valid_size * len(all_smiles):
                    valid_idx += scaffold_set
                else:
                    test_idx += scaffold_set

            return train_idx, valid_idx, test_idx
        else:
            print("Using random split...")
            indices = list(range(len(self)))
            train_idx, temp_idx = train_test_split(indices, test_size=valid_size + test_size, random_state=42)
            valid_idx, test_idx = train_test_split(temp_idx, test_size=test_size / (valid_size + test_size), random_state=42)

            return train_idx, valid_idx, test_idx

    @property
    def num_classes(self):
        return self.__num_classes__

before auc :0.72 ------ current auc : 0.84

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant