-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodel_class.py
196 lines (190 loc) · 10.2 KB
/
model_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from transformers import AutoTokenizer
from ctypes_bindings import *
from model_configs import model_configs
import threading
import time
import sys
import os
MODEL_PATH = "./models"
# Create a dict of various model configs, and then check the ./models directory if any exist
# This will become the content of the model selector's drop down menu
def available_models():
if not os.path.exists(MODEL_PATH):
os.mkdir(MODEL_PATH)
# Initialize the dict of available models as empty
rkllm_model_files = {}
# Populate the dictionary with found models, and their base configurations
for family, config in model_configs.items():
for model, details in config["models"].items():
filename = details["filename"]
if os.path.exists(os.path.join(MODEL_PATH, filename)):
rkllm_model_files[model] = {}
rkllm_model_files[model].update({"name": model,"family": family, "filename": filename, "config": config["base_config"]})
return rkllm_model_files
# Define the callback function
def callback_impl(result, userdata, state):
global global_text, global_state, split_byte_data
if state == LLMCallState.RKLLM_RUN_FINISH:
global_state = state
print("\n")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_ERROR:
global_state = state
print("run error")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
'''
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
With these three parameters, you can retrieve the data from last_hidden_layer.
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
'''
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
print(f"data_size: {data_size}")
global_text.append(f"data_size: {data_size}\n")
output_path = os.getcwd() + "/last_hidden_layer.bin"
with open(output_path, "wb") as outFile:
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
outFile.write(bytearray(float_array))
print(f"Data saved to {output_path} successfully!")
global_text.append(f"Data saved to {output_path} successfully!")
else:
print("Invalid hidden layer data.")
global_text.append("Invalid hidden layer data.")
global_state = state
time.sleep(0.05)
sys.stdout.flush()
else:
# Save the output token text and the RKLLM running state
global_state = state
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
try:
if split_byte_data == None or split_byte_data == "" or split_byte_data == '':
global_text.append((b"" + result.contents.text).decode('utf-8'))
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
split_byte_data = bytes(b"")
else:
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
split_byte_data = bytes(b"")
except:
if result.contents.text is not None:
split_byte_data += result.contents.text
sys.stdout.flush()
# Connect the callback function between the Python side and the C++ side
callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
callback = callback_type(callback_impl)
class RKLLMLoaderClass:
def __init__(self, model="", qtype="w8a8", opt="1", hybrid_quant="1.0"):
self.qtype = qtype
self.opt = opt
self.model = model
self.hybrid_quant = hybrid_quant
if self.model == "":
print("No models loaded yet!")
else:
self.available_models = available_models()
self.model = self.available_models[model]
self.family = self.model["family"]
self.model_path = "models/" + self.model["filename"]
self.base_config = self.model["config"]
self.model_name = self.model["name"]
self.st_model_id = self.base_config["st_model_id"]
self.system_prompt = self.base_config["system_prompt"]
self.rkllm_param = RKLLMParam()
self.rkllm_param.model_path = bytes(self.model_path, 'utf-8')
self.rkllm_param.max_context_len = self.base_config["max_context_len"]
self.rkllm_param.max_new_tokens = self.base_config["max_new_tokens"]
self.rkllm_param.skip_special_token = True
self.rkllm_param.top_k = self.base_config["top_k"]
self.rkllm_param.top_p = self.base_config["top_p"]
# self.rkllm_param.min_p = 0.1
self.rkllm_param.temperature = self.base_config["temperature"]
self.rkllm_param.repeat_penalty = self.base_config["repeat_penalty"]
self.rkllm_param.frequency_penalty = self.base_config["frequency_penalty"]
self.rkllm_param.presence_penalty = 0.0
self.rkllm_param.mirostat = 0
self.rkllm_param.mirostat_tau = 5.0
self.rkllm_param.mirostat_eta = 0.1
self.rkllm_param.is_async = False
self.rkllm_param.img_start = "<image>".encode('utf-8')
self.rkllm_param.img_end = "</image>".encode('utf-8')
self.rkllm_param.img_content = "<unk>".encode('utf-8')
self.rkllm_param.extend_param.base_domain_id = 0
self.handle = RKLLM_Handle_t()
self.rkllm_init = rkllm_lib.rkllm_init
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
self.rkllm_init.restype = ctypes.c_int
self.rkllm_init(self.handle, self.rkllm_param, callback)
self.rkllm_run = rkllm_lib.rkllm_run
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
self.rkllm_run.restype = ctypes.c_int
self.rkllm_abort = rkllm_lib.rkllm_abort
self.rkllm_abort.argtypes = [RKLLM_Handle_t]
self.rkllm_abort.restype = ctypes.c_int
self.rkllm_destroy = rkllm_lib.rkllm_destroy
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
self.rkllm_destroy.restype = ctypes.c_int
# Record the user's input prompt
def get_user_input(self, user_message, history):
history = history + [[user_message, None]]
return "", history
def tokens_to_ctypes_array(self, tokens, ctype):
# Converts a Python list to a ctypes array.
# The tokenizer outputs as a Python list.
return (ctype * len(tokens))(*tokens)
# Run inference
def run(self, prompt):
self.rkllm_infer_params = RKLLMInferParam()
ctypes.memset(ctypes.byref(self.rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
self.rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
self.rkllm_input = RKLLMInput()
self.rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_TOKEN
self.rkllm_input.input_data.token_input.input_ids = self.tokens_to_ctypes_array(prompt, ctypes.c_int)
self.rkllm_input.input_data.token_input.n_tokens = ctypes.c_ulong(len(prompt))
self.rkllm_run(self.handle, ctypes.byref(self.rkllm_input), ctypes.byref(self.rkllm_infer_params), None)
return
# Release RKLLM object from memory
def release(self):
self.rkllm_abort(self.handle)
self.rkllm_destroy(self.handle)
# Retrieve the output from the RKLLM model and print it in a streaming manner
def get_RKLLM_output(self, message, history):
# Link global variables to retrieve the output information from the callback function
global global_text, global_state
global_text = []
global_state = -1
user_prompt = {"role": "user", "content": message}
history.append(user_prompt)
# Gemma 2 does not support system prompt.
if self.system_prompt == "":
prompt = [user_prompt]
else:
prompt = [
{"role": "system", "content": self.system_prompt},
user_prompt
]
# print(prompt)
tokenizer = AutoTokenizer.from_pretrained(self.st_model_id, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)
# response = {"role": "assistant", "content": "Loading..."}
response = {"role": "assistant", "content": ""}
history.append(response)
model_thread = threading.Thread(target=self.run, args=(prompt,))
model_thread.start()
model_thread_finished = False
while not model_thread_finished:
while len(global_text) > 0:
response["content"] += global_text.pop(0)
# Marco-o1
response["content"] = str(response["content"]).replace("<Thought>", "\\<Thought\\>")
response["content"] = str(response["content"]).replace("</Thought>", "\\<\\/Thought\\>")
response["content"] = str(response["content"]).replace("<Output>", "\\<Output\\>")
response["content"] = str(response["content"]).replace("</Output>", "\\<\\/Output\\>")
time.sleep(0.005)
# Gradio automatically pushes the result returned by the yield statement when calling the then method
yield response
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()