forked from ztxz16/fastllm
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.cpp
107 lines (95 loc) · 3.26 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <cstdio>
#include <cstring>
#include <iostream>
#include <getopt.h>
#include "moss.h"
#include "chatglm.h"
struct RunConfig {
std::string model = "chatglm"; // 模型类型, chatglm或moss
std::string path; // 模型文件路径
int threads = 4; // 使用的线程数
};
static struct option long_options[] = {
{"help", no_argument, nullptr, 'h'},
{"model", required_argument, nullptr, 'm'},
{"path", required_argument, nullptr, 'p'},
{"threads", required_argument, nullptr, 't'},
{nullptr, 0, nullptr, 0},
};
void Usage() {
std::cout << "Usage:" << std::endl;
std::cout << "[-h|--help]: 显示帮助" << std::endl;
std::cout << "<-m|--model> <args>: 模型类型,默认为chatglm, 可以设置为chatglm, moss" << std::endl;
std::cout << "<-p|--path> <args>: 模型文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
}
void ParseArgs(int argc, char **argv, RunConfig &config) {
int opt;
int option_index = 0;
const char *opt_string = "h:m:p:t:";
while ((opt = getopt_long_only(argc, argv, opt_string, long_options, &option_index)) != -1) {
switch (opt) {
case 'h':
Usage();
exit (0);
case 'm':
config.model = argv[optind - 1];
break;
case 'p':
config.path = argv[optind - 1];
break;
case 't':
config.threads = atoi(argv[optind - 1]);
break;
default:
Usage();
exit (-1);
}
}
}
int main(int argc, char **argv) {
RunConfig config;
ParseArgs(argc, argv, config);
fastllm::SetThreads(config.threads);
if (config.model == "moss") {
fastllm::MOSSModel moss;
moss.LoadFromFile(config.path);
while (true) {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
if (input == "stop") {
break;
}
std::string ret = moss.Response("You are an AI assistant whose name is MOSS. <|Human|>: " + input + "<eoh>");
}
} else if (config.model == "chatglm") {
fastllm::ChatGLMModel chatGlm;
chatGlm.LoadFromFile(config.path);
int round = 0;
std::string history;
while (true) {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
if (input == "stop") {
break;
}
if (input == "reset") {
history = "";
round = 0;
continue;
}
history += ("[Round " + std::to_string(round++) + "]\n问:" + input);
auto prompt = round > 1 ? history : input;
printf("ChatGLM: ");
std::string ret = chatGlm.Response(prompt);
history += ("\n答:" + ret + "\n");
}
//chatGlm.weight.SaveLowBitModel("/root/chatglm-6b-int4.bin", 4);
} else {
Usage();
exit(-1);
}
return 0;
}