-
Notifications
You must be signed in to change notification settings - Fork 337
/
Copy pathrinna_gpt2.py
88 lines (74 loc) · 2.63 KB
/
rinna_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import time
import sys
import os
import numpy
from utils_rinna_gpt2 import *
import ailia
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser # noqa: E402
from model_utils import check_and_download_models # noqa: E402
# logger
from logging import getLogger # noqa: E402
logger = getLogger(__name__)
# ======================
# Arguemnt Parser Config
# ======================
DEFAULT_TEXT = '生命、宇宙、そして万物についての究極の疑問の答えは'
parser = get_base_parser('rinna-gpt2 text generation', None, None)
# overwrite
parser.add_argument(
'--input', '-i', default=DEFAULT_TEXT
)
parser.add_argument(
'--outlength', '-o', default=50
)
parser.add_argument(
'--onnx',
action='store_true',
help='By default, the ailia SDK is used, but with this option, you can switch to using ONNX Runtime'
)
parser.add_argument(
'--disable_ailia_tokenizer',
action='store_true',
help='disable ailia tokenizer.'
)
args = update_parser(parser, check_input_type=False)
# ======================
# PARAMETERS
# ======================
WEIGHT_PATH = "japanese-gpt2-small.opt.onnx"
MODEL_PATH = "japanese-gpt2-small.opt.onnx.prototxt"
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/rinna_gpt2/"
# ======================
# Main function
# ======================
def main():
if args.onnx:
import onnxruntime
ailia_model = onnxruntime.InferenceSession(WEIGHT_PATH)
else:
logger.info("This model requires multiple input shape, so running on CPU")
ailia_model = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=0)#args.env_id)
if args.disable_ailia_tokenizer:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
else:
from ailia_tokenizer import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("./tokenizer/")
logger.info("Input : "+args.input)
# inference
if args.benchmark:
logger.info('BENCHMARK mode')
for i in range(5):
start = int(round(time.time() * 1000))
output = generate_text(tokenizer, ailia_model, args.input, int(args.outlength), args.onnx)
end = int(round(time.time() * 1000))
logger.info("\tailia processing time {} ms".format(end - start))
else:
output = generate_text(tokenizer, ailia_model, args.input, int(args.outlength), args.onnx)
logger.info("output : "+output)
logger.info('Script finished successfully.')
if __name__ == "__main__":
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)
main()