Skip to content

Commit

Permalink
Format code.
Browse files Browse the repository at this point in the history
Fix bug of radix tree and test.

Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Mar 6, 2024
1 parent b425df8 commit 71741ef
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 60 deletions.
17 changes: 8 additions & 9 deletions modules/llm-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ Status KVStateCacheBuilder::Split(
childKVStateCacheBlockBuilder =
new KVStateCacheBlockBuilder(client, this->tensorBytes, this->layer,
kvStateCacheBlockBuilder->GetBlockSize());
RETURN_ON_ASSERT(childKVStateCacheBlockBuilder != nullptr,
"Not enough memory for new block builder.");
VINEYARD_ASSERT(childKVStateCacheBlockBuilder != nullptr,
"Not enough memory for new block builder.");

for (size_t i = 0; i < nodeDataList.size(); i++) {
OffsetData* data =
Expand All @@ -154,10 +154,9 @@ Status KVStateCacheBuilder::Split(
return Status::OK();
}

Status KVStateCacheBuilder::Update(Client& client,
const std::vector<int>& tokenList,
int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheBuilder::Update(
Client& client, const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(nextToken);

Expand Down Expand Up @@ -223,9 +222,9 @@ Status KVStateCacheBuilder::Update(Client& client,
return Status::OK();
}

Status KVStateCacheBuilder::Query(Client& client,
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheBuilder::Query(
Client& client, const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(token);

Expand Down
3 changes: 2 additions & 1 deletion modules/llm-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder {
KVStateCacheBlockBuilder*& childKVStateCacheBlockBuilder);

Status Update(Client& client, const std::vector<int>& token_list,
int next_token, const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);
int next_token,
const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

Status Query(Client& client, const std::vector<int>& token_list, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);
Expand Down
14 changes: 7 additions & 7 deletions modules/llm-cache/ds/kv_state_cache_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
}
}

Status KVStateCacheBlockBuilder::Query(Client& client, int index,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheBlockBuilder::Query(
Client& client, int index,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
RETURN_ON_ASSERT((index >= 0 && index < this->blockSize),
"Index out of range: " + std::to_string(index));
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
Expand Down Expand Up @@ -167,18 +168,17 @@ bool KVStateCacheBlockBuilder::IsFull() {
return true;
}

Status KVStateCacheBlockBuilder::Update(const std::map<int, std::pair<LLMKV, LLMKV>>& kvState,
OffsetData* data) {
Status KVStateCacheBlockBuilder::Update(
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState, OffsetData* data) {
int index = this->FindEmptySlot();
RETURN_ON_ASSERT((index >= 0 && index < this->blockSize),
"Index out of range: " + std::to_string(index));

for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
LLMKV keyState = (kvState.find(currentLayer)->second).first;
LLMKV valueState = (kvState.find(currentLayer)->second).second;
RETURN_ON_ASSERT(
(keyState.length == (size_t) this->tensorBytes &&
valueState.length == (size_t) this->tensorBytes));
RETURN_ON_ASSERT((keyState.length == (size_t) this->tensorBytes &&
valueState.length == (size_t) this->tensorBytes));

uint8_t* keyData = keyStateTensorBuilderList[currentLayer]->data();
uint8_t* valueData = valueStateTensorBuilderList[currentLayer]->data();
Expand Down
6 changes: 4 additions & 2 deletions modules/llm-cache/ds/kv_state_cache_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
* @param kv_state The kv-state of the prompt. A LLM inference can contain
* multiple kv-states for each layer.
*/
Status Update(const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state, OffsetData* data);
Status Update(const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state,
OffsetData* data);

/**
* @brief Query the kv-state using the whole token list.
Expand All @@ -149,7 +150,8 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
* @param kv_state The kv-state of the prompt returned by radix-tree. If the
* kv-state is not found, the data of kv-state is invalid.
*/
Status Query(Client& client, int index, std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);
Status Query(Client& client, int index,
std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

bool IsFull();

Expand Down
33 changes: 18 additions & 15 deletions modules/llm-cache/ds/kv_state_cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ Status KVStateCacheManager::Make(Client& client,
return Status::OK();
}

Status KVStateCacheManager::UpdateInternal(const std::vector<int>& tokenList,
int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheManager::UpdateInternal(
const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
return kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState);
}

Status KVStateCacheManager::QueryInternal(const std::vector<int>& tokenList,
int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheManager::QueryInternal(
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
return kvStateCacheBuilder->Query(client, tokenList, token, kvState);
}

Status KVStateCacheManager::Update(const std::vector<int>& tokenList,
int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheManager::Update(
const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");

Expand All @@ -111,8 +111,9 @@ Status KVStateCacheManager::Update(const std::vector<int>& tokenList,
return result;
}

Status KVStateCacheManager::Update(const std::vector<int>& tokenList,
const std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& kvState) {
Status KVStateCacheManager::Update(
const std::vector<int>& tokenList,
const std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& kvState) {
Status result =
Status::Invalid("Update cache failed: can not gain the cache lock.");
if (!syncMutex.try_lock()) {
Expand All @@ -132,8 +133,9 @@ Status KVStateCacheManager::Update(const std::vector<int>& tokenList,
return result;
}

Status KVStateCacheManager::Query(const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status KVStateCacheManager::Query(
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");

Expand All @@ -147,8 +149,9 @@ Status KVStateCacheManager::Query(const std::vector<int>& tokenList, int token,
return result;
}

Status KVStateCacheManager::Query(const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState) {
Status KVStateCacheManager::Query(
const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");
if (!syncMutex.try_lock()) {
Expand Down
10 changes: 6 additions & 4 deletions modules/llm-cache/ds/kv_state_cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,18 @@ class KVStateCacheManager {
std::string llmCacheObjectName = "llm_cache_object");

Status Update(const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState);
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState);

Status Update( const std::vector<int>& tokenList,
Status Update(
const std::vector<int>& tokenList,
const std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& kvState);

Status Query(const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState);

Status Query(const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState);
Status Query(
const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState);

~KVStateCacheManager();

Expand Down
8 changes: 4 additions & 4 deletions modules/llm-cache/radix-tree/radix-tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,24 @@ RadixTree::~RadixTree() {
}

std::shared_ptr<NodeData> RadixTree::Insert(
std::vector<int>& tokens, std::shared_ptr<NodeData>& evictedNode) {
std::vector<int> tokens, std::shared_ptr<NodeData>& evictedNode) {
tokens.insert(tokens.begin(), INT32_MAX);
return InsertInternal(tokens, evictedNode);
}

void RadixTree::Delete(std::vector<int>& tokens,
void RadixTree::Delete(std::vector<int> tokens,
std::shared_ptr<NodeData>& evictedNode) {
tokens.insert(tokens.begin(), INT32_MAX);
DeleteInternal(tokens, evictedNode);
}

std::shared_ptr<NodeData> RadixTree::Query(std::vector<int>& key) {
std::shared_ptr<NodeData> RadixTree::Query(std::vector<int> key) {
key.insert(key.begin(), INT32_MAX);
return QueryInternal(key);
}

std::vector<std::shared_ptr<NodeData>> RadixTree::Split(
std::vector<int>& tokens, std::shared_ptr<NodeData>& header) {
std::vector<int> tokens, std::shared_ptr<NodeData>& header) {
tokens.insert(tokens.begin(), INT32_MAX);
return SplitInternal(tokens, header);
}
Expand Down
8 changes: 4 additions & 4 deletions modules/llm-cache/radix-tree/radix-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ class RadixTree : public std::enable_shared_from_this<RadixTree> {

~RadixTree();

std::shared_ptr<NodeData> Insert(std::vector<int>& tokens,
std::shared_ptr<NodeData> Insert(std::vector<int> tokens,
std::shared_ptr<NodeData>& evictedNode);

void Delete(std::vector<int>& tokens, std::shared_ptr<NodeData>& evictedNode);
void Delete(std::vector<int> tokens, std::shared_ptr<NodeData>& evictedNode);

std::shared_ptr<NodeData> Query(std::vector<int>& key);
std::shared_ptr<NodeData> Query(std::vector<int> key);

std::vector<std::shared_ptr<NodeData>> Split(
std::vector<int>& tokens, std::shared_ptr<NodeData>& header);
std::vector<int> tokens, std::shared_ptr<NodeData>& header);

std::string Serialize();

Expand Down
5 changes: 3 additions & 2 deletions modules/llm-cache/tests/kv_state_cache_multi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.

#include "common/util/logging.h"

constexpr char* program = "./build/bin/kv_state_cache_test";
constexpr const char* program = "./build/bin/kv_state_cache_test";

pid_t create_subprocess(char* argv[]) {
pid_t pid = fork();
Expand Down Expand Up @@ -80,9 +80,10 @@ int main(int argc, char** argv) {
for (size_t i = 0; i < pids.size(); i++) {
int status;
waitpid(pids[i], &status, 0);
if (WIFEXITED(status) && WEXITSTATUS(status) != 0) {
if ((!WIFEXITED(status)) || WEXITSTATUS(status) != 0) {
free(socket_str[0]);
free(socket_str[1]);
LOG(INFO) << "child error!";
return 1;
}
}
Expand Down
18 changes: 6 additions & 12 deletions modules/llm-cache/tests/kv_state_cache_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ std::map<int, std::pair<LLMKV, LLMKV>> generate_kv_state(int token) {

for (int i = 0; i < tensorBytes; ++i) {
(reinterpret_cast<uint8_t*>(key_state.data))[i] =
(static_cast<uint8_t>(token)) / tensorBytes * (i + 1) +
currentLayer * 10;
(static_cast<uint8_t>(token)) + i + currentLayer;
(reinterpret_cast<uint8_t*>(value_state.data))[i] =
(static_cast<uint8_t>(token)) / tensorBytes * (i + 1) * 2 +
currentLayer * 10;
(static_cast<uint8_t>(token)) + i + currentLayer;
}
kv_state.insert(
std::make_pair(currentLayer, std::make_pair(key_state, value_state)));
Expand All @@ -106,27 +104,23 @@ void check_kv_state(const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state,
VINEYARD_ASSERT(iter->second.second.length == (size_t) tensorBytes);
for (int i = 0; i < tensorBytes; ++i) {
if ((reinterpret_cast<uint8_t*>(iter->second.first.data))[i] !=
(static_cast<uint8_t>(token)) / tensorBytes * (i + 1) +
iter->first * 10) {
(static_cast<uint8_t>(token)) + i + iter->first) {
LOG(INFO) << "token:" << token << " tensorBytes" << tensorBytes
<< " layer:" << iter->first;
LOG(INFO) << "key_state[" << i << "]: "
<< (reinterpret_cast<uint8_t*>(iter->second.first.data))[i]
<< ". But is should be "
<< (static_cast<uint8_t>(token)) / tensorBytes * (i + 1) +
iter->first * 10;
<< (static_cast<uint8_t>(token)) + i + iter->first;
throw std::runtime_error("key_state error!");
}
if (reinterpret_cast<uint8_t*>(iter->second.second.data)[i] !=
(static_cast<uint8_t>(token)) / tensorBytes * (i + 1) * 2 +
iter->first * 10) {
(static_cast<uint8_t>(token)) + i + iter->first) {
LOG(INFO) << "token:" << token << " tensorBytes" << tensorBytes
<< " layer:" << iter->first;
LOG(INFO) << "value_state[" << i << "]: "
<< (reinterpret_cast<uint8_t*>(iter->second.second.data))[i]
<< ". But is should be "
<< (static_cast<uint8_t>(token)) / tensorBytes * (i + 1) * 2 +
iter->first * 10;
<< (static_cast<uint8_t>(token)) + i + iter->first * 10;
throw std::runtime_error("value_state error!");
}
}
Expand Down

0 comments on commit 71741ef

Please sign in to comment.