Skip to content

Commit ce1eeb2

Browse files
committed
Add wangchanglm to pythainlp.generate
1 parent 59c81ee commit ce1eeb2

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

pythainlp/generate/decoder_model.py renamed to pythainlp/chat/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,4 @@
1111
# distributed under the License is distributed on an "AS IS" BASIS,
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
# WIP
14+
# limitations under the License.

pythainlp/generate/wangchanglm.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (C) 2016-2023 PyThaiNLP Project
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import re
16+
import pandas as pd
17+
import torch
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
21+
class WangChanGLM:
22+
def __init__(self):
23+
self.exclude_pattern = re.compile(r'[^ก-๙]+')
24+
self.PROMPT_DICT = {
25+
"prompt_input": (
26+
"<context>: {input}\n<human>: {instruction}\n<bot>: "
27+
),
28+
"prompt_no_input": (
29+
"<human>: {instruction}\n<bot>: "
30+
),
31+
}
32+
def is_exclude(self, text):
33+
return bool(self.exclude_pattern.search(text))
34+
def load_model(
35+
self,
36+
model_path,
37+
return_dict=True,
38+
load_in_8bit=False,
39+
device_map="auto",
40+
torch_dtype=torch.float16,
41+
offload_folder="./",
42+
low_cpu_mem_usage=True,
43+
**
44+
):
45+
self.model_path = model_path
46+
self.model = AutoModelForCausalLM.from_pretrained(
47+
self.model_path
48+
return_dict=return_dict,
49+
load_in_8bit=load_in_8bit,
50+
device_map=device_map,
51+
torch_dtype=torch_dtype,
52+
offload_folder=offload_folder,
53+
low_cpu_mem_usage=low_cpu_mem_usage,
54+
**
55+
)
56+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
57+
self.df = pd.DataFrame(self.tokenizer.vocab.items(), columns=['text', 'idx'])
58+
self.df['is_exclude'] = self.df.text.map(self.is_exclude)
59+
self.exclude_ids = self.df[self.df.is_exclude==True].idx.tolist()
60+
def gen_instruct(
61+
self,
62+
text,
63+
max_new_tokens=512,
64+
top_p=0.95,
65+
temperature=0.9,
66+
top_k=50,
67+
no_repeat_ngram_size=2,
68+
typical_p=1.
69+
):
70+
batch = self.tokenizer(text, return_tensors="pt")
71+
with torch.cuda.amp.autocast(): # cuda -> cpu if cpu
72+
if Thai=="Yes":
73+
output_tokens = self.model.generate(
74+
input_ids=batch["input_ids"],
75+
max_new_tokens=max_new_tokens, # 512
76+
begin_suppress_tokens = self.exclude_ids,
77+
no_repeat_ngram_size=no_repeat_ngram_size,
78+
#oasst k50
79+
top_k=top_k,
80+
top_p=top_p, # 0.95
81+
typical_p=typical_p,
82+
temperature=temperature, # 0.9
83+
)
84+
else:
85+
output_tokens = self.model.generate(
86+
input_ids=batch["input_ids"],
87+
max_new_tokens=max_new_tokens, # 512
88+
no_repeat_ngram_size=no_repeat_ngram_size,
89+
#oasst k50
90+
top_k=top_k,
91+
top_p=top_p, # 0.95
92+
typical_p=typical_p,
93+
temperature=temperature, # 0.9
94+
)
95+
return self.tokenizer.decode(output_tokens[0][len(batch["input_ids"][0]):], skip_special_tokens=True)
96+
def instruct_generate(
97+
self,
98+
instruct: str,
99+
context: str = None,
100+
max_gen_len=512,
101+
temperature: float =0.9,
102+
top_p: float = 0.95,
103+
top_k=50,
104+
no_repeat_ngram_size=2,
105+
typical_p=1
106+
):
107+
if context == None or context=="":
108+
prompt = self.PROMPT_DICT['prompt_no_input'].format_map(
109+
{'instruction': instruct, 'input': ''}
110+
)
111+
else:
112+
prompt = self.PROMPT_DICT['prompt_input'].format_map(
113+
{'instruction': instruct, 'input': context}
114+
)
115+
result = self.gen_instruct(
116+
prompt,
117+
max_gen_len=max_gen_len,
118+
top_p=top_p,
119+
top_k=top_k,
120+
temperature=temperature,
121+
no_repeat_ngram_size=no_repeat_ngram_size,
122+
typical_p=typical_p
123+
)
124+
return result

0 commit comments

Comments
 (0)