Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
first commit
  • Loading branch information
yangjianxin1 authored Jun 15, 2021
1 parent aac84ea commit fe9eee3
Show file tree
Hide file tree
Showing 13 changed files with 31,193 additions and 0 deletions.
319 changes: 319 additions & 0 deletions README.md

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions config/cpm-medium.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 1,
"embd_pdrop": 0.1,
"eos_token_id": 2,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 1024,
"n_embd": 1024,
"n_head": 16,
"n_layer": 24,
"n_positions": 1024,
"n_special": 0,
"predict_special_tokens": true,
"resid_pdrop": 0.1,
"summary_activation": null,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": true,
"summary_type": "cls_index",
"summary_use_proj": true,
"task_specific_params": {
"text-generation": {
"do_sample": true,
"max_length": 50
}
},
"vocab_size": 30000
}
31 changes: 31 additions & 0 deletions config/cpm-small.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"resid_pdrop": 0.1,
"summary_activation": null,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": true,
"summary_type": "cls_index",
"summary_use_proj": true,
"task_specific_params": {
"text-generation": {
"do_sample": true,
"max_length": 50
}
},
"vocab_size": 30000
}
109 changes: 109 additions & 0 deletions data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@

from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply

def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
try:
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
except:
print('obj', obj.size())
print('dim', dim)
print('chunk_sizes', chunk_sizes)
quit()
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
r"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)

def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)

# print('len(inputs): ', str(len(inputs)))
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))

if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
if self.gpu0_bsz == 0:
replicas = self.replicate(self.module, self.device_ids)
else:
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])

# replicas = self.replicate(self.module, device_ids[:len(inputs)])
if self.gpu0_bsz == 0:
replicas = replicas[1:]

# print('replicas:', str(len(replicas)))

outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)

def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])

def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)

# print('bsz: ', bsz)
# print('num_dev: ', num_dev)
# print('gpu0_bsz: ', gpu0_bsz)
# print('bsz_unit: ', bsz_unit)
# print('chunk_sizes: ', chunk_sizes)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

21 changes: 21 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch.utils.data import Dataset
import torch


class CPMDataset(Dataset):
"""
"""

def __init__(self, input_list, max_len):
self.input_list = input_list
self.max_len = max_len

def __getitem__(self, index):
input_ids = self.input_list[index]
input_ids = input_ids[:self.max_len]
input_ids = torch.tensor(input_ids, dtype=torch.long)
return input_ids

def __len__(self):
return len(self.input_list)
124 changes: 124 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
import torch.nn.functional as F
import os
import argparse
from tqdm import trange
from transformers import GPT2LMHeadModel, GPT2Config, CpmTokenizer
from utils import top_k_top_p_filtering, set_logger
from os.path import join, exists


def generate_next_token(input_ids):
"""
对于给定的上文,生成下一个单词
"""
outputs = model(input_ids=input_ids)
logits = outputs.logits
# next_token_logits表示最后一个token的hidden_state对应的prediction_scores,也就是模型要预测的下一个token的概率
next_token_logits = logits[0, -1, :]
next_token_logits = next_token_logits / args.temperature
# 对于<unk>的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[unk_id] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中选出无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
return next_token_id


def generate(max_len):
# 对title与context进行tokenize
title_ids = tokenizer.encode(title, add_special_tokens=False)
context_ids = tokenizer.encode(context, add_special_tokens=False)
input_ids = title_ids + [sep_id] + context_ids
cur_len = len(input_ids)
last_token_id = input_ids[-1] # 已生成的内容的最后一个token
input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)

while True:
next_token_id = generate_next_token(input_ids)
input_ids = torch.cat((input_ids, next_token_id.unsqueeze(0)), dim=1)
cur_len += 1
word = tokenizer.convert_ids_to_tokens(next_token_id.item())
# if cur_len >= max_len:
# break
# 超过最大长度,并且换行
if cur_len >= max_len and last_token_id == 8 and next_token_id == 3:
break
# 超过最大长度,并且生成标点符号
if cur_len >= max_len and word in [".", "。", "!", "!", "?", "?", ",", ","]:
break
# 生成结束符
if next_token_id == eod_id:
break
result = tokenizer.decode(input_ids.squeeze(0))
return result


if __name__ == '__main__':
# 参数设置
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
parser.add_argument('--topk', default=0, type=int, required=False, help='最高几选一')
parser.add_argument('--topp', default=0.85, type=float, required=False, help='最高积累概率')
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help='重复惩罚参数')
parser.add_argument('--max_len', default=200, type=int, required=False, help='生成的最长长度')
parser.add_argument('--log_path', default='log/generate.log', type=str, required=False, help='日志存放位置')
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
parser.add_argument('--model_path', type=str, default='model/zuowen_epoch40', help='模型存放位置')
# parser.add_argument('--title', type=str, default='徜徉在书籍的阳光世界', help='作文标题')
# parser.add_argument('--context', type=str, default='一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙', help='作文上文')
parser.add_argument('--title', type=str, default='家乡的四季', help='作文标题')
parser.add_argument('--context', type=str, default='家乡的四季,最美不过了', help='作文上文')
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
args.cuda = torch.cuda.is_available() and not args.no_cuda # 当用户使用GPU,并且GPU可用时
device = 'cuda:0' if args.cuda else 'cpu'
# device = 'cpu'

# 创建日志对象
logger = set_logger(args.log_path)

# 初始化tokenizer
tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")
eod_id = tokenizer.convert_tokens_to_ids("<eod>") # 文档结束符
sep_id = tokenizer.sep_token_id
unk_id = tokenizer.unk_token_id

# 加载模型
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model.eval()
model = model.to(device)

title = args.title
context = args.context
logger.info("title:{}".format(title))
logger.info("context:{}".format(context))

# 开始生成
result = generate(args.max_len)
result = result.split("<sep>")[1]
logger.info("result:{}\n".format(result))

# 通过控制台循环生成
# print('开始生成,输入CTRL + Z以退出')
# while True:
# try:
# # 用户输入title与context
# title = input("请输入作文标题:")
# context = input("请输入作文起始句子:")
#
# logger.info("title:{}".format(title))
# logger.info("context:{}".format(context))
#
# # 开始生成
# result = generate(args.max_len)
# result = result.split("<sep>")[1]
# logger.info("result:{}\n".format(result))
# break
#
# except KeyboardInterrupt:
# break


Loading

0 comments on commit fe9eee3

Please sign in to comment.