-
Notifications
You must be signed in to change notification settings - Fork 14
/
client_sample.cc
123 lines (100 loc) · 4.11 KB
/
client_sample.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "llm.grpc.pb.h"
#include <grpc++/grpc++.h>
#include <chrono>
using namespace grpc;
using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;
using namespace std::chrono;
using namespace ppl::llm;
ABSL_FLAG(std::string, target, "localhost:50052", "Server address");
class GenerationClient {
public:
GenerationClient(std::shared_ptr<Channel> channel) : stub_(proto::LLMService::NewStub(channel)) {}
int Generation(const std::vector<std::string>& prompts) {
// Data we are sending to the server.
ClientContext context;
proto::BatchedRequest req_list;
std::unordered_map<int, std::string> rsp_stream_store;
for (size_t i = 0; i < prompts.size(); i++) {
// request
auto* req = req_list.add_req();
req->set_id(i);
req->set_prompt(prompts[i]);
auto* choosing_parameter = req->mutable_choosing_parameters();
choosing_parameter->set_do_sample(false);
choosing_parameter->set_temperature(1.f);
choosing_parameter->set_repetition_penalty(1.f);
choosing_parameter->set_presence_penalty(0.f);
choosing_parameter->set_frequency_penalty(0.f);
auto* stopping_parameters = req->mutable_stopping_parameters();
stopping_parameters->set_max_new_tokens(16);
stopping_parameters->set_ignore_eos_token(false);
rsp_stream_store[i] = "";
}
// response
proto::BatchedResponse batched_rsp;
std::unique_ptr<ClientReader<proto::BatchedResponse>> reader(stub_->Generation(&context, req_list));
// stream chat
auto start = system_clock::now();
auto first_fill_time = system_clock::now();
bool is_first_fill = true;
while (reader->Read(&batched_rsp)) {
if (is_first_fill) {
first_fill_time = system_clock::now();
is_first_fill = false;
}
for (const auto& rsp : batched_rsp.rsp()) {
int tid = rsp.id();
std::string rsp_stream = rsp.generated();
rsp_stream_store[tid] += rsp_stream;
}
}
auto end = system_clock::now();
std::cout << "------------------------------" << std::endl;
std::cout << "--------- Answer -------------" << std::endl;
std::cout << "------------------------------" << std::endl;
for (auto rsp : rsp_stream_store) {
std::cout << rsp.second << std::endl;
std::cout << "--------------------" << std::endl;
}
auto first_till_duration = duration_cast<std::chrono::milliseconds>(first_fill_time - start);
auto duration = duration_cast<std::chrono::milliseconds>(end - start);
std::cout << "first fill: " << first_till_duration.count() << " ms" << std::endl;
std::cout << "total: " << duration.count() << " ms" << std::endl;
Status status = reader->Finish();
if (status.ok()) {
std::cout << "Generation rpc succeeded." << std::endl;
} else {
std::cerr << "Generation rpc failed." << std::endl;
return -1;
}
return 0;
}
private:
std::unique_ptr<proto::LLMService::Stub> stub_;
};
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "usage: " << argv[0] << " host:port" << std::endl;
return -1;
}
const std::string target_str = argv[1];
GenerationClient generator(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
const std::string prompt = "Building a website can be done in 10 simple steps:\n";
const std::vector<std::string> prompts(3, prompt);
std::cout << "------------------------------" << std::endl;
std::cout << "--------- Question -------------" << std::endl;
std::cout << "------------------------------" << std::endl;
for (auto& str : prompts) {
std::cout << str << std::endl;
}
generator.Generation(prompts);
return 0;
}