1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (C) 2016-2023 PyThaiNLP Project
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import pandas as pd
17
+ import torch
18
+ from transformers import AutoModelForCausalLM , AutoTokenizer
19
+
20
+
21
+ class WangChanGLM :
22
+ def __init__ (self ):
23
+ self .exclude_pattern = re .compile (r'[^ก-๙]+' )
24
+ self .PROMPT_DICT = {
25
+ "prompt_input" : (
26
+ "<context>: {input}\n <human>: {instruction}\n <bot>: "
27
+ ),
28
+ "prompt_no_input" : (
29
+ "<human>: {instruction}\n <bot>: "
30
+ ),
31
+ }
32
+ def is_exclude (self , text ):
33
+ return bool (self .exclude_pattern .search (text ))
34
+ def load_model (
35
+ self ,
36
+ model_path ,
37
+ return_dict = True ,
38
+ load_in_8bit = False ,
39
+ device_map = "auto" ,
40
+ torch_dtype = torch .float16 ,
41
+ offload_folder = "./" ,
42
+ low_cpu_mem_usage = True ,
43
+ **
44
+ ):
45
+ self .model_path = model_path
46
+ self .model = AutoModelForCausalLM .from_pretrained (
47
+ self .model_path
48
+ return_dict = return_dict ,
49
+ load_in_8bit = load_in_8bit ,
50
+ device_map = device_map ,
51
+ torch_dtype = torch_dtype ,
52
+ offload_folder = offload_folder ,
53
+ low_cpu_mem_usage = low_cpu_mem_usage ,
54
+ **
55
+ )
56
+ self .tokenizer = AutoTokenizer .from_pretrained (self .model_path )
57
+ self .df = pd .DataFrame (self .tokenizer .vocab .items (), columns = ['text' , 'idx' ])
58
+ self .df ['is_exclude' ] = self .df .text .map (self .is_exclude )
59
+ self .exclude_ids = self .df [self .df .is_exclude == True ].idx .tolist ()
60
+ def gen_instruct (
61
+ self ,
62
+ text ,
63
+ max_new_tokens = 512 ,
64
+ top_p = 0.95 ,
65
+ temperature = 0.9 ,
66
+ top_k = 50 ,
67
+ no_repeat_ngram_size = 2 ,
68
+ typical_p = 1.
69
+ ):
70
+ batch = self .tokenizer (text , return_tensors = "pt" )
71
+ with torch .cuda .amp .autocast (): # cuda -> cpu if cpu
72
+ if Thai == "Yes" :
73
+ output_tokens = self .model .generate (
74
+ input_ids = batch ["input_ids" ],
75
+ max_new_tokens = max_new_tokens , # 512
76
+ begin_suppress_tokens = self .exclude_ids ,
77
+ no_repeat_ngram_size = no_repeat_ngram_size ,
78
+ #oasst k50
79
+ top_k = top_k ,
80
+ top_p = top_p , # 0.95
81
+ typical_p = typical_p ,
82
+ temperature = temperature , # 0.9
83
+ )
84
+ else :
85
+ output_tokens = self .model .generate (
86
+ input_ids = batch ["input_ids" ],
87
+ max_new_tokens = max_new_tokens , # 512
88
+ no_repeat_ngram_size = no_repeat_ngram_size ,
89
+ #oasst k50
90
+ top_k = top_k ,
91
+ top_p = top_p , # 0.95
92
+ typical_p = typical_p ,
93
+ temperature = temperature , # 0.9
94
+ )
95
+ return self .tokenizer .decode (output_tokens [0 ][len (batch ["input_ids" ][0 ]):], skip_special_tokens = True )
96
+ def instruct_generate (
97
+ self ,
98
+ instruct : str ,
99
+ context : str = None ,
100
+ max_gen_len = 512 ,
101
+ temperature : float = 0.9 ,
102
+ top_p : float = 0.95 ,
103
+ top_k = 50 ,
104
+ no_repeat_ngram_size = 2 ,
105
+ typical_p = 1
106
+ ):
107
+ if context == None or context == "" :
108
+ prompt = self .PROMPT_DICT ['prompt_no_input' ].format_map (
109
+ {'instruction' : instruct , 'input' : '' }
110
+ )
111
+ else :
112
+ prompt = self .PROMPT_DICT ['prompt_input' ].format_map (
113
+ {'instruction' : instruct , 'input' : context }
114
+ )
115
+ result = self .gen_instruct (
116
+ prompt ,
117
+ max_gen_len = max_gen_len ,
118
+ top_p = top_p ,
119
+ top_k = top_k ,
120
+ temperature = temperature ,
121
+ no_repeat_ngram_size = no_repeat_ngram_size ,
122
+ typical_p = typical_p
123
+ )
124
+ return result
0 commit comments