Skip to content

Commit

Permalink
improve the benchmark test of vineyard llm kv cache (#1816)
Browse files Browse the repository at this point in the history
After the benchmark test, we can get the following result.

```
Token list size is 17792Total Update time is 2.22029s Total Query time is 0.646123s Average update time is 8013.38token/s Average query time is 27536.5token/s
```

The query time including (query kv tensor ptr from vineyard) + (memcpy
from the kv tensor ptr to users' buffer)

Signed-off-by: Ye Cao <[email protected]>
  • Loading branch information
dashanji authored May 21, 2024
1 parent a403b12 commit c8de264
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"ldbc",
"leaseid",
"leasekeepalive",
"LLMKV",
"libboost",
"libclang",
"libgrape",
Expand Down Expand Up @@ -289,6 +290,7 @@
"Succ",
"thirdparty",
"thiserror",
"TENSORBYTES",
"Timepoint",
"toctree",
"TORCHELASTIC",
Expand Down
72 changes: 44 additions & 28 deletions modules/llm-cache/tests/kv_state_cache_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,32 @@ limitations under the License.
using namespace vineyard; // NOLINT(build/namespaces)

constexpr int TENSORBYTES = 800;
constexpr int CAPACITY = 1000;
constexpr int CAPACITY = 8000;
constexpr int LAYER = 64;
constexpr int BLOCK_SIZE = 100;

std::shared_ptr<KVStateCacheManager> manager;
VineyardCacheConfig config(TENSORBYTES, CAPACITY, LAYER, BLOCK_SIZE, 3);
VineyardCacheConfig config(TENSORBYTES, CAPACITY, LAYER, BLOCK_SIZE, 300);
Client client;

void init(std::string socket) {
VINEYARD_CHECK_OK(client.Connect(socket));
VINEYARD_CHECK_OK(KVStateCacheManager::Make(client, manager, config));
}

std::vector<int> generate_random_tokens(size_t max_length) {
std::vector<int> generate_unique_tokens(size_t max_length) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dist(1, 10000);

std::unordered_set<int> unique_tokens;

size_t length = dist(gen) % max_length + 1;
std::vector<int> tokens(length);
for (size_t i = 0; i < length; ++i) {
tokens[i] = dist(gen);
while (unique_tokens.size() < length) {
int token = dist(gen);
unique_tokens.insert(token);
}
std::vector<int> tokens(unique_tokens.begin(), unique_tokens.end());
return tokens;
}

Expand Down Expand Up @@ -78,46 +81,59 @@ void benchmark_inference(std::vector<std::vector<int>>& tokens) {
double token_list_size = 0;
std::chrono::duration<double> update_duration(0);
std::chrono::duration<double> query_duration(0);
double total_update_duration = 0;
double total_query_duration = 0;

std::vector<int> inference_tokens;
std::map<int, std::pair<LLMKV, LLMKV>> kv_state_list;
void* key_state = malloc(TENSORBYTES);
void* value_state = malloc(TENSORBYTES);

for (size_t i = 0; i < tokens.size(); ++i) {
std::vector<int> inference_tokens;
inference_tokens.clear();
for (size_t j = 0; j < tokens[i].size(); ++j) {
start = std::chrono::steady_clock::now();
kv_state = generate_kv_state(tokens[i][j]);
Status status = manager->Query(inference_tokens, tokens[i][j], kv_state);
start = std::chrono::steady_clock::now();
Status status = manager->Update(inference_tokens, tokens[i][j], kv_state);
if (!status.ok()) {
VLOG(100) << "KV state is not in the cache.";
}
end = std::chrono::steady_clock::now();
query_duration += end - start;
update_duration += end - start;
inference_tokens.push_back(tokens[i][j]);
token_list_size++;
}
}

if (kv_state.size() == 0) {
start = std::chrono::steady_clock::now();
Status status =
manager->Update(inference_tokens, tokens[i][j], kv_state);
if (!status.ok()) {
// Not a error. May be the cache is full.
VLOG(100) << "Put kv state into cache failed.";
// query time
for (size_t i = 0; i < tokens.size(); ++i) {
inference_tokens.clear();
kv_state_list.clear();
for (size_t j = 0; j < tokens[i].size(); ++j) {
start = std::chrono::steady_clock::now();
Status status =
manager->Query(inference_tokens, tokens[i][j], kv_state_list);
if (!status.ok()) {
VLOG(100) << "KV state is not in the cache.";
}
for (auto& kv : kv_state_list) {
for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) {
memcpy(key_state, kv.second.first.data, kv.second.first.length);
memcpy(value_state, kv.second.second.data, kv.second.second.length);
}
end = std::chrono::steady_clock::now();
update_duration += end - start;
}
end = std::chrono::steady_clock::now();
query_duration += end - start;
inference_tokens.push_back(tokens[i][j]);
token_list_size++;
}
total_update_duration += update_duration.count();
total_query_duration += query_duration.count();
}

LOG(INFO) << "Token list size is " << token_list_size
<< "Total Update time is " << total_update_duration << "s "
<< "Total Query time is " << total_query_duration << "s "
<< "Total Update time is " << update_duration.count() << "s "
<< "Total Query time is " << query_duration.count() << "s "
<< "Average update time is "
<< token_list_size / total_update_duration << "token/s "
<< token_list_size / update_duration.count() << "token/s "
<< "Average query time is "
<< token_list_size / total_query_duration << "token/s ";
<< token_list_size / query_duration.count() << "token/s ";
}

int main(int argc, char** argv) {
Expand Down Expand Up @@ -151,7 +167,7 @@ int main(int argc, char** argv) {
const size_t num_lists = 10;
std::vector<std::vector<int>> all_token_lists;
for (size_t i = 0; i < num_lists; ++i) {
all_token_lists.push_back(generate_random_tokens(2000));
all_token_lists.push_back(generate_unique_tokens(2000));
}

benchmark_inference(all_token_lists);
Expand Down

0 comments on commit c8de264

Please sign in to comment.