From 8ed24f7442e6ef374996ce2f0123b4080af3258f Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Tue, 19 Mar 2024 11:08:35 -0400 Subject: [PATCH] feat: add support for PyTorch Lightning --- graphstorm-lightning/.gitignore | 89 ++++ graphstorm-lightning/README.md | 3 + graphstorm-lightning/examples/README.md | 7 + graphstorm-lightning/examples/node_gnn.ipynb | 379 ++++++++++++++++++ .../graphstorm_lightning/__init__.py | 3 + .../datamodule/__init__.py | 1 + .../datamodule/node_gnn.py | 139 +++++++ .../graphstorm_lightning/module/__init__.py | 1 + .../graphstorm_lightning/module/node_gnn.py | 132 ++++++ .../graphstorm_lightning/utils.py | 127 ++++++ graphstorm-lightning/pyproject.toml | 59 +++ 11 files changed, 940 insertions(+) create mode 100644 graphstorm-lightning/.gitignore create mode 100644 graphstorm-lightning/README.md create mode 100644 graphstorm-lightning/examples/README.md create mode 100644 graphstorm-lightning/examples/node_gnn.ipynb create mode 100644 graphstorm-lightning/graphstorm_lightning/__init__.py create mode 100644 graphstorm-lightning/graphstorm_lightning/datamodule/__init__.py create mode 100644 graphstorm-lightning/graphstorm_lightning/datamodule/node_gnn.py create mode 100644 graphstorm-lightning/graphstorm_lightning/module/__init__.py create mode 100644 graphstorm-lightning/graphstorm_lightning/module/node_gnn.py create mode 100644 graphstorm-lightning/graphstorm_lightning/utils.py create mode 100644 graphstorm-lightning/pyproject.toml diff --git a/graphstorm-lightning/.gitignore b/graphstorm-lightning/.gitignore new file mode 100644 index 0000000000..43bd28a707 --- /dev/null +++ b/graphstorm-lightning/.gitignore @@ -0,0 +1,89 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Sphinx documentation +docs/_build/ + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +.python-version + + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +poetry.lock + + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# VSCode +**/.vscode/ + +# pre-commit +.pre-commit-config.yaml diff --git a/graphstorm-lightning/README.md b/graphstorm-lightning/README.md new file mode 100644 index 0000000000..4042dc2a7e --- /dev/null +++ b/graphstorm-lightning/README.md @@ -0,0 +1,3 @@ +# GraphStorm Support for PyTorch Lightning + +Provides essential wrappers for GraphStorm constructs to run training using [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/). diff --git a/graphstorm-lightning/examples/README.md b/graphstorm-lightning/examples/README.md new file mode 100644 index 0000000000..7f1d1e1094 --- /dev/null +++ b/graphstorm-lightning/examples/README.md @@ -0,0 +1,7 @@ +# Setup + +1. Run `poetry install --with dev` +1. Run `poetry run python -m ipykernel install --user --name=graphstorm_lightning` +1. Run `poetry run jupyter notebook` +1. Open an example notebook +1. In `Kernel -> Change Kernel` choose `graphstorm_lightning` diff --git a/graphstorm-lightning/examples/node_gnn.ipynb b/graphstorm-lightning/examples/node_gnn.ipynb new file mode 100644 index 0000000000..7d596aaac9 --- /dev/null +++ b/graphstorm-lightning/examples/node_gnn.ipynb @@ -0,0 +1,379 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f7fa7f2-8b5c-4690-8f5b-800dee3f8e86", + "metadata": {}, + "source": [ + "# Graphstorm PyTorch Lightning Demonstration - Node Classification\n", + "\n", + "In this notebook, we'll demonstrate how to use Graphstorm with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable) for Node Classification.\n", + "\n", + "---\n", + "\n", + "## Setup \n", + "\n", + "Please follow the README.md in graphstorm-lightning/examples." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "af273d51-441d-4bd9-933e-146d7a6c9eff", + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "import graphstorm as gs\n", + "import graphstorm_lightning as gsl\n", + "import pytorch_lightning as pl\n", + "import requests\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a631c359-b1f9-4fed-88fb-1117b41c421f", + "metadata": {}, + "outputs": [], + "source": [ + "num_nodes = 1" + ] + }, + { + "cell_type": "markdown", + "id": "6f8402f3-8979-4e11-986b-1020d3eba061", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Data Preparation\n", + "\n", + "In this notebook we'll create ACM graph dataset following [this guide](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fce80697-7bab-4241-b2e0-aed483958de9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Namespace(download_path='/tmp/ACM.mat', dataset_name='acm', output_type='raw', output_path='/tmp/acm_raw')\n", + "Graph(num_nodes={'author': 17431, 'paper': 12499, 'subject': 73},\n", + " num_edges={('author', 'writing', 'paper'): 37055, ('paper', 'cited', 'paper'): 30789, ('paper', 'citing', 'paper'): 30789, ('paper', 'is-about', 'subject'): 12499, ('paper', 'written-by', 'author'): 37055, ('subject', 'has', 'paper'): 12499},\n", + " metagraph=[('author', 'paper', 'writing'), ('paper', 'paper', 'cited'), ('paper', 'paper', 'citing'), ('paper', 'subject', 'is-about'), ('paper', 'author', 'written-by'), ('subject', 'paper', 'has')])\n", + "\n", + " Number of classes: 14\n", + "\n", + " Paper node labels: torch.Size([12499])\n", + "\n", + " ('paper', 'citing', 'paper') edge labels:30789\n", + "Saving ACM data to /tmp/acm.dgl ......\n", + "/tmp/acm.dgl saved.\n", + "Saving ACM node text to /tmp/acm_text.pkl ......\n", + "/tmp/acm_text.pkl saved.\n", + "author nodes have: Index(['node_id', 'feat'], dtype='object') columns ......\n", + "paper nodes have: Index(['node_id', 'label', 'feat'], dtype='object') columns ......\n", + "subject nodes have: Index(['node_id', 'feat'], dtype='object') columns ......\n", + "Saved author node data to /tmp/acm_raw/nodes/author.parquet.\n", + "Saved paper node data to /tmp/acm_raw/nodes/paper.parquet.\n", + "Saved subject node data to /tmp/acm_raw/nodes/subject.parquet.\n", + "Saved ('author', 'writing', 'paper') edge data to /tmp/acm_raw/edges/author_writing_paper.parquet\n", + "Saved ('paper', 'cited', 'paper') edge data to /tmp/acm_raw/edges/paper_cited_paper.parquet\n", + "Saved ('paper', 'citing', 'paper') edge data to /tmp/acm_raw/edges/paper_citing_paper.parquet\n", + "Saved ('paper', 'is-about', 'subject') edge data to /tmp/acm_raw/edges/paper_is-about_subject.parquet\n", + "Saved ('paper', 'written-by', 'author') edge data to /tmp/acm_raw/edges/paper_written-by_author.parquet\n", + "Saved ('subject', 'has', 'paper') edge data to /tmp/acm_raw/edges/subject_has_paper.parquet\n" + ] + } + ], + "source": [ + "acm_raw = Path(\"/tmp/acm_raw\")\n", + "if not acm_raw.exists():\n", + " acm_raw.mkdir(parents=True)\n", + " \n", + " # get dataset creation script\n", + " url = \"https://raw.githubusercontent.com/awslabs/graphstorm/main/examples/acm_data.py\"\n", + " acm_data = acm_raw / \"acm_data.py\"\n", + " response = requests.get(url)\n", + " assert response.status_code == 200\n", + " with open(acm_data, \"wb\") as f:\n", + " f.write(response.content)\n", + "\n", + " !python {acm_data} --output-path {acm_raw}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "97de6ba9-576a-45a5-bf02-b5c9dfbb29cb", + "metadata": {}, + "outputs": [], + "source": [ + "acm_gs = \"/tmp/acm_gs\"\n", + "if not Path(acm_gs).exists():\n", + " !python -m graphstorm.gconstruct.construct_graph \\\n", + " --conf-file {acm_raw}/config.json \\\n", + " --output-dir {acm_gs} \\\n", + " --num-parts {num_nodes} \\\n", + " --graph-name acm" + ] + }, + { + "cell_type": "markdown", + "id": "b4fa1233-844e-47c3-aee1-a4cb7402aef4", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Model Training" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c876825c-0810-42ad-a522-fa28cd71b8db", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n" + ] + } + ], + "source": [ + "# Works in Jupyter, Colab and Kaggle!\n", + "trainer = pl.Trainer(accelerator=\"auto\", devices=\"auto\", max_epochs=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "892318da-050b-423c-b2e7-8b9bdd415998", + "metadata": {}, + "outputs": [], + "source": [ + "config = yaml.safe_load(f\"\"\"\n", + " gsf:\n", + " basic:\n", + " graph_name: acm\n", + " part_config: /tmp/acm_nc_{num_nodes}p/acm.json\n", + " model_encoder_type: rgcn\n", + " gnn:\n", + " fanout: \"15,10\"\n", + " num_layers: 2\n", + " hidden_size: 128\n", + " use_mini_batch_infer: false\n", + " input:\n", + " restore_model_path: null\n", + " output:\n", + " save_model_path: null\n", + " save_embed_path: null\n", + " hyperparam:\n", + " dropout: 0.5\n", + " lr: 0.001\n", + " num_epochs: 10\n", + " batch_size: 1024\n", + " wd_l2norm: 0\n", + " rgcn:\n", + " num_bases: -1\n", + " use_self_loop: true\n", + " sparse_optimizer_lr: 1e-2\n", + " use_node_embeddings: false\n", + " node_classification:\n", + " node_feat_name:\n", + " - paper:feat\n", + " - author:feat\n", + " - subject:feat\n", + " target_ntype: paper\n", + " label_field: label\n", + " multilabel: false\n", + " num_classes: 40\n", + " eval_metric:\n", + " - accuracy\n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "21fa3903-22b6-4c7b-8c9a-ed104f1f1410", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = gsl.datamodule.GSgnnNodeTrainDataModule(trainer=trainer, config=config, graph_data_uri=acm_gs)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2dff0714-f93e-49a8-bea5-b7947e21ee9e", + "metadata": {}, + "outputs": [], + "source": [ + "model = gsl.module.GSgnnNodeModel(datamodule=datamodule, config=config)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "03ad74a4-6fe2-4914-8bc1-a354ed4f092d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/dgl/distributed/dist_context.py:248: net_type is deprecated and will be removed in future release.\n", + "INFO:root:Start to load partition from /tmp/acm_nc_1p/part0/graph.dgl which is 5055161 bytes. It may take non-trivial time for large partition.\n", + "INFO:root:Finished loading partition from /tmp/acm_nc_1p/part0/graph.dgl.\n", + "INFO:root:Finished loading node data.\n", + "INFO:root:Finished loading edge data.\n", + "INFO:root:part 0, train: 9999, val: 1249, test: 1249\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initialize the distributed services with graphbolt: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | model | GSgnnNodeModel | 333 K \n", + "-----------------------------------------\n", + "333 K Trainable params\n", + "0 Non-trainable params\n", + "333 K Total params\n", + "1.332 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n", + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n", + "/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e64dc622c02d499b9b4c184472287701", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n", + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n", + "INFO:root:[Rank 0] dist_inference: finishes 0 iterations.\n", + "`Trainer.fit` stopped: `max_epochs=2` reached.\n" + ] + } + ], + "source": [ + "trainer.fit(model=model, datamodule=datamodule)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "graphstorm_lightning", + "language": "python", + "name": "graphstorm_lightning" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/graphstorm-lightning/graphstorm_lightning/__init__.py b/graphstorm-lightning/graphstorm_lightning/__init__.py new file mode 100644 index 0000000000..46238a5b16 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/__init__.py @@ -0,0 +1,3 @@ +from . import datamodule +from . import module +from . import utils diff --git a/graphstorm-lightning/graphstorm_lightning/datamodule/__init__.py b/graphstorm-lightning/graphstorm_lightning/datamodule/__init__.py new file mode 100644 index 0000000000..9d8bb0c473 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/datamodule/__init__.py @@ -0,0 +1 @@ +from .node_gnn import GSgnnNodeTrainDataModule diff --git a/graphstorm-lightning/graphstorm_lightning/datamodule/node_gnn.py b/graphstorm-lightning/graphstorm_lightning/datamodule/node_gnn.py new file mode 100644 index 0000000000..925fd13ec1 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/datamodule/node_gnn.py @@ -0,0 +1,139 @@ +from typing import Any, Dict, Optional + +import graphstorm as gs +import graphstorm_lightning as gsl +import pytorch_lightning as pl +import torch +import contextlib + + +class GSgnnNodeTrainDataModule(pl.LightningDataModule): + def __init__( + self, + trainer: pl.Trainer, + config: Dict[str, Any], + graph_data_uri: Optional[str] = None, + ): + super().__init__() + self.save_hyperparameters(ignore=["trainer"]) + self.trainer = trainer + + def _device(self) -> torch.device: + return self.trainer.strategy.root_device + + def prepare_data(self) -> None: + gsl.utils.load_data(self.trainer, self.hparams.config, self.hparams.graph_data_uri) + + def setup(self, stage: str) -> None: + ip_config = gsl.utils.initialize_dgl(self.trainer, self.hparams.config) + with ip_config or contextlib.nullcontext(): + if ip_config: + self.hparams.config["gsf"]["basic"]["ip_config"] = ip_config.name + self.config = config = gsl.utils.get_config(self.trainer, self.hparams.config) + self.gnn = gs.dataloading.GSgnnNodeTrainData( + config.graph_name, + config.part_config, + train_ntypes=config.target_ntype, + eval_ntypes=config.eval_target_ntype, + node_feat_field=config.node_feat_name, + label_field=config.label_field, + lm_feat_ntypes=gs.utils.get_lm_ntypes(config.node_lm_configs), + ) + + def _dataloader(self, target_idxs: Dict[str, torch.Tensor]) -> gs.dataloading.GSgnnNodeDataLoader: + return gs.dataloading.GSgnnNodeDataLoader( + self.gnn, + target_idxs, + fanout=self.hparams.fanout, + batch_size=self.hparams.batch_size, + device=self._device(), + train_task=self.hparams.train_task, + construct_feat_ntype=self.hparams.construct_feat_ntype, + construct_feat_fanout=self.hparams.construct_feat_fanout, + ) + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int): + data = self.gnn + + # from source + input_nodes, seeds, blocks = batch + if not isinstance(input_nodes, dict): + assert len(data.g.ntypes) == 1 + input_nodes = {data.g.ntypes[0]: input_nodes} + # make sure input_nodes and seeds are on CPU, since they get converted to NumPy by DGL internally + input_feats = data.get_node_feats(input_nodes, device) + lbl = data.get_labels(seeds, device) + blocks = pl.utilities.move_data_to_device(blocks, device) + return (input_nodes, lbl, blocks, input_feats) + + def train_dataloader(self) -> gs.dataloading.GSgnnNodeDataLoader: # type: ignore + train_data = self.gnn + config = self.config + # from https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/run/gsgnn_np/gsgnn_np.py + dataloader = None + if config.use_pseudolabel: + # Use nodes not in train_idxs as unlabeled node sets + unlabeled_idxs = train_data.get_unlabeled_idxs() + # semi-supervised loader + dataloader = gs.dataloading.GSgnnNodeSemiSupDataLoader( + train_data, + train_data.train_idxs, + unlabeled_idxs, + fanout=config.fanout, + batch_size=config.batch_size, + device=self._device(), + train_task=True, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout, + ) + else: + dataloader = gs.dataloading.GSgnnNodeDataLoader( + train_data, + train_data.train_idxs, + fanout=config.fanout, + batch_size=config.batch_size, + device=self._device(), + train_task=True, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout, + ) + return dataloader + + def val_dataloader(self) -> Optional[gs.dataloading.GSgnnNodeDataLoader]: # type: ignore + train_data = self.gnn + config = self.config + # from https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/run/gsgnn_np/gsgnn_np.py + test_dataloader = None + if len(train_data.test_idxs) > 0: + fanout = config.eval_fanout if config.use_mini_batch_infer else [] + test_dataloader = gs.dataloading.GSgnnNodeDataLoader( + train_data, + train_data.test_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + device=self._device(), + train_task=False, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout, + ) + return test_dataloader + + def test_dataloader(self) -> Optional[gs.dataloading.GSgnnNodeDataLoader]: # type: ignore + train_data = self.gnn + config = self.config + # from https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/run/gsgnn_np/gsgnn_np.py + test_dataloader = None + if len(train_data.test_idxs) > 0: + # we don't need fanout for full-graph inference + fanout = config.eval_fanout if config.use_mini_batch_infer else [] + test_dataloader = gs.dataloading.GSgnnNodeDataLoader( + train_data, + train_data.test_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + device=self._device(), + train_task=False, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout, + ) + return test_dataloader diff --git a/graphstorm-lightning/graphstorm_lightning/module/__init__.py b/graphstorm-lightning/graphstorm_lightning/module/__init__.py new file mode 100644 index 0000000000..03d50f35b3 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/module/__init__.py @@ -0,0 +1 @@ +from .node_gnn import GSgnnNodeModel diff --git a/graphstorm-lightning/graphstorm_lightning/module/node_gnn.py b/graphstorm-lightning/graphstorm_lightning/module/node_gnn.py new file mode 100644 index 0000000000..37e0bbacb0 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/module/node_gnn.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +import graphstorm_lightning as gsl +from graphstorm import create_builtin_node_gnn_model +from graphstorm.model import GSgnnNodeModelInterface +from graphstorm.model.gnn import GSOptimizer +from graphstorm.run.gsgnn_np.gsgnn_np import get_evaluator +from graphstorm.trainer.np_trainer import do_full_graph_inference +from typing_extensions import override + + +class GSgnnNodeModel(pl.LightningModule): + def __init__( + self, + datamodule: pl.LightningDataModule, + config: Dict[str, Any], + max_grad_norm: Optional[float] = None, + ) -> None: + super().__init__() + self.automatic_optimization = False + self.save_hyperparameters(ignore=["data"]) + self.datamodule = datamodule + self.model: GSgnnNodeModelInterface + self.optimizer: GSOptimizer + self.val_preds = {} + self.val_labels = {} + + @override + def configure_model(self) -> None: + if hasattr(self, "model"): + return + config = gsl.utils.get_config(self.trainer, self.hparams.config) + self.val_fanout = config.eval_fanout + train_data = self.datamodule.gnn + self.model = create_builtin_node_gnn_model(train_data.g, config, train_task=True) + self.evaluator = get_evaluator(config) + + @override + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + + # from source + input_nodes, lbl, blocks, input_feats = batch + # torch.distributed might move this tensor to GPU + input_nodes = pl.utilities.move_data_to_device(input_nodes, "cpu") + loss = self.model(blocks, input_feats, None, lbl, input_nodes) + self.manual_backward(loss) + + # clip gradient + if self.hparams.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.max_grad_norm) + + # accumulate gradients of N batches + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + self.optimizer.step() + self.optimizer.zero_grad() + + self.log("train_loss", loss, prog_bar=True) + return loss + + @override + def on_validation_epoch_start(self) -> None: + model = self.model + data = self.datamodule.gnn + + # from source + self.val_emb = do_full_graph_inference(model, data, fanout=self.val_fanout) + + @override + def validation_step(self, batch: Any, batch_idx: int) -> None: + model = self.model + device = self.device + labels = self.val_labels + preds = self.val_preds + emb = self.val_emb + + # from source + input_nodes, lbl, _, _ = batch + # torch.distributed might move this tensor to GPU + input_nodes = pl.utilities.move_data_to_device(input_nodes, "cpu") + for ntype, in_nodes in input_nodes.items(): + if isinstance(model.decoder, torch.nn.ModuleDict): + assert ntype in model.decoder, f"Node type {ntype} not in decoder" + decoder = model.decoder[ntype] + else: + decoder = model.decoder + pred = decoder.predict(emb[ntype][in_nodes].to(device)) + if ntype in preds: + preds[ntype].append(pred.cpu()) + else: + preds[ntype] = [pred.cpu()] + if ntype in labels: + labels[ntype].append(lbl[ntype]) + else: + labels[ntype] = [lbl[ntype]] + + @override + def on_validation_epoch_end(self) -> None: + self.val_emb = None + total_steps = self.trainer.global_step + preds = self.val_preds + labels = self.val_labels + + # from source + val_pred = {} + val_label = {} + for ntype, ntype_pred in preds.items(): + val_pred[ntype] = torch.cat(ntype_pred) + for ntype, ntype_label in labels.items(): + val_label[ntype] = torch.cat(ntype_label) + labels.clear() + preds.clear() + + # TODO(wlcong) we only support node prediction on one node type for evaluation now + assert len(val_label) == 1, "We only support prediction on one node type for now." + ntype = list(val_label.keys())[0] + # We need to have val and label (test and test label) data in GPU + # when backend is nccl, as we need to use nccl.all_reduce to exchange + # data between GPUs + val_pred = val_pred[ntype] + val_label = val_label[ntype] + val_score, test_score = self.evaluator.evaluate(val_pred, None, val_label, None, total_steps) + for metric in self.evaluator.metric: + self.log(f"val_{metric}", val_score[metric]) + best_val_score = self.evaluator.best_val_score + self.log(f"best_val_{metric}", best_val_score[metric]) + + @override + def configure_optimizers(self) -> pl.utilities.types.OptimizerLRScheduler: + self.optimizer = self.model.create_optimizer() + return [] diff --git a/graphstorm-lightning/graphstorm_lightning/utils.py b/graphstorm-lightning/graphstorm_lightning/utils.py new file mode 100644 index 0000000000..6658148989 --- /dev/null +++ b/graphstorm-lightning/graphstorm_lightning/utils.py @@ -0,0 +1,127 @@ +import atexit +import os +import signal +import socket +import tempfile +from multiprocessing import Process +from pathlib import Path +from typing import IO, Any, Dict, Optional, List + +import dgl +import graphstorm as gs +import pyarrow as pa +import pytorch_lightning as pl +import torch +import yaml + + +def get_config(trainer: pl.Trainer, cfg: Dict[str, Any]) -> gs.config.GSConfig: + with tempfile.NamedTemporaryFile(prefix="gs_", mode="w", suffix=".yaml") as yaml_file: + yaml.safe_dump(cfg, yaml_file) + gs_config_args = [ + "--cf", + yaml_file.name, + "--local-rank", + str(trainer.local_rank), + ] + gs_parser = gs.config.get_argument_parser() + gs_args = gs_parser.parse_args(gs_config_args) + config = gs.config.GSConfig(gs_args) + return config + + +def load_data(trainer: pl.Trainer, config: Dict[str, Any], from_: str): + to = get_part_config(config) + fs, path = pa.fs.FileSystem.from_uri(from_) # type: ignore[call-arg,arg-type] + num_devices = trainer.num_devices + rank = trainer.global_rank + local_ranks = set(range(rank, rank + num_devices)) + dir = fs.get_file_info(path) + if dir.type != pa.fs.FileType.Directory: + raise ValueError("Directory path must be provided") + local_root = Path(to) + local_root = local_root.parent if local_root.name else local_root + local_root.mkdir(exist_ok=True) + selector = pa.fs.FileSelector(path, recursive=False) + for file in fs.get_file_info(selector): + file_path = Path(file.path).name + if file_path.startswith("part"): + rank = int(file_path.replace("part", "")) + if rank not in local_ranks: + continue + local_file = local_root / file_path + if file.type == pa.fs.FileType.Directory: + local_file.mkdir(exist_ok=True) + pa.fs.copy_files(source=file.path, destination=local_file.as_uri(), source_filesystem=fs) + + +def get_ip_tensor() -> torch.ByteTensor: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + return torch.ByteTensor(list(ip.encode("ascii"))) + + +def get_part_config(config: Dict[str, Any]) -> str: + return config.get("gsf", {}).get("basic", {}).get("part_config") + + +def start_server(role: str, server_id: int, num_server: int, num_client: int, conf_path: str, ip_config: str) -> None: + os.environ["DGL_ROLE"] = role + os.environ["DGL_DIST_MODE"] = "distributed" + os.environ["DGL_SERVER_ID"] = str(server_id) + os.environ["DGL_NUM_SERVER"] = str(num_server) + os.environ["DGL_NUM_CLIENT"] = str(num_client) + os.environ["DGL_CONF_PATH"] = conf_path + os.environ["DGL_IP_CONFIG"] = ip_config + dgl.distributed.initialize(ip_config, net_type="socket") + + +def ip_addresses(trainer: pl.Trainer) -> Optional[List[str]]: + if not torch.distributed.is_initialized(): + return None + gloo = torch.distributed.new_group(backend="gloo") + ip_tensor = get_ip_tensor() + max_length = max(15, len(ip_tensor)) # assuming IPv4 and fixed xxx.xxx.xxx.xxx format + ip_tensor_padded = torch.cat((ip_tensor, torch.zeros(max_length - len(ip_tensor), dtype=torch.uint8))) + gathered_ip_tensors = [torch.zeros(max_length, dtype=torch.uint8) for _ in range(trainer.world_size)] + torch.distributed.all_gather(gathered_ip_tensors, ip_tensor_padded, group=gloo) + gathered_ips = sorted( + set(tensor.numpy().tobytes().decode("ascii").rstrip("\x00") for tensor in gathered_ip_tensors) + ) + torch.distributed.destroy_process_group(gloo) + return gathered_ips + + +def prepare_data(trainer: pl.Trainer, config: Dict[str, Any], graph_data_uri: Optional[str]) -> None: + part_config = get_part_config(config) + if part_config and graph_data_uri: + load_data(from_=graph_data_uri, to=part_config, num_devices=trainer.num_devices, rank=trainer.global_rank) + + +def initialize_dgl(trainer: pl.Trainer, config: Dict[str, Any]) -> Optional[IO[Any]]: + gathered_ips = ip_addresses(trainer) + if not gathered_ips: + dgl.distributed.initialize(None, net_type="socket") + return None + ip_config = tempfile.NamedTemporaryFile(prefix="ip_config", mode="w", suffix=".txt") + ip_config.write("\n".join(gathered_ips)) + ip_config.flush() + + id = trainer.global_rank + num_servers = trainer.num_nodes + num_clients = trainer.world_size + part_config = get_part_config(config) + if trainer.local_rank == 0: + dgl_process = Process( + target=start_server, args=("server", id, num_servers, num_clients, part_config, ip_config.name) + ) + dgl_process.start() + start_server("client", id, num_servers, num_clients, part_config, ip_config.name) + + def teardown(*args: Any) -> None: + dgl_process.terminate() + + signal.signal(signal.SIGTERM, teardown) # terminate signal + signal.signal(signal.SIGINT, teardown) # keyboard interrupt + atexit.register(teardown) + return ip_config diff --git a/graphstorm-lightning/pyproject.toml b/graphstorm-lightning/pyproject.toml new file mode 100644 index 0000000000..3170a7a032 --- /dev/null +++ b/graphstorm-lightning/pyproject.toml @@ -0,0 +1,59 @@ +[tool.poetry] +name = "graphstorm_lightning" +version = "0.0.1" +description = "PyTorch Lightning support for GraphStorm" +readme = "README.md" +packages = [{include = "graphstorm_lightning"}] +authors = [ + "Amazon AI Graph ML team" +] + +[tool.poetry.dependencies] +python = ">=3.8" +pytorch-lightning = ">=2.2.1" +graphstorm = "*" + +[tool.poetry.group.dev] +optional = true + +[tool.poetry.group.dev.dependencies] +pytest = ">=7.4.0" +mock = ">=5.0.2" +coverage = ">=7.0.0" +sphinx = ">=6.0.0" +mypy = ">=1.0.0" +types-psutil = "^5.9.5.15" +black = ">=24.2.0" +pre-commit = "^3.3.3" +types-mock = "^5.1.0.1" +pylint = "~2.17.5" +jupyter = "*" +ipykernel = "*" +torch = "==2.1.0" +dgl = {url="https://data.dgl.ai/wheels/cu121/dgl-2.1.0%2Bcu121-cp310-cp310-manylinux1_x86_64.whl"} +dglgo = "*" + +[project] +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX", +] + +[tool.mypy] +check_untyped_defs = true + +[tool.pytest] +filterwarnings = "ignore::DeprecationWarning" + +[build-system] +requires = ["poetry-core>=1.0.8"] +build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 100 +target-version = ['py39'] + +[virtualenvs] +in-project = true