From 9aa114ae39fc704b12bf8390f3d922d2d9d42a3b Mon Sep 17 00:00:00 2001 From: VoidHaruhi Date: Sat, 1 Feb 2025 13:50:42 +0000 Subject: [PATCH 1/2] add llaga --- examples/llaga/eval.sh | 27 ++ examples/llaga/llaga_eval.py | 402 ++++++++++++++++++++++ examples/llaga/llaga_trainer.py | 332 ++++++++++++++++++ gammagl/models/llaga.py | 582 ++++++++++++++++++++++++++++++++ gammagl/utils/builder.py | 160 +++++++++ gammagl/utils/conversation.py | 405 ++++++++++++++++++++++ gammagl/utils/gfm_utils.py | 53 +++ 7 files changed, 1961 insertions(+) create mode 100644 examples/llaga/eval.sh create mode 100644 examples/llaga/llaga_eval.py create mode 100644 examples/llaga/llaga_trainer.py create mode 100644 gammagl/models/llaga.py create mode 100644 gammagl/utils/builder.py create mode 100644 gammagl/utils/conversation.py create mode 100644 gammagl/utils/gfm_utils.py diff --git a/examples/llaga/eval.sh b/examples/llaga/eval.sh new file mode 100644 index 00000000..211a781a --- /dev/null +++ b/examples/llaga/eval.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +export PYTHONPATH=$(dirname $(dirname $(realpath $0))):$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +model_path="/local/yy3/vicuna-7b-v1.5-16k" +model_base="/local/yy3/llaga-vicuna-7b-simteg-ND-general_model-2-layer-mlp-projector" #meta-llama/Llama-2-7b-hf +mode="v1" # use 'llaga_llama_2' for llama and "v1" for others +dataset="arxiv" #test dataset +task="nc" #test task +emb="simteg" +use_hop=2 +sample_size=10 +template="ND" # or ND +output_path="llaga/test.txt" + +python llaga/llaga_trainer.py \ +--model_path ${model_path} \ +--model_base ${model_base} \ +--conv_mode ${mode} \ +--dataset ${dataset} \ +--pretrained_embedding_type ${emb} \ +--use_hop ${use_hop} \ +--sample_neighbor_size ${sample_size} \ +--answers_file ${output_path} \ +--task ${task} \ +--cache_dir ../../checkpoint \ +--template ${template} \ No newline at end of file diff --git a/examples/llaga/llaga_eval.py b/examples/llaga/llaga_eval.py new file mode 100644 index 00000000..e6787ae4 --- /dev/null +++ b/examples/llaga/llaga_eval.py @@ -0,0 +1,402 @@ +import random + +import torch +import json +import argparse +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F +import numpy as np +from sklearn.metrics import f1_score +from sklearn.metrics import roc_auc_score +import os +root_path = "/home/yy3/graph-text-align/LLaGA" + + +def sbert(model_type, device): + model = SentenceTransformer(model_type, device=device) + return model + +def get_sbert_embedding(model_type, texts, device): + if model_type == 'sbert': + model_type = 'all-MiniLM-L6-v2' + sbert_model = sbert(model_type, f'cuda:{device}') + sbert_embeds = sbert_model.encode(texts, batch_size=8, show_progress_bar=True) + return torch.tensor(sbert_embeds) + +def eval_arxiv_nd(res_path): + data=torch.load(os.path.join(root_path,"dataset/ogbn-arxiv/processed_data.pt" )) + labels=data.label_texts + short_labels = [l[0:5] for l in labels] + ys=data.y.numpy().tolist() + + titles = data.title + + all_sample=0 + short_correct=0 + all_correct=0 + gt=[] + out=[] + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + id=res["question_id"] + y=ys[id] + short_label = short_labels[y] + label=labels[y] + if label.strip() in ans.strip(): + all_correct+=1 + if short_label in ans: + short_correct+=1 + out.append(ans) + gt.append(f"This is a paper in {label} domain, it's about {titles[id]}.") + short_acc = short_correct/all_sample + all_acc = all_correct / all_sample + print(f"Test samples: {all_sample}\nshort_correct: {short_correct}\nshort_acc: {short_acc:.4f}\nall_correct: {all_correct}\nall_acc: {all_acc:.4f}") + gt_embedding = get_sbert_embedding("sbert", gt, 0) + out_embedding = get_sbert_embedding("sbert", out, 0) + gt_embedding=F.normalize(gt_embedding, p=2, eps=1e-6, dim=1) + out_embedding=F.normalize(out_embedding, p=2, eps=1e-6, dim=1) + predict_sim=(gt_embedding*out_embedding).sum(1).mean().item() + gt_sim_matrix=torch.mm(gt_embedding, gt_embedding.transpose(0, 1)).detach().cpu() + n=gt_sim_matrix.shape[0] + gt_sim_matrix[torch.eye(n, dtype=torch.bool)]=0 + gt_sim=(gt_sim_matrix.sum()/(n*(n-1))).item() + print(f"Predict similarity {predict_sim: .4f}, Pairwise similarity: {gt_sim: .4f}") + + +def eval_arxiv_nc(res_path): + data=torch.load(os.path.join(root_path,"dataset/ogbn-arxiv/processed_data.pt" )) + labels=data.label_texts + short_labels = [l[0:5] for l in labels] + ys=data.y.numpy().tolist() + + all_sample=0 + overall_correct=0 + strict_correct=0 + error=[] + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + y=ys[res["question_id"]] + short_label = short_labels[y] + label=labels[y] + if label.lower().strip() == ans.lower().strip(): + strict_correct+=1 + overall_correct+=1 + elif short_label.lower() in ans.lower() and sum([la.lower() in ans.lower() for la in short_labels])==1: + overall_correct+=1 + else: + error.append((ans, label)) + if args.sample > 0 and all_sample >= args.sample: + break + overall_acc = overall_correct/all_sample + strict_acc = strict_correct / all_sample + print(f"Test samples: {all_sample}\nstrict_acc: {strict_acc:.4f}\noverall_acc: {overall_acc:.4f}") + + +def eval_lp(res_path): + all_sample=0 + correct = 0 + with open(res_path, 'r') as f: + for line in f: + res = json.loads(line) + ans = res["text"].strip() + label=res["gt"].strip() + all_sample += 1 + if ("yes" in ans and "yes" in label) or ("yes" not in ans and "no" in label): + correct += 1 + if args.sample > 0 and all_sample >= args.sample: + break + acc = correct / all_sample + print(f"Test samples: {all_sample}\ncorrect: {correct}\n acc: {acc:.4f}") + +def eval_lprank(res_path): + all_sample=0 + correct = 0 + y_true = [] + y_pred=[] + with open(res_path, 'r') as f: + for line in f: + res = json.loads(line) + logit = res["logit"] + score = torch.softmax(torch.tensor(logit[:2]), dim=-1)[0].item() + # score = logit[0] + label=res["gt"].strip() + if label == "yes": + y_true.append(1) + else: + y_true.append(0) + y_pred.append(score) + auc = roc_auc_score(y_true, y_pred) + y_pred = torch.tensor(y_pred) + y_true = torch.tensor(y_true) + acc = ((y_pred>0.5)==y_true).sum()/y_pred.shape[0] + + print(f"AUC: {auc:.4f}") + print(f"ACC: {acc:.4f}") + y_pos=y_pred[y_true==1] + y_neg=y_pred[y_true==0] + y_neg_sort, _ = torch.sort(y_neg) + for n in [10,50,100,200,500,1000]: + if n > y_neg_sort.shape[0]: + break + th = y_neg_sort[-n] + h = (y_pos>th).sum()/y_pos.shape[0] + print(f"Hits@{n}: {h:.4f}") + +# here +def eval_products_nc(res_path): + eval_set = set() + data=torch.load(os.path.join(root_path,"dataset/ogbn-products/processed_data.pt" )) + labels=data.label_names + ys=data.y.numpy().tolist() + + all_sample=0 + strict_correct=0 + overall_correct=0 + with open(res_path, 'r') as f: + for line in f: + if args.sample > 0 and all_sample >= args.sample: + break + all_sample+=1 + res = json.loads(line) + if res['question_id'] in eval_set: + print(f"{res['question_id']} repeat!!") + return + eval_set.add(res['question_id']) + ans = res["text"].strip() + y=ys[res["question_id"]][0] + label=labels[y].strip() + if label.lower()==ans.lower(): + strict_correct+=1 + overall_correct+=1 + elif label.lower() in ans.lower() and sum([l.lower() in ans.lower() for l in labels])<=2: + overall_correct += 1 + + overall_acc = overall_correct / all_sample + strict_acc = strict_correct / all_sample + print(f"Test samples: {all_sample}\nstrict_acc: {strict_acc:.4f}\noverall_acc: {overall_acc:.4f}") + +def eval_products_nd(res_path): + eval_set = set() + data=torch.load(os.path.join(root_path,"dataset/ogbn-products/processed_data.pt" )) + labels=data.label_names + ys=data.y.numpy().tolist() + + all_sample=0 + all_correct=0 + gt = [] + out = [] + with open(res_path, 'r') as f: + for line in f: + if args.sample > 0 and all_sample >= args.sample: + break + all_sample+=1 + res = json.loads(line) + if res['question_id'] in eval_set: + print(f"{res['question_id']} repeat!!") + eval_set.add(res['question_id']) + ans = res["text"].strip() + y=ys[res["question_id"]][0] + label=labels[y].strip() + if label.lower() in ans.lower(): + all_correct+=1 + desc = data.raw_texts[res['question_id']] + assistant_prompt = f"This is an amazon product which can be categorized as {label}. It can be described as {desc}" + gt.append(assistant_prompt) + out.append(ans) + all_acc = all_correct / all_sample + print(f"Test samples: {all_sample}acc: {all_acc:.4f}") + + gt_embedding = get_sbert_embedding("sbert", gt, 0) + out_embedding = get_sbert_embedding("sbert", out, 0) + gt_embedding = F.normalize(gt_embedding, p=2, eps=1e-6, dim=1) + out_embedding = F.normalize(out_embedding, p=2, eps=1e-6, dim=1) + predict_sim = (gt_embedding * out_embedding).sum(1).mean().item() + gt_sim_matrix = torch.mm(gt_embedding, gt_embedding.transpose(0, 1)).detach().cpu() + n = gt_sim_matrix.shape[0] + gt_sim_matrix[torch.eye(n, dtype=torch.bool)] = 0 + gt_sim = (gt_sim_matrix.sum() / (n * (n - 1))).item() + print(f"Predict similarity {predict_sim: .4f}, Pairwise similarity: {gt_sim: .4f}") + + +def eval_pubmed_nc(res_path): + data=torch.load(os.path.join(root_path,"dataset/pubmed/processed_data.pt" )) + labels=data.label_texts + short_labels = [l[18:] for l in labels] + ys=data.y.numpy().tolist() + + all_sample=0 + strict_correct=0 + overall_correct=0 + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + y=ys[res["question_id"]] + short_label = short_labels[y] + label=labels[y] + if ans.lower().strip() == label.lower().strip(): + strict_correct+=1 + overall_correct+=1 + elif short_label.lower().strip() in ans.lower().strip() and sum([la.lower().strip() in ans.lower().strip() for la in short_labels]) == 1: + overall_correct += 1 + if args.sample > 0 and all_sample >= args.sample: + break + + overall_acc = overall_correct / all_sample + strict_acc = strict_correct / all_sample + print(f"Test samples: {all_sample}\nstrict_acc: {strict_acc:.4f}\noverall_acc: {overall_acc:.4f}") + + +def eval_pubmed_nd(res_path): + data = torch.load(os.path.join(root_path,"dataset/pubmed/processed_data.pt" )) + labels = data.label_texts + short_labels = [l[18:] for l in labels] + ys = data.y.numpy().tolist() + + titles = data.title + abs = data.abs + + all_sample=0 + short_correct=0 + all_correct=0 + gt=[] + out=[] + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + id=res["question_id"] + y=ys[id] + short_label = short_labels[y] + label=labels[y] + if label.strip() in ans.strip(): + all_correct+=1 + if short_label in ans: + short_correct+=1 + out.append(ans) + gt.append(f"This is a paper in {label} domain, it's about {titles[id]}.") + short_acc = short_correct/all_sample + all_acc = all_correct / all_sample + print(f"Test samples: {all_sample}\nshort_correct: {short_correct}\nshort_acc: {short_acc:.4f}\nall_correct: {all_correct}\nall_acc: {all_acc:.4f}") + gt_embedding = get_sbert_embedding("sbert", gt, 0) + out_embedding = get_sbert_embedding("sbert", out, 0) + gt_embedding=F.normalize(gt_embedding, p=2, eps=1e-6, dim=1) + out_embedding=F.normalize(out_embedding, p=2, eps=1e-6, dim=1) + predict_sim=(gt_embedding*out_embedding).sum(1).mean().item() + gt_sim_matrix=torch.mm(gt_embedding, gt_embedding.transpose(0, 1)).detach().cpu() + n=gt_sim_matrix.shape[0] + gt_sim_matrix[torch.eye(n, dtype=torch.bool)]=0 + gt_sim=(gt_sim_matrix.sum()/(n*(n-1))).item() + print(f"Predict similarity {predict_sim: .4f}, Pairwise similarity: {gt_sim: .4f}") + + +def eval_cora_nc(res_path): + data=torch.load(os.path.join(root_path,"dataset/cora/processed_data.pt" )) + labels=data.label_texts + short_labels = [l.split('_')[0] for l in labels] + ys=data.y.numpy().tolist() + + all_sample=0 + correct=0 + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + y=ys[res["question_id"]] + label=labels[y] + short_label=short_labels[y] + if short_label.strip().lower() in ans.strip().lower() and sum([l.strip().lower() in ans.strip().lower() for l in short_labels])==1: + correct+=1 + acc=correct/all_sample + print(f"Test samples: {all_sample}\nacc: {acc:.4f}") + + + +def eval_cora_nd(res_path): + data = torch.load(os.path.join(root_path,"dataset/cora/processed_data.pt" )) + labels = data.label_texts + ys = data.y.numpy().tolist() + + titles = data.title + all_sample=0 + short_correct=0 + all_correct=0 + gt=[] + out=[] + with open(res_path, 'r') as f: + for line in f: + all_sample+=1 + res = json.loads(line) + ans = res["text"] + id=res["question_id"] + y=ys[id] + label=labels[y] + if label.strip() in ans.strip(): + all_correct+=1 + short_correct+=1 + out.append(ans) + gt.append(f"This is a paper in {label} domain, it's about {titles[id]}.") + short_acc = short_correct/all_sample + all_acc = all_correct / all_sample + print(f"Test samples: {all_sample}\nshort_correct: {short_correct}\nshort_acc: {short_acc:.4f}\nall_correct: {all_correct}\nall_acc: {all_acc:.4f}") + gt_embedding = get_sbert_embedding("sbert", gt, 0) + out_embedding = get_sbert_embedding("sbert", out, 0) + gt_embedding=F.normalize(gt_embedding, p=2, eps=1e-6, dim=1) + out_embedding=F.normalize(out_embedding, p=2, eps=1e-6, dim=1) + predict_sim=(gt_embedding*out_embedding).sum(1).mean().item() + gt_sim_matrix=torch.mm(gt_embedding, gt_embedding.transpose(0, 1)).detach().cpu() + n=gt_sim_matrix.shape[0] + gt_sim_matrix[torch.eye(n, dtype=torch.bool)]=0 + gt_sim=(gt_sim_matrix.sum()/(n*(n-1))).item() + print(f"Predict similarity {predict_sim: .4f}, Pairwise similarity: {gt_sim: .4f}") + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--res_path", type=str, default="./results/llaga-opt-2.7b-v1-simteg_all_origin_tape_multihop-laplacian_-1-2-10-linear-only-train-pretrain_acc1_nc_test_nc.jsonl") + parser.add_argument("--task", type=str, default="nc") + parser.add_argument("--dataset", type=str, default="arxiv") + parser.add_argument("--sample", type=int, default=-1) + args = parser.parse_args() + + func_dict = { + "arxiv":{ + "nc": eval_arxiv_nc, + "nd": eval_arxiv_nd, + "lp": eval_lp, + "lprank": eval_lprank + }, + "products": { + "nc": eval_products_nc, + "nd": eval_products_nd, + "lp": eval_lp, + "lprank": eval_lprank + }, + "pubmed": { + "nc": eval_pubmed_nc, + "nd": eval_pubmed_nd, + "lp": eval_lp, + "lprank": eval_lprank + }, + "cora": { + "nc": eval_cora_nc, + "nd": eval_cora_nd, + "lp": eval_lp, + "lprank": eval_lprank + }, + } + + func=func_dict[args.dataset][args.task] + func(args.res_path) \ No newline at end of file diff --git a/examples/llaga/llaga_trainer.py b/examples/llaga/llaga_trainer.py new file mode 100644 index 00000000..db826585 --- /dev/null +++ b/examples/llaga/llaga_trainer.py @@ -0,0 +1,332 @@ +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +# os.environ['TL_BACKEND'] = 'torch' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR + +import argparse + +import os +import torch +import json +from tqdm import tqdm +import shortuuid + +from gammagl.utils.gfm_utils import GRAPH_TOKEN_INDEX, DEFAULT_GRAPH_TOKEN, DEFAULT_GRAPH_PAD_ID, DEFAULT_GRAPH_START_TOKEN, DEFAULT_GRAPH_END_TOKEN +from gammagl.utils.conversation import conv_templates, SeparatorStyle +from gammagl.utils.builder import load_pretrained_model + +from gammagl.utils.gfm_utils import disable_torch_init, tokenizer_graph_token, get_model_name_from_path + +from gammagl.utils import k_hop_subgraph, degree, remove_self_loops, add_self_loops +from gammagl.layers.conv import MessagePassing +import math + +SMALL_DATASETS=["pubmed", "cora"] + +class MP(MessagePassing): + def __init__(self): + super().__init__(aggr='add') # "Add" aggregation (Step 5). + def message(self, x_j, norm): + return norm.view(-1, 1) * x_j + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +# def get_chunk(lst, n, k): +# chunks = split_list(lst, n) +# return chunks[k] + +def load_pretrain_embedding_graph(data_dir, pretrained_embedding_type): + if pretrained_embedding_type == "simteg": + simteg_sbert = torch.load(os.path.join(data_dir, "simteg_sbert_x.pt")) + simteg_roberta = torch.load(os.path.join(data_dir, "simteg_roberta_x.pt")) + simteg_e5 = torch.load(os.path.join(data_dir, "simteg_e5_x.pt")) + pretrained_emb = torch.concat([simteg_sbert, simteg_roberta, simteg_e5], dim=-1) + else: + pretrained_emb = torch.load(os.path.join(data_dir, f"{pretrained_embedding_type}_x.pt")) + return pretrained_emb + +def load_pretrain_embedding_hop(data_dir, pretrained_embedding_type, hop, mask): + if pretrained_embedding_type == "simteg": + simteg_sbert=[torch.load(os.path.join(data_dir, f"simteg_sbert_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_sbert_{i}hop_x.pt"))[mask] for i in range(1, hop + 1)] + simteg_roberta = [torch.load(os.path.join(data_dir, f"simteg_roberta_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_roberta_{i}hop_x.pt"))[mask] for i in range(1, hop + 1)] + simteg_e5 = [torch.load(os.path.join(data_dir, f"simteg_e5_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_e5_{i}hop_x.pt"))[mask] for i in range(1, hop + 1)] + pretrained_embs = [torch.cat([simteg_sbert[i], simteg_roberta[i], simteg_e5[i]], dim=-1) for i in range(hop + 1)] + else: + pretrained_embs = [torch.load(os.path.join(data_dir, f"{pretrained_embedding_type}_x.pt"))[mask]]+ [torch.load(os.path.join(data_dir, f"{pretrained_embedding_type}_{i}hop_x.pt"))[mask] for i in range(1, hop+1)] + + return pretrained_embs + +def load_pretrain_embedding_hop_lp(data_dir, pretrained_embedding_type, hop): + mask = torch.load(os.path.join(data_dir, f"no_test_link_mask.pt")) + if pretrained_embedding_type == "simteg": + simteg_sbert=[torch.load(os.path.join(data_dir, f"simteg_sbert_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_sbert_{i}hop_x_notestlink.pt")) for i in range(1, hop + 1)] + simteg_roberta = [torch.load(os.path.join(data_dir, f"simteg_roberta_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_roberta_{i}hop_x_notestlink.pt")) for i in range(1, hop + 1)] + simteg_e5 = [torch.load(os.path.join(data_dir, f"simteg_e5_x.pt"))[mask]] + [torch.load(os.path.join(data_dir, f"simteg_e5_{i}hop_x_notestlink.pt")) for i in range(1, hop + 1)] + pretrained_embs = [torch.cat([simteg_sbert[i], simteg_roberta[i], simteg_e5[i]], dim=-1) for i in range(hop + 1)] + else: + pretrained_embs = [torch.load(os.path.join(data_dir, f"{pretrained_embedding_type}_x.pt"))[mask]]+ [torch.load(os.path.join(data_dir, f"{pretrained_embedding_type}_{i}hop_x_notestlink.pt")) for i in range(1, hop+1)] + + return pretrained_embs, mask + +def eval_model(args): + # Model + disable_torch_init() + + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + print(f"Loaded from {model_path}. Model Base: {args.model_base}") + tokenizer, model, context_len = load_pretrained_model(model_path, args.model_base, model_name, + cache_dir=args.cache_dir) + model = model.to(torch.float16).cuda() + # data_dir=os.path.expanduser(args.data_dir) + if args.dataset == "arxiv": + data_dir = "dataset/ogbn-arxiv" + elif args.dataset == "products": + data_dir = "dataset/ogbn-products" + elif args.dataset == "pubmed": + data_dir = "dataset/pubmed" + elif args.dataset == "cora": + data_dir = "dataset/cora" + else: + print(f"{args.dataset} not exists") + raise ValueError + data_dir = os.path.join("/home/yy3/graph-text-align/LLaGA", data_dir) + if args.task in ["nc", "nd", "nda", "nctext"]: + if args.template == "HO": + prompt_file = os.path.join(data_dir, f"sampled_2_10_test.jsonl") + else: + prompt_file = os.path.join(data_dir, f"sampled_{args.use_hop}_{args.sample_neighbor_size}_test.jsonl") + data_path = os.path.join(data_dir, f"processed_data.pt") + elif args.task in ["lp"]: + if args.template == "HO": + prompt_file = os.path.join(data_dir, f"edge_sampled_2_10_only_test.jsonl") + else: + prompt_file = os.path.join(data_dir, f"edge_sampled_{args.use_hop}_{args.sample_neighbor_size}_only_test.jsonl") + data_path = os.path.join(data_dir, f"processed_data.pt") + else: + raise ValueError + + data = torch.load(data_path) + print(f"Load from {prompt_file}\n") + lines = open(prompt_file, "r").readlines() + + if args.start >= 0: + if args.end < 0: + args.end = len(lines) + lines = lines[args.start:args.end] + elif args.end > 0: + lines = lines[:args.end] + + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + # FIXME + # if "tmp" not in args.answers_file and os.path.exists(answers_file): + # line_number = len(open(answers_file, 'r').readlines()) + # print(f"{args.answers_file} already exists! it has {line_number} lines!!") + # if line_number >= len(lines): + # return + # lines = lines[line_number:] + # ans_file = open(answers_file, "a") + # else: + ans_file = open(answers_file, "w") + + questions = [json.loads(q) for q in lines] + + index = None + if args.template == "ND": + pretrained_emb = load_pretrain_embedding_graph(data_dir, args.pretrained_embedding_type) + structure_emb = torch.load( + f"/home/yy3/graph-text-align/LLaGA/dataset/laplacian_{args.use_hop}_{args.sample_neighbor_size}.pt") + + elif args.template == "HO": + n = data.num_nodes + if args.dataset in SMALL_DATASETS and args.task == "lp": + pretrained_emb = load_pretrain_embedding_graph(data_dir, args.pretrained_embedding_type) + elif args.task == "lp": + # for small dataset, we remove test link during testing + # for large dataset, remove test link and compute embedding may be more memory- and time-consuming , we precompute the embedding + pretrained_emb, mask = load_pretrain_embedding_hop_lp(data_dir, args.pretrained_embedding_type,args.use_hop) + index = torch.full([n], fill_value=n + 1, dtype=torch.long) + test_index = torch.arange(mask.sum()) + index[mask] = test_index + else: + mask = torch.full([n], fill_value=False, dtype=torch.bool) + for q in questions: + idx = q["id"] + if "lp" in args.task: + assert len(idx) == 2 + mask[idx[0]] = True + mask[idx[1]] = True + elif args.task in ["nc", "nd", "nctext"]: + assert isinstance(idx, int) + mask[idx] = True + pretrained_emb = load_pretrain_embedding_hop(data_dir, args.pretrained_embedding_type, args.use_hop, mask) + index = torch.full([n], fill_value=n + 1, dtype=torch.long) + test_index = torch.arange(mask.sum()) + index[mask] = test_index + structure_emb = None + else: + raise ValueError + + + for line in tqdm(questions): + idx = line["id"] + if args.task in ["nd", "nda"]: + qs=f"Please briefly describe the center node of {DEFAULT_GRAPH_TOKEN}." + elif args.task == "nc": + if args.dataset == "products": + qs = f"Given a node-centered graph: {DEFAULT_GRAPH_TOKEN}, where nodes represent products sold in Amazon, and edges between products indicate they are purchased together. We need to classify the center node into 47 classes: Home & Kitchen, Health & Personal Care, Beauty, Sports & Outdoors, Books, Patio, Lawn & Garden, Toys & Games, CDs & Vinyl, Cell Phones & Accessories, Grocery & Gourmet Food, Arts, Crafts & Sewing, Clothing, Shoes & Jewelry, Electronics, Movies & TV, Software, Video Games, Automotive, Pet Supplies, Office Products, Industrial & Scientific, Musical Instruments, Tools & Home Improvement, Magazine Subscriptions, Baby Products, label 25, Appliances, Kitchen & Dining, Collectibles & Fine Art, All Beauty, Luxury Beauty, Amazon Fashion, Computers, All Electronics, Purchase Circles, MP3 Players & Accessories, Gift Cards, Office & School Supplies, Home Improvement, Camera & Photo, GPS & Navigation, Digital Music, Car Electronics, Baby, Kindle Store, Buy a Kindle, Furniture & Décor, #508510, please tell me which class the center node belongs to?" + else: + qs = line["conversations"][0]['value'] + elif args.task == "nctext": + text = data.raw_texts[line['id']] + text = text[:2000] + if args.dataset == "arxiv": + qs = f"Given a node-centered graph: {DEFAULT_GRAPH_TOKEN}, where nodes represent papers and edges represent co-citations, the node feature of center node is {text}. We need to classify the center node into 40 classes: cs.NA(Numerical Analysis), cs.MM(Multimedia), cs.LO(Logic in Computer Science), cs.CY(Computers and Society), cs.CR(Cryptography and Security), cs.DC(Distributed, Parallel, and Cluster Computing), cs.HC(Human-Computer Interaction), cs.CE(Computational Engineering, Finance, and Science), cs.NI(Networking and Internet Architecture), cs.CC(Computational Complexity), cs.AI(Artificial Intelligence), cs.MA(Multiagent Systems), cs.GL(General Literature), cs.NE(Neural and Evolutionary Computing), cs.SC(Symbolic Computation), cs.AR(Hardware Architecture), cs.CV(Computer Vision and Pattern Recognition), cs.GR(Graphics), cs.ET(Emerging Technologies), cs.SY(Systems and Control), cs.CG(Computational Geometry), cs.OH(Other Computer Science), cs.PL(Programming Languages), cs.SE(Software Engineering), cs.LG(Machine Learning), cs.SD(Sound), cs.SI(Social and Information Networks), cs.RO(Robotics), cs.IT(Information Theory), cs.PF(Performance), cs.CL(Computational Complexity), cs.IR(Information Retrieval), cs.MS(Mathematical Software), cs.FL(Formal Languages and Automata Theory), cs.DS(Data Structures and Algorithms), cs.OS(Operating Systems), cs.GT(Computer Science and Game Theory), cs.DB(Databases), cs.DL(Digital Libraries), cs.DM(Discrete Mathematics), please tell me which class the center node belongs to? Direct tell me the class name." + elif args.dataset == "products": + qs = f"Given a node-centered graph: {DEFAULT_GRAPH_TOKEN}, where nodes represent products sold in Amazon, and edges between products indicate they are purchased together, the node feature of center node is {text}. We need to classify the center node into 47 classes: Home & Kitchen, Health & Personal Care, Beauty, Sports & Outdoors, Books, Patio, Lawn & Garden, Toys & Games, CDs & Vinyl, Cell Phones & Accessories, Grocery & Gourmet Food, Arts, Crafts & Sewing, Clothing, Shoes & Jewelry, Electronics, Movies & TV, Software, Video Games, Automotive, Pet Supplies, Office Products, Industrial & Scientific, Musical Instruments, Tools & Home Improvement, Magazine Subscriptions, Baby Products, label 25, Appliances, Kitchen & Dining, Collectibles & Fine Art, All Beauty, Luxury Beauty, Amazon Fashion, Computers, All Electronics, Purchase Circles, MP3 Players & Accessories, Gift Cards, Office & School Supplies, Home Improvement, Camera & Photo, GPS & Navigation, Digital Music, Car Electronics, Baby, Kindle Store, Buy a Kindle, Furniture & Décor, #508510, please tell me which class the center node belongs to? Direct tell me the class name." + elif args.dataset == "pubmed": + qs = f"Given a node-centered graph: {DEFAULT_GRAPH_TOKEN}, where nodes represent papers about Diabetes and edges represent co-citations, the node feature of center node is {text}. We need to classify the center node into 3 classes: Diabetes Mellitus Experimental, Diabetes Mellitus Type1, Diabetes Mellitus Type2, please tell me which class the center node belongs to? Direct tell me the class name." + elif args.dataset == "cora": + qs = f"Given a node-centered graph: {DEFAULT_GRAPH_TOKEN}, where nodes represent papers and edges represent co-citations, the node feature of center node is {text}. We need to classify the center node into 7 classes: Case_Based, Genetic_Algorithms, Neural_Networks, Probabilistic_Methods, Reinforcement_Learning, Rule_Learning, Theory, please tell me which class the center node belongs to? Direct tell me the class name." + else: + raise ValueError + elif args.task == "lp": + qs=f"Given two node-centered subgraphs: {DEFAULT_GRAPH_TOKEN} and {DEFAULT_GRAPH_TOKEN}, we need to predict whether these two nodes connect with each other. Please tell me whether two center nodes in the subgraphs should connect to each other." + else: + print(f"NOT SUPPORT {args.task}!!!") + raise ValueError + cur_prompt = qs + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_graph_token(prompt, tokenizer, GRAPH_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + if not isinstance(line['graph'][0], list): + line['graph'] = [line['graph']] + if args.template == "ND": + graph = torch.LongTensor(line['graph']) + mask = graph != DEFAULT_GRAPH_PAD_ID + masked_graph_emb = pretrained_emb[graph[mask]] + s, n, d = graph.shape[0], graph.shape[1], masked_graph_emb.shape[1] + graph_emb = torch.zeros((s, n, d)) + graph_emb[mask] = masked_graph_emb + if structure_emb is not None: + graph_emb = torch.cat([graph_emb, structure_emb.unsqueeze(0).expand(s, -1, -1)], dim=-1) + elif args.template == "HO": + # for small dataset, we remove test link during testing + # for large dataset, remove test link and compute embedding may be more memory- and time-consuming , we precompute the embedding + + if args.dataset in SMALL_DATASETS and args.task == "lp": + mp = MP() + center_nodes = [] + for g in range(len(line['graph'])): + center_id = line['graph'][g][0] + line['graph'][g] = [center_id] * (args.use_hop + 1) + center_nodes.append(center_id) + graph = torch.LongTensor(line['graph']) + center_id = graph[:, 0] + graph_embs = [pretrained_emb[center_id].cuda()] + subset, edge_index, mapping, edge_mask = k_hop_subgraph(center_nodes, args.use_hop, data.edge_index, + relabel_nodes=True) + local_edge_mask = ((edge_index[0] == mapping[0]) & (edge_index[1] == mapping[1])) | ( + (edge_index[0] == mapping[1]) & (edge_index[1] == mapping[0])) + edge_index = edge_index[:, ~local_edge_mask] + local_x = pretrained_emb[subset].cuda() + n = subset.shape[0] + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index) + edge_index = edge_index.cuda() + row, col = edge_index + deg = degree(col, n, dtype=pretrained_emb.dtype) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + # local_x = pretrained_emb + for _ in range(args.use_hop): + local_x = mp.propagate(edge_index, x=local_x, norm=norm) + graph_embs.append(local_x[mapping]) + graph_emb = torch.stack(graph_embs, dim=1) + else: + + for g in range(len(line['graph'])): + center_id = line['graph'][g][0] + line['graph'][g] = [center_id]*(args.use_hop+1) + graph = torch.LongTensor(line['graph']) + center_id = graph[:, 0] + graph_emb = torch.stack([emb[index[center_id]] for emb in pretrained_emb], dim=1) + else: + raise ValueError + + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + + # try: + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + graph_emb=graph_emb.half().cuda(), + graph=graph.cuda(), + do_sample=True, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + # no_repeat_ngram_size=3, + max_new_tokens=1024, + use_cache=True) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + # except Exception as e: + # print(f"!!!!!!Error!!!!! {e}") + # outputs="" + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "graph": line['graph'], + "text": outputs, + "gt":line["conversations"][1]['value'], + "answer_id": ans_id}) + "\n") + ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="facebook/opt-350m") + parser.add_argument("--model_base", type=str, default=None) + # parser.add_argument("--data_dir", type=str, default=None) + parser.add_argument("--pretrained_embedding_type", type=str, default="sbert") + parser.add_argument("--use_hop", type=int, default=2) + parser.add_argument("--sample_neighbor_size", type=int, default=5) + parser.add_argument("--answers_file", type=str, default="answer.jsonl") + parser.add_argument("--conv_mode", type=str, default="v1") + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--start", type=int, default=-1) + parser.add_argument("--end", type=int, default=-1) + parser.add_argument("--test_path", type=str, default=None) + parser.add_argument("--mm_use_graph_start_end",default=False, action="store_true") + parser.add_argument("--task", type=str, default="nc") + parser.add_argument("--dataset", type=str, default="arxiv") + parser.add_argument("--cache_dir", type=str, default="../../checkpoint") + parser.add_argument("--template", type=str, default="ND") + args = parser.parse_args() + + eval_model(args) diff --git a/gammagl/models/llaga.py b/gammagl/models/llaga.py new file mode 100644 index 00000000..994eeebd --- /dev/null +++ b/gammagl/models/llaga.py @@ -0,0 +1,582 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import tensorlayerx as tlx +import tensorlayerx.nn as nn +import torch +import re + +from gammagl.utils.gfm_utils import GRAPH_TOKEN_INDEX, DEFAULT_GRAPH_TOKEN, DEFAULT_GRAPH_PAD_ID, DEFAULT_GRAPH_START_TOKEN, DEFAULT_GRAPH_END_TOKEN, IGNORE_INDEX + +import math + +def build_graph_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + hidden_dim = getattr(config, 'word_embed_proj_dim', getattr(config, 'hidden_size', 'linear')) + + if projector_type == 'linear': + return nn.Linear(in_features=config.mm_hidden_size, out_features=hidden_dim) + mlp_gelu_match = re.match(r'^(\d+)-layer-mlp$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [torch.nn.Linear(in_features=config.mm_hidden_size, out_features=hidden_dim)] + for _ in range(1, mlp_depth): + modules.append(torch.nn.GELU()) + modules.append(torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) + return nn.Sequential(*modules) + else: + raise ValueError(f'Unknown projector type: {projector_type}') + + + +class LlagaMetaModel: + + def __init__(self, config): + super(LlagaMetaModel, self).__init__(config) + + if hasattr(config, "mm_hidden_size"): + self.mm_projector = build_graph_projector(config) + if hasattr(config, "mm_use_graph_special_token") and getattr(config, 'mm_use_graph_special_token', False): + self.special_token_emb = self.build_special_tokens() + + + def initialize_graph_modules(self, model_args, fsdp=None): + pretrain_mm_mlp_adapter = getattr(model_args, 'pretrain_mm_mlp_adapter', None) + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = getattr(model_args, 'mm_hidden_size') + + + self.mm_projector = build_graph_projector(self.config) + if hasattr(self.config, "mm_use_graph_special_token") and getattr(self.config, 'mm_use_graph_special_token', False): + self.special_token_emb = self.build_special_tokens() + + # TODO: implement model load in ggl + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = tlx.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + def build_special_tokens(self): + if hasattr(self.config, "mm_use_graph_special_token") and getattr(self.config, 'mm_use_graph_special_token', False): + num_token=self.config.use_hop+2 + input_embeddings = self.get_input_embeddings().weight.data + input_embeddings_avg = input_embeddings.mean(dim=0, keepdim=True).unsqueeze(1).detach() + special_token_emb=torch.nn.Parameter(data=input_embeddings_avg.repeat(num_token, 1, 1), requires_grad=True) + return special_token_emb + return None + +class LlagaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def encode_graphs(self, graph, graph_emb): + graph_features = self.get_model().mm_projector(graph_emb) + graph_features[graph==DEFAULT_GRAPH_PAD_ID] = 0. + return graph_features + + def inject_special_token(self, graph_emb): + use_hop=self.config.use_hop + sample_size = self.config.sample_neighbor_size + assert graph_emb.shape[-2] == int((sample_size ** (use_hop + 1) - 1) / (sample_size - 1)) + assert self.model.special_token_emb.shape[0] == use_hop + 2 + new_graph_emb = [] + new_graph_emb.append(self.model.special_token_emb[0]) + cur=0 + for i in range(use_hop+1): + cur_size = sample_size**i + new_graph_emb.append(graph_emb[cur:cur+cur_size]) + cur+=cur_size + new_graph_emb.append(self.model.special_token_emb[i+1]) + new_graph_emb = tlx.concat(new_graph_emb, axis=0) + return new_graph_emb + + def prepare_inputs_labels_for_multimodal( + self, input_ids, attention_mask, past_key_values, labels, graphs, graph_emb + ): + if past_key_values is not None and graphs is not None and input_ids.shape[1] == 1: + attention_mask = tlx.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels + + graph_features = self.encode_graphs(graphs, graph_emb) + + new_input_embeds = [] + new_labels = [] if labels is not None else None + cur_graph_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + if (cur_input_ids == GRAPH_TOKEN_INDEX).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + # FIXME: this is a hacky fix, for deepspeed zero3 to work + half_len = cur_input_ids.shape[0] // 2 + cur_graph_features = graph_features[cur_graph_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) + cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) + cur_input_embeds = tlx.concat([cur_input_embeds_1, cur_graph_features[0:0], cur_input_embeds_2], axis=0) + new_input_embeds.append(cur_input_embeds) + if labels is not None: + new_labels.append(labels[batch_idx]) + cur_graph_idx += 1 + continue + graph_token_indices = (cur_input_ids==GRAPH_TOKEN_INDEX).nonzero().squeeze(dim=0) + cur_new_input_embeds = [] + if labels is not None: + cur_labels = labels[batch_idx] + cur_new_labels = [] + assert cur_labels.shape == cur_input_ids.shape + while graph_token_indices.numel() > 0: # 分段处理graph token,把graph feature插入到对应位置,拼接成新的input_embeds + cur_graph_features = graph_features[cur_graph_idx] + if hasattr(self.config, "mm_use_graph_special_token") and getattr(self.config, 'mm_use_graph_special_token', False): + cur_graph_features = self.inject_special_token(cur_graph_features) + + graph_token_start = graph_token_indices[0] + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start-1]).detach()) + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start-1:graph_token_start])) + cur_new_input_embeds.append(cur_graph_features) + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start+1:graph_token_start+2])) + if labels is not None: + cur_new_labels.append(cur_labels[:graph_token_start]) + cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) + cur_new_labels.append(cur_labels[graph_token_start:graph_token_start+1]) + cur_labels = cur_labels[graph_token_start+2:] + else: + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start])) + cur_new_input_embeds.append(cur_graph_features) + if labels is not None: + cur_new_labels.append(cur_labels[:graph_token_start]) + cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) + cur_labels = cur_labels[graph_token_start+1:] + cur_graph_idx += 1 + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_input_ids = cur_input_ids[graph_token_start+2:] + else: + cur_input_ids = cur_input_ids[graph_token_start+1:] + graph_token_indices = (cur_input_ids==GRAPH_TOKEN_INDEX).nonzero().squeeze(dim=0) + if cur_input_ids.numel() > 0: + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) + else: + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) + if labels is not None: + cur_new_labels.append(cur_labels) + cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] + cur_new_input_embeds = tlx.concat(cur_new_input_embeds, axis=0) + new_input_embeds.append(cur_new_input_embeds) + if labels is not None: + cur_new_labels = tlx.concat(cur_new_labels, axis=0) + new_labels.append(cur_new_labels) + + if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): + max_len = max(x.shape[0] for x in new_input_embeds) + + new_input_embeds_align = [] + for cur_new_embed in new_input_embeds: + cur_new_embed = tlx.concat((cur_new_embed, tlx.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), axis=0) + new_input_embeds_align.append(cur_new_embed) + new_input_embeds = tlx.stack(new_input_embeds_align, axis=0) + + if labels is not None: + new_labels_align = [] + _new_labels = new_labels + for cur_new_label in new_labels: + cur_new_label = tlx.concat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), axis=0) + new_labels_align.append(cur_new_label) + new_labels = tlx.stack(new_labels_align, axis=0) + + if attention_mask is not None: + new_attention_mask = [] + for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): + new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) + new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) + cur_new_attention_mask = tlx.concat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), axis=0) + new_attention_mask.append(cur_new_attention_mask) + attention_mask = tlx.stack(new_attention_mask) + assert attention_mask.shape == new_labels.shape + else: + new_input_embeds = tlx.stack(new_input_embeds, axis=0) + if labels is not None: + new_labels = tlx.stack(new_labels, axis=0) + + if attention_mask is not None: + new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = tlx.concat((new_attn_mask_pad_left, attention_mask), axis=1) + assert attention_mask.shape == new_input_embeds.shape[:2] + + return None, attention_mask, past_key_values, new_input_embeds, new_labels + + + def prepare_inputs_labels_for_multimodal_with_pad_mask( + self, input_ids, attention_mask, past_key_values, labels, graphs, graph_emb + ): + if past_key_values is not None and graphs is not None and input_ids.shape[1] == 1: + attention_mask = tlx.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels + + graph_features = self.encode_graphs(graphs, graph_emb) + + new_input_embeds = [] + new_labels = [] if labels is not None else None + new_attention_masks = [] + cur_graph_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_attention_mask = attention_mask[batch_idx] + if (cur_input_ids == GRAPH_TOKEN_INDEX).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + # FIXME: this is a hacky fix, for deepspeed zero3 to work + half_len = cur_input_ids.shape[0] // 2 + cur_graph_features = graph_features[cur_graph_idx] + cur_graph = graphs[cur_graph_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) + cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) + cur_input_embeds = tlx.concat([cur_input_embeds_1, cur_graph_features[0:0], cur_input_embeds_2], dim=0) + new_input_embeds.append(cur_input_embeds) + if labels is not None: + new_labels.append(labels[batch_idx]) + cur_graph_idx += 1 + continue + graph_token_indices = (cur_input_ids==GRAPH_TOKEN_INDEX).nonzero().squeeze(dim=0) + cur_new_input_embeds = [] + cur_attn_masks=[] + if labels is not None: + cur_labels = labels[batch_idx] + cur_new_labels = [] + assert cur_labels.shape == cur_input_ids.shape + while graph_token_indices.numel() > 0: + cur_graph_features = graph_features[cur_graph_idx] + cur_graph = graphs[cur_graph_idx] + cur_graph_mask = (cur_graph != DEFAULT_GRAPH_PAD_ID) + if hasattr(self.config, "mm_use_graph_special_token") and getattr(self.config, 'mm_use_graph_special_token', False): + cur_graph_features = self.inject_special_token(cur_graph_features) + + graph_token_start = graph_token_indices[0] + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start-1]).detach()) + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start-1:graph_token_start])) + cur_new_input_embeds.append(cur_graph_features) + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start+1:graph_token_start+2])) + cur_attn_masks.append(cur_attention_mask[:graph_token_start]) + cur_attn_masks.append(cur_graph_mask) + cur_attn_masks.append(cur_attention_mask[graph_token_start+1:graph_token_start+2]) + if labels is not None: + cur_new_labels.append(cur_labels[:graph_token_start]) + cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) + cur_new_labels.append(cur_labels[graph_token_start:graph_token_start+1]) + cur_labels = cur_labels[graph_token_start+2:] + else: + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start])) + cur_new_input_embeds.append(cur_graph_features) + cur_attn_masks.append(cur_attention_mask[:graph_token_start]) + cur_attn_masks.append(cur_graph_mask) + if labels is not None: + cur_new_labels.append(cur_labels[:graph_token_start]) + cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) + cur_labels = cur_labels[graph_token_start+1:] + + cur_graph_idx += 1 + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_input_ids = cur_input_ids[graph_token_start+2:] + cur_attention_mask = cur_attention_mask[graph_token_start+2:] + else: + cur_input_ids = cur_input_ids[graph_token_start+1:] + cur_attention_mask = cur_attention_mask[graph_token_start + 1:] + graph_token_indices = (cur_input_ids==GRAPH_TOKEN_INDEX).nonzero().squeeze(dim=0) + if cur_input_ids.numel() > 0: + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False): + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) + else: + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) + if labels is not None: + cur_new_labels.append(cur_labels) + cur_attn_masks.append(cur_attention_mask) + cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] + cur_new_input_embeds = tlx.concat(cur_new_input_embeds, dim=0) + cur_attn_masks = [x.to(device=self.device) for x in cur_attn_masks] + cur_attn_masks = tlx.concat(cur_attn_masks, dim=0) + new_input_embeds.append(cur_new_input_embeds) + new_attention_masks.append(cur_attn_masks) + if labels is not None: + cur_new_labels = tlx.concat(cur_new_labels, dim=0) + new_labels.append(cur_new_labels) + + if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): + max_len = max(x.shape[0] for x in new_input_embeds) + + new_input_embeds_align = [] + for cur_new_embed in new_input_embeds: + cur_new_embed = tlx.concat((cur_new_embed, tlx.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) + new_input_embeds_align.append(cur_new_embed) + new_input_embeds = tlx.stack(new_input_embeds_align, axis=0) + + if labels is not None: + new_labels_align = [] + _new_labels = new_labels + for cur_new_label in new_labels: + cur_new_label = tlx.concat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), axis=0) + new_labels_align.append(cur_new_label) + new_labels = tlx.stack(new_labels_align, axis=0) + + if attention_mask is not None: + new_attention_mask = [] + for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(new_attention_masks, _new_labels, new_labels): + assert cur_attention_mask.shape == cur_new_labels.shape + # new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) + new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) + cur_new_attention_mask = tlx.concat((cur_attention_mask, new_attn_mask_pad_right), axis=0) + new_attention_mask.append(cur_new_attention_mask) + attention_mask = tlx.stack(new_attention_mask, axis=0) + assert attention_mask.shape == new_labels.shape + + else: + new_input_embeds = tlx.stack(new_input_embeds, axis=0) + if labels is not None: + new_labels = tlx.stack(new_labels, axis=0) + + attention_mask = tlx.stack(new_attention_masks, axis=0) + assert attention_mask.shape == new_input_embeds.shape[:2] + # if attention_mask is not None: + # new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) + # attention_mask = tlx.concat((new_attn_mask_pad_left, attention_mask), dim=1) + # assert attention_mask.shape == new_input_embeds.shape[:2] + + return None, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_graph_tokenizer(self, model_args, tokenizer): + + if model_args.mm_use_graph_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_GRAPH_START_TOKEN, DEFAULT_GRAPH_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = tlx.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + + + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class LlagaConfig(LlamaConfig): + model_type = "llaga" + + +class LlagaLlamaModel(LlagaMetaModel, LlamaModel): + config_class = LlagaConfig + + def __init__(self, config: LlamaConfig): + super(LlagaLlamaModel, self).__init__(config) + + +class LlagaLlamaForCausalLM(LlamaForCausalLM, LlagaMetaForCausalLM): + config_class = LlagaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = LlagaLlamaModel(config) + + self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: tlx = None, + attention_mask = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache = None, + output_attentions = None, + output_hidden_states = None, + graph = None, + graph_emb = None, + return_dict = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = True + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, graph, graph_emb) + + + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = tlx.losses.binary_cross_entropy(ignore_index=IGNORE_INDEX) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + def forward_no_loss( + self, + input_ids = None, + attention_mask = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache = None, + output_attentions = None, + output_hidden_states = None, + graph = None, + graph_emb = None, + return_dict = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = True + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, graph, graph_emb) + + + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + return CausalLMOutputWithPast( + loss=0, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "graph": kwargs.get("graph", None), + "graph_emb": kwargs.get("graph_emb", None), + } + ) + return model_inputs + +AutoConfig.register("llaga", LlagaConfig) +AutoModelForCausalLM.register(LlagaConfig, LlagaLlamaForCausalLM) diff --git a/gammagl/utils/builder.py b/gammagl/utils/builder.py new file mode 100644 index 00000000..16a6cfec --- /dev/null +++ b/gammagl/utils/builder.py @@ -0,0 +1,160 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch + +from gammagl.models.llaga import * +from huggingface_hub import hf_hub_download + + + + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", cache_dir="../../checkpoint"): + kwargs = {"device_map": device_map} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + if 'llaga' in model_name.lower(): + # Load LLaGA model + if 'lora' in model_name.lower() and model_base is None: + warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') + if 'lora' in model_name.lower() and model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading LLaGA from base model...') + model = LlagaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, cache_dir=cache_dir, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features + if model.lm_head.weight.shape[0] != token_num: + model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + + print('Loading additional LLaGA weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder) + return torch.load(cache_file, map_location='cpu') + non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + # this may be mm projector only + print('Loading LLaGA from base model...') + # if 'mpt' in model_name.lower(): + # if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): + # shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) + # tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) + # cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + # model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + # else: + # tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + # cfg_pretrained = AutoConfig.from_pretrained(model_path) + # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + if 'opt' in model_base: + tokenizer = AutoTokenizer.from_pretrained(model_base) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlagaOPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, + cache_dir=cache_dir, + **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlagaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, cache_dir=cache_dir, + **kwargs) + # model.get_model().initialize_graph_modules(cfg_pretrained) + if os.path.exists(os.path.join(model_path, 'mm_projector.bin')): + mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') + print("Load from local path") + else: + from huggingface_hub import hf_hub_download + model_path_hf = hf_hub_download(repo_id=model_path, filename='mm_projector.bin') + mm_projector_weights = torch.load(model_path_hf, map_location='cpu') + print("Load from huggingface") + mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} + model.load_state_dict(mm_projector_weights, strict=False) + else: + # if 'mpt' in model_name.lower(): + # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + # model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + # else: + # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + # model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = LlagaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", cache_dir=cache_dir) + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print('Convert to FP16...') + model.to(torch.float16) + else: + use_fast = False + if 'mpt' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, cache_dir=cache_dir, **kwargs) + + + if 'llaga' in model_name.lower(): + mm_use_graph_start_end = getattr(model.config, "mm_use_graph_start_end", False) + if mm_use_graph_start_end: + tokenizer.add_tokens([DEFAULT_GRAPH_START_TOKEN, DEFAULT_GRAPH_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, context_len diff --git a/gammagl/utils/conversation.py b/gammagl/utils/conversation.py new file mode 100644 index 00000000..e088a9db --- /dev/null +++ b/gammagl/utils/conversation.py @@ -0,0 +1,405 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + if 'mmtag' in self.version: + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + else: + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + # + # def get_images(self, return_pil=False): + # images = [] + # for i, (role, msg) in enumerate(self.messages[self.offset:]): + # if i % 2 == 0: + # if type(msg) is tuple: + # import base64 + # from io import BytesIO + # from PIL import Image + # msg, image, image_process_mode = msg + # if image_process_mode == "Pad": + # def expand2square(pil_img, background_color=(122, 116, 104)): + # width, height = pil_img.size + # if width == height: + # return pil_img + # elif width > height: + # result = Image.new(pil_img.mode, (width, width), background_color) + # result.paste(pil_img, (0, (width - height) // 2)) + # return result + # else: + # result = Image.new(pil_img.mode, (height, height), background_color) + # result.paste(pil_img, ((height - width) // 2, 0)) + # return result + # image = expand2square(image) + # elif image_process_mode in ["Default", "Crop"]: + # pass + # elif image_process_mode == "Resize": + # image = image.resize((336, 336)) + # else: + # raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + # max_hw, min_hw = max(image.size), min(image.size) + # aspect_ratio = max_hw / min_hw + # max_len, min_len = 800, 400 + # shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + # longest_edge = int(shortest_edge * aspect_ratio) + # W, H = image.size + # if longest_edge != max(image.size): + # if H > W: + # H, W = longest_edge, shortest_edge + # else: + # H, W = shortest_edge, longest_edge + # image = image.resize((W, H)) + # if return_pil: + # images.append(image) + # else: + # buffered = BytesIO() + # image.save(buffered, format="PNG") + # img_b64_str = base64.b64encode(buffered.getvalue()).decode() + # images.append(img_b64_str) + # return images + + # def to_gradio_chatbot(self): + # ret = [] + # for i, (role, msg) in enumerate(self.messages[self.offset:]): + # if i % 2 == 0: + # if type(msg) is tuple: + # import base64 + # from io import BytesIO + # msg, image, image_process_mode = msg + # max_hw, min_hw = max(image.size), min(image.size) + # aspect_ratio = max_hw / min_hw + # max_len, min_len = 800, 400 + # shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + # longest_edge = int(shortest_edge * aspect_ratio) + # W, H = image.size + # if H > W: + # H, W = longest_edge, shortest_edge + # else: + # H, W = shortest_edge, longest_edge + # image = image.resize((W, H)) + # buffered = BytesIO() + # image.save(buffered, format="JPEG") + # img_b64_str = base64.b64encode(buffered.getvalue()).decode() + # img_str = f'user upload image' + # msg = img_str + msg.replace('', '').strip() + # ret.append([msg, None]) + # else: + # ret.append([msg, None]) + # else: + # ret[-1][-1] = msg + # return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + # if len(self.get_images()) > 0: + # return { + # "system": self.system, + # "roles": self.roles, + # "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + # "offset": self.offset, + # "sep": self.sep, + # "sep2": self.sep2, + # } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llava_llama_2 = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llaga_llama_2 = Conversation( + system="You are a helpful language and graph assistant. " + "You are able to understand the graph content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +# conv_mpt = Conversation( +# system="""<|im_start|>system +# A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", +# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), +# version="mpt", +# messages=(), +# offset=0, +# sep_style=SeparatorStyle.MPT, +# sep="<|im_end|>", +# ) + +conv_mpt = Conversation( + system="""<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +conv_llava_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v0_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the graph content that the user provides, and assist the user with a variety of tasks using natural language." + "The graph content will be provided with the following format: graph content.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + version="v0_mmtag", +) + +conv_llava_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_v1_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the graph content that the user provides, and assist the user with a variety of tasks using natural language." + "The graph content will be provided with the following format: graph content.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", + version="v1_mmtag", +) + +default_conversation = conv_vicuna_v0 +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "llama_2": conv_llama_2, + + "plain": conv_llava_plain, + "v0_plain": conv_llava_plain, + "llava_v0": conv_llava_v0, + "v0_mmtag": conv_llava_v0_mmtag, + "llava_v1": conv_llava_v1, + "v1_mmtag": conv_llava_v1_mmtag, + "llava_llama_2": conv_llava_llama_2, + "llaga_llama_2": conv_llaga_llama_2, + "mpt": conv_mpt, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/gammagl/utils/gfm_utils.py b/gammagl/utils/gfm_utils.py new file mode 100644 index 00000000..9cfad948 --- /dev/null +++ b/gammagl/utils/gfm_utils.py @@ -0,0 +1,53 @@ +import torch +# from .multimodal_encoder.builder import build_vision_tower +# from .multimodal_projector.builder import build_vision_projector + +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +GRAPH_TOKEN_INDEX = -200 +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_START_TOKEN = "" +DEFAULT_GRAPH_END_TOKEN = "" +DEFAULT_GRAPH_PAD_ID = -500 + +def disable_torch_init(): + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + import tensorlayerx as tlx + setattr(tlx.nn.Linear, "reset_parameters", lambda self: None) + setattr(tlx.nn.LayerNorm, "reset_parameters", lambda self: None) + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + +def tokenizer_graph_token(prompt, tokenizer, graph_token_index=GRAPH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_GRAPH_TOKEN)] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [graph_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids \ No newline at end of file From 68fd8583536246633c0f08fa60eed906d9230584 Mon Sep 17 00:00:00 2001 From: VoidHaruhi Date: Wed, 5 Feb 2025 12:55:09 +0000 Subject: [PATCH 2/2] add graphgpt --- examples/graphgpt/README.md | 31 + examples/graphgpt/eval.sh | 13 + examples/graphgpt/graphgpt_eval.py | 106 +++ examples/graphgpt/graphgpt_trainer.py | 232 +++++++ examples/llaga/README.md | 31 + examples/llaga/eval.sh | 4 +- gammagl/models/graphgpt.py | 888 ++++++++++++++++++++++++++ gammagl/models/llaga.py | 9 + gammagl/models/simple_tokenizer.py | 132 ++++ gammagl/utils/conversation.py | 14 + gammagl/utils/gfm_utils.py | 30 +- 11 files changed, 1487 insertions(+), 3 deletions(-) create mode 100644 examples/graphgpt/README.md create mode 100644 examples/graphgpt/eval.sh create mode 100644 examples/graphgpt/graphgpt_eval.py create mode 100644 examples/graphgpt/graphgpt_trainer.py create mode 100644 examples/llaga/README.md create mode 100644 gammagl/models/graphgpt.py create mode 100644 gammagl/models/simple_tokenizer.py diff --git a/examples/graphgpt/README.md b/examples/graphgpt/README.md new file mode 100644 index 00000000..fc2d94d7 --- /dev/null +++ b/examples/graphgpt/README.md @@ -0,0 +1,31 @@ +# GraphGPT: Graph Instruction Tuning for Large Language Models +* Paper link: http://arxiv.org/abs/2310.13023 +* Author's code repo: https://github.com/HKUDS/GraphGPT + +# How to Run + +* First, follow the original repo to install all required packages; + +* Then download all required datasets and pretrained checkpoints, and fill their path into corresponding values in eval.sh + +# Dataset Statics +| Dataset | # Nodes | # Edges | # Classes | +| :-------: | :-------: | :------: | :------: | +| Cora | 25,120 | 182,280 | 70 | +| PubMed | 19,717 | 44,338 | 3 | +| ogb-arxiv | 169,343 | 1,166,243 | 40 | + +# Files Description +* graphgpt_trainer.py: the trainer of graphgpt, inference stage +* graphgpt_eval.py: run this to evaluate + +# Results +```bash +# run inference +TL_BACKEND="torch" nohup bash examples/graphgpt/eval.sh > log/test_graphgpt.out & +# run evaluation +python examples/graphgpt/graphgpt_eval.py --dataset cora +``` +| Dataset | Paper | Our(torch) | +| :-------: | :-------: | :------: | +| Cora | 0.1501 | 0.1451 | \ No newline at end of file diff --git a/examples/graphgpt/eval.sh b/examples/graphgpt/eval.sh new file mode 100644 index 00000000..ce1b6229 --- /dev/null +++ b/examples/graphgpt/eval.sh @@ -0,0 +1,13 @@ +export PYTHONPATH=$(dirname $(dirname $(realpath $0))):$PYTHONPATH +# to fill in the following path to extract projector for the second tuning stage! +output_model=/local/yy3/graphgpt/GraphGPT-7B-mix-all # path to the pre-trained model checkpoint +datapath=/local/yy3/graphgpt/data/eval/cora_test_instruct_std.json # path to the instruction datset +graph_data_path=/local/yy3/graphgpt/data/graph_data_all.pt # path to the graph data +res_path=./output_stage_2_cora_nc # path to save the results +start_id=0 +end_id=20000 # total number of instructions to test +num_gpus=1 + +export CUDA_VISIBLE_DEVICES=2 # specify the GPU id + +python ./examples/graphgpt/graphgpt_trainer.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus} \ No newline at end of file diff --git a/examples/graphgpt/graphgpt_eval.py b/examples/graphgpt/graphgpt_eval.py new file mode 100644 index 00000000..1b3f26bd --- /dev/null +++ b/examples/graphgpt/graphgpt_eval.py @@ -0,0 +1,106 @@ +import json +import os.path as osp +import os +import torch as th +import re +import pandas as pd +from tqdm import tqdm +from sklearn.metrics import classification_report + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', type=str, default='arxiv') +args = parser.parse_args() + +label_to_idx = { + "cora":{"databases, object oriented": 29, "operating systems, memory management": 59, "data structures algorithms and theory, quantum computing": 24, "artificial intelligence, planning": 13, "artificial intelligence, knowledge representation": 4, "artificial intelligence, data mining": 1, "artificial intelligence, vision and pattern recognition": 17, "artificial intelligence, machine learning, case-based": 5, "artificial intelligence, agents": 0, "artificial intelligence, machine learning, probabilistic methods": 8, "encryption and compression, security": 36, "operating systems, distributed": 57, "human computer interaction, interface design": 46, "artificial intelligence, machine learning, genetic algorithms": 6, "human computer interaction, graphics and virtual reality": 45, "artificial intelligence, machine learning, rule learning": 10, "programming, functional": 63, "programming, object oriented": 67, "encryption and compression, encryption": 35, "databases, performance": 30, "networking, protocols": 54, "data structures algorithms and theory, randomized": 25, "data structures algorithms and theory, formal languages": 20, "data structures algorithms and theory, parallel": 23, "programming, software development": 69, "programming, compiler design": 61, "artificial intelligence, machine learning, theory": 11, "artificial intelligence, machine learning, neural networks": 7, "programming, logic": 66, "databases, relational": 32, "information retrieval, retrieval": 52, "programming, debugging": 62, "networking, wireless": 56, "artificial intelligence, theorem proving": 16, "databases, temporal": 33, "encryption and compression, compression": 34, "information retrieval, filtering": 51, "data structures algorithms and theory, computational complexity": 18, "programming, garbage collection": 64, "artificial intelligence, machine learning, reinforcement learning": 9, "human computer interaction, multimedia": 47, "hardware and architecture, vlsi": 43, "artificial intelligence, nlp": 12, "hardware and architecture, microprogramming": 42, "operating systems, fault tolerance": 58, "programming, java": 65, "operating systems, realtime": 60, "human computer interaction, cooperative": 44, "artificial intelligence, speech": 15, "databases, deductive": 28, "artificial intelligence, robotics": 14, "data structures algorithms and theory, logic": 22, "networking, routing": 55, "hardware and architecture, logic design": 40, "hardware and architecture, distributed architectures": 37, "data structures algorithms and theory, hashing": 21, "programming, semantics": 68, "artificial intelligence, games and search": 3, "databases, concurrency": 27, "data structures algorithms and theory, sorting": 26, "human computer interaction, wearable computers": 48, "information retrieval, digital library": 49, "artificial intelligence, expert systems": 2, "information retrieval, extraction": 50, "data structures algorithms and theory, computational geometry": 19, "databases, query evaluation": 31, "networking, internet": 53, "hardware and architecture, memory structures": 41, "hardware and architecture, high performance computing": 38, "hardware and architecture, input output and storage": 39}, + "pubmed":{"Experimentally induced diabetes": 0, "Type 2 diabetes": 2, "Type 1 diabetes": 1} +} + + + +data_list = [] +folder = 'output_stage_2_{}_nc'.format(args.dataset) +for filename in os.listdir(folder): + if filename.endswith('.json'): + file_path = os.path.join(folder, filename) + with open(file_path, 'r') as f: + data = json.load(f) + data_list.extend(data) + +print(data_list[1]) + +graph_data = th.load('/local/yy3/graphgpt/data/graph_data_all.pt')[args.dataset] +labels = graph_data.y + +def cal_map(): + label_dict = {} + if args.dataset == "arxiv": + df = pd.read_csv(os.path.expanduser('~/datasets/OGB/ogbn_arxiv/mapping/labelidx2arxivcategeory.csv.gz'), compression='gzip') + for index, line in df.iterrows(): + lb = line['arxiv category'].split(' ')[-1] + lb_new = 'cs.' + lb.upper() + label_dict[lb_new] = line['label idx'] + else: + label_dict = label_to_idx[args.dataset] + return label_dict + +class_map = cal_map() + +inverse_class_map = {} +for lb, lb_id in class_map.items(): + inverse_class_map[lb_id] = lb + + +pattern = r"cs\.[A-Z]{2}" + + +topk = 3 + +correct = 0 +total = len(data_list) + +trues = [] +preds = [] + +for instruct_item in tqdm(data_list): + nid = instruct_item['node_idx'] + gpt_res = instruct_item['res'] + + + true_y = labels[nid] + + pred_y = [] + if args.dataset == "arxiv": + matches = list(set(re.findall(pattern, gpt_res))) # pred + sorted_matches = sorted(matches, key=lambda x: gpt_res.index(x)) + for m in sorted_matches: + try: + pred_y.append(class_map[m]) + except: + pass + try: + # print(sorted_matches) + preds.append(pred_y[0]) + except: + preds.append(-1) + else: + for lb, lb_id in class_map.items(): + if lb in gpt_res: + pred_y.append(lb_id) + try: + # print(sorted_matches) + preds.append(pred_y[0]) + except: + preds.append(-1) + trues.append(true_y.item()) + res_tmp = 1 if true_y in pred_y[:topk] else 0 + correct = correct + 1 if true_y in pred_y[:topk] else correct + +acc = correct / total + +print("Accuracy:", acc) + +report = classification_report(trues, preds, digits=6) + +print(report) \ No newline at end of file diff --git a/examples/graphgpt/graphgpt_trainer.py b/examples/graphgpt/graphgpt_trainer.py new file mode 100644 index 00000000..ed1f298c --- /dev/null +++ b/examples/graphgpt/graphgpt_trainer.py @@ -0,0 +1,232 @@ +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +from gammagl.utils.conversation import conv_templates, SeparatorStyle +from gammagl.utils.gfm_utils import disable_torch_init, KeywordsStoppingCriteria +from gammagl.utils.gfm_utils import DEFAULT_G_END_TOKEN, DEFAULT_G_START_TOKEN, DEFAULT_GRAPH_PATCH_TOKEN, DEFAULT_GRAPH_TOKEN, GRAPH_TOKEN_INDEX +from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria +from gammagl.models.graphgpt import * + +from torch_geometric.data import Data +import json +import copy +from tqdm import tqdm +import json +import os.path as osp +import ray + +os.environ['TL_BACKEND'] = 'torch' + +def load_graph(instruct_item, graph_data_path): + graph_data_all = torch.load(graph_data_path) + graph_dict = instruct_item['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(instruct_item['id']).split('_')[0] + graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + + graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + return { + 'graph_data': graph_ret, + 'graph_token_len': cur_token_len + } + + +def load_prompting_file(file_path): + with open(file_path, 'r') as f: + data = json.load(f) + return data + +# def prepare_query(instruct_item): + + +def run_eval(args, num_gpus): + # split question file into num_gpus files + prompt_file = load_prompting_file(args.prompting_file) + args.end_id = min(args.end_id, len(prompt_file)) + prompt_file = prompt_file[args.start_id:args.end_id] + chunk_size = len(prompt_file) // num_gpus + ans_handles = [] + split_list = list(range(args.start_id, args.end_id, chunk_size)) + idx_list = list(range(0, len(prompt_file), chunk_size)) + if len(split_list) == num_gpus: + split_list.append(args.end_id) + idx_list.append(len(prompt_file)) + elif len(split_list) == num_gpus + 1: + split_list[-1] = args.end_id + idx_list[-1] = len(prompt_file) + else: + raise ValueError('error in the number of list') + + if osp.exists(args.output_res_path) is False: + os.mkdir(args.output_res_path) + + for idx in range(len(idx_list) - 1): + start_idx = idx_list[idx] + end_idx = idx_list[idx + 1] + + start_split = split_list[idx] + end_split = split_list[idx + 1] + ans_handles.append( + eval_model.remote( + args, prompt_file[start_idx:end_idx], start_split, end_split + ) + ) + + ans_jsons = [] + for ans_handle in ans_handles: + ans_jsons.extend(ray.get(ans_handle)) + + # with open(args.output_res_path, "w") as ans_file: + # for line in ans_jsons: + # ans_file.write(json.dumps(line) + "\n") + + +@ray.remote(num_gpus=1) +@torch.inference_mode() +def eval_model(args, prompt_file, start_idx, end_idx): + # load prompting file + # prompt_file = load_prompting_file(args.prompting_file) + + + # Model + disable_torch_init() + # model_name = os.path.expanduser(args.model_name) + print('start loading') + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + print('finish loading') + + print('start loading') + model = GraphLlamaForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda() + print('finish loading') + + use_graph_start_end = getattr(model.config, "use_graph_start_end", False) + tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) + if use_graph_start_end: + tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) + + graph_tower = model.get_model().graph_tower + + # TODO: add graph tower + # if graph_tower.device.type == 'meta': + # print('meta') + clip_graph, args_graph= load_model_pretrained(CLIP, model.config.pretrain_graph_model_path) + graph_tower = graph_transformer(args_graph) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + + model.get_model().graph_tower = graph_tower.cuda() + # else: + # print('other') + # print(next(graph_tower.parameters()).dtype) + graph_tower.to(device='cuda', dtype=torch.float16) + graph_config = graph_tower.config + graph_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0] + graph_config.use_graph_start_end = use_graph_start_end + if use_graph_start_end: + graph_config.graph_start_token, graph_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) + # TODO: add graph token len + + res_data = [] + print(f'total: {len(prompt_file)}') + for idx, instruct_item in tqdm(enumerate(prompt_file)): + # instruct_item = prompt_file[0] + # if idx >= 3: + # break + graph_dict = load_graph(instruct_item, args.graph_data_path) + graph_token_len = graph_dict['graph_token_len'] + graph_data = graph_dict['graph_data'] + + qs = instruct_item["conversations"][0]["value"] + # if use_graph_start_end: + # qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN + # else: + # qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + + replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN + qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token) + + # if "v1" in args.model_name.lower(): + # conv_mode = "graphchat_v1" + # else: + # raise ValueError('Don\'t support this model') + conv_mode = "graphchat_v1" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + inputs = tokenizer([prompt]) + + + + input_ids = torch.as_tensor(inputs.input_ids).cuda() + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + graph_data.graph_node = graph_data.graph_node.to(torch.float16) + # graph_data.edge_index = graph_data.edge_index.to(torch.float16) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + graph_data=graph_data.cuda(), + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + stopping_criteria=[stopping_criteria]) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + # print(outputs) + + res_data.append({"id": instruct_item["id"], "node_idx": instruct_item["graph"]["node_idx"], "res": outputs}.copy()) + with open(osp.join(args.output_res_path, 'arxiv_test_res_{}_{}.json'.format(start_idx, end_idx)), "w") as fout: + json.dump(res_data, fout, indent=4) + return res_data + # with open(args.output_res_path, "w") as fout: + # json.dump(res_data, fout, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + # parser.add_argument("--image-file", type=str, required=True) + # parser.add_argument("--query", type=str, required=True) + parser.add_argument("--prompting_file", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--graph_data_path", type=str, default=None) + + parser.add_argument("--output_res_path", type=str, default=None) + parser.add_argument("--num_gpus", type=int, default=4) + + parser.add_argument("--start_id", type=int, default=0) + parser.add_argument("--end_id", type=int, default=20567) + + args = parser.parse_args() + + # eval_model(args) + + ray.init() + run_eval(args, args.num_gpus) + + +# protobuf 4.22.3 \ No newline at end of file diff --git a/examples/llaga/README.md b/examples/llaga/README.md new file mode 100644 index 00000000..bb5e027a --- /dev/null +++ b/examples/llaga/README.md @@ -0,0 +1,31 @@ +# LLaGA: Large Language and Graph Assistant +* Paper link: http://arxiv.org/abs/2402.08170 +* Author's code repo: https://github.com/VITA-Group/LLaGA + +# How to Run + +* First, follow the original repo to install all required packages; + +* Then download all required datasets and pretrained checkpoints, and fill their path into corresponding values in eval.sh + +# Dataset Statics +| Dataset | # Nodes | # Edges | # Classes | +| :-------: | :-------: | :------: | :------: | +| Cora | 2,708 | 5,429 | 7 | +| PubMed | 19,717 | 44,338 | 3 | +| Arxiv | 169,343 | 1,166,243 | 40 | + +# Files Description +* llaga_trainer.py: the trainer of graphgpt, inference stage +* llaga_eval.py: run this to evaluate + +# Results +```bash +# run inference +TL_BACKEND="torch" nohup bash examples/llaga/eval.sh > log/test_llaga.out & +# run evaluation +python examples/llaga/llaga_eval.py --dataset cora --task nc --res_path examples/llaga/test.txt # "output_path" you specified in eval.sh +``` +| Dataset | Paper | Our(torch) | +| :-------: | :-------: | :------: | +| Cora | 0.8782 | 0.8727 | \ No newline at end of file diff --git a/examples/llaga/eval.sh b/examples/llaga/eval.sh index 211a781a..de0f028c 100644 --- a/examples/llaga/eval.sh +++ b/examples/llaga/eval.sh @@ -2,8 +2,8 @@ export PYTHONPATH=$(dirname $(dirname $(realpath $0))):$PYTHONPATH export CUDA_VISIBLE_DEVICES=0 -model_path="/local/yy3/vicuna-7b-v1.5-16k" -model_base="/local/yy3/llaga-vicuna-7b-simteg-ND-general_model-2-layer-mlp-projector" #meta-llama/Llama-2-7b-hf +model_base="/local/yy3/vicuna-7b-v1.5-16k" # path to base model (LLM) +model_path="/local/yy3/llaga-vicuna-7b-simteg-ND-general_model-2-layer-mlp-projector" # path to checkpoint of LLaGA mode="v1" # use 'llaga_llama_2' for llama and "v1" for others dataset="arxiv" #test dataset task="nc" #test task diff --git a/gammagl/models/graphgpt.py b/gammagl/models/graphgpt.py new file mode 100644 index 00000000..48f62191 --- /dev/null +++ b/gammagl/models/graphgpt.py @@ -0,0 +1,888 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM, \ + CLIPVisionModel, CLIPImageProcessor + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.configuration_utils import PretrainedConfig +from torch_geometric.data import Data +import json +import os.path as osp +import glob +from gammagl.models.simple_tokenizer import SimpleTokenizer as _Tokenizer +import tensorlayerx as tlx +from collections import OrderedDict + +from gammagl.layers.conv import MessagePassing +from gammagl.mpops import segment_sum +from gammagl.utils import add_self_loops +from gammagl.utils.gfm_utils import DEFAULT_G_END_TOKEN, DEFAULT_G_START_TOKEN, DEFAULT_GRAPH_PATCH_TOKEN, DEFAULT_GRAPH_TOKEN, GRAPH_TOKEN_INDEX + +import torch.nn as nn +import torch.nn.functional as F +import torch +import math + +_tokenizer = _Tokenizer("/home/yy3/graph-text-align/GraphGPT/graphgpt/model/graph_layers/bpe_simple_vocab_16e6.txt.gz") # the path of this file, should be found in GraphGPT project + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x): + orig_type = x.dtype + ret = super().forward(x.type(tlx.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x): + return x * tlx.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x): + return self.resblocks(x) + + +class GNN(MessagePassing): + def __init__(self, args, **kwargs): + super(GNN, self).__init__(aggr='add', **kwargs) + self.config = PretrainedConfig() + self.vars = nn.ParameterList() + + w = nn.Parameter(tlx.ones([args.gnn_hid, args.gnn_input])) + nn.init.xavier_uniform_(w) + self.vars.append(w) + self.vars.append(nn.Parameter(torch.zeros(args.gnn_hid))) + + w = nn.Parameter(tlx.ones([args.gnn_output, args.gnn_hid])) + nn.init.xavier_uniform_(w) + self.vars.append(w) + self.vars.append(nn.Parameter(torch.zeros(args.gnn_output))) + + @staticmethod + def norm(edge_index, num_nodes, improved=False, dtype=None): + edge_weight = tlx.ones((edge_index.size(1),), dtype=dtype, + device=edge_index.device) + + fill_value = 1.0 if not improved else 2.0 + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + deg = segment_sum(edge_weight, row, num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + def forward(self, g, vars=None): + device = self.parameters()[0].device + g = g.to(device) + + edge_index = g.edge_index + x = g.graph_node + if vars is None: + vars = self.vars + improved = False + + w, b = vars[0], vars[1] + edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) + x = self.propagate(edge_index, x=x, norm=norm) + w = w.to(x.device) + b = b.to(x.device) + x = F.linear(x, w, b) + x = F.leaky_relu(x) + + w, b = vars[2], vars[3] + edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) + x = self.propagate(edge_index, x=x, norm=norm) + w = w.to(x.device) + b = b.to(x.device) + x = F.linear(x, w, b) + + return x + + def parameters(self): + return self.vars + + + +def Mv2SameDevice(var_list): + for vid in range(1, len(var_list)): + var_list[vid] = var_list[vid].to(var_list[0].device) + return var_list + +class CLIP(nn.Module): + def __init__(self, + args + ): + super().__init__() + + self.context_length = args.context_length + self.args = args + self.edge_coef = args.edge_coef + + if args.gnn_type == 'gcn': + self.gnn = GNN(args) + elif args.gnn_type == 'gt': + self.gnn = graph_transformer(args) + self.transformer = Transformer( + width=args.transformer_width, + layers=args.transformer_layers, + heads=args.transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = args.vocab_size + self.token_embedding = nn.Embedding(args.vocab_size, + args.transformer_width) # the embedding for all possible tokens + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, args.transformer_width)) + self.ln_final = LayerNorm(args.transformer_width) + + self.text_projection = nn.Parameter(torch.empty(args.transformer_width, args.embed_dim)) + # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + if args.gnn_type == 'gcn': + self.dtype = self.gnn.vars[0].dtype + elif args.gnn_type == 'gt': + self.dtype = self.gnn.W_pos.dtype + + self.optim = torch.optim.Adam([{'params': self.token_embedding.weight}, + {'params': self.positional_embedding}, + {'params': self.transformer.parameters()}, + {'params': self.text_projection}, + {'params': self.gnn.parameters()} + ], lr=args.lr) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = tlx.ops.constant(float("-inf"), shape=[self.context_length, self.context_length]) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, idx_train, g): + embs = self.gnn(g) + idx_train = idx_train.to(embs.device) + idx_train = idx_train + train_embs = embs[idx_train] + return train_embs + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, + 2) # NLD -> LND, batch_size * context_length *emb_dim -> context_length * batch_size *emb_dim + x = self.transformer(x) + x = x.permute(1, 0, + 2) # LND -> NLD, context_length * batch_size *emb_dim -> batch_size * context_length *emb_dim + x = self.ln_final(x).type(self.dtype) + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot (end of token) embedding (eot_token is the highest number in each sequence) + # so there is node need to shorten the context length + x = x[tlx.arange(x.shape[0]), text.argmax(dim=-1)] # + x = x @ self.text_projection + return x + + def forward(self, g, s_n, t_n, s_n_text, t_n_text, training=True): + + s_image_features = self.encode_image(s_n, g) + + s_text_features = self.encode_text(s_n_text) + + t_text_features = self.encode_text(t_n_text) + t_text_features = t_text_features.reshape(s_image_features.shape[0], self.args.neigh_num, self.args.gnn_output) + t_text_features = tlx.mean(t_text_features, dim=1, keepdim=False) + # normalized features + s_image_features = s_image_features / s_image_features.norm(dim=-1, keepdim=True) + s_text_features = s_text_features / s_text_features.norm(dim=-1, keepdim=True) + t_text_features = t_text_features / t_text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + + labels = tlx.arange(s_image_features.shape[0]).cuda() + + # logit_scale = self.logit_scale.exp() # the temporature hyperparameter + # logit_scale, s_image_features, s_text_features = Mv2SameDevice([logit_scale, s_image_features, s_text_features]) + # logits = logit_scale * s_image_features @ s_text_features.t() + # loss_i = tlx.cross_entropy(logits, labels) + # loss_t = tlx.cross_entropy(logits.T, labels) + # node_loss = (loss_i + loss_t) / 2 + + # logit_scale, s_image_features, t_text_features = Mv2SameDevice([logit_scale, s_image_features, t_text_features]) + # logits = logit_scale * s_image_features @ t_text_features.t() + # loss_i = tlx.cross_entropy(logits, labels) + # loss_t = tlx.cross_entropy(logits.T, labels) + # gt_loss = (loss_i + loss_t)/2 + + # logit_scale, s_text_features, t_text_features = Mv2SameDevice([logit_scale, s_text_features, t_text_features]) + # logits = logit_scale * s_text_features @ t_text_features.t() + # loss_i = tlx.cross_entropy(logits, labels) + # loss_t = tlx.cross_entropy(logits.T, labels) + # tt_loss = (loss_i + loss_t)/2 + + + + # shape = [global_batch_size, global_batch_size] + # return all_loss + return s_image_features, s_text_features, t_text_features, labels + + +def tokenize(texts: Union[str, List[str]], context_length: int = 128, truncate: bool = True): + + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=tlx.int64) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = tokens + + return result + + +class GraphLlamaConfig(LlamaConfig): + model_type = "GraphLlama" + +class GraphPretrainConfig: + def __init__(self, dictionary): + for key, value in dictionary.items(): + setattr(self, key, value) + +def load_model_pretrained(model_name, pretrain_model_path): + # load conig json + + assert osp.exists(osp.join(pretrain_model_path, 'config.json')), 'config.json missing' + with open(osp.join(pretrain_model_path, 'config.json'), 'r') as f: + config_dict = json.load(f) + args = GraphPretrainConfig(config_dict) + model = model_name(args) + pkl_files = glob.glob(osp.join(pretrain_model_path, '*.pkl')) + state_dict = torch.load(pkl_files[0]) + # print(state_dict.keys()) + if 'logit_scale' in state_dict.keys(): + state_dict.pop('logit_scale') + print('loading graph pre train model') + model.load_state_dict(state_dict) + + + return model, args +def transfer_param_tograph(clip_graph, gnn): + + print(clip_graph) + gnn_state_dict = clip_graph.gnn.state_dict() + gnn.load_state_dict(gnn_state_dict) + return gnn + + +init = nn.init.xavier_uniform_ +uniformInit = nn.init.uniform + +def PositionalEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + return pe + + +def pos_encoding(pe, learn_pe, nvar, d_model): + # Positional encoding + if pe == None: + W_pos = torch.empty((nvar, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((nvar, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((nvar, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((nvar, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((nvar, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'sincos': W_pos = PositionalEncoding(nvar, d_model, normalize=True) + else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) + + +class graph_transformer(nn.Module): + def __init__(self, args): + super(graph_transformer, self).__init__() + self.config = PretrainedConfig() + self.gtLayers = nn.Sequential(*[GTLayer(args) for i in range(args.gt_layers)]) + + self.W_pos = pos_encoding('zeros', True, 1, args.att_d_model) + + self.W_P = nn.Linear(args.gnn_input, args.att_d_model) + self.dropout = nn.Dropout(0.1) + self.inverW_P = nn.Linear(args.att_d_model, args.gnn_output) + self.args = args + + def forward(self, g): + # Adj: sp adj + # x: bs * n * d_model * num_patch + + # print(edge_index) + device = self.parameters().__next__().device + g = g.to(device) + + x = g.graph_node + + # x, W_P_weight, W_P_bias= Mv2Samedevice([x, self.W_P.weight, self.W_P.bias]) + # self.W_P.weight = nn.Parameter(W_P_weight.to(x.dtype)) + # self.W_P.bias = nn.Parameter(W_P_bias.to(x.dtype)) + # print(self.W_P.dtype, x.dtype) + z = self.W_P(x) + if self.args.if_pos: + embeds = self.dropout(z + self.W_pos) + else: + embeds = self.dropout(z) + for gt in self.gtLayers: + embeds = gt(g, embeds) # bs * num_patch * n * d_model + # embeds, inverW_P_weight, inverW_P_bias = Mv2Samedevice([embeds, self.inverW_P.weight, self.inverW_P.bias]) + # self.inverW_P.weight = nn.Parameter(inverW_P_weight.to(embeds.dtype)) + # self.inverW_P.bias = nn.Parameter(inverW_P_bias.to(embeds.dtype)) + ret = self.inverW_P(embeds) + return ret +def Mv2Samedevice(vars): + return [var.to(vars[0].device) for var in vars] + +class GTLayer(nn.Module): + def __init__(self, args): + super(GTLayer, self).__init__() + self.qTrans = nn.Parameter(init(torch.empty(args.att_d_model, args.att_d_model))) + self.kTrans = nn.Parameter(init(torch.empty(args.att_d_model, args.att_d_model))) + self.vTrans = nn.Parameter(init(torch.empty(args.att_d_model, args.att_d_model))) + if args.att_norm: + self.norm = nn.LayerNorm(args.att_d_model, eps=1e-6) + self.args = args + + + + def forward(self, g, embeds): + # Adj: adj + # x: n * d_model + rows, cols = g.edge_index + nvar, _ = embeds.shape + # print(rows) + # print(cols) + + rowEmbeds = embeds[rows, :] + colEmbeds = embeds[cols, :] + evar, _ = rowEmbeds.shape + + # rowEmbeds, qTrans, kTrans, vTrans = Mv2Samedevice([rowEmbeds, self.qTrans, self.kTrans, self.vTrans]) + # self.qTrans = nn.Parameter(qTrans.to(rowEmbeds.dtype)) + # self.kTrans = nn.Parameter(kTrans.to(rowEmbeds.dtype)) + # self.vTrans = nn.Parameter(vTrans.to(rowEmbeds.dtype)) + qEmbeds = (rowEmbeds @ self.qTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + kEmbeds = (colEmbeds @ self.kTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + vEmbeds = (colEmbeds @ self.vTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + + att = torch.einsum('ehd, ehd -> eh', qEmbeds, kEmbeds) + att = torch.clamp(att, -10.0, 10.0) + expAtt = torch.exp(att) + + tem = torch.zeros([nvar, self.args.head]).to(expAtt.device, dtype=expAtt.dtype) + # print(tem.device, expAtt.device, rows.device) + rows = rows.to(expAtt.device) + attNorm = (tem.index_add_(0, rows, expAtt))[rows, :] + att = expAtt / (attNorm + 1e-8) # bleh + + resEmbeds = torch.einsum('eh, ehd -> ehd', att, vEmbeds).view([evar, self.args.att_d_model]) + tem = torch.zeros([nvar, self.args.att_d_model]).to(resEmbeds.device, dtype=resEmbeds.dtype) + rows = rows.to(resEmbeds.device) + tem = tem.to(resEmbeds.dtype) + resEmbeds = tem.index_add_(0, rows, resEmbeds) # nd + resEmbeds = resEmbeds + embeds + if self.args.att_norm: + # resEmbeds, norm_weight, norm_bias = Mv2Samedevice([resEmbeds, self.norm.weight, self.norm.bias]) + # self.norm.weight = nn.Parameter(norm_weight.to(resEmbeds.dtype)) + # self.norm.bias = nn.Parameter(norm_bias.to(resEmbeds.dtype)) + resEmbeds = self.norm(resEmbeds) + + return resEmbeds + +class GraphLlamaModel(LlamaModel): + config_class = GraphLlamaConfig + + def __init__(self, config: LlamaConfig): + super(GraphLlamaModel, self).__init__(config) + + if hasattr(config, "graph_tower"): + # HACK: for FSDP + # self.vision_tower = [CLIPVisionModel.from_pretrained(config.graph_tower)] + # self.arxiv_projector = nn.Linear(config.graph_hidden_size, config.hidden_size) + clip_graph, args= load_model_pretrained(CLIP, config.pretrain_graph_model_path) + self.graph_tower = graph_transformer(args) + self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower) + + + + # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) + + if hasattr(config, "use_graph_proj"): + self.graph_projector = nn.Linear(in_features=config.graph_hidden_size, out_features=config.hidden_size) + + def get_graph_tower(self): + graph_tower = getattr(self, 'graph_tower', None) + if type(graph_tower) is list: + graph_tower = graph_tower[0] + return graph_tower + + def initialize_graph_modules(self, graph_tower, graph_select_layer, + pretrain_graph_mlp_adapter=None, fsdp=None): # TODO: modify this function + self.config.graph_tower = graph_tower + + + if not hasattr(self, 'graph_tower'): + clip_graph, args= load_model_pretrained(CLIP, self.config.pretrain_graph_model_path) + graph_tower = graph_transformer(args) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + else: + graph_tower = self.graph_tower + graph_tower.requires_grad_(False) + + if fsdp is not None and len(fsdp) > 0: + self.graph_tower = [graph_tower] + else: + self.graph_tower = graph_tower + + + + self.config.use_graph_proj = True + self.config.graph_select_layer = graph_select_layer + + if not hasattr(self, 'graph_projector'): + self.graph_projector = nn.Linear(in_features=self.config.graph_hidden_size, out_features=self.config.hidden_size) + + if pretrain_graph_mlp_adapter is not None: + graph_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu') + self.graph_projector.load_state_dict({k.split('.')[-1]: v for k, v in graph_projector_weights.items()}) + + def forward( + self, + input_ids = None, + attention_mask = None, + past_key_values = None, + inputs_embeds = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + # graph_node_reps: Optional[torch.FloatTensor] = None, + # edge_index_reps: Optional[torch.FloatTensor] = None, + graph_data: Optional[Data] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, 'orig_embeds_params', None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + graph_tower = self.get_graph_tower() + # graph_data = graph_data[0] + if graph_tower is not None and (input_ids.shape[1] != 1 or self.training) and graph_data is not None: + # TODO: this is a modified multimodal LLM -- Haotian Liu + with torch.no_grad(): + if type(graph_data) is list: + # variable length images + graph_node_features = [] + if type(graph_data[0]) is Data: + for g in graph_data: + # print(g) + node_forward_out = graph_tower(g) + graph_node_features.append(node_forward_out) + elif type(graph_data[0]) is dict: + for g_dict in graph_data: + node_forward_out_1 = graph_tower(g_dict['graph_1']) + node_forward_out_2 = graph_tower(g_dict['graph_2']) + graph_node_features.append(node_forward_out_1) + graph_node_features.append(node_forward_out_2) + else: + raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}') + if type(graph_data) is list: + # if type(graph_node_features[0]) is not dict: + graph_node_features = [self.graph_projector(node_feature) for node_feature in graph_node_features] + # else: + # graph_node_features = [{'graph_1': self.graph_projector(node_feature['graph_1']), 'graph_2': self.graph_projector(node_feature['graph_2'])} for node_feature in graph_node_features] + else: + raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}') + dummy_graph_features = torch.zeros(256, 128, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_graph_features = self.graph_projector(dummy_graph_features) + + new_input_embeds = [] + cur_graph_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == graph_tower.config.graph_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + (0. * dummy_graph_features).sum() + new_input_embeds.append(cur_input_embeds) + cur_graph_idx += 1 + continue + if graph_tower.config.use_graph_start_end: + cur_graph_features = graph_node_features[cur_graph_idx] + num_patches = cur_graph_features.shape[0] + if (cur_input_ids == graph_tower.config.graph_start_token).sum() != (cur_input_ids == graph_tower.config.graph_end_token).sum(): + raise ValueError("The number of graph start tokens and graph end tokens should be the same.") + graph_start_tokens = (cur_input_ids == graph_tower.config.graph_start_token).nonzero().squeeze(dim=0) + # print(graph_start_tokens) + for graph_start_token_pos in graph_start_tokens: + cur_graph_features = graph_node_features[cur_graph_idx].to(device=cur_input_embeds.device) + num_patches = cur_graph_features.shape[0] + if cur_input_ids[graph_start_token_pos + num_patches + 1] != graph_tower.config.graph_end_token: + raise ValueError("The graph end token should follow the graph start token.") + if orig_embeds_params is not None: + cur_new_input_embeds = tlx.concat((cur_input_embeds[:graph_start_token_pos].detach(), cur_input_embeds[graph_start_token_pos:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:graph_start_token_pos + num_patches + 2], cur_input_embeds[graph_start_token_pos + num_patches + 2:].detach()), axis=0) + else: + cur_new_input_embeds = tlx.concat((cur_input_embeds[:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:]), axis=0) + cur_graph_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_graph_features = graph_node_features[cur_graph_idx] + num_patches = cur_graph_features.shape[0] + if (cur_input_ids == graph_tower.config.graph_patch_token).sum() != num_patches: + raise ValueError("The number of graph patch tokens should be the same as the number of graph patches.") + masked_indices = (cur_input_ids == graph_tower.config.graph_patch_token).nonzero().squeeze(dim=0) + mask_index_start = masked_indices[0] + if (masked_indices != tlx.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): + raise ValueError("The graph patch tokens should be consecutive.") + if orig_embeds_params is not None: + cur_new_input_embeds = tlx.concat((cur_input_embeds[:mask_index_start].detach(), cur_graph_features, cur_input_embeds[mask_index_start+num_patches:].detach()), axis=0) + else: + cur_new_input_embeds = tlx.concat((cur_input_embeds[:mask_index_start], cur_graph_features, cur_input_embeds[mask_index_start+num_patches:]), axis=0) + new_input_embeds.append(cur_new_input_embeds) + cur_graph_idx += 1 + + # print(cur_graph_idx) + # print(len(graph_node_features)) + assert cur_graph_idx == len(graph_node_features) + inputs_embeds = tlx.stack(new_input_embeds, axis=0) + + return super(GraphLlamaModel, self).forward( + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + +class GraphLlamaForCausalLM(LlamaForCausalLM): + r"""GraphGPT model from the + `"GraphGPT: Graph Instruction Tuning for Large Language Models" + `_ paper. Based on LlamaForCausalLM with graph encoder implemented. + + Parameters + ---------- + config: GraphLlamaConfig + Defined as in the config.json file in the model directory + """ + config_class = GraphLlamaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = GraphLlamaModel(config) + + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def get_graph_tower(self): + return self.get_model().get_graph_tower() + + def get_vision_tower(self): + model = self.get_model() + graph_tower = model.graph_tower + if type(graph_tower) is list: + graph_tower = graph_tower[0] + return graph_tower + + def forward( + self, + input_ids = None, + attention_mask = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache = None, + output_attentions = None, + output_hidden_states: Optional[bool] = None, + # graph_node_reps: Optional[torch.FloatTensor] = None, + # edge_index_reps: Optional[torch.FloatTensor] = None, + graph_data: Optional[Data] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # graph_node_reps=graph_node_reps, + # edge_index_reps=edge_index_reps + graph_data = graph_data + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = tlx.losses.binary_cross_entropy(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "graph_data": [kwargs.get("graph_data", None)], + # "edge_index_reps": kwargs.get("edge_index_reps", None), + } + ) + return model_inputs + + def initialize_graph_tokenizer(self, use_graph_start_end, tokenizer, device, + tune_graph_mlp_adapter=False, pretrain_graph_mlp_adapter=None): + vision_config = self.get_graph_tower().config + vision_config.use_graph_start_end = use_graph_start_end + tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if use_graph_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + vision_config.graph_start_token, vision_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_graph_mlp_adapter: + self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_graph_mlp_adapter: + mm_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + + vision_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0] + def forward_no_loss( + self, + input_ids = None, + attention_mask = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + # graph_node_reps: Optional[torch.FloatTensor] = None, + # edge_index_reps: Optional[torch.FloatTensor] = None, + graph_data: Optional[Data] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # graph_node_reps=graph_node_reps, + # edge_index_reps=edge_index_reps + graph_data = graph_data + ) + + + return CausalLMOutputWithPast( + loss=0, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +AutoConfig.register("GraphLlama", GraphLlamaConfig) +AutoModelForCausalLM.register(GraphLlamaConfig, GraphLlamaForCausalLM) diff --git a/gammagl/models/llaga.py b/gammagl/models/llaga.py index 994eeebd..c613e0bd 100644 --- a/gammagl/models/llaga.py +++ b/gammagl/models/llaga.py @@ -434,6 +434,15 @@ def __init__(self, config: LlamaConfig): class LlagaLlamaForCausalLM(LlamaForCausalLM, LlagaMetaForCausalLM): + r"""LLaGA model from the + `"LLaGA: Large Language and Graph Assistant" + `_ paper. Based on LlamaForCausalLM with additional graph encoder implemented. + + Parameters + ---------- + config: LlagaConfig + Defined in the config.json file in the model directory + """ config_class = LlagaConfig def __init__(self, config): diff --git a/gammagl/models/simple_tokenizer.py b/gammagl/models/simple_tokenizer.py new file mode 100644 index 00000000..0a66286b --- /dev/null +++ b/gammagl/models/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/gammagl/utils/conversation.py b/gammagl/utils/conversation.py index e088a9db..81653190 100644 --- a/gammagl/utils/conversation.py +++ b/gammagl/utils/conversation.py @@ -381,6 +381,19 @@ def dict(self): version="v1_mmtag", ) +conv_graphchat_v1 = Conversation( + system="You are GraphGPT, a large language and graph-structral assistant trained by HKUDS Lab." + "You are able to understand the graph structures that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + default_conversation = conv_vicuna_v0 conv_templates = { "default": conv_vicuna_v0, @@ -397,6 +410,7 @@ def dict(self): "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "llaga_llama_2": conv_llaga_llama_2, + "graphchat_v1": conv_graphchat_v1, "mpt": conv_mpt, } diff --git a/gammagl/utils/gfm_utils.py b/gammagl/utils/gfm_utils.py index 9cfad948..0317a2bc 100644 --- a/gammagl/utils/gfm_utils.py +++ b/gammagl/utils/gfm_utils.py @@ -1,4 +1,6 @@ import torch + +from transformers import AutoConfig, StoppingCriteria # from .multimodal_encoder.builder import build_vision_tower # from .multimodal_projector.builder import build_vision_projector @@ -14,6 +16,10 @@ DEFAULT_GRAPH_START_TOKEN = "" DEFAULT_GRAPH_END_TOKEN = "" DEFAULT_GRAPH_PAD_ID = -500 +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + def disable_torch_init(): import torch @@ -50,4 +56,26 @@ def insert_separator(X, sep): if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') - return input_ids \ No newline at end of file + return input_ids + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] + self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + else: + for keyword_id in self.keyword_ids: + if output_ids[0, -1] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False \ No newline at end of file