Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model:Implement GraphGPT and LLaGA #232

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/graphgpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# GraphGPT: Graph Instruction Tuning for Large Language Models
* Paper link: http://arxiv.org/abs/2310.13023
* Author's code repo: https://github.com/HKUDS/GraphGPT

# How to Run

* First, follow the original repo to install all required packages;

* Then download all required datasets and pretrained checkpoints, and fill their path into corresponding values in eval.sh

# Dataset Statics
| Dataset | # Nodes | # Edges | # Classes |
| :-------: | :-------: | :------: | :------: |
| Cora | 25,120 | 182,280 | 70 |
| PubMed | 19,717 | 44,338 | 3 |
| ogb-arxiv | 169,343 | 1,166,243 | 40 |

# Files Description
* graphgpt_trainer.py: the trainer of graphgpt, inference stage
* graphgpt_eval.py: run this to evaluate

# Results
```bash
# run inference
TL_BACKEND="torch" nohup bash examples/graphgpt/eval.sh > log/test_graphgpt.out &
# run evaluation
python examples/graphgpt/graphgpt_eval.py --dataset cora
```
| Dataset | Paper | Our(torch) |
| :-------: | :-------: | :------: |
| Cora | 0.1501 | 0.1451 |
13 changes: 13 additions & 0 deletions examples/graphgpt/eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export PYTHONPATH=$(dirname $(dirname $(realpath $0))):$PYTHONPATH
# to fill in the following path to extract projector for the second tuning stage!
output_model=/local/yy3/graphgpt/GraphGPT-7B-mix-all # path to the pre-trained model checkpoint
datapath=/local/yy3/graphgpt/data/eval/cora_test_instruct_std.json # path to the instruction datset
graph_data_path=/local/yy3/graphgpt/data/graph_data_all.pt # path to the graph data
res_path=./output_stage_2_cora_nc # path to save the results
start_id=0
end_id=20000 # total number of instructions to test
num_gpus=1

export CUDA_VISIBLE_DEVICES=2 # specify the GPU id

python ./examples/graphgpt/graphgpt_trainer.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus}
106 changes: 106 additions & 0 deletions examples/graphgpt/graphgpt_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import json
import os.path as osp
import os
import torch as th
import re
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='arxiv')
args = parser.parse_args()

label_to_idx = {
"cora":{"databases, object oriented": 29, "operating systems, memory management": 59, "data structures algorithms and theory, quantum computing": 24, "artificial intelligence, planning": 13, "artificial intelligence, knowledge representation": 4, "artificial intelligence, data mining": 1, "artificial intelligence, vision and pattern recognition": 17, "artificial intelligence, machine learning, case-based": 5, "artificial intelligence, agents": 0, "artificial intelligence, machine learning, probabilistic methods": 8, "encryption and compression, security": 36, "operating systems, distributed": 57, "human computer interaction, interface design": 46, "artificial intelligence, machine learning, genetic algorithms": 6, "human computer interaction, graphics and virtual reality": 45, "artificial intelligence, machine learning, rule learning": 10, "programming, functional": 63, "programming, object oriented": 67, "encryption and compression, encryption": 35, "databases, performance": 30, "networking, protocols": 54, "data structures algorithms and theory, randomized": 25, "data structures algorithms and theory, formal languages": 20, "data structures algorithms and theory, parallel": 23, "programming, software development": 69, "programming, compiler design": 61, "artificial intelligence, machine learning, theory": 11, "artificial intelligence, machine learning, neural networks": 7, "programming, logic": 66, "databases, relational": 32, "information retrieval, retrieval": 52, "programming, debugging": 62, "networking, wireless": 56, "artificial intelligence, theorem proving": 16, "databases, temporal": 33, "encryption and compression, compression": 34, "information retrieval, filtering": 51, "data structures algorithms and theory, computational complexity": 18, "programming, garbage collection": 64, "artificial intelligence, machine learning, reinforcement learning": 9, "human computer interaction, multimedia": 47, "hardware and architecture, vlsi": 43, "artificial intelligence, nlp": 12, "hardware and architecture, microprogramming": 42, "operating systems, fault tolerance": 58, "programming, java": 65, "operating systems, realtime": 60, "human computer interaction, cooperative": 44, "artificial intelligence, speech": 15, "databases, deductive": 28, "artificial intelligence, robotics": 14, "data structures algorithms and theory, logic": 22, "networking, routing": 55, "hardware and architecture, logic design": 40, "hardware and architecture, distributed architectures": 37, "data structures algorithms and theory, hashing": 21, "programming, semantics": 68, "artificial intelligence, games and search": 3, "databases, concurrency": 27, "data structures algorithms and theory, sorting": 26, "human computer interaction, wearable computers": 48, "information retrieval, digital library": 49, "artificial intelligence, expert systems": 2, "information retrieval, extraction": 50, "data structures algorithms and theory, computational geometry": 19, "databases, query evaluation": 31, "networking, internet": 53, "hardware and architecture, memory structures": 41, "hardware and architecture, high performance computing": 38, "hardware and architecture, input output and storage": 39},
"pubmed":{"Experimentally induced diabetes": 0, "Type 2 diabetes": 2, "Type 1 diabetes": 1}
}



data_list = []
folder = 'output_stage_2_{}_nc'.format(args.dataset)
for filename in os.listdir(folder):
if filename.endswith('.json'):
file_path = os.path.join(folder, filename)
with open(file_path, 'r') as f:
data = json.load(f)
data_list.extend(data)

print(data_list[1])

graph_data = th.load('/local/yy3/graphgpt/data/graph_data_all.pt')[args.dataset]
labels = graph_data.y

def cal_map():
label_dict = {}
if args.dataset == "arxiv":
df = pd.read_csv(os.path.expanduser('~/datasets/OGB/ogbn_arxiv/mapping/labelidx2arxivcategeory.csv.gz'), compression='gzip')
for index, line in df.iterrows():
lb = line['arxiv category'].split(' ')[-1]
lb_new = 'cs.' + lb.upper()
label_dict[lb_new] = line['label idx']
else:
label_dict = label_to_idx[args.dataset]
return label_dict

class_map = cal_map()

inverse_class_map = {}
for lb, lb_id in class_map.items():
inverse_class_map[lb_id] = lb


pattern = r"cs\.[A-Z]{2}"


topk = 3

correct = 0
total = len(data_list)

trues = []
preds = []

for instruct_item in tqdm(data_list):
nid = instruct_item['node_idx']
gpt_res = instruct_item['res']


true_y = labels[nid]

pred_y = []
if args.dataset == "arxiv":
matches = list(set(re.findall(pattern, gpt_res))) # pred
sorted_matches = sorted(matches, key=lambda x: gpt_res.index(x))
for m in sorted_matches:
try:
pred_y.append(class_map[m])
except:
pass
try:
# print(sorted_matches)
preds.append(pred_y[0])
except:
preds.append(-1)
else:
for lb, lb_id in class_map.items():
if lb in gpt_res:
pred_y.append(lb_id)
try:
# print(sorted_matches)
preds.append(pred_y[0])
except:
preds.append(-1)
trues.append(true_y.item())
res_tmp = 1 if true_y in pred_y[:topk] else 0
correct = correct + 1 if true_y in pred_y[:topk] else correct

acc = correct / total

print("Accuracy:", acc)

report = classification_report(trues, preds, digits=6)

print(report)
232 changes: 232 additions & 0 deletions examples/graphgpt/graphgpt_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from gammagl.utils.conversation import conv_templates, SeparatorStyle
from gammagl.utils.gfm_utils import disable_torch_init, KeywordsStoppingCriteria
from gammagl.utils.gfm_utils import DEFAULT_G_END_TOKEN, DEFAULT_G_START_TOKEN, DEFAULT_GRAPH_PATCH_TOKEN, DEFAULT_GRAPH_TOKEN, GRAPH_TOKEN_INDEX
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from gammagl.models.graphgpt import *

from torch_geometric.data import Data
import json
import copy
from tqdm import tqdm
import json
import os.path as osp
import ray

os.environ['TL_BACKEND'] = 'torch'

def load_graph(instruct_item, graph_data_path):
graph_data_all = torch.load(graph_data_path)
graph_dict = instruct_item['graph']
graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long()
graph_node_list = copy.deepcopy(graph_dict['node_list'])
target_node = copy.deepcopy(graph_dict['node_idx'])
graph_type = copy.deepcopy(instruct_item['id']).split('_')[0]
graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ##

cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size

graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node]))

return {
'graph_data': graph_ret,
'graph_token_len': cur_token_len
}


def load_prompting_file(file_path):
with open(file_path, 'r') as f:
data = json.load(f)
return data

# def prepare_query(instruct_item):


def run_eval(args, num_gpus):
# split question file into num_gpus files
prompt_file = load_prompting_file(args.prompting_file)
args.end_id = min(args.end_id, len(prompt_file))
prompt_file = prompt_file[args.start_id:args.end_id]
chunk_size = len(prompt_file) // num_gpus
ans_handles = []
split_list = list(range(args.start_id, args.end_id, chunk_size))
idx_list = list(range(0, len(prompt_file), chunk_size))
if len(split_list) == num_gpus:
split_list.append(args.end_id)
idx_list.append(len(prompt_file))
elif len(split_list) == num_gpus + 1:
split_list[-1] = args.end_id
idx_list[-1] = len(prompt_file)
else:
raise ValueError('error in the number of list')

if osp.exists(args.output_res_path) is False:
os.mkdir(args.output_res_path)

for idx in range(len(idx_list) - 1):
start_idx = idx_list[idx]
end_idx = idx_list[idx + 1]

start_split = split_list[idx]
end_split = split_list[idx + 1]
ans_handles.append(
eval_model.remote(
args, prompt_file[start_idx:end_idx], start_split, end_split
)
)

ans_jsons = []
for ans_handle in ans_handles:
ans_jsons.extend(ray.get(ans_handle))

# with open(args.output_res_path, "w") as ans_file:
# for line in ans_jsons:
# ans_file.write(json.dumps(line) + "\n")


@ray.remote(num_gpus=1)
@torch.inference_mode()
def eval_model(args, prompt_file, start_idx, end_idx):
# load prompting file
# prompt_file = load_prompting_file(args.prompting_file)


# Model
disable_torch_init()
# model_name = os.path.expanduser(args.model_name)
print('start loading')
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
print('finish loading')

print('start loading')
model = GraphLlamaForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda()
print('finish loading')

use_graph_start_end = getattr(model.config, "use_graph_start_end", False)
tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True)
if use_graph_start_end:
tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True)

graph_tower = model.get_model().graph_tower

# TODO: add graph tower
# if graph_tower.device.type == 'meta':
# print('meta')
clip_graph, args_graph= load_model_pretrained(CLIP, model.config.pretrain_graph_model_path)
graph_tower = graph_transformer(args_graph)
graph_tower = transfer_param_tograph(clip_graph, graph_tower)

model.get_model().graph_tower = graph_tower.cuda()
# else:
# print('other')
# print(next(graph_tower.parameters()).dtype)
graph_tower.to(device='cuda', dtype=torch.float16)
graph_config = graph_tower.config
graph_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0]
graph_config.use_graph_start_end = use_graph_start_end
if use_graph_start_end:
graph_config.graph_start_token, graph_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN])
# TODO: add graph token len

res_data = []
print(f'total: {len(prompt_file)}')
for idx, instruct_item in tqdm(enumerate(prompt_file)):
# instruct_item = prompt_file[0]
# if idx >= 3:
# break
graph_dict = load_graph(instruct_item, args.graph_data_path)
graph_token_len = graph_dict['graph_token_len']
graph_data = graph_dict['graph_data']

qs = instruct_item["conversations"][0]["value"]
# if use_graph_start_end:
# qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN
# else:
# qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len

replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN
qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token)

# if "v1" in args.model_name.lower():
# conv_mode = "graphchat_v1"
# else:
# raise ValueError('Don\'t support this model')
conv_mode = "graphchat_v1"

if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])



input_ids = torch.as_tensor(inputs.input_ids).cuda()

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

graph_data.graph_node = graph_data.graph_node.to(torch.float16)
# graph_data.edge_index = graph_data.edge_index.to(torch.float16)

with torch.inference_mode():
output_ids = model.generate(
input_ids,
graph_data=graph_data.cuda(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
# print(outputs)

res_data.append({"id": instruct_item["id"], "node_idx": instruct_item["graph"]["node_idx"], "res": outputs}.copy())
with open(osp.join(args.output_res_path, 'arxiv_test_res_{}_{}.json'.format(start_idx, end_idx)), "w") as fout:
json.dump(res_data, fout, indent=4)
return res_data
# with open(args.output_res_path, "w") as fout:
# json.dump(res_data, fout, indent=4)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
# parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--prompting_file", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--graph_data_path", type=str, default=None)

parser.add_argument("--output_res_path", type=str, default=None)
parser.add_argument("--num_gpus", type=int, default=4)

parser.add_argument("--start_id", type=int, default=0)
parser.add_argument("--end_id", type=int, default=20567)

args = parser.parse_args()

# eval_model(args)

ray.init()
run_eval(args, args.num_gpus)


# protobuf 4.22.3
Loading
Loading