This repository contains an integrated module for Graph Neural Networks (GNNs), along with example scripts. All models are implemented with PyG (https://pyg.org/). The module is designed to fit data representations and formattings specific to your task. However, some lines in the code will require modifications to ensure a smooth run. Please read the following instructions to understand the purpose of each file and make the necessary changes for successful training.
This file includes data utility scripts to handle your dataset. You should store your data in h5 files. where it contains node features, adjacency matrix (optional), sparse edge index, graph level label (optional). A snippet of the script to save to such format is provided below.
with h5py.File(os.path.join(output_folder, comb, '{}.h5'.format(comb)), 'w') as f:
f.create_dataset('X', data=feature_matrix, compression='gzip', compression_opts=9)
f.create_dataset('A', data=A, compression='gzip', compression_opts=9)
f.create_dataset('eI', data=edge_index, compression='gzip', compression_opts=9)
f.create_dataset('y', data=np.array(label,dtype=int).reshape((1,1)), compression='gzip', compression_opts=9)
The module utilizes the Dataset object from the PyG library, which is a popular choice for working with graph data in PyTorch. It is essential to modify this file to preprocess and load your specific graph data in the desired format. Follow the comments in the script to understand its functionalities and adapt it to your dataset.
Here, you will find implementations of various GNN models, including:
- Graph Convolutional Network (GCN)
- Graph Isomorphism Network (GIN)
- Graph Attention Network (GAT)
- Message Passing Neural Network (MPNN)
- Graph Equivariant Network (GEN)
NOTE THE MODELS HERE ARE DESIGNED FOR GRAPH LEVEL OUTPUT. If you wish to run node/edge level task, you need to change the pooling and fc layers accordingly. The file also includes JK-layer (Jumping Knowledge layer) and other global pooling methods, which can be useful components for certain GNN architectures. Review this file to understand the structure and architecture of each model and select the one that best suits your task.
This is the main training code responsible for training the GNN model using your dataset and chosen model architecture. During training, this script will output the following metrics:
- Training loss
- Validation loss
- Validation accuracy
- Validation confusion matrix
To train the model, you need to provide specific command-line arguments as follows:
python train.py --input_folder .... --labels .... --model model_name --output_folder ./results
Replace the placeholders:
- Replace
....
with the actual paths or values for theinput_folder
(where your preprocessed graph data is stored) andlabels
(dataframe contain names of the instances file and labels). model_name
with the name of the GNN model you want to use, such as "GCN," "GIN," "GAT," "MPNN," or "GEN."./results
with the output folder path where the training results will be saved. Feel free to change this path to a different directory if needed.
Before running the training script, ensure you have installed the required dependencies, including PyTorch, PyG, and any other libraries mentioned in the code.
- Review the provided files and comments carefully to understand their functionalities and purpose.
- Modify
data_utils.py
to preprocess and load your specific graph data in the required format. - Select an appropriate GNN model from
models.py
that best fits your task and dataset. - Adjust the hyperparameters and settings in
params.json
to suit your dataset and training preferences. - Run the training script (
train.py
) with the necessary command-line arguments to start the training process.
- Please ensure to follow good coding practices and document any significant changes you make to the code.
- If you encounter any issues or have further questions, don't hesitate to refer to the original documentation or create an issue in the repository for support.
Happy training! If you need further assistance or have any questions, feel free to reach out.