From 8f767784b7207dc17ec0566ebeff773e6a01c36d Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Mon, 22 Jan 2024 14:15:13 +0800 Subject: [PATCH 01/20] Support kv state cache for llm using vineyard (#1720) - support kv state cache on local machine - implement cache object / cache block object - support api of query and update for cache access - using radix tree to store token list as key - seal cache object to vineyard - serialize / deserialize radix tree - build / resolve cache object Working on #1708 and #1712. Signed-off-by: vegetableysm Signed-off-by: Ye Cao Co-authored-by: Ye Cao Co-authored-by: Ye Cao <952129620@qq.com> --- .gitmodules | 3 + CMakeLists.txt | 6 + modules/kv-state-cache/CMakeLists.txt | 27 + modules/kv-state-cache/README.rst | 28 + modules/kv-state-cache/ds/kv_state_cache.cc | 255 ++ modules/kv-state-cache/ds/kv_state_cache.h | 111 + .../kv-state-cache/ds/kv_state_cache_block.cc | 227 ++ .../kv-state-cache/ds/kv_state_cache_block.h | 166 ++ modules/kv-state-cache/lz4 | 1 + .../kv-state-cache/radix-tree/radix-tree.h | 473 ++++ modules/kv-state-cache/radix-tree/radix.cc | 2181 +++++++++++++++++ modules/kv-state-cache/radix-tree/radix.h | 227 ++ .../kv-state-cache/radix-tree/rax_malloc.h | 43 + .../kv-state-cache/strategy/LRU_strategy.cc | 143 ++ .../kv-state-cache/strategy/LRU_strategy.h | 67 + .../kv-state-cache/strategy/cache_strategy.h | 31 + .../utils/kv_state_cache_utils.cc | 274 +++ .../utils/kv_state_cache_utils.h | 33 + src/client/client.cc | 27 + src/client/client.h | 27 + src/client/client_base.h | 19 + src/client/rpc_client.h | 24 + src/common/util/protocols.cc | 67 +- src/common/util/protocols.h | 24 + src/server/async/socket_server.cc | 44 + src/server/async/socket_server.h | 3 + src/server/server/vineyard_server.cc | 34 + src/server/server/vineyard_server.h | 5 + src/server/services/etcd_meta_service.cc | 42 + src/server/services/etcd_meta_service.h | 4 + src/server/services/local_meta_service.h | 12 + src/server/services/meta_service.h | 5 + src/server/services/redis_meta_service.h | 12 + test/distributed_lock_test.cc | 66 + test/kv_state_cache_object_test.cc | 174 ++ test/kv_state_cache_test.cc | 110 + 36 files changed, 4994 insertions(+), 1 deletion(-) create mode 100644 modules/kv-state-cache/CMakeLists.txt create mode 100644 modules/kv-state-cache/README.rst create mode 100644 modules/kv-state-cache/ds/kv_state_cache.cc create mode 100644 modules/kv-state-cache/ds/kv_state_cache.h create mode 100644 modules/kv-state-cache/ds/kv_state_cache_block.cc create mode 100644 modules/kv-state-cache/ds/kv_state_cache_block.h create mode 160000 modules/kv-state-cache/lz4 create mode 100644 modules/kv-state-cache/radix-tree/radix-tree.h create mode 100644 modules/kv-state-cache/radix-tree/radix.cc create mode 100644 modules/kv-state-cache/radix-tree/radix.h create mode 100644 modules/kv-state-cache/radix-tree/rax_malloc.h create mode 100644 modules/kv-state-cache/strategy/LRU_strategy.cc create mode 100644 modules/kv-state-cache/strategy/LRU_strategy.h create mode 100644 modules/kv-state-cache/strategy/cache_strategy.h create mode 100644 modules/kv-state-cache/utils/kv_state_cache_utils.cc create mode 100644 modules/kv-state-cache/utils/kv_state_cache_utils.h create mode 100644 test/distributed_lock_test.cc create mode 100644 test/kv_state_cache_object_test.cc create mode 100644 test/kv_state_cache_test.cc diff --git a/.gitmodules b/.gitmodules index f326d1eb..0c08193f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -49,3 +49,6 @@ [submodule "modules/graph/thirdparty/powturbo"] path = modules/graph/thirdparty/powturbo url = https://github.com/powturbo/TurboPFor-Integer-Compression.git +[submodule "modules/kv-state-cache/lz4"] + path = modules/kv-state-cache/lz4 + url = https://github.com/lz4/lz4.git diff --git a/CMakeLists.txt b/CMakeLists.txt index b1afa939..269d65a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(BUILD_VINEYARD_MALLOC_OVERRIDE "Using the client-side malloc to override option(BUILD_VINEYARD_FUSE "Enable vineyard's fuse support" OFF) option(BUILD_VINEYARD_FUSE_PARQUET "Enable vineyard's fuse parquet support" OFF) option(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME "Enable hosseinmoein dataframe support" OFF) +option(BUILD_VINEYARD_KV_STATE_CACHE "Enable kv-state cache support" ON) option(BUILD_VINEYARD_TESTS "Generate make targets for vineyard tests" ON) option(BUILD_VINEYARD_TESTS_ALL "Include make targets for vineyard tests to ALL" OFF) @@ -932,6 +933,11 @@ if(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME) list(APPEND VINEYARD_INSTALL_LIBS vineyard_hosseinmoein_dataframe) endif() +if(BUILD_VINEYARD_KV_STATE_CACHE) + add_subdirectory(modules/kv-state-cache) + list(APPEND VINEYARD_INSTALL_LIBS vineyard_kv_state_cache) +endif() + if(BUILD_VINEYARD_TESTS) add_subdirectory(test) endif() diff --git a/modules/kv-state-cache/CMakeLists.txt b/modules/kv-state-cache/CMakeLists.txt new file mode 100644 index 00000000..ca96b72e --- /dev/null +++ b/modules/kv-state-cache/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LZ4_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/lz4") + +add_custom_target(build_lz4 + COMMAND make -C ${LZ4_SOURCE_DIR} + WORKING_DIRECTORY ${LZ4_SOURCE_DIR}) + +file(GLOB LZ4_LIBRARIES "${LZ4_SOURCE_DIR}/lib/*.so") + +file(GLOB VINEYARD_KV_STATE_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" + "ds/*.cc" + "ds/*.h" + "radix-tree/*.cc" + "radix-tree/*.h" + "utils/*.cc" + "utils/*.h" + "strategy/*.cc" + "strategy/*.h" + "lz4/lib/*.h" +) + +add_library(vineyard_kv_state_cache ${VINEYARD_KV_STATE_CACHE_SRCS}) +target_link_libraries(vineyard_kv_state_cache PUBLIC vineyard_client vineyard_basic ${LZ4_LIBRARIES}) + +install_export_vineyard_target(vineyard_kv_state_cache) +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") diff --git a/modules/kv-state-cache/README.rst b/modules/kv-state-cache/README.rst new file mode 100644 index 00000000..4acc1a20 --- /dev/null +++ b/modules/kv-state-cache/README.rst @@ -0,0 +1,28 @@ +KV-state cache on Vineyard +============================= + +Run test +-------- + +Build vineyard and vineyard test + +.. code:: bash + mkdir build + cd build + cmake .. -DBUILD_VINEYARD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug + make build_lz4 + make -j$(nproc) + make vineyard_tests -j$(nproc) + +Start vineyard server + +.. code:: bash + cd build + ./bin/vineyardd --socket=/tmp/vineyard_test.sock # make sure the env VINEYARD_IPC_SOCKET is set properly + +Run test + +.. code:: bash + cd build + export VINEYARD_IPC_SOCKET=/tmp/vineyard_test.sock + ./bin/kv_state_cache_test diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc new file mode 100644 index 00000000..c4df7d00 --- /dev/null +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -0,0 +1,255 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "client/client.h" +#include "common/util/base64.h" +#include "common/util/logging.h" +#include "common/util/status.h" +#include "kv-state-cache/radix-tree/radix-tree.h" +#include "kv_state_cache.h" + +namespace vineyard { + +void KVStateCache::Construct(const ObjectMeta& meta) { + Object::Construct(meta); + Resolve(); +} + +void KVStateCache::Resolve() { + LOG(INFO) << "Resolve"; + std::string typeName = type_name(); + + VINEYARD_ASSERT(this->meta_.GetTypeName() == typeName, + "Expect typename '" + typeName + "', but got '" + + this->meta_.GetTypeName() + "'"); + + // 1. construct the kv_state_cache_block_builder + this->kv_state_cache_block = std::dynamic_pointer_cast( + this->meta_.GetMember("root_kv_state_cache_block")); + // 2. construct the radix tree + this->root_tree = RadixTree::Deserialize( + base64_decode(this->meta_.GetKeyValue("radix_tree"))); + // 3. construct the member field + this->dimension = this->meta_.GetKeyValue("dimension"); +} + +KVStateCache::~KVStateCache() { + // TBD +} + +KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, + int cache_capacity) { + this->dimension = dimension; + this->version = 0; + this->kv_state_cache_block_builder = + std::make_shared(client, this->dimension); + + this->root_tree = std::make_shared(cache_capacity); + this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), + sizeof(KVStateCacheBlockBuilder)); +} + +KVStateCacheBuilder::KVStateCacheBuilder(Client& client, + std::shared_ptr cache) { + // TBD + this->dimension = cache->GetDemension(); + this->version = cache->GetVersion(); + this->kv_state_cache_block_builder = + std::make_shared(client, + cache->GetKVStateCacheBlock()); + + this->root_tree = cache->GetRootTree(); + this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), + sizeof(KVStateCacheBlockBuilder)); +} + +KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( + Client& client, KVStateCacheBlockBuilder* kv_state_cache_block_builder, + std::vector> node_with_tree_attri_list) { + // Split the tree if the list of kv_state is full. + VINEYARD_ASSERT(node_with_tree_attri_list.size() > 0); + KVStateCacheBlockBuilder* child_kv_state_cache_block_builder = + new KVStateCacheBlockBuilder(client, this->dimension); + for (size_t i = 0; i < node_with_tree_attri_list.size(); i++) { + offset_data* data = + (offset_data*) node_with_tree_attri_list[i]->get_node()->get_data(); + int index = data->offset; + + // Transfer the data from this builder to the child builder. + const std::shared_ptr> k_builder = + kv_state_cache_block_builder->getKBuilder(); + const std::shared_ptr> v_builder = + kv_state_cache_block_builder->getVBuilder(); + offset_data* new_offset_data = new offset_data(); + child_kv_state_cache_block_builder->Update( + k_builder->data() + index * this->dimension, + v_builder->data() + index * this->dimension, + this->dimension * sizeof(double), new_offset_data); + node_with_tree_attri_list[i]->get_node()->set_data(new_offset_data, + sizeof(offset_data)); + // Clear the bitmap. + kv_state_cache_block_builder->DeleteKVCache(index); + } + kv_state_cache_block_builder->SetChildKVStateCacheBlockBuilder( + child_kv_state_cache_block_builder); + return child_kv_state_cache_block_builder; +} + +void KVStateCacheBuilder::Update(Client& client, + const std::vector& token_list, + int next_token, + const KV_STATE_WITH_LAYER& kv_state) { + LOG(INFO) << "update"; + std::vector token_list_copy = token_list; + token_list_copy.push_back(next_token); + + // Create a empty node of tokens from radix tree. + std::shared_ptr evicted_node = nullptr; + std::shared_ptr node_with_tree_attri = + this->root_tree->Insert(token_list_copy, evicted_node); + if (node_with_tree_attri == nullptr) { + LOG(INFO) << "insert failed"; + return; + } + std::shared_ptr sub_tree = node_with_tree_attri->get_tree(); + KVStateCacheBlockBuilder* kv_state_cache_block_builder = + (KVStateCacheBlockBuilder*) sub_tree->GetCustomData(); + if (evicted_node != nullptr) { + offset_data* data = (offset_data*) evicted_node->get_node()->get_data(); + KVStateCacheBlockBuilder* builder = + (KVStateCacheBlockBuilder*) evicted_node->get_tree()->GetCustomData(); + builder->DeleteKVCache(data->offset); + + delete (offset_data*) evicted_node->get_node()->get_data(); + } + + // TBD + // Use lock to protect the kv_state_cache_builder + // kv_state_cache_builder->Lock(); + + if (kv_state_cache_block_builder->IsFull()) { + /** + * If the kv-state cache of the tree is full, triggle split. Delete the + * empty node from the radix tree and split the tree. Then, kv-state cache + * split according to the new tree. + */ + std::shared_ptr evicted_node = nullptr; + this->root_tree->Delete(token_list_copy, evicted_node); + std::shared_ptr new_tree = sub_tree->Split(token_list_copy); + + std::vector> node_with_tree_attri_list = + RadixTree::TraverseTreeWithoutSubTree(new_tree); + KVStateCacheBlockBuilder* new_kv_state_cache_block_builder = + Split(client, kv_state_cache_block_builder, node_with_tree_attri_list); + new_tree->SetCustomData(new_kv_state_cache_block_builder, + sizeof(KVStateCacheBlockBuilder)); + + // kv_state_cache_builder->UnLock(); + Update(client, token_list, next_token, kv_state); + } else { + // Update the kv-state cache. + offset_data* data = new offset_data(); + kv_state_cache_block_builder->Update(kv_state, data); + std::shared_ptr node = node_with_tree_attri->get_node(); + node->set_data(data, sizeof(offset_data)); + } + + LOG(INFO) << "bitmap:" << kv_state_cache_block_builder->GetBitmapStr(); +} + +static std::shared_ptr node; + +KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( + Client& client, const std::vector& token_list, int token) { + std::vector token_list_copy = token_list; + token_list_copy.push_back(token); + + KV_STATE_WITH_LAYER kv_state; + std::shared_ptr node_with_tree_attri = + this->root_tree->Query(token_list_copy); + /**/ + if (node_with_tree_attri != nullptr) { + offset_data* data = + (offset_data*) node_with_tree_attri->get_node()->get_data(); + int offset = data->offset; + + KVStateCacheBlockBuilder* kv_state_cache_block_builder = + (KVStateCacheBlockBuilder*) node_with_tree_attri->get_tree() + ->GetCustomData(); + // kv_state_cache_builder->Lock(); + + kv_state_cache_block_builder->Query(client, offset, kv_state); + // kv_state_cache_builder->UnLock(); + node = node_with_tree_attri->get_node(); + } + LOG(INFO) << "query success"; + return kv_state; +} + +std::shared_ptr KVStateCacheBuilder::Merge( + Client& client, std::shared_ptr kv_state_cache) { + // TBD + if (kv_state_cache == nullptr) { + return nullptr; + } + // VINEYARD_ASSERT(false); + return nullptr; +} + +Status KVStateCacheBuilder::Build(Client& client) { + // TBD + return Status::OK(); +} + +std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { + this->Build(client); + + std::shared_ptr kv_state_cache = + std::make_shared(); + + // 1. store the member variables to cache object meta + kv_state_cache->meta_.AddKeyValue("dimension", this->dimension); + + // 2. seal all the kv_state_cache_block + // 3. put cache_block_object_id to cache object meta + kv_state_cache->meta_.AddMember( + "root_kv_state_cache_block", + this->kv_state_cache_block_builder->_Seal(client)); + + // 4. put the serialized sequence radix tree to cache object meta + kv_state_cache->meta_.AddKeyValue( + "radix_tree", base64_encode(this->root_tree->Serialize())); + + // 5. put the object type to the meta + kv_state_cache->meta_.SetTypeName(type_name()); + + VINEYARD_CHECK_OK( + client.CreateMetaData(kv_state_cache->meta_, kv_state_cache->id_)); + LOG(INFO) << "KVStateCacheBuilder::_Seal: " << kv_state_cache->id_; + return kv_state_cache; +} + +KVStateCacheBuilder::~KVStateCacheBuilder() { + // TBD + std::vector> node_with_tree_attri_list = + RadixTree::TraverseTreeWithoutSubTree(this->root_tree); + for (size_t i = 0; i < node_with_tree_attri_list.size(); i++) { + delete (offset_data*) node_with_tree_attri_list[i]->get_node()->get_data(); + } +} + +} // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h new file mode 100644 index 00000000..f1e58e12 --- /dev/null +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -0,0 +1,111 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "common/util/status.h" +#include "kv-state-cache/radix-tree/radix-tree.h" +#include "kv-state-cache/strategy/LRU_strategy.h" +#include "kv_state_cache_block.h" + +#ifndef MODULES_KV_STATE_CACHE_H_ +#define MODULES_KV_STATE_CACHE_H_ + +namespace vineyard { + +class KVStateCache : public vineyard::Registered { + private: + std::shared_ptr kv_state_cache_block; + std::shared_ptr root_tree; + int dimension; + int cache_capacity; + uint64_t version; + + public: + static std::unique_ptr Create() __attribute__((used)) { + return std::static_pointer_cast( + std::unique_ptr{new KVStateCache()}); + } + + void Construct(const ObjectMeta& meta) override; + + void Resolve(); + + // for test + std::shared_ptr GetKVStateCacheBlock() { + return this->kv_state_cache_block; + } + + int GetDemension() { return this->dimension; } + + int GetCacheCapacity() { return this->cache_capacity; } + + uint64_t GetVersion() { return this->version; } + + std::shared_ptr GetRootTree() { return this->root_tree; } + + ~KVStateCache(); + + friend class KVStateCacheBuilder; +}; + +class KVStateCacheBuilder : public vineyard::ObjectBuilder { + std::shared_ptr kv_state_cache_block_builder; + std::shared_ptr root_tree; + int dimension; + uint64_t version; + + public: + KVStateCacheBuilder(Client& client, int dimension, int cache_capacity); + + KVStateCacheBuilder(Client& client, std::shared_ptr cache); + + KVStateCacheBlockBuilder* Split( + Client& client, KVStateCacheBlockBuilder* kv_state_cache_block_builder, + std::vector> + node_with_tree_attri_list); + + void Update(Client& client, const std::vector& token_list, + int next_token, const KV_STATE_WITH_LAYER& kv_state); + + KV_STATE_WITH_LAYER Query(Client& client, const std::vector& token_list, + int token); + + std::shared_ptr Merge( + Client& client, std::shared_ptr kv_state_cache); + + uint64_t GetVersion() { return this->version; } + + Status Build(Client& client) override; + + std::shared_ptr _Seal(Client& client) override; + + std::shared_ptr GetKVStateCacheBlockBuilder() { + return this->kv_state_cache_block_builder; + } + + uint64_t GetDemension() { return this->dimension; } + + std::shared_ptr GetRootTree() { return this->root_tree; } + + ~KVStateCacheBuilder(); +}; + +} // namespace vineyard + +#endif \ No newline at end of file diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/kv-state-cache/ds/kv_state_cache_block.cc new file mode 100644 index 00000000..64079ded --- /dev/null +++ b/modules/kv-state-cache/ds/kv_state_cache_block.cc @@ -0,0 +1,227 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "kv_state_cache_block.h" +#include "client/client.h" +#include "common/util/logging.h" + +namespace vineyard { + +// this function will be removed in the future +std::string KVStateCacheBlock::GetBitmapStr() { + std::string result; + const int bits = 8 * sizeof(unsigned long long); + for (int i = bits - 1; i >= 0; --i) { + result += ((this->bitmap >> i) & 1) ? '1' : '0'; + } + return result; +} + +std::string KVStateCacheBlockBuilder::GetBitmapStr() { + std::string result; + const int bits = 8 * sizeof(unsigned long long); + for (int i = bits - 1; i >= 0; --i) { + result += ((this->bitmap >> i) & 1) ? '1' : '0'; + } + return result; +} + +void KVStateCacheBlock::Construct(const ObjectMeta& meta) { + Object::Construct(meta); + + std::string typeName = type_name(); + + VINEYARD_ASSERT(meta.GetTypeName() == typeName, + "Expect typename '" + typeName + "', but got '" + + meta.GetTypeName() + "'"); + + // TBD + // 1. construct the k_builder and v_builder + this->k_tensor = std::dynamic_pointer_cast>( + this->meta_.GetMember("k_builder")); + this->v_tensor = std::dynamic_pointer_cast>( + this->meta_.GetMember("v_builder")); + // 2. construct the child kv_state_cache_block_builder + int child_num = this->meta_.GetKeyValue("child_num"); + for (int i = 0; i < child_num; ++i) { + std::shared_ptr child_kv_state_cache_block_builder = + std::dynamic_pointer_cast(this->meta_.GetMember( + "child_kv_state_cache_block_" + std::to_string(i))); + this->child_kv_state_cache_block_list.push_back( + child_kv_state_cache_block_builder); + } + // 3. construct the member field + this->bitmap = this->meta_.GetKeyValue("bitmap"); + this->dimension = this->meta_.GetKeyValue("dimension"); +} + +KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, + int dimension) { + pthread_spin_init(&(this->spin_lock), 0); + this->bitmap = UINT64_MAX; + std::vector shape = {LIST_SIZE, dimension}; + this->k_builder = std::make_shared>(client, shape); + this->v_builder = std::make_shared>(client, shape); + this->dimension = dimension; +} + +KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( + Client& client, std::shared_ptr kv_state_cache_block) { + pthread_spin_init(&(this->spin_lock), 0); + this->bitmap = kv_state_cache_block->bitmap; + this->dimension = kv_state_cache_block->dimension; + std::vector shape = {LIST_SIZE, dimension}; + this->k_builder = std::make_shared>(client, shape); + this->v_builder = std::make_shared>(client, shape); + + // transfer the data from kv_state_cache to this builder + memcpy(this->k_builder->data(), kv_state_cache_block->k_tensor->data(), + LIST_SIZE * this->dimension * sizeof(double)); + memcpy(this->v_builder->data(), kv_state_cache_block->v_tensor->data(), + LIST_SIZE * this->dimension * sizeof(double)); + for (size_t i = 0; + i < kv_state_cache_block->child_kv_state_cache_block_list.size(); ++i) { + this->child_kv_state_cache_builder_list.push_back( + new KVStateCacheBlockBuilder( + client, kv_state_cache_block->child_kv_state_cache_block_list[i])); + } +} + +// current we do not consider the layer. +Status KVStateCacheBlockBuilder::Query(Client& client, int index, + KV_STATE_WITH_LAYER& kv_state) { + std::vector k_state; + std::vector v_state; + + for (int i = 0; i < this->dimension; ++i) { + k_state.push_back(((double*) k_builder->data())[index * dimension + i]); + } + + for (int i = 0; i < this->dimension; ++i) { + v_state.push_back(((double*) v_builder->data())[index * dimension + i]); + } + + kv_state.insert(std::make_pair(1, std::make_pair(k_state, v_state))); + return Status::OK(); +} + +int KVStateCacheBlockBuilder::FindEmptySlot() { + int index = ffsll(this->bitmap) - 1; + VINEYARD_ASSERT(index >= 0 && index < LIST_SIZE); + return index; +} + +bool KVStateCacheBlockBuilder::IsFull() { + int index = ffsll(this->bitmap) - 1; + return index < 0 || index >= LIST_SIZE; +} + +void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kv_state, + offset_data* data) { + int index = this->FindEmptySlot(); + LOG(INFO) << "index:" << index; + std::vector k_state = (kv_state.find(1)->second).first; + std::vector v_state = (kv_state.find(1)->second).second; + VINEYARD_ASSERT(k_state.size() == (size_t) this->dimension); + VINEYARD_ASSERT(v_state.size() == (size_t) this->dimension); + + double* key_data = (double*) k_builder->data(); + double* value_data = (double*) v_builder->data(); + for (int i = 0; i < this->dimension; ++i) { + key_data[index * this->dimension + i] = k_state[i]; + } + for (int i = 0; i < this->dimension; ++i) { + value_data[index * this->dimension + i] = v_state[i]; + } + data->offset = index; + + LOG(INFO) << "before:" << this->bitmap; + ACQUIRE_BIT_RESOURCE(this->bitmap, index); + LOG(INFO) << "after:" << this->bitmap; +} + +void KVStateCacheBlockBuilder::Update(double* k_data, double* v_data, + unsigned long data_length, + offset_data* data) { + int index = FindEmptySlot(); + double* key_data = (double*) k_builder->data(); + double* value_data = (double*) v_builder->data(); + VINEYARD_ASSERT((unsigned long) this->dimension == data_length); + for (unsigned long i = 0; i < data_length; ++i) { + key_data[index * this->dimension + i] = k_data[i]; + } + for (unsigned long i = 0; i < data_length; ++i) { + value_data[index * this->dimension + i] = v_data[i]; + } + data->offset = index; + + ACQUIRE_BIT_RESOURCE(this->bitmap, index); + LOG(INFO) << "bitmap:" << this->GetBitmapStr(); +} + +void KVStateCacheBlockBuilder::SetChildKVStateCacheBlockBuilder( + KVStateCacheBlockBuilder* child_kv_state_cache_builder) { + this->child_kv_state_cache_builder_list.push_back( + child_kv_state_cache_builder); +} + +Status KVStateCacheBlockBuilder::Build(Client& client) { + // TBD craete vineyard object + // pthread_spin_lock(&(this->spin_lock)); + // ObjectMeta meta; + // meta.SetTypeName(type_name()); + // meta.AddKeyValue("bitmap", this->bitmap); + // for (int i = 0; i < LIST_SIZE; ++i) { + // // TBD + // // create tensor meta + // } + // // TBD check the status + // client.CreateMetaData(meta, id); + // pthread_spin_unlock(&(this->spin_lock)); + return Status::OK(); +} + +std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { + this->Build(client); + // pthread_spin_lock(&(this->spin_lock)); + // pthread_spin_unlock(&(this->spin_lock)); + + std::shared_ptr kv_state_cache_block = + std::make_shared(); + + // TBD + // 1. seal k_builder and v_builder + kv_state_cache_block->meta_.AddMember("k_builder", k_builder->Seal(client)); + kv_state_cache_block->meta_.AddMember("v_builder", v_builder->Seal(client)); + // 2. seal child kv_state_cache_block_builder + for (size_t i = 0; i < this->child_kv_state_cache_builder_list.size(); ++i) { + kv_state_cache_block->meta_.AddMember( + "child_kv_state_cache_block_" + std::to_string(i), + this->child_kv_state_cache_builder_list[i]->_Seal(client)); + } + kv_state_cache_block->meta_.AddKeyValue( + "child_num", this->child_kv_state_cache_builder_list.size()); + // 3. store the member field to meta + kv_state_cache_block->meta_.AddKeyValue("bitmap", this->bitmap); + kv_state_cache_block->meta_.AddKeyValue("dimension", this->dimension); + // 4. set the object type to meta + kv_state_cache_block->meta_.SetTypeName(type_name()); + + VINEYARD_CHECK_OK(client.CreateMetaData(kv_state_cache_block->meta_, + kv_state_cache_block->id_)); + return kv_state_cache_block; +} + +} // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h new file mode 100644 index 00000000..222c2ba4 --- /dev/null +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -0,0 +1,166 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_KV_STATE_CACHE_BLOCK_H_ +#define MODULES_KV_STATE_CACHE_BLOCK_H_ + +#include +#include +#include +#include + +#include "basic/ds/tensor.h" +#include "client/ds/blob.h" +#include "client/ds/i_object.h" + +typedef std::map, std::vector>> + KV_STATE_WITH_LAYER; +typedef std::vector< + std::map, std::vector>>> + LIST_KV_STATE_WITH_LAYER; + +// Set the bit to 1, which means the resource is not being used +#define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit))) + +// Set the bit to 0, which means the resource is being used +#define ACQUIRE_BIT_RESOURCE(value, bit) \ + ((value) &= (~(((uint64_t) 1) << (bit)))) + +struct offset_data { + short offset; +}; + +namespace vineyard { + +#define LIST_SIZE 64 + +/** + * @brief KVStateCacheBlock is a cache for kv-cache of LLM. When a new prompt + * comes, LLM can query KVStateCacheBlock to get the state of the kv-cache to + * avoid caclulating the kv-cache again if the new prompt is similar to the + * previous one. + * + * KVStateCacheBlock is stored in vineyard as a vineyard object which contains a + * radix tree. The token sequence is the key of the radix tree and the value + * point out the offset of the kv-cache in the tensor list. + * + * KVStateCacheBlock can be shared by multiple machines. + */ + +class KVStateCacheBlock : public vineyard::Registered { + private: + std::shared_ptr> k_tensor; + std::shared_ptr> v_tensor; + std::vector> + child_kv_state_cache_block_list; + uint64_t bitmap; + ObjectID id; + int dimension; + + public: + static std::unique_ptr Create() __attribute__((used)) { + return std::static_pointer_cast( + std::unique_ptr{new KVStateCacheBlock()}); + } + + void Construct(const ObjectMeta& meta) override; + + std::string GetBitmapStr(); + + uint64_t GetDimension() { return this->dimension; } + + uint64_t GetBitmap() { return this->bitmap; } + + std::shared_ptr> GetKTensor() { return this->k_tensor; } + + std::shared_ptr> GetVTensor() { return this->v_tensor; } + + friend class KVStateCacheBlockBuilder; +}; + +class KVStateCacheBlockBuilder : public ObjectBuilder { + private: + std::shared_ptr> k_builder; + std::shared_ptr> v_builder; + std::vector child_kv_state_cache_builder_list; + // TBD + // support more than 64 kv-state cache slots + uint64_t bitmap; + pthread_spinlock_t spin_lock; + int dimension; + + int FindEmptySlot(); + + public: + KVStateCacheBlockBuilder(Client& client, int dimension); + + KVStateCacheBlockBuilder( + Client& client, std::shared_ptr kv_state_cache_block); + + /** + * @brief Update the kv-state using next token. + * + * @param client The vineyard client. + * @param kv_state The kv-state of the prompt. A LLM inference can contain + * multiple kv-states for each layer. + */ + void Update(const KV_STATE_WITH_LAYER& kv_state, offset_data* data); + + void Update(double* k_data, double* v_data, unsigned long data_length, + offset_data* data); + + /** + * @brief Query the kv-state using the whole token list. + * + * @param token_list The token list of the prompt. + * @param token The token of the prompt. + * @param kv_state The kv-state of the prompt returned by radix-tree. If the + * kv-state is not found, the data of kv-state is invalid. + */ + Status Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state); + + bool IsFull(); + + Status Build(Client& client) override; + + std::shared_ptr _Seal(Client& client) override; + + void Lock() { pthread_spin_lock(&(this->spin_lock)); } + + void UnLock() { pthread_spin_unlock(&(this->spin_lock)); } + + const std::shared_ptr> getKBuilder() { + return k_builder; + } + + const std::shared_ptr> getVBuilder() { + return v_builder; + } + + void DeleteKVCache(int bit) { FREE_BIT_RESOURCE(this->bitmap, bit); } + + void SetChildKVStateCacheBlockBuilder( + KVStateCacheBlockBuilder* child_kv_state_cache_builder); + + std::string GetBitmapStr(); + + uint64_t GetBitmap() { return this->bitmap; } + + uint64_t GetDimension() { return this->dimension; } +}; + +} // namespace vineyard + +#endif \ No newline at end of file diff --git a/modules/kv-state-cache/lz4 b/modules/kv-state-cache/lz4 new file mode 160000 index 00000000..4cf83dd1 --- /dev/null +++ b/modules/kv-state-cache/lz4 @@ -0,0 +1 @@ +Subproject commit 4cf83dd1952898e2a8c5fcd689ce459c53f22ff0 diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h new file mode 100644 index 00000000..ce46360e --- /dev/null +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -0,0 +1,473 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef RADIX_TREE_H +#define RADIX_TREE_H + +#include "radix.h" + +#include "common/util/base64.h" +#include "common/util/logging.h" +#include "kv-state-cache/strategy/LRU_strategy.h" +#include "lz4.h" + +#include +#include +#include +#include + +using namespace vineyard; + +typedef struct nodeData { + int data_length; + void* data; + std::shared_ptr cache_node; +} nodeData; + +class Node { + private: + nodeData* data; + raxNode* node; + + public: + Node(raxNode* node) { + this->data = (nodeData*) raxGetData(node); + this->node = node; + } + + Node(nodeData* data) { + this->data = data; + this->node = NULL; + } + + void set_data(void* data, int data_length) { + if (this->node == NULL) { + LOG(INFO) << "set data failed, node is null"; + return; + } + this->data->data = data; + this->data->data_length = data_length; + raxSetData(this->node, this->data); + } + + void set_cache_node(std::shared_ptr cache_node) { + if (this->node == NULL) { + LOG(INFO) << "set data failed, node is null"; + return; + } + this->data->cache_node = cache_node; + raxSetData(this->node, this->data); + } + + void* get_data() { return this->data->data; } + + int get_data_length() { return this->data->data_length; } + + std::shared_ptr get_cache_node() { + return this->data->cache_node; + } +}; + +class RadixTree; + +class NodeWithTreeAttri { + private: + std::shared_ptr node; + std::shared_ptr belong_to; + + public: + NodeWithTreeAttri(std::shared_ptr node, + std::shared_ptr belong_to) { + this->node = node; + this->belong_to = belong_to; + } + + std::shared_ptr get_node() { return node; } + + std::shared_ptr get_tree() { return belong_to; } +}; + +class RadixTree : public std::enable_shared_from_this { + private: + void* custom_data; + int custom_data_length; + // the whole radix tree for prefix match + rax* tree; + // the sub tree for mapping a vineyard object + // rax* sub_tree; + LRUStrategy* lru_strategy; + + public: + RadixTree(int cache_capacity) { + LOG(INFO) << "init radix tree"; + this->tree = raxNew(); + // this->sub_tree = this->tree; + this->custom_data = NULL; + this->custom_data_length = 0; + lru_strategy = new LRUStrategy(cache_capacity); + } + + RadixTree(void* custom_data, int custom_data_length, int cache_capacity) { + LOG(INFO) << "init radix tree with custom data"; + this->tree = raxNew(); + // this->sub_tree = this->tree; + this->custom_data = custom_data; + this->custom_data_length = custom_data_length; + this->lru_strategy = new LRUStrategy(cache_capacity); + } + + ~RadixTree() { + // TBD + // free all the node and the whole tree. + } + + std::shared_ptr Insert( + std::vector tokens, + std::shared_ptr evicted_node) { + // insert the token vector to the radix tree + int* insert_tokens_array = tokens.data(); + size_t insert_tokens_array_len = tokens.size(); + nodeData* dummy_data = new nodeData(); + nodeData* old_data; + raxNode* dataNode = NULL; + int retval = raxInsertAndReturnDataNode( + this->tree, insert_tokens_array, insert_tokens_array_len, dummy_data, + (void**) &dataNode, (void**) &old_data); + if (dataNode == NULL) { + LOG(INFO) << "insert failed"; + return NULL; + } + LOG(INFO) << "insert success"; + + if (retval == 0) { + // (retval == 0 ) means the token vector already exists in the radix tree + // remove the token vector from the lru cache as it will be inserted again + std::shared_ptr node = std::make_shared(old_data); + std::shared_ptr cache_node = node->get_cache_node(); + lru_strategy->Remove(cache_node); + delete old_data; + } + + // refresh the lru cache + std::vector evicted_tokens; + std::shared_ptr cache_node = + lru_strategy->InsertToHeader(tokens, evicted_tokens); + if (cache_node == nullptr) { + LOG(INFO) << "WTF?"; + } + dummy_data->cache_node = cache_node; + raxSetData(dataNode, dummy_data); + if (evicted_tokens.size() > 0) { + this->Delete(evicted_tokens, evicted_node); + } + + return std::make_shared(std::make_shared(dataNode), + shared_from_this()); + } + + void Delete(std::vector tokens, + std::shared_ptr& evicted_node) { + // remove the token vector from the radix tree + int* delete_tokens_array = tokens.data(); + size_t delete_tokens_array_len = tokens.size(); + + nodeData* old_data; + int retval = raxRemove(this->tree, delete_tokens_array, + delete_tokens_array_len, (void**) &old_data); + if (retval == 1) { + LOG(INFO) << "remove success"; + std::shared_ptr node = std::make_shared(old_data); + evicted_node = + std::make_shared(node, shared_from_this()); + delete old_data; + } else { + LOG(INFO) << "remove failed"; + } + } + + std::shared_ptr Query(std::vector key) { + LOG(INFO) << "Query"; + int* tokens = key.data(); + size_t tokens_len = key.size(); + + LOG(INFO) << "Query with tokens_len:" << tokens_len; + if (this->tree == nullptr) { + LOG(INFO) << "WTF!"; + return NULL; + } + + raxNode* dataNode = + raxFindAndReturnDataNode(this->tree, tokens, tokens_len); + if (dataNode == NULL) { + LOG(INFO) << "get failed"; + return NULL; + } + LOG(INFO) << "get success"; + + // refresh the lru cache + std::shared_ptr node = std::make_shared(dataNode); + std::shared_ptr cache_node = node->get_cache_node(); + lru_strategy->MoveToHead(cache_node); + + return std::make_shared(node, shared_from_this()); + } + + std::string Serialize() { + std::vector> token_list; + std::vector data_list; + raxSerialize(this->tree, token_list, data_list); + + std::map, bool> cache_node_map; + std::shared_ptr current_node = + this->lru_strategy->GetHeader(); + + // the string format is: + // [token list] [data hex string]\n + // E.g + // tokens | data + // 1,2|0800000008000000xxxx + std::string serialized_str; + while (current_node != nullptr) { + cache_node_map[current_node] = true; + auto it = std::lower_bound(token_list.begin(), token_list.end(), + current_node->tokens); + + if (it != token_list.end() && *it == current_node->tokens) { + // get the index of the token vector via binary search + int index = std::distance(token_list.begin(), it); + for (size_t i = 0; i < (*it).size(); i++) { + serialized_str += std::to_string((*it)[i]); + if (i < (*it).size() - 1) { + serialized_str += ","; + } + } + // serialized_str += "|" + std::to_string(index) + "|"; + serialized_str += "|"; + + // convert data to hex string + char* bytes = (char*) ((nodeData*) data_list[index])->data; + std::ostringstream oss; + + for (int i = 0; i < ((nodeData*) data_list[index])->data_length; ++i) { + oss << bytes[i]; + } + serialized_str += oss.str() + "\n"; + } else { + throw std::runtime_error("The token vector is not in the radix tree"); + } + current_node = current_node->next; + } + + // use LZ4 to compress the serialized string + const char* const src = serialized_str.c_str(); + const int src_size = serialized_str.size(); + const int max_dst_size = LZ4_compressBound(src_size); + char* compressed_data = new char[max_dst_size]; + if (compressed_data == NULL) { + LOG(INFO) << "Failed to allocate memory for *compressed_data."; + } + + const int compressed_data_size = + LZ4_compress_default(src, compressed_data, src_size, max_dst_size); + if (compressed_data_size <= 0) { + LOG(INFO) << "A 0 or negative result from LZ4_compress_default() " + "indicates a failure trying to compress the data. "; + } + + if (compressed_data_size > 0) { + LOG(INFO) << "We successfully compressed some data! Ratio: " + << ((float) compressed_data_size / src_size); + } + + // compressed_data = + // (char*) realloc(compressed_data, (size_t) compressed_data_size); + if (compressed_data == NULL) { + LOG(INFO) << "Failed to re-alloc memory for compressed_data. Sad :("; + } + + std::string compressed_str = + std::string(compressed_data, compressed_data_size); + std::string result = + std::string((char*) &src_size, sizeof(int)) + compressed_str; + delete[] compressed_data; + return result; + } + + static std::shared_ptr Deserialize(std::string data) { + // use LZ4 to decompress the serialized string + int src_size = *(int*) data.c_str(); + data.erase(0, sizeof(int)); + char* const decompress_buffer = new char[src_size]; + if (decompress_buffer == NULL) { + LOG(INFO) << "Failed to allocate memory for *decompress_buffer."; + } + + const int decompressed_size = LZ4_decompress_safe( + data.c_str(), decompress_buffer, data.size(), src_size); + if (decompressed_size < 0) { + LOG(INFO) << "A negative result from LZ4_decompress_safe indicates a " + "failure trying to decompress the data. See exit code " + "(echo $?) for value returned."; + } + if (decompressed_size >= 0) { + LOG(INFO) << "We successfully decompressed some data!"; + } + // if (decompressed_size != data.size()) { + // LOG(INFO) << "Decompressed data is different from original! \n"; + // } + data = std::string(decompress_buffer, decompressed_size); + delete[] decompress_buffer; + + std::vector> token_list; + std::vector data_list; + std::vector data_size_list; + std::istringstream iss(data); + std::string line; + + while (std::getline(iss, line)) { + std::istringstream lineStream(line); + std::string tokenListPart, dataPart; + + if (!std::getline(lineStream, tokenListPart, '|')) { + throw std::runtime_error( + "Invalid serialized string format in key part."); + } + if (!std::getline(lineStream, dataPart)) { + throw std::runtime_error( + "Invalid serialized string format in data part."); + } + + std::istringstream keyStream(tokenListPart); + std::string token; + std::vector keys; + while (std::getline(keyStream, token, ',')) { + keys.push_back(std::stoi(token)); + } + + // size_t dataSize = dataPart.length() / 2; + size_t dataSize = dataPart.length(); + data_size_list.push_back(dataSize); + + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + char* data = new char[dataSize]; + std::istringstream dataStream(dataPart); + // for (size_t i = 0; i < dataSize; ++i) { + // // Temporary buffer to store two hexadecimal chars + null + // terminator char hex[3] = {}; + // // Read two characters for one byte + // if (!dataStream.read(hex, 2)) { + // delete[] data; + // LOG(INFO) << "Invalid data format."; + // throw std::runtime_error("Invalid data format."); + // } + // // Convert the two hex characters to one byte + // unsigned int byte; + // std::istringstream hexStream(hex); + // if (!(hexStream >> std::hex >> byte)) { + // delete[] data; + // LOG(INFO) << "Invalid data format."; + // throw std::runtime_error("Invalid data format."); + // } + // reinterpret_cast(data)[i] = static_cast(byte); + // } + if (!dataStream.read(data, dataSize)) { + delete[] data; + LOG(INFO) << "Invalid data."; + } + + token_list.push_back(keys); + data_list.push_back(data); + } + + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + std::shared_ptr radix_tree = std::make_shared(10); + // nodeData* dummy_data = new nodeData(); + // rax *root = raxNew(); + // for (int i = token_list.size()-1; i >= 0; i--) { + // LOG(INFO) << "insert token list:"; + // for (int j = 0; j < token_list[i].size(); j++) { + // LOG(INFO) << token_list[i][j]; + // } + // if (raxInsert(root, token_list[i].data(), token_list[i].size(), + // dummy_data, NULL) != 1) { + // LOG(INFO) << "Insert failed"; + // return NULL; + // } + // std::vector evicted_tokens; + // std::shared_ptr cache_node = + // radix_tree->lru_strategy->InsertToHeader(token_list[i], + // evicted_tokens); + // if (cache_node == nullptr) { + // LOG(INFO) << "WTF?"; + // } + // dummy_data->cache_node = cache_node; + // } + for (int i = token_list.size() - 1; i >= 0; i--) { + std::shared_ptr evicted_node; + std::shared_ptr node = + radix_tree->Insert(token_list[i], evicted_node); + node->get_node()->set_data(data_list[i], data_size_list[i]); + } + return radix_tree; + } + + std::shared_ptr Split(std::vector tokens) { + nodeData* dummy_data = new nodeData(); + raxNode* sub_tree_root_node = + raxSplit(this->tree, tokens.data(), tokens.size(), dummy_data); + + // TBD + // if the sub_tree is null, delete this pointer. + std::shared_ptr sub_tree = + std::make_shared(this->lru_strategy->GetCapacity()); + sub_tree->tree = this->tree; + rax* sub_rax = raxNew(); + sub_rax->head = sub_tree_root_node; + return sub_tree; + } + + // Get child node list from this tree. + static std::vector> + TraverseTreeWithoutSubTree(std::shared_ptr radix_tree) { + std::vector> nodes; + if (radix_tree == NULL) { + LOG(INFO) << "traverse failed"; + return nodes; + } + + std::vector> dataNodeList; + raxNode* headNode = radix_tree->tree->head; + raxTraverseSubTree(headNode, dataNodeList); + for (size_t i = 0; i < dataNodeList.size(); i++) { + nodes.push_back(std::make_shared( + std::make_shared(dataNodeList[i].get()), radix_tree)); + } + return nodes; + } + + void* GetCustomData() { return custom_data; } + + void SetCustomData(void* custom_data, int custom_data_length) { + this->custom_data = custom_data; + this->custom_data_length = custom_data_length; + } +}; + +#endif diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc new file mode 100644 index 00000000..2a6eaf5f --- /dev/null +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -0,0 +1,2181 @@ +/* Rax -- A radix tree implementation. + * + * Version 1.2 -- 7 February 2019 + * + * Copyright (c) 2017-2019, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include +#include +#include +#include "radix.h" + +#ifndef RAX_MALLOC_INCLUDE +#define RAX_MALLOC_INCLUDE "rax_malloc.h" +#endif + +#include RAX_MALLOC_INCLUDE + +/* This is a special pointer that is guaranteed to never have the same value + * of a radix tree node. It's used in order to report "not found" error without + * requiring the function to have multiple return values. */ +void *raxNotFound = (void*)"rax-not-found-pointer"; + +/* -------------------------------- Debugging ------------------------------ */ + +void raxDebugShowNode(const char *msg, raxNode *n); + +/* Turn debugging messages on/off by compiling with RAX_DEBUG_MSG macro on. + * When RAX_DEBUG_MSG is defined by default Rax operations will emit a lot + * of debugging info to the standard output, however you can still turn + * debugging on/off in order to enable it only when you suspect there is an + * operation causing a bug using the function raxSetDebugMsg(). */ +#ifdef RAX_DEBUG_MSG +#define debugf(...) \ + if (raxDebugMsg) { \ + printf("%s:%s:%d:\t", __FILE__, __FUNCTION__, __LINE__); \ + printf(__VA_ARGS__); \ + fflush(stdout); \ + } + +#define debugnode(msg,n) raxDebugShowNode(msg,n) +#else +#define debugf(...) +#define debugnode(msg,n) +#endif + +/* By default log debug info if RAX_DEBUG_MSG is defined. */ +static int raxDebugMsg = 1; + +/* When debug messages are enabled, turn them on/off dynamically. By + * default they are enabled. Set the state to 0 to disable, and 1 to + * re-enable. */ +void raxSetDebugMsg(int onoff) { + raxDebugMsg = onoff; +} + +/* ------------------------- raxStack functions -------------------------- + * The raxStack is a simple stack of pointers that is capable of switching + * from using a stack-allocated array to dynamic heap once a given number of + * items are reached. It is used in order to retain the list of parent nodes + * while walking the radix tree in order to implement certain operations that + * need to navigate the tree upward. + * ------------------------------------------------------------------------- */ + +/* Initialize the stack. */ +static inline void raxStackInit(raxStack *ts) { + ts->stack = ts->static_items; + ts->items = 0; + ts->maxitems = RAX_STACK_STATIC_ITEMS; + ts->oom = 0; +} + +/* Push an item into the stack, returns 1 on success, 0 on out of memory. */ +static inline int raxStackPush(raxStack *ts, void *ptr) { + if (ts->items == ts->maxitems) { + if (ts->stack == ts->static_items) { + ts->stack = (void **)rax_malloc(sizeof(void*)*ts->maxitems*2); + if (ts->stack == NULL) { + ts->stack = ts->static_items; + ts->oom = 1; + errno = ENOMEM; + return 0; + } + memcpy(ts->stack,ts->static_items,sizeof(void*)*ts->maxitems); + } else { + void **newalloc = (void **)rax_realloc(ts->stack,sizeof(void*)*ts->maxitems*2); + if (newalloc == NULL) { + ts->oom = 1; + errno = ENOMEM; + return 0; + } + ts->stack = newalloc; + } + ts->maxitems *= 2; + } + ts->stack[ts->items] = ptr; + ts->items++; + return 1; +} + +/* Pop an item from the stack, the function returns NULL if there are no + * items to pop. */ +static inline void *raxStackPop(raxStack *ts) { + if (ts->items == 0) return NULL; + ts->items--; + return ts->stack[ts->items]; +} + +/* Return the stack item at the top of the stack without actually consuming + * it. */ +static inline void *raxStackPeek(raxStack *ts) { + if (ts->items == 0) return NULL; + return ts->stack[ts->items-1]; +} + +/* Free the stack in case we used heap allocation. */ +static inline void raxStackFree(raxStack *ts) { + if (ts->stack != ts->static_items) rax_free(ts->stack); +} + +/* Add the number of nodes in the stack to each node. */ +void raxStackAddNumNodes(raxStack *stack, int num) { + for (int i=0; iitems; i++) { + raxNode *node = (raxNode *)stack->stack[i]; + node->numnodes+=(num); + } +} +/* ---------------------------------------------------------------------------- + * Radix tree implementation + * --------------------------------------------------------------------------*/ + +/* Return the padding needed in the tokens section of a node having size + * 'nodesize'. The padding is needed to store the child pointers to aligned + * addresses. Note that we add 4 to the node size because the node has a four + * bytes header. */ +#define raxPadding(nodesize) ((sizeof(void*) - ((nodesize * sizeof(int) + 4) % sizeof(void*))) & (sizeof(void*)-1)) + +/* Return the pointer to the last child pointer in a node. For the compressed + * nodes this is the only child pointer. */ +#define raxNodeLastChildPtr(n) ((raxNode**) ( \ + ((char*)(n)) + \ + raxNodeCurrentLength(n) - \ + sizeof(raxNode*) - \ + (((n)->iskey && !(n)->isnull) ? sizeof(void*) : 0) \ +)) + +/* Return the pointer to the first child pointer. */ +#define raxNodeFirstChildPtr(n) ((raxNode**) ( \ + (char*)((n)->data) + \ + ((n)->size * sizeof(int)) + \ + raxPadding((n)->size))) + +/* Return the current total size of the node. Note that the second line + * computes the padding after the list of tokens, needed in order to + * save pointers to aligned addresses. */ +#define raxNodeCurrentLength(n) ( \ + sizeof(raxNode) + \ + (n)->size * sizeof(int) + \ + raxPadding((n)->size) + \ + ((n)->iscompr ? sizeof(raxNode*) : sizeof(raxNode*) * (n)->size) + \ + (((n)->iskey && !(n)->isnull) * sizeof(void*)) \ +) + +/* Allocate a new non compressed node with the specified number of children. + * If datafiled is true, the allocation is made large enough to hold the + * associated data pointer. + * Returns the new node pointer. On out of memory NULL is returned. */ +raxNode *raxNewNode(size_t children, int datafield) { + size_t nodesize = sizeof(raxNode) + children * sizeof(int) + + raxPadding(children) + sizeof(raxNode*) * children; + if (datafield) nodesize += sizeof(void*); + raxNode *node = (raxNode *)rax_malloc(nodesize); + if (node == NULL) return NULL; + node->iskey = 0; + node->isnull = 0; + node->iscompr = 0; + node->numnodes = 1; + node->size = children; + return node; +} + +/* Allocate a new rax and return its pointer. On out of memory the function + * returns NULL. */ +rax *raxNew(void) { + rax *rax = (struct rax *)rax_malloc(sizeof(*rax)); + if (rax == NULL) return NULL; + rax->numele = 0; + rax->numnodes = 1; + rax->head = raxNewNode(0,0); + if (rax->head == NULL) { + rax_free(rax); + return NULL; + } else { + return rax; + } +} + +/* realloc the node to make room for auxiliary data in order + * to store an item in that node. On out of memory NULL is returned. */ +raxNode *raxReallocForData(raxNode *n, void *data) { + if (data == NULL) return n; /* No reallocation needed, setting isnull=1 */ + size_t curlen = raxNodeCurrentLength(n); + return (raxNode *)rax_realloc(n,curlen+sizeof(void*)); +} + +/* Set the node auxiliary data to the specified pointer. */ +void raxSetData(raxNode *n, void *data) { + n->iskey = 1; + if (data != NULL) { + n->isnull = 0; + void **ndata = (void**) + ((char*)n+raxNodeCurrentLength(n)-sizeof(void*)); + memcpy(ndata,&data,sizeof(data)); + } else { + n->isnull = 1; + } +} + +/* Get the node auxiliary data. */ +void *raxGetData(raxNode *n) { + if (n->isnull) return NULL; + void **ndata =(void**)((char*)n+raxNodeCurrentLength(n)-sizeof(void*)); + void *data; + memcpy(&data,ndata,sizeof(data)); + return data; +} + +/* Add a new child to the node 'n' representing the token 'c' and return + * its new pointer, as well as the child pointer by reference. Additionally + * '***parentlink' is populated with the raxNode pointer-to-pointer of where + * the new child was stored, which is useful for the caller to replace the + * child pointer if it gets reallocated. + * + * On success the new parent node pointer is returned (it may change because + * of the realloc, so the caller should discard 'n' and use the new value). + * On out of memory NULL is returned, and the old node is still valid. */ +raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlink) { + assert(n->iscompr == 0); + + size_t curlen = raxNodeCurrentLength(n); + n->size++; + size_t newlen = raxNodeCurrentLength(n); + n->size--; /* For now restore the orignal size. We'll update it only on + success at the end. */ + + /* Alloc the new child we will link to 'n'. */ + raxNode *child = raxNewNode(0,0); + if (child == NULL) return NULL; + + int parent_numnodes = n->numnodes; + /* Make space in the original node. */ + raxNode *newn = (raxNode *)rax_realloc(n,newlen); + if (newn == NULL) { + rax_free(child); + return NULL; + } + n = newn; + + /* After the reallocation, we have up to 8/16 (depending on the system + * pointer size, and the required node padding) bytes at the end, that is, + * the additional token in the 'data' section, plus one pointer to the new + * child, plus the padding needed in order to store addresses into aligned + * locations. + * + * So if we start with the following node, having [1,2,4,5] edges. + * + * Note: + * - We assume 4 bytes pointer for simplicity. + * - Each space below corresponds to four byte + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr]|AUXP| + * + * After the reallocation we need: 4 byte for the new edge token + * plus 4 bytes for a new child pointer (assuming 32 bit machine). + * However after adding 4 byte to the edge tokens, the header + the edge + * tokens are no longer aligned, so we also need 4 bytes of padding. + * In total the reallocation will add 4+4 bytes = 8 bytes: + * + * (Blank bytes are represented by ".") + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr]|AUXP|[....][....] + * + * Let's find where to insert the new child in order to make sure + * it is inserted in-place lexicographically. Assuming we are adding + * a child token 3 in our case pos will be = 2 after the end of the following + * loop. */ + int pos; + for (pos = 0; pos < n->size; pos++) { + if (n->data[pos] > c) break; + } + + /* Now, if present, move auxiliary data pointer at the end + * so that we can mess with the other data without overwriting it. + * We will obtain something like that: + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr][....][....]|AUXP| + */ + char *src, *dst; + if (n->iskey && !n->isnull) { + src = ((char*)n+curlen-sizeof(void*)); + dst = ((char*)n+newlen-sizeof(void*)); + memmove(dst,src,sizeof(void*)); + } + + /* Compute the "shift", that is, how many bytes we need to move the + * pointers section forward because of the addition of the new child + * byte in the string section. Note that if we had no padding, that + * would be always "4", since we are adding a single token in the string + * section of the node (where now there is [1,2,4,5] basically). + * + * However we have padding, so it could be zero, or up to 8. + * + * Another way to think at the shift is, how many bytes we need to + * move child pointers forward *other than* the obvious sizeof(void*) + * needed for the additional pointer itself. */ + size_t shift = newlen - curlen - sizeof(void*); + + /* We said we are adding a node with edge token 3. The insertion + * point is between token 2 and token 4, so the 'pos' variable value is + * the index of the first child pointer that we need to move forward + * to make space for our new pointer. + * + * To start, move all the child pointers after the insertion point + * of shift+sizeof(pointer) bytes on the right, to obtain: + * + * [HDR*][1,2][4,5][1ptr][2ptr][....][....][4ptr][5ptr]|AUXP| + */ + src = (char *)n->data+(n->size)*sizeof(int)+ + raxPadding(n->size)+ + sizeof(raxNode*)*pos; + memmove(src+shift+sizeof(raxNode*),src,sizeof(raxNode*)*(n->size-pos)); + + /* Move the pointers to the left of the insertion position as well. Often + * we don't need to do anything if there was already some padding to use. In + * that case the final destination of the pointers will be the same, however + * in our example there was no pre-existing padding, so we added 4 byte + * plus 4 bytes of padding. After the next memmove() things will look + * like that: + * + * [HDR*][1,2][4,5][....][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + */ + if (shift) { + src = (char*) raxNodeFirstChildPtr(n); + memmove(src+shift,src,sizeof(raxNode*)*pos); + } + + /* Now make the space for the additional char in the data section, + * but also move the pointers before the insertion point to the right + * by shift bytes, in order to obtain the following: + * + * [HDR*][1,2][.4][5.][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + */ + src = (char*)n->data+pos*sizeof(int); + memmove(src+4,src,(n->size-pos)*sizeof(int)); + + /* We can now set the character and its child node pointer to get: + * + * [HDR*][1,2][3,4][5.][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + * [HDR*][1,2][3,4][5.][1ptr][2ptr][3ptr][4ptr][5ptr]|AUXP| + */ + n->data[pos] = c; + n->numnodes = parent_numnodes + 1; + n->size++; + src = (char*) raxNodeFirstChildPtr(n); + raxNode **childfield = (raxNode**)(src+sizeof(raxNode*)*pos); + memcpy(childfield,&child,sizeof(child)); + *childptr = child; + *parentlink = childfield; + return n; +} + +/* Turn the node 'n', that must be a node without any children, into a + * compressed node representing a set of nodes linked one after the other + * and having exactly one child each. The node can be a key or not: this + * property and the associated value if any will be preserved. + * + * The function also returns a child node, since the last node of the + * compressed chain cannot be part of the chain: it has zero children while + * we can only compress inner nodes with exactly one child each. */ +raxNode *raxCompressNode(raxNode *n, int *s, size_t len, raxNode **child) { + assert(n->size == 0 && n->iscompr == 0); + void *data = NULL; /* Initialized only to avoid warnings. */ + size_t newsize; + + debugf("Compress node: %.*d\n", (int)len,s); + + /* Allocate the child to link to this node. */ + *child = raxNewNode(0,0); + if (*child == NULL) return NULL; + + /* Make space in the parent node. */ + // update + newsize = sizeof(raxNode)+len*sizeof(int)+raxPadding(len)+sizeof(raxNode*); + if (n->iskey) { + data = raxGetData(n); /* To restore it later. */ + if (!n->isnull) newsize += sizeof(void*); + } + raxNode *newn = (raxNode *)rax_realloc(n,newsize); + if (newn == NULL) { + rax_free(*child); + return NULL; + } + n = newn; + + n->iscompr = 1; + n->numnodes++; + n->size = len; + memcpy(n->data,s,len * sizeof(int)); + if (n->iskey) raxSetData(n,data); + raxNode **childfield = raxNodeLastChildPtr(n); + memcpy(childfield,child,sizeof(*child)); + return n; +} + +/* Low level function that walks the tree looking for the token list + * s of 'len' bytes. The function returns the number of tokens + * of the key that was possible to process: if the returned integer + * is the same as 'len', then it means that the node corresponding to the + * token list was found (however it may not be a key in case the node->iskey is + * zero or if simply we stopped in the middle of a compressed node, so that + * 'splitpos' is non zero). + * + * Otherwise if the returned integer is not the same as 'len', there was an + * early stop during the tree walk because of a token mismatch. + * + * The node where the search ended (because the full token list was processed + * or because there was an early stop) is returned by reference as + * '*stopnode' if the passed pointer is not NULL. This node link in the + * parent's node is returned as '*plink' if not NULL. Finally, if the + * search stopped in a compressed node, '*splitpos' returns the index + * inside the compressed node where the search ended. This is useful to + * know where to split the node for insertion. + * + * Note that when we stop in the middle of a compressed node with + * a perfect match, this function will return a length equal to the + * 'len' argument (all the key matched), and will return a *splitpos which is + * always positive (that will represent the index of the character immediately + * *after* the last match in the current compressed node). + * + * When instead we stop at a compressed node and *splitpos is zero, it + * means that the current node represents the key (that is, none of the + * compressed node tokens list are needed to represent the key, just all + * its parents nodes). */ +static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts) { + raxNode *h = rax->head; + raxNode **parentlink = &rax->head; + + size_t i = 0; /* Position in the string. */ + size_t j = 0; /* Position in the node children (or bytes if compressed).*/ + while(h->size && i < len) { + debugnode("Lookup current node",h); + int *v = h->data; + + if (h->iscompr) { + for (j = 0; j < h->size && i < len; j++, i++) { + if (v[j] != s[i]) { + break; + } + } + if (j != h->size) { + break; + } + } else { + /* Even when h->size is large, linear scan provides good + * performances compared to other approaches that are in theory + * more sounding, like performing a binary search. */ + for (j = 0; j < h->size; j++) { + if (v[j] == s[i]) { + break; + } + } + if (j == h->size) { + break; + } + i++; + } + + if (ts) raxStackPush(ts,h); /* Save stack of parent nodes. */ + raxNode **children = raxNodeFirstChildPtr(h); + if (h->iscompr) j = 0; /* Compressed node only child is at index 0. */ + memcpy(&h,children+j,sizeof(h)); + parentlink = children+j; + j = 0; /* If the new node is compressed and we do not + iterate again (since i == l) set the split + position to 0 to signal this node represents + the searched key. */ + } + debugnode("Lookup stop node is",h); + if (stopnode) *stopnode = h; + if (plink) *plink = parentlink; + if (splitpos && h->iscompr) *splitpos = j; + return i; +} + +int handleOutOfMemory(rax *rax, raxNode *h, int *s, size_t len, void **old){ + /* This code path handles out of memory after part of the sub-tree was + * already modified. Set the node as a key, and then remove it. However we + * do that only if the node is a terminal node, otherwise if the OOM + * happened reallocating a node in the middle, we don't need to free + * anything. */ + if (h->size == 0) { + h->isnull = 1; + h->iskey = 1; + rax->numele++; /* Compensate the next remove. */ + assert(raxRemove(rax,s,len,NULL) != 0); + } + errno = ENOMEM; + return 0; +} + +/* Insert the token list 's' of size 'len', setting as auxiliary data + * the pointer 'data'. If the element is already present, the associated + * data is updated (only if 'overwrite' is set to 1), and 0 is returned, + * otherwise the element is inserted and 1 is returned. On out of memory the + * function returns 0 as well but sets errno to ENOMEM, otherwise errno will + * be set to 0. + */ +int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int overwrite, void **dataNode) { + size_t i; + int j = 0; /* Split position. If raxLowWalk() stops in a compressed + node, the index 'j' represents the char we stopped within the + compressed node, that is, the position where to split the + node for insertion. */ + raxNode *h, **parentlink; + + raxStack lowWalkStack, splitStack; + raxStackInit(&lowWalkStack); + raxStackInit(&splitStack); + int all_added_node = 0; + + i = raxLowWalk(rax,s,len,&h,&parentlink,&j,&lowWalkStack); + debugf("######## after raxLowWalk##########"); + /* If i == len we walked following the whole string. If we are not + * in the middle of a compressed node, the string is either already + * inserted or this middle node is currently not a key, but can represent + * our key. We have just to reallocate the node and make space for the + * data pointer. */ + if (i == len && (!h->iscompr || j == 0 /* not in the middle if j is 0 */)) { + debugf("### Insert: node representing key exists\n"); + /* Make space for the value pointer if needed. */ + if (!h->iskey || (h->isnull && overwrite)) { + h = raxReallocForData(h,data); + if (h) memcpy(parentlink,&h,sizeof(h)); + } + if (h == NULL) { + errno = ENOMEM; + return 0; + } + + /* Update the existing key if there is already one. */ + if (h->iskey) { + if (old) *old = raxGetData(h); + if (overwrite) raxSetData(h,data); + *dataNode = h; + errno = 0; + return 0; /* Element already exists. */ + } + + /* Otherwise set the node as a key. Note that raxSetData() + * will set h->iskey. */ + raxSetData(h,data); + rax->numele++; + *dataNode = h; + return 1; /* Element inserted. */ + } + + /* If the node we stopped at is a compressed node, we need to + * split it before to continue. + * + * Splitting a compressed node have a few possible cases. + * Imagine that the node 'h' we are currently at is a compressed + * node contaning the token list [1,2,3,4,5,6,7] (it means that it represents + * nodes 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 with the only child + * pointer of this node pointing at the 7 node, because remember that + * we have tokens at the edges of the graph, not inside the nodes + * themselves. + * + * In order to show a real case imagine our node to also point to + * another compressed node, that finally points at the node without + * children, representing 'O': + * + * [1,2,3,4,5,6,7] -> [2,4,5] -> [] + * + * When inserting we may face the following cases. Note that all the cases + * require the insertion of a non compressed node with exactly two + * children, except for the last case which just requires splitting a + * compressed node. + * + * 1) Inserting [1,2,3,6,7,8] + * + * |4| -> [5,6,7] -> [2,4,5] -> [] + * [1,2,3] -> |-| + * |6| -> (... continue algo ...) [7,8] -> [] + * + * 2) Inserting [1,2,3,4,5,6,8] + * + * |7| -> [2,4,5] -> [] + * [1,2,3,4,5,6] -> |-| + * |8| -> (... continue algo ...) [] + * + * 3) Inserting [1,3,4] (Like case 1, but set iscompr = 0 into original node) + * + * |2| -> [3,4,5,6,7] -> [2,4,5] -> [] + * |1| -> |-| + * |3| -> (... continue algo ...) |4| -> [] + * + * 4) Inserting [2,3,4] + * + * |1| -> [2,3,4,5,6,7] -> [2,4,5] -> [] + * |-| + * |2| -> (... continue algo ...) [3,4] -> [] + * + * 5) Inserting [1,2,3] + * + * [1,2,3] -> [4,5,6,7] -> [2,4,5] -> [] + * + * The final algorithm for insertion covering all the above cases is as + * follows. + * + * ============================= ALGO 1 ============================= + * + * For the above cases 1 to 4, that is, all cases where we stopped in + * the middle of a compressed node for a character mismatch, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of characters, we found the mismatching + * character. For example if the node contains [1,2,3,4,5,6,7] and we add + * [1,2,3,6,7,8] the $SPLITPOS is 3, that is, the index at which the + * mismatching token is found. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create "split node" having as child the non common token + * at the compressed node. The other non common token (at the key) + * will be added later as we continue the normal insertion algorithm + * at step "6". + * + * 3a. IF $SPLITPOS == 0: + * Replace the old node with the split node, by copying the auxiliary + * data if any. Fix parent's reference. Free old node eventually + * (we still need its data for the next steps of the algorithm). + * + * 3b. IF $SPLITPOS != 0: + * Trim the compressed node (reallocating it as well) in order to + * contain $splitpos tokens. Change chilid pointer in order to link + * to the split node. If new compressed node len is just 1, set + * iscompr to 0 (layout is the same). Fix parent's reference. + * + * 4a. IF the postfix len (the length of the remaining token list of the + * original compressed node after the split token) is non zero, + * create a "postfix node". If the postfix node has just one token + * set iscompr to 0, otherwise iscompr to 1. Set the postfix node + * child pointer to $NEXT. + * + * 4b. IF the postfix len is zero, just use $NEXT as postfix pointer. + * + * 5. Set child[0] of split node to postfix node. + * + * 6. Set the split node as the current node, set current index at child[1] + * and continue insertion algorithm as usually. + * + * ============================= ALGO 2 ============================= + * + * For case 5, that is, if we stopped in the middle of a compressed + * node but no mismatch was found, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of token list, we stopped iterating because + * there were no more keys character to match. So in the example of + * the node [1,2,3,4,5,6,7], addig the token list [1,2,3,4], the $SPLITPOS is 4. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create a "postfix node" containing all the tokens from $SPLITPOS + * to the end. Use $NEXT as the postfix node child pointer. + * If the postfix node length is 1, set iscompr to 0. + * Set the node as a key with the associated value of the new + * inserted key. + * + * 3. Trim the current node to contain the first $SPLITPOS tokens. + * As usually if the new node length is just 1, set iscompr to 0. + * Take the iskey / associated value as it was in the orignal node. + * Fix the parent's reference. + * + * 4. Set the postfix node as the only child pointer of the trimmed + * node created at step 1. + */ + + /* ------------------------- ALGORITHM 1 --------------------------- */ + if (h->iscompr && i != len) { + debugf("ALGO 1: Stopped at compressed node (%p)\n",(void*)h); + debugf("Still to insert: "); + debugf("Splitting at %d: '%d'\n", j, ((int*)h->data)[j]); + debugf("Other key is: %d ", s[i]); + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next,childfield,sizeof(next)); + debugf("Next is %p\n", (void*)next); + debugf("iskey %d\n", h->iskey); + if (h->iskey) { + debugf("key value is %p\n", raxGetData(h)); + } + debugf("get next data %p\n", raxGetData(next)); + + /* Set the length of the additional nodes we will need. */ + size_t trimmedlen = j; + size_t postfixlen = h->size - j - 1; + int split_node_is_key = !trimmedlen && h->iskey && !h->isnull; + size_t nodesize; + + /* 2: Create the split node. Also allocate the other nodes we'll need + * ASAP, so that it will be simpler to handle OOM. */ + raxNode *splitnode = raxNewNode(1, split_node_is_key); + raxNode *trimmed = NULL; + raxNode *postfix = NULL; + + if (trimmedlen) { + nodesize = sizeof(raxNode)+trimmedlen * sizeof(int)+raxPadding(trimmedlen)+ + sizeof(raxNode*); + if (h->iskey && !h->isnull) nodesize += sizeof(void*); + trimmed = (raxNode *)rax_malloc(nodesize); + } + + if (postfixlen) { + nodesize = sizeof(raxNode)+postfixlen * sizeof(int)+raxPadding(postfixlen)+ + sizeof(raxNode*); + postfix = (raxNode *)rax_malloc(nodesize); + } + + /* OOM? Abort now that the tree is untouched. */ + if (splitnode == NULL || + (trimmedlen && trimmed == NULL) || + (postfixlen && postfix == NULL)) + { + rax_free(splitnode); + rax_free(trimmed); + rax_free(postfix); + errno = ENOMEM; + return 0; + } + splitnode->data[0] = h->data[j]; + splitnode->numnodes = h->numnodes; + debugf("split node is %p, data is %d\n", &splitnode, splitnode->data[0]); + + if (j == 0) { + /* 3a: Replace the old node with the split node. */ + if (h->iskey) { + void *ndata = raxGetData(h); + raxSetData(splitnode,ndata); + } + memcpy(parentlink,&splitnode,sizeof(splitnode)); + } else { + /* 3b: Trim the compressed node. */ + trimmed->size = j; + memcpy(trimmed->data,h->data,j*sizeof(int)); + trimmed->iscompr = j > 1 ? 1 : 0; + trimmed->numnodes = h->numnodes; + trimmed->iskey = h->iskey; + trimmed->isnull = h->isnull; + if (h->iskey && !h->isnull) { + void *ndata = raxGetData(h); + raxSetData(trimmed,ndata); + } + raxNode **cp = raxNodeLastChildPtr(trimmed); + memcpy(cp,&splitnode,sizeof(splitnode)); + memcpy(parentlink,&trimmed,sizeof(trimmed)); + parentlink = cp; /* Set parentlink to splitnode parent. */ + all_added_node++; + rax->numnodes++; + } + + /* 4: Create the postfix node: what remains of the original + * compressed node after the split. */ + if (postfixlen) { + /* 4a: create a postfix node. */ + postfix->iskey = 0; + postfix->isnull = 0; + postfix->numnodes = h->numnodes; + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + memcpy(postfix->data,h->data+j+1,postfixlen*sizeof(int)); + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp,&next,sizeof(next)); + all_added_node++; + rax->numnodes++; + } else { + /* 4b: just use next as postfix node. */ + postfix = next; + postfix->numnodes = next->numnodes; + } + + /* 5: Set splitnode first child as the postfix node. */ + raxNode **splitchild = raxNodeLastChildPtr(splitnode); + memcpy(splitchild,&postfix,sizeof(postfix)); + + /* 6. Continue insertion: this will cause the splitnode to + * get a new child (the non common token at the currently + * inserted token list). */ + rax_free(h); + h = splitnode; + if (trimmedlen != 0 && postfixlen == 0) { + trimmed->numnodes++; + raxStackPush(&splitStack, trimmed); + } else if (trimmedlen == 0 && postfixlen != 0) { + splitnode->numnodes++; + raxStackPush(&splitStack, splitnode); + } else { + trimmed->numnodes+=2; + splitnode->numnodes++; + raxStackPush(&splitStack, trimmed); + raxStackPush(&splitStack, splitnode); + } + } else if (h->iscompr && i == len) { + /* ------------------------- ALGORITHM 2 --------------------------- */ + debugf("ALGO 2: Stopped at compressed node %d (%p) j = %d\n", + ((int*)h->data)[j], (void*)h, j); + + /* Allocate postfix & trimmed nodes ASAP to fail for OOM gracefully. */ + size_t postfixlen = h->size - j; + size_t nodesize = sizeof(raxNode)+postfixlen*sizeof(int)+raxPadding(postfixlen)+ + sizeof(raxNode*); + if (data != NULL) nodesize += sizeof(void*); + raxNode *postfix = (raxNode *)rax_malloc(nodesize); + + nodesize = sizeof(raxNode)+j*sizeof(int)+raxPadding(j)+sizeof(raxNode*); + if (h->iskey && !h->isnull) nodesize += sizeof(void*); + raxNode *trimmed = (raxNode *)rax_malloc(nodesize); + + if (postfix == NULL || trimmed == NULL) { + rax_free(postfix); + rax_free(trimmed); + errno = ENOMEM; + return 0; + } + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next,childfield,sizeof(next)); + + /* 2: Create the postfix node. */ + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + postfix->numnodes = h->numnodes; + postfix->iskey = 1; + postfix->isnull = 0; + memcpy(postfix->data,h->data+j,postfixlen*sizeof(int)); + raxSetData(postfix,data); + *dataNode = postfix; + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp,&next,sizeof(next)); + rax->numnodes++; + all_added_node++; + + /* 3: Trim the compressed node. */ + trimmed->size = j; + trimmed->iscompr = j > 1; + trimmed->numnodes = h->numnodes+1; + trimmed->iskey = 0; + trimmed->isnull = 0; + memcpy(trimmed->data,h->data,j*sizeof(int)); + memcpy(parentlink,&trimmed,sizeof(trimmed)); + if (h->iskey) { + void *aux = raxGetData(h); + raxSetData(trimmed,aux); + } + + /* Fix the trimmed node child pointer to point to + * the postfix node. */ + cp = raxNodeLastChildPtr(trimmed); + memcpy(cp,&postfix,sizeof(postfix)); + + raxStackAddNumNodes(&lowWalkStack, all_added_node); + raxStackFree(&lowWalkStack); + /* Finish! We don't need to continue with the insertion + * algorithm for ALGO 2. The key is already inserted. */ + rax->numele++; + rax_free(h); + return 1; /* Key inserted. */ + } + + raxNode *prev_node = NULL; + int insert_new_node = 0; + /* We walked the radix tree as far as we could, but still there are left + * tokens in our string. We need to insert the missing nodes. */ + while(i < len) { + raxNode *child; + // error + + /* If this node is going to have a single child, and there + * are other tokens, so that that would result in a chain + * of single-childed nodes, turn it into a compressed node. */ + if (h->size == 0 && len-i > 1) { + debugf("Inserting compressed node: [%d, \n", s[i]); + size_t comprsize = len-i; + if (comprsize > RAX_NODE_MAX_SIZE) + comprsize = RAX_NODE_MAX_SIZE; + raxNode *newh = raxCompressNode(h,s+i,comprsize,&child); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + memcpy(parentlink,&h,sizeof(h)); + parentlink = raxNodeLastChildPtr(h); + i += comprsize; + } else { + debugf("Inserting normal node %d\n", s[i]); + raxNode **new_parentlink; + raxNode *newh = raxAddChild(h,s[i],&child,&new_parentlink); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + memcpy(parentlink,&h,sizeof(h)); + parentlink = new_parentlink; + i++; + } + if (prev_node != NULL) { + prev_node->numnodes++; + } else { + prev_node = h; + } + all_added_node++; + insert_new_node++; + rax->numnodes++; + h = child; + } + raxStackAddNumNodes(&lowWalkStack, all_added_node); + raxStackAddNumNodes(&splitStack, insert_new_node); + raxStackFree(&lowWalkStack); + raxStackFree(&splitStack); + raxNode *newh = raxReallocForData(h,data); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + if (!h->iskey) rax->numele++; + raxSetData(h,data); + memcpy(parentlink,&h,sizeof(h)); + *dataNode = h; + return 1; /* Element inserted. */ +} + +/* Overwriting insert. Just a wrapper for raxGenericInsert() that will + * update the element if there is already one for the same key. */ +int raxInsert(rax *rax, int *s, size_t len, void *data, void **old) { + void *dataNode = NULL; + return raxGenericInsert(rax,s,len,data,old,1,&dataNode); +} + +/* Non overwriting insert function: this if an element with the same key + * exists, the value is not updated and the function returns 0. + * This is a just a wrapper for raxGenericInsert(). */ +int raxTryInsert(rax *rax, int *s, size_t len, void *data, void **old) { + void *dataNode = NULL; + return raxGenericInsert(rax,s,len,data,old,0,&dataNode); +} + +/* +Overwriting insert. Return the raxNode that contains the key. +*/ +int raxInsertAndReturnDataNode(rax *rax, int *s, size_t len, void *data, void **node, void **old) { + return raxGenericInsert(rax,s,len,data,old,1, node); +} + +/* Find a key in the rax, returns raxNotFound special void pointer value + * if the item was not found, otherwise the value associated with the + * item is returned. */ +void *raxFind(rax *rax, int *s, size_t len) { + raxNode *h; + + //debugf("### Lookup: %d\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return raxNotFound; + return raxGetData(h); +} + +/* Find a key in the rax, returns the stack +*/ +raxStack raxFindWithStack(rax *rax, int *s, size_t len) { + raxNode *h; + + raxStack ts; + raxStackInit(&ts); + //debugf("### Lookup: %.*s\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return ts; + return ts; +} + +/* +Find a key in the rax, returns the raxNode that contains the key. +*/ +raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len) { + raxNode *h; + + //debugf("### Lookup: %.*s\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return NULL; + return h; +} + +/* Return the memory address where the 'parent' node stores the specified + * 'child' pointer, so that the caller can update the pointer with another + * one if needed. The function assumes it will find a match, otherwise the + * operation is an undefined behavior (it will continue scanning the + * memory without any bound checking). */ +raxNode **raxFindParentLink(raxNode *parent, raxNode *child) { + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode *c; + while(1) { + memcpy(&c,cp,sizeof(c)); + if (c == child) break; + cp++; + } + return cp; +} + +/* Low level child removal from node. The new node pointer (after the child + * removal) is returned. Note that this function does not fix the pointer + * of the parent node in its parent, so this task is up to the caller. + * The function never fails for out of memory. */ +raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { + debugnode("raxRemoveChild before", parent); + /* If parent is a compressed node (having a single child, as for definition + * of the data structure), the removal of the child consists into turning + * it into a normal node without children. */ + if (parent->iscompr) { + void *data = NULL; + if (parent->iskey) data = raxGetData(parent); + parent->isnull = 0; + parent->iscompr = 0; + parent->size = 0; + if (parent->iskey) raxSetData(parent,data); + debugnode("raxRemoveChild after", parent); + return parent; + } + + /* Otherwise we need to scan for the child pointer and memmove() + * accordingly. + * + * 1. To start we seek the first element in both the children + * pointers and edge bytes in the node. */ + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode **c = cp; + int *e = parent->data; + + /* 2. Search the child pointer to remove inside the array of children + * pointers. */ + while(1) { + raxNode *aux; + memcpy(&aux,c,sizeof(aux)); + if (aux == child) break; + c++; + e++; + } + + /* 3. Remove the edge and the pointer by memmoving the remaining children + * pointer and edge bytes one position before. */ + int taillen = parent->size - (e - parent->data) - 1; + memmove(e,(int*)e+1,taillen*sizeof(int)); + + /* Compute the shift, that is the amount of bytes we should move our + * child pointers to the left, since the removal of one edge token + * and the corresponding padding change, may change the layout. + * We just check if in the old version of the node there was at the + * end just a single token and all padding: in that case removing one token + * will remove a whole sizeof(void*) word. */ + size_t shift = ((parent->size * sizeof(int)+4) % sizeof(void*)) == 4 ? sizeof(void *) : 0; + + /* Move the children pointers before the deletion point. */ + if (shift) + memmove(((char*)cp)-shift,cp,(parent->size-taillen-1)*sizeof(raxNode**)); + + /* Move the remaining "tail" pointers at the right position as well. */ + size_t valuelen = (parent->iskey && !parent->isnull) ? sizeof(void*) : 0; + memmove(((char*)c)-shift,c+1,taillen*sizeof(raxNode**)+valuelen); + + /* 4. Update size. */ + parent->size--; + + /* realloc the node according to the theoretical memory usage, to free + * data if we are over-allocating right now. */ + raxNode *newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); + if (newnode) { + debugnode("raxRemoveChild after", newnode); + } + /* Note: if rax_realloc() fails we just return the old address, which + * is valid. */ + return newnode ? newnode : parent; +} + +/* Remove the specified item. Returns 1 if the item was found and + * deleted, 0 otherwise. */ +int raxRemove(rax *rax, int *s, size_t len, void **old) { + raxNode *h; + raxStack ts; + + debugf("### Delete: "); + raxStackInit(&ts); + int splitpos = 0; + int all_added_node = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { + raxStackFree(&ts); + return 0; + } + if (old) *old = raxGetData(h); + h->iskey = 0; + rax->numele--; + + /* If this node has no children, the deletion needs to reclaim the + * no longer used nodes. This is an iterative process that needs to + * walk the three upward, deleting all the nodes with just one child + * that are not keys, until the head of the rax is reached or the first + * node with more than one child is found. */ + + int trycompress = 0; /* Will be set to 1 if we should try to optimize the + tree resulting from the deletion. */ + + if (h->size == 0) { + debugf("Key deleted in node without children. Cleanup needed.\n"); + raxNode *child = NULL; + int added_nodes = 0; + while(h != rax->head) { + child = h; + debugf("Freeing child %p, ", (void*)child); + debugf(" key:%d\n", child->iskey) + rax_free(child); + rax->numnodes--; + added_nodes--; + h = (raxNode *)raxStackPop(&ts); + h->numnodes += added_nodes; + /* If this node has more then one child, or actually holds + * a key, stop here. */ + if (h->iskey || (!h->iscompr && h->size != 1)) break; + } + raxStackAddNumNodes(&ts, added_nodes); + if (child) { + debugf("Unlinking child %p from parent %p\n", + (void*)child, (void*)h); + raxNode *newNode = (raxNode *)raxRemoveChild(h,child); + if (newNode != h) { + raxNode *parent = (raxNode *)raxStackPeek(&ts); + raxNode **parentlink; + if (parent == NULL) { + parentlink = &rax->head; + } else { + parentlink = raxFindParentLink(parent,h); + } + memcpy(parentlink,&newNode,sizeof(newNode)); + } + + /* If after the removal the node has just a single child + * and is not a key, we need to try to compress it. */ + if (newNode->size == 1 && newNode->iskey == 0) { + trycompress = 1; + h = newNode; + } + } + } else if (h->size == 1) { + /* If the node had just one child, after the removal of the key + * further compression with adjacent nodes is pontentially possible. */ + trycompress = 1; + } + + /* Don't try node compression if our nodes pointers stack is not + * complete because of OOM while executing raxLowWalk() */ + if (trycompress && ts.oom) trycompress = 0; + + /* Recompression: if trycompress is true, 'h' points to a radix tree node + * that changed in a way that could allow to compress nodes in this + * sub-branch. Compressed nodes represent chains of nodes that are not + * keys and have a single child, so there are two deletion events that + * may alter the tree so that further compression is needed: + * + * 1) A node with a single child was a key and now no longer is a key. + * 2) A node with two children now has just one child. + * + * We try to navigate upward till there are other nodes that can be + * compressed, when we reach the upper node which is not a key and has + * a single child, we scan the chain of children to collect the + * compressable part of the tree, and replace the current node with the + * new one, fixing the child pointer to reference the first non + * compressable node. + * + * Example of case "1". A tree stores the keys [1,2,3] = 1 and + * [1,2,3,4,5] = 2: + * + * + * [1,2,3] -> [4,5] -> [] (2) + * (1) + * + * After the removal of [1,2,3] the tree can be compressed as: + * + * [1,2,3,4,5] -> [] (2) + * + * + * Example of case "2". A tree stores the keys [1,2,3,4,5] = 1 and + * [1,2,3,5,6] = 2: + * + * |4| -> [5] -> [] (1) + * [1,2,3] -> |-| + * |5| -> [6] -> [] (2) + * + * After the removal of [1,2,3,5,6] the resulting tree is: + * + * [1,2,3] -> |4| -> [5] -> [] (1) + * + * That can be compressed into: + * + * [1,2,3,5,6] -> [] (1) + */ + if (trycompress) { + debugf("After removing %d len: %d:\n", s[0], (int)len); + debugnode("Compression may be needed",h); + debugf("Seek start node\n"); + + /* Try to reach the upper node that is compressible. + * At the end of the loop 'h' will point to the first node we + * can try to compress and 'parent' to its parent. */ + raxNode *parent; + while(1) { + parent = (raxNode *)raxStackPop(&ts); + if (!parent || parent->iskey || + (!parent->iscompr && parent->size != 1)) break; + h = parent; + debugnode("Going up to",h); + } + raxNode *start = h; /* Compression starting node. */ + int start_num_nodes = start->numnodes; + + /* Scan chain of nodes we can compress. */ + size_t comprsize = h->size; + int nodes = 1; + while(h->size != 0) { + raxNode **cp = raxNodeLastChildPtr(h); + memcpy(&h,cp,sizeof(h)); + if (h->iskey || (!h->iscompr && h->size != 1)) break; + /* Stop here if going to the next node would result into + * a compressed node larger than h->size can hold. */ + if (comprsize + h->size > RAX_NODE_MAX_SIZE) break; + nodes++; + comprsize += h->size; + } + if (nodes > 1) { + /* If we can compress, create the new node and populate it. */ + size_t nodesize = + sizeof(raxNode)+comprsize*sizeof(int)+raxPadding(comprsize)+sizeof(raxNode*); + raxNode *newNode = (raxNode *)rax_malloc(nodesize); + /* An out of memory here just means we cannot optimize this + * node, but the tree is left in a consistent state. */ + if (newNode == NULL) { + raxStackFree(&ts); + return 1; + } + newNode->iskey = 0; + newNode->isnull = 0; + newNode->iscompr = 1; + newNode->size = comprsize; + newNode->numnodes = h->numnodes+1; + all_added_node++; + rax->numnodes++; + + /* Scan again, this time to populate the new node content and + * to fix the new node child pointer. At the same time we free + * all the nodes that we'll no longer use. */ + comprsize = 0; + h = start; + while(h->size != 0) { + memcpy((int*)newNode->data+comprsize,h->data,h->size*sizeof(int)); + comprsize += h->size; + raxNode **cp = raxNodeLastChildPtr(h); + raxNode *tofree = h; + memcpy(&h,cp,sizeof(h)); + rax_free(tofree); rax->numnodes--; + all_added_node--; + if (h->iskey || (!h->iscompr && h->size != 1)) break; + } + newNode->numnodes = start_num_nodes+all_added_node; + debugnode("New node",new); + + /* Now 'h' points to the first node that we still need to use, + * so our new node child pointer will point to it. */ + raxNode **cp = (raxNode **)raxNodeLastChildPtr(newNode); + memcpy(cp,&h,sizeof(h)); + + /* Fix parent link. */ + if (parent) { + parent->numnodes+=all_added_node; + raxNode **parentlink = raxFindParentLink(parent,start); + memcpy(parentlink,&newNode,sizeof(newNode)); + } else { + rax->head = newNode; + } + + debugf("Compressed %d nodes, %d total bytes\n", + nodes, (int)comprsize); + } + } + raxStackAddNumNodes(&ts, all_added_node); + raxStackFree(&ts); + return 1; +} + +/* This is the core of raxFree(): performs a depth-first scan of the + * tree and releases all the nodes found. */ +void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void*)) { + debugnode("free traversing",n); + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeLastChildPtr(n); + while(numchildren--) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxRecursiveFree(rax,child,free_callback); + cp--; + } + debugnode("free depth-first",n); + if (free_callback && n->iskey && !n->isnull) + free_callback(raxGetData(n)); + rax_free(n); + rax->numnodes--; +} + +/* Free a whole radix tree, calling the specified callback in order to + * free the auxiliary data. */ +void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)) { + raxRecursiveFree(rax,rax->head,free_callback); + assert(rax->numnodes == 0); + rax_free(rax); +} + +/* Free a whole radix tree. */ +void raxFree(rax *rax) { + raxFreeWithCallback(rax,NULL); +} + +/* ------------------------------- Iterator --------------------------------- */ + +/* Initialize a Rax iterator. This call should be performed a single time + * to initialize the iterator, and must be followed by a raxSeek() call, + * otherwise the raxPrev()/raxNext() functions will just return EOF. */ +void raxStart(raxIterator *it, rax *rt) { + it->flags = RAX_ITER_EOF; /* No crash if the iterator is not seeked. */ + it->rt = rt; + it->key_len = 0; + it->key = it->key_static_tokens; + it->key_max = RAX_ITER_STATIC_LEN; + it->data = NULL; + it->node_cb = NULL; + raxStackInit(&it->stack); +} + +/* Append token at the current key string of the iterator 'it'. This + * is a low level function used to implement the iterator, not callable by + * the user. Returns 0 on out of memory, otherwise 1 is returned. */ +int raxIteratorAddToken(raxIterator *it, int *s, size_t len) { + if (it->key_max < it->key_len+len) { + int *old = (it->key == it->key_static_tokens) ? NULL : + it->key; + size_t new_max = (it->key_len+len)*2; + // update + it->key = (int *)rax_realloc(old,new_max * sizeof(int)); + if (it->key == NULL) { + it->key = (!old) ? it->key_static_tokens : old; + errno = ENOMEM; + return 0; + } + // update + if (old == NULL) memcpy(it->key,it->key_static_tokens,it->key_len * sizeof(int)); + it->key_max = new_max; + } + /* Use memmove since there could be an overlap between 's' and + * it->key when we use the current key in order to re-seek. */ + // update + memmove(it->key+it->key_len,s,len * sizeof(int)); + it->key_len += len; + return 1; +} + +/* Remove the specified number of chars from the right of the current + * iterator key. */ +void raxIteratorDelChars(raxIterator *it, size_t count) { + it->key_len -= count; +} + +/* Do an iteration step towards the next element. At the end of the step the + * iterator key will represent the (new) current key. If it is not possible + * to step in the specified direction since there are no longer elements, the + * iterator is flagged with RAX_ITER_EOF. + * + * If 'noup' is true the function starts directly scanning for the next + * lexicographically smaller children, and the current node is already assumed + * to be the parent of the last key node, so the first operation to go back to + * the parent will be skipped. This option is used by raxSeek() when + * implementing seeking a non existing element with the ">" or "<" options: + * the starting node is not a key in that particular case, so we start the scan + * from a node that does not represent the key set. + * + * The function returns 1 on success or 0 on out of memory. */ +int raxIteratorNextStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while(1) { + int children = it->node->iscompr ? 1 : it->node->size; + if (!noup && children) { + debugf("GO DEEPER\n"); + /* Seek the lexicographically smaller key in this subtree, which + * is the first one found always going torwards the first child + * of every successive node. */ + if (!raxStackPush(&it->stack,it->node)) return 0; + raxNode **cp = raxNodeFirstChildPtr(it->node); + if (!raxIteratorAddToken(it,it->node->data, + it->node->iscompr ? it->node->size : 1)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + /* Call the node callback if any, and replace the node pointer + * if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp,&it->node,sizeof(it->node)); + /* For "next" step, stop every time we find a key along the + * way, since the key is lexicograhically smaller compared to + * what follows in the sub-children. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } else { + /* If we finished exporing the previous sub-tree, switch to the + * new one: go upper until a node is found where there are + * children representing keys lexicographically greater than the + * current key. */ + while(1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + /* If there are no children at the current node, try parent's + * next child. */ + int prevchild = it->key[it->key_len-1]; + if (!noup) { + it->node = (raxNode *)raxStackPop(&it->stack); + } else { + noup = 0; + } + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it,todel); + + /* Try visiting the next child if there was at least one + * additional child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeFirstChildPtr(it->node); + int i = 0; + while (i < it->node->size) { + debugf("SCAN NEXT %c\n", it->node->data[i]); + if (it->node->data[i] > prevchild) break; + i++; + cp++; + } + if (i != it->node->size) { + debugf("SCAN found a new node\n"); + raxIteratorAddToken(it,it->node->data+i,1); + if (!raxStackPush(&it->stack,it->node)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + /* Call the node callback if any, and replace the node + * pointer if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp,&it->node,sizeof(it->node)); + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + break; + } + } + } + } + } +} + +/* Seek the greatest key in the subtree at the current node. Return 0 on + * out of memory, otherwise 1. This is an helper function for different + * iteration functions below. */ +int raxSeekGreatest(raxIterator *it) { + while(it->node->size) { + if (it->node->iscompr) { + if (!raxIteratorAddToken(it,it->node->data, + it->node->size)) return 0; + } else { + if (!raxIteratorAddToken(it,it->node->data+it->node->size-1,1)) + return 0; + } + raxNode **cp = raxNodeLastChildPtr(it->node); + if (!raxStackPush(&it->stack,it->node)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + } + return 1; +} + +/* Like raxIteratorNextStep() but implements an iteration step moving + * to the lexicographically previous element. The 'noup' option has a similar + * effect to the one of raxIteratorNextStep(). */ +int raxIteratorPrevStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while(1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + + unsigned char prevchild = it->key[it->key_len-1]; + if (!noup) { + it->node = (raxNode *)raxStackPop(&it->stack); + } else { + noup = 0; + } + + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it,todel); + + /* Try visiting the prev child if there is at least one + * child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeLastChildPtr(it->node); + int i = it->node->size-1; + while (i >= 0) { + debugf("SCAN PREV %c\n", it->node->data[i]); + if (it->node->data[i] < prevchild) break; + i--; + cp--; + } + /* If we found a new subtree to explore in this node, + * go deeper following all the last children in order to + * find the key lexicographically greater. */ + if (i != -1) { + debugf("SCAN found a new node\n"); + /* Enter the node we just found. */ + if (!raxIteratorAddToken(it,it->node->data+i,1)) return 0; + if (!raxStackPush(&it->stack,it->node)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + /* Seek sub-tree max. */ + if (!raxSeekGreatest(it)) return 0; + } + } + + /* Return the key: this could be the key we found scanning a new + * subtree, or if we did not find a new subtree to explore here, + * before giving up with this node, check if it's a key itself. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } +} + +/* Seek an iterator at the specified element. + * Return 0 if the seek failed for syntax error or out of memory. Otherwise + * 1 is returned. When 0 is returned for out of memory, errno is set to + * the ENOMEM value. */ +int raxSeek(raxIterator *it, const char *op, int *ele, size_t len) { + int eq = 0, lt = 0, gt = 0, first = 0, last = 0; + + it->stack.items = 0; /* Just resetting. Intialized by raxStart(). */ + it->flags |= RAX_ITER_JUST_SEEKED; + it->flags &= ~RAX_ITER_EOF; + it->key_len = 0; + it->node = NULL; + + /* Set flags according to the operator used to perform the seek. */ + if (op[0] == '>') { + gt = 1; + if (op[1] == '=') eq = 1; + } else if (op[0] == '<') { + lt = 1; + if (op[1] == '=') eq = 1; + } else if (op[0] == '=') { + eq = 1; + } else if (op[0] == '^') { + first = 1; + } else if (op[0] == '$') { + last = 1; + } else { + errno = 0; + return 0; /* Error. */ + } + + /* If there are no elements, set the EOF condition immediately and + * return. */ + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 1; + } + + if (first) { + /* Seeking the first key greater or equal to the empty string + * is equivalent to seeking the smaller key available. */ + return raxSeek(it,">=",NULL,0); + } + + if (last) { + /* Find the greatest key taking always the last child till a + * final node is found. */ + it->node = it->rt->head; + if (!raxSeekGreatest(it)) return 0; + assert(it->node->iskey); + it->data = raxGetData(it->node); + return 1; + } + + /* We need to seek the specified key. What we do here is to actually + * perform a lookup, and later invoke the prev/next key code that + * we already use for iteration. */ + int splitpos = 0; + size_t i = raxLowWalk(it->rt,ele,len,&it->node,NULL,&splitpos,&it->stack); + + /* Return OOM on incomplete stack info. */ + if (it->stack.oom) return 0; + + if (eq && i == len && (!it->node->iscompr || splitpos == 0) && + it->node->iskey) + { + /* We found our node, since the key matches and we have an + * "equal" condition. */ + if (!raxIteratorAddToken(it,ele,len)) return 0; /* OOM. */ + it->data = raxGetData(it->node); + } else if (lt || gt) { + /* Exact key not found or eq flag not set. We have to set as current + * key the one represented by the node we stopped at, and perform + * a next/prev operation to seek. To reconstruct the key at this node + * we start from the parent and go to the current node, accumulating + * the characters found along the way. */ + if (!raxStackPush(&it->stack,it->node)) return 0; + for (size_t j = 1; j < it->stack.items; j++) { + raxNode *parent = (raxNode *)it->stack.stack[j-1]; + raxNode *child = (raxNode *)it->stack.stack[j]; + if (parent->iscompr) { + if (!raxIteratorAddToken(it,parent->data,parent->size)) + return 0; + } else { + raxNode **cp = raxNodeFirstChildPtr(parent); + int *p = parent->data; + while(1) { + raxNode *aux; + memcpy(&aux,cp,sizeof(aux)); + if (aux == child) break; + cp++; + p++; + } + if (!raxIteratorAddToken(it,p,1)) return 0; + } + } + raxStackPop(&it->stack); + + /* We need to set the iterator in the correct state to call next/prev + * step in order to seek the desired element. */ + debugf("After initial seek: i=%d len=%d key=%.*s\n", + (int)i, (int)len, (int)it->key_len, it->key); + if (i != len && !it->node->iscompr) { + /* If we stopped in the middle of a normal node because of a + * mismatch, add the mismatching character to the current key + * and call the iterator with the 'noup' flag so that it will try + * to seek the next/prev child in the current node directly based + * on the mismatching character. */ + if (!raxIteratorAddToken(it,ele+i,1)) return 0; + debugf("Seek normal node on mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (lt && !raxIteratorPrevStep(it,1)) return 0; + if (gt && !raxIteratorNextStep(it,1)) return 0; + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else if (i != len && it->node->iscompr) { + debugf("Compressed mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + /* In case of a mismatch within a compressed node. */ + int nodechar = it->node->data[splitpos]; + int keychar = ele[i]; + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (gt) { + /* If the key the compressed node represents is greater + * than our seek element, continue forward, otherwise set the + * state in order to go back to the next sub-tree. */ + if (nodechar > keychar) { + if (!raxIteratorNextStep(it,0)) return 0; + } else { + if (!raxIteratorAddToken(it,it->node->data,it->node->size)) + return 0; + if (!raxIteratorNextStep(it,1)) return 0; + } + } + if (lt) { + /* If the key the compressed node represents is smaller + * than our seek element, seek the greater key in this + * subtree, otherwise set the state in order to go back to + * the previous sub-tree. */ + if (nodechar < keychar) { + if (!raxSeekGreatest(it)) return 0; + it->data = raxGetData(it->node); + } else { + if (!raxIteratorAddToken(it,it->node->data,it->node->size)) + return 0; + if (!raxIteratorPrevStep(it,1)) return 0; + } + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else { + debugf("No mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + /* If there was no mismatch we are into a node representing the + * key, (but which is not a key or the seek operator does not + * include 'eq'), or we stopped in the middle of a compressed node + * after processing all the key. Continue iterating as this was + * a legitimate key we stopped at. */ + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (it->node->iscompr && it->node->iskey && splitpos && lt) { + /* If we stopped in the middle of a compressed node with + * perfect match, and the condition is to seek a key "<" than + * the specified one, then if this node is a key it already + * represents our match. For instance we may have nodes: + * + * "f" -> "oobar" = 1 -> "" = 2 + * + * Representing keys "f" = 1, "foobar" = 2. A seek for + * the key < "foo" will stop in the middle of the "oobar" + * node, but will be our match, representing the key "f". + * + * So in that case, we don't seek backward. */ + it->data = raxGetData(it->node); + } else { + if (gt && !raxIteratorNextStep(it,0)) return 0; + if (lt && !raxIteratorPrevStep(it,0)) return 0; + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } + } else { + /* If we are here just eq was set but no match was found. */ + it->flags |= RAX_ITER_EOF; + return 1; + } + return 1; +} + +/* Go to the next element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxNext(raxIterator *it) { + if (!raxIteratorNextStep(it,0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Go to the previous element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxPrev(raxIterator *it) { + if (!raxIteratorPrevStep(it,0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Perform a random walk starting in the current position of the iterator. + * Return 0 if the tree is empty or on out of memory. Otherwise 1 is returned + * and the iterator is set to the node reached after doing a random walk + * of 'steps' steps. If the 'steps' argument is 0, the random walk is performed + * using a random number of steps between 1 and two times the logarithm of + * the number of elements. + * + * NOTE: if you use this function to generate random elements from the radix + * tree, expect a disappointing distribution. A random walk produces good + * random elements if the tree is not sparse, however in the case of a radix + * tree certain keys will be reported much more often than others. At least + * this function should be able to expore every possible element eventually. */ +int raxRandomWalk(raxIterator *it, size_t steps) { + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 0; + } + + if (steps == 0) { + size_t fle = 1+floor(log(it->rt->numele)); + fle *= 2; + steps = 1 + rand() % fle; + } + + raxNode *n = it->node; + while(steps > 0 || !n->iskey) { + int numchildren = n->iscompr ? 1 : n->size; + int r = rand() % (numchildren+(n != it->rt->head)); + + if (r == numchildren) { + /* Go up to parent. */ + n = (raxNode *)raxStackPop(&it->stack); + int todel = n->iscompr ? n->size : 1; + raxIteratorDelChars(it,todel); + } else { + /* Select a random child. */ + if (n->iscompr) { + if (!raxIteratorAddToken(it,n->data,n->size)) return 0; + } else { + if (!raxIteratorAddToken(it,n->data+r,1)) return 0; + } + raxNode **cp = raxNodeFirstChildPtr(n)+r; + if (!raxStackPush(&it->stack,n)) return 0; + memcpy(&n,cp,sizeof(n)); + } + if (n->iskey) steps--; + } + it->node = n; + it->data = raxGetData(it->node); + return 1; +} + +/* Compare the key currently pointed by the iterator to the specified + * key according to the specified operator. Returns 1 if the comparison is + * true, otherwise 0 is returned. */ +int raxCompare(raxIterator *iter, const char *op, int *key, size_t key_len) { + int eq = 0, lt = 0, gt = 0; + + if (op[0] == '=' || op[1] == '=') eq = 1; + if (op[0] == '>') gt = 1; + else if (op[0] == '<') lt = 1; + else if (op[1] != '=') return 0; /* Syntax error. */ + + size_t minlen = key_len < iter->key_len ? key_len : iter->key_len; + // update + int cmp = memcmp(iter->key,key,minlen * sizeof(int)); + + /* Handle == */ + if (lt == 0 && gt == 0) return cmp == 0 && key_len == iter->key_len; + + /* Handle >, >=, <, <= */ + if (cmp == 0) { + /* Same prefix: longer wins. */ + if (eq && key_len == iter->key_len) return 1; + else if (lt) return iter->key_len < key_len; + else if (gt) return iter->key_len > key_len; + else return 0; /* Avoid warning, just 'eq' is handled before. */ + } else if (cmp > 0) { + return gt ? 1 : 0; + } else /* (cmp < 0) */ { + return lt ? 1 : 0; + } +} + +/* Free the iterator. */ +void raxStop(raxIterator *it) { + if (it->key != it->key_static_tokens) rax_free(it->key); + raxStackFree(&it->stack); +} + +/* Return if the iterator is in an EOF state. This happens when raxSeek() + * failed to seek an appropriate element, so that raxNext() or raxPrev() + * will return zero, or when an EOF condition was reached while iterating + * with raxNext() and raxPrev(). */ +int raxEOF(raxIterator *it) { + return it->flags & RAX_ITER_EOF; +} + +/* Return the number of elements inside the radix tree. */ +uint64_t raxSize(rax *rax) { + return rax->numele; +} + +/* ----------------------------- Introspection ------------------------------ */ + +/* This function is mostly used for debugging and learning purposes. + * It shows an ASCII representation of a tree on standard output, outling + * all the nodes and the contained keys. + * + * The representation is as follow: + * + * [1,2,3,4] (compressed node) + * [5,6] (normal node with three children) + * [5,6]=0x12345678 (node is a key, pointing to value 0x12345678) + * [] (a normal empty node) + * + * Children are represented in new idented lines, each children prefixed by + * the "`-(x)" string, where "x" is the edge byte. + * + * [1,2,3] + * `-(1) [1,1,1,1] + * `-(2) [3,4] + * `-(3) [] + * + * However when a node has a single child the following representation + * is used instead: + * + * [1,2] -> [1,2,3,4] -> [] + */ + +/* The actual implementation of raxShow(). */ +void raxRecursiveShow(int level, int lpad, raxNode *n) { + char s = n->iscompr ? '"' : '['; + char e = n->iscompr ? '"' : ']'; + + int numchars = printf("%c", s); + for (int i = 0; i < n->size; i++) { + numchars += printf("%d ", n->data[i]); + } + numchars += printf("%c %d ", e, n->numnodes); + + if (n->iskey) { + numchars += printf("=%p",raxGetData(n)); + } + + int numchildren = n->iscompr ? 1 : n->size; + /* Note that 7 and 4 magic constants are the string length + * of " `-(x) " and " -> " respectively. */ + if (level) { + lpad += (numchildren > 1) ? 7 : 4; + if (numchildren == 1) lpad += numchars; + } + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + char *branch = " `-(%d) "; + if (numchildren > 1) { + printf("\n"); + for (int j = 0; j < lpad; j++) putchar(' '); + printf(branch,n->data[i]); + } else { + printf(" -> "); + } + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxRecursiveShow(level+1,lpad,child); + cp++; + } +} + + +/* Show a tree, as outlined in the comment above. */ +void raxShow(rax *rax) { + raxRecursiveShow(0,0,rax->head); + putchar('\n'); +} + +/* Used by debugnode() macro to show info about a given node. */ +void raxDebugShowNode(const char *msg, raxNode *n) { + if (raxDebugMsg == 0) return; + printf("%s: %p [ ",msg, (void*)n); + + for (int i=0;i<(int)n->size;i++) { + printf("%d,", *(int*)(n->data+i)); + } + printf("],key:%d size:%d children:",n->iskey, n->size); + int numcld = n->iscompr ? 1 : n->size; + raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld-1); + while(numcld--) { + raxNode *child; + memcpy(&child,cldptr,sizeof(child)); + cldptr++; + printf("%p ", (void*)child); + } + printf("\n"); + fflush(stdout); +} + +/* Touch all the nodes of a tree returning a check sum. This is useful + * in order to make Valgrind detect if there is something wrong while + * reading the data structure. + * + * This function was used in order to identify Rax bugs after a big refactoring + * using this technique: + * + * 1. The rax-test is executed using Valgrind, adding a printf() so that for + * the fuzz tester we see what iteration in the loop we are in. + * 2. After every modification of the radix tree made by the fuzz tester + * in rax-test.c, we add a call to raxTouch(). + * 3. Now as soon as an operation will corrupt the tree, raxTouch() will + * detect it (via Valgrind) immediately. We can add more calls to narrow + * the state. + * 4. At this point a good idea is to enable Rax debugging messages immediately + * before the moment the tree is corrupted, to see what happens. + */ +unsigned long raxTouch(raxNode *n) { + debugf("Touching %p\n", (void*)n); + unsigned long sum = 0; + if (n->iskey) { + sum += (unsigned long)raxGetData(n); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + int count = 0; + for (int i = 0; i < numchildren; i++) { + if (numchildren > 1) { + sum += (long)n->data[i]; + } + raxNode *child; + memcpy(&child,cp,sizeof(child)); + if (child == (void*)0x65d1760) count++; + if (count > 1) exit(1); + sum += raxTouch(child); + cp++; + } + return sum; +} + +/* +* Traverse the tree and collect all the nodes that contain data. +* these nodes are stored in the dataNodeList +*/ +void raxTraverse(raxNode *n, std::vector> &dataNodeList) { + if (n->iskey) { + dataNodeList.push_back(std::shared_ptr(n, [](raxNode*){})); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxTraverse(child, dataNodeList); + cp++; + } +} + +/* +* Split the tree into two sub trees, and return the root node of the new sub tree +* +* Input a token list, and split the tree into two sub trees via the token list +* It will find the node that nearest N/2 but not more than N/2, and split the tree +* into two sub trees. If there is no node that has N/2 children, it will split the +* tree from the root node. +* +*/ +raxNode *raxSplit(rax *rax, int *s, size_t len, void *data){ + int retval = raxInsert(rax, s, len, data, NULL); + if (retval == 0 && errno != 0) { + return NULL; + } + raxNode *childNode = NULL; + raxNode *splitNode = NULL; + raxStack stack = raxFindWithStack(rax, s, len); + int items = stack.items; + while (items > 0) { + raxNode *node = (raxNode *)raxStackPop(&stack); + if (node->numnodes >= (uint32_t)RAX_NODE_MAX_SIZE/2 || node->issubtree) { + splitNode = childNode; + raxStackPush(&stack, node); + break; + } + childNode = node; + items--; + } + // if the splitNode is NULL, it means that the tree only has one node + if (splitNode == NULL) { + return rax->head; + } + + raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); + raxStackFree(&stack); + + splitNode->issubtree = 1; + + return splitNode; +} + +/* +* Traverse the subtree and return all the nodes that contain data under the subtree +* these nodes are stored in the dataNodeList +*/ +void raxTraverseSubTree(raxNode *n, std::vector> &dataNodeList) { + if (n->iskey) { + dataNodeList.push_back(std::shared_ptr(n, [](raxNode*){})); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + if (!child->issubtree) { + raxTraverseSubTree(child, dataNodeList); + } + cp++; + } +} + +void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList) { + raxIterator iter; + raxStart(&iter, root); + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + std::vector token; + for (int i = 0; i < iter.key_len; i++) { + token.push_back(iter.key[i]); + } + tokenList.push_back(token); + dataList.push_back(iter.data); + } + raxStop(&iter); +} diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h new file mode 100644 index 00000000..81682912 --- /dev/null +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -0,0 +1,227 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017-2018, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef RADIX_H +#define RADIX_H + +#include +#include +#include +#include + +/* Representation of a radix tree as implemented in this file, that contains + * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after + * the insertion of each token list. When the node represents a key inside + * the radix tree, we write it between [], otherwise it is written between (). + * + * This is the vanilla representation: + * + * (1) [] + * \ + * (2) [1] + * \ + * (3) [1,3] + * \ + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] (5) (7) [1,2,3,6] + * / \ + * [1,2,3,4,5] (6) (8) [1,2,3,6,7] + * / \ + *[1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * However, this implementation implements a very common optimization where + * successive nodes having a single child are "compressed" into the node + * itself as a list of tokens, each representing a next-level child, + * and only the link to the node representing the last token node is + * provided inside the representation. So the above representation is turned + * into: + * + * ([1,2,3]) [] + * | + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] ([5,6]) ([7,8]) [1,2,3,6] + * / \ + * [1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * However this optimization makes the implementation a bit more complex. + * For instance if a token list [1,1,2] is added in the above radix tree, a + * "node splitting" operation is needed, since the [1,2,3] prefix is no longer + * composed of nodes having a single child one after the other. This is the + * above tree and the resulting node splitting after this event happens: + * + * + * (1) [] + * / \ + * [1] ([1,2) ([2,3]) [1] + * \ + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] ([5,6]) ([7,8]) [1,2,3,6] + * / \ + * [1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * + * Similarly after deletion, if a new chain of nodes having a single child + * is created (the chain must also not include nodes that represent keys), + * it must be compressed back into a single node. + * + */ + +#define RAX_NODE_MAX_SIZE 1024 +typedef struct raxNode { + uint32_t iskey:1; /* Does this node contain a key? */ + uint32_t isnull:1; /* Associated value is NULL (don't store it). */ + uint32_t iscompr:1; /* Node is compressed. */ + uint32_t issubtree:1; /* Node is the root node of a sub tree */ + uint32_t size:28; /* Number of children, or compressed string len. */ + uint32_t numnodes; /* Number of the child nodes */ + /* Data layout is as follows: + * + * If node is not compressed we have 'size' bytes, one for each children + * token, and 'size' raxNode pointers, point to each child node. + * Note how the character is not stored in the children but in the + * edge of the parents: + * + * [header iscompr=0][1,2,3][1-ptr][2-ptr][3-ptr](value-ptr?) + * + * if node is compressed (iscompr bit is 1) the node has 1 children. + * In that case the 'size' bytes of the string stored immediately at + * the start of the data section, represent a sequence of successive + * nodes linked one after the other, for which only the last one in + * the sequence is actually represented as a node, and pointed to by + * the current compressed node. + * + * [header iscompr=1][1,2,3][3-ptr](value-ptr?) + * + * Both compressed and not compressed nodes can represent a key + * with associated data in the radix tree at any level (not just terminal + * nodes). + * + * If the node has an associated key (iskey=1) and is not NULL + * (isnull=0), then after the raxNode pointers poiting to the + * children, an additional value pointer is present (as you can see + * in the representation above as "value-ptr" field). + */ + int data[]; +} raxNode; + +typedef struct rax { + raxNode *head; + uint64_t numele; + uint64_t numnodes; +} rax; + +/* Stack data structure used by raxLowWalk() in order to, optionally, return + * a list of parent nodes to the caller. The nodes do not have a "parent" + * field for space concerns, so we use the auxiliary stack when needed. */ +#define RAX_STACK_STATIC_ITEMS 32 +typedef struct raxStack { + void **stack; /* Points to static_items or an heap allocated array. */ + size_t items, maxitems; /* Number of items contained and total space. */ + /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap + * and use this static array of pointers instead. */ + void *static_items[RAX_STACK_STATIC_ITEMS]; + int oom; /* True if pushing into this stack failed for OOM at some point. */ +} raxStack; + +/* Optional callback used for iterators and be notified on each rax node, + * including nodes not representing keys. If the callback returns true + * the callback changed the node pointer in the iterator structure, and the + * iterator implementation will have to replace the pointer in the radix tree + * internals. This allows the callback to reallocate the node to perform + * very special operations, normally not needed by normal applications. + * + * This callback is used to perform very low level analysis of the radix tree + * structure, scanning each possible node (but the root node), or in order to + * reallocate the nodes to reduce the allocation fragmentation (this is the + * Redis application for this callback). + * + * This is currently only supported in forward iterations (raxNext) */ +typedef int (*raxNodeCallback)(raxNode **noderef); + +/* Radix tree iterator state is encapsulated into this data structure. */ +#define RAX_ITER_STATIC_LEN 128 +#define RAX_ITER_JUST_SEEKED (1<<0) /* Iterator was just seeked. Return current + element for the first iteration and + clear the flag. */ +#define RAX_ITER_EOF (1<<1) /* End of iteration reached. */ +#define RAX_ITER_SAFE (1<<2) /* Safe iterator, allows operations while + iterating. But it is slower. */ +typedef struct raxIterator { + int flags; + rax *rt; /* Radix tree we are iterating. */ + int *key; /* The current string. */ + void *data; /* Data associated to this key. */ + size_t key_len; /* Current key length. */ + size_t key_max; /* Max key len the current key buffer can hold. */ + int key_static_tokens[RAX_ITER_STATIC_LEN]; + raxNode *node; /* Current node. Only for unsafe iteration. */ + raxStack stack; /* Stack used for unsafe iteration. */ + raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ +} raxIterator; + +/* A special pointer returned for not found items. */ +extern void *raxNotFound; + +/* Exported API. */ +rax *raxNew(void); +int raxInsert(rax *rax, int *s, size_t len, void *data, void **old); +int raxTryInsert(rax *rax, int *s, size_t len, void *data, void **old); +int raxInsertAndReturnDataNode(rax *rax, int *s, size_t len, void *data, void **node, void **old); +int raxRemove(rax *rax, int *s, size_t len, void **old); +void *raxFind(rax *rax, int *s, size_t len); +raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len); +void raxFree(rax *rax); +void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)); +void raxStart(raxIterator *it, rax *rt); +int raxSeek(raxIterator *it, const char *op, int *ele, size_t len); +int raxNext(raxIterator *it); +int raxPrev(raxIterator *it); +int raxRandomWalk(raxIterator *it, size_t steps); +int raxCompare(raxIterator *iter, const char *op, int *key, size_t key_len); +void raxStop(raxIterator *it); +int raxEOF(raxIterator *it); +void raxShow(rax *rax); +uint64_t raxSize(rax *rax); +unsigned long raxTouch(raxNode *n); +void raxSetDebugMsg(int onoff); +void raxTraverse(raxNode *rax, std::vector> &dataNodeList); +void raxTraverseSubTree(raxNode *n, std::vector> &dataNodeList); +raxNode *raxSplit(rax *rax, int *s, size_t len, void *data); +void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList); + +/* Internal API. May be used by the node callback in order to access rax nodes + * in a low level way, so this function is exported as well. */ +void raxSetData(raxNode *n, void *data); +void *raxGetData(raxNode *n); + +#endif diff --git a/modules/kv-state-cache/radix-tree/rax_malloc.h b/modules/kv-state-cache/radix-tree/rax_malloc.h new file mode 100644 index 00000000..e9d5d5d7 --- /dev/null +++ b/modules/kv-state-cache/radix-tree/rax_malloc.h @@ -0,0 +1,43 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/* Allocator selection. + * + * This file is used in order to change the Rax allocator at compile time. + * Just define the following defines to what you want to use. Also add + * the include of your alternate allocator if needed (not needed in order + * to use the default libc allocator). */ + +#ifndef RAX_ALLOC_H +#define RAX_ALLOC_H +#define rax_malloc malloc +#define rax_realloc realloc +#define rax_free free +#endif diff --git a/modules/kv-state-cache/strategy/LRU_strategy.cc b/modules/kv-state-cache/strategy/LRU_strategy.cc new file mode 100644 index 00000000..bf6b646e --- /dev/null +++ b/modules/kv-state-cache/strategy/LRU_strategy.cc @@ -0,0 +1,143 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "LRU_strategy.h" +#include "common/util/logging.h" + +namespace vineyard { + +void PrintTokenList(std::vector& vector) { + std::string tokens_str = ""; + for (size_t i = 0; i < vector.size(); ++i) { + tokens_str += std::to_string(vector[i]); + } + LOG(INFO) << tokens_str; +} + +LRUStrategy::LRUStrategy(int capacity) { + this->capacity = capacity; + this->header = this->tail = nullptr; + this->current_size = 0; +} + +LRUStrategy::LRUStrategy(const std::vector>& cache_list, + int capacity) { + // TBD +} + +std::shared_ptr LRUStrategy::InsertToHeader( + const std::vector& tokens, std::vector& evicted_tokens) { + if (current_size == capacity) { + std::shared_ptr remove_node = Remove(); + evicted_tokens = remove_node->tokens; + } + + std::shared_ptr cache_node = std::make_shared(); + cache_node->tokens = tokens; + + if (header == nullptr) { + header = cache_node; + tail = cache_node; + } else { + cache_node->next = header; + header->prev = cache_node; + header = cache_node; + } + + current_size++; + return cache_node; +} + +void LRUStrategy::MoveToHead(std::shared_ptr cache_node) { + if (cache_node == header) { + return; + } + + if (cache_node == tail) { + tail = tail->prev; + tail->next = nullptr; + } else { + cache_node->prev->next = cache_node->next; + cache_node->next->prev = cache_node->prev; + } + + cache_node->next = header; + header->prev = cache_node; + header = cache_node; + cache_node->prev = nullptr; +} + +std::shared_ptr LRUStrategy::Remove() { + std::shared_ptr cache_node = tail; + if (tail->prev != nullptr) { + tail->prev->next = nullptr; + tail = tail->prev; + } else { + header = nullptr; + tail = nullptr; + } + current_size--; + + LOG(INFO) << "Remove token:"; + PrintTokenList(cache_node->tokens); + return cache_node; +} + +void LRUStrategy::Remove(std::shared_ptr cache_node) { + if (cache_node == header) { + header = header->next; + header->prev = nullptr; + } else if (cache_node == tail) { + tail = tail->prev; + tail->next = nullptr; + } else { + cache_node->prev->next = cache_node->next; + cache_node->next->prev = cache_node->prev; + } + current_size--; +} + +std::shared_ptr LRUStrategy::GetHeader() { return header; } + +// void LRUStrategy::Remove(const std::vector& prefix, int token) { +// std::vector tokens = prefix; +// tokens.push_back(token); + +// std::shared_ptr node_with_tree_attri = +// radix_tree->Query(tokens); +// if (node_with_tree_attri == nullptr) { +// return; +// } + +// std::shared_ptr cache_node = +// std::static_pointer_cast( +// node_with_tree_attri->get_node()->get_data()); +// if (cache_node == header) { +// header = header->next; +// header->prev = nullptr; +// } else if (cache_node == tail) { +// tail = tail->prev; +// tail->next = nullptr; +// } else { +// cache_node->prev->next = cache_node->next; +// cache_node->next->prev = cache_node->prev; +// } +// current_size--; +// radix_tree->Delete(tokens); +// } + +// LRUStrategy::~LRUStrategy() { delete radix_tree; } + +} // namespace vineyard diff --git a/modules/kv-state-cache/strategy/LRU_strategy.h b/modules/kv-state-cache/strategy/LRU_strategy.h new file mode 100644 index 00000000..58c13b05 --- /dev/null +++ b/modules/kv-state-cache/strategy/LRU_strategy.h @@ -0,0 +1,67 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "cache_strategy.h" + +#ifndef MODULES_LRU_STRATEGY_H_ +#define MODULES_LRU_STRATEGY_H_ + +namespace vineyard { + +struct LRUCacheNode { + std::shared_ptr next; + std::shared_ptr prev; + std::vector tokens; +}; + +class LRUStrategy : public CacheStrategy { + private: + int current_size; + + std::shared_ptr header; + + std::shared_ptr tail; + + LRUStrategy(); + + std::shared_ptr Remove(); + + ~LRUStrategy(); + + public: + LRUStrategy(int capacity); + + LRUStrategy(const std::vector>& cache_list, int capacity); + + void MoveToHead(std::shared_ptr cache_node); + + std::shared_ptr InsertToHeader( + const std::vector& tokens, std::vector& evicted_tokens); + + void Remove(std::shared_ptr cache_node); + + std::shared_ptr GetHeader(); + + int GetCapacity() { return capacity; } + // for distributed sync + // void Remove(const std::vector& prefix, int token); +}; + +} // namespace vineyard + +#endif \ No newline at end of file diff --git a/modules/kv-state-cache/strategy/cache_strategy.h b/modules/kv-state-cache/strategy/cache_strategy.h new file mode 100644 index 00000000..36596cd1 --- /dev/null +++ b/modules/kv-state-cache/strategy/cache_strategy.h @@ -0,0 +1,31 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#ifndef MODULES_CACHE_STRATEGY_H_ +#define MODULES_CACHE_STRATEGY_H_ + +namespace vineyard { + +class CacheStrategy { + protected: + int capacity; +}; + +} // namespace vineyard + +#endif \ No newline at end of file diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc new file mode 100644 index 00000000..ddefce59 --- /dev/null +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -0,0 +1,274 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "kv-state-cache/ds/kv_state_cache.h" + +using namespace vineyard; + +static Client client; +static std::shared_ptr kv_state_cache_builder = nullptr; +static std::string llm_cache_sync_lock = "llm_cache_sync_lock"; +static std::string llm_cache_object_name = "llm_cache_object"; +static std::thread* sync_thread; +static bool exit_flag = false; +static pthread_mutex_t sync_mutex; + +#ifndef SYNC_INTERVAL +#define SYNC_INTERVAL 3 +#endif + +void threadFunc(); + +void signalHandler(int signum) { + /* + * TBD + * Avoid dead lock if the client is down when the lock is acquired. + * Use lease to prevent dead lock in the future. + */ + std::cout << "Interrupt signal (" << signum << ") received.\n"; + exit_flag = true; + sync_thread->join(); + exit(signum); +} + +void initKVStateCache(int dimension = 10, int cache_capacity = 10) { + if (kv_state_cache_builder == nullptr) { + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); + LOG(INFO) << "socket:" << socket; + client.Connect(socket); + LOG(INFO) << "conneted"; + + pthread_mutex_init(&sync_mutex, NULL); + // TBD + // try to get cache object + std::string actural_key; + bool result; + while (1) { + client.TryAcquireLock(llm_cache_sync_lock, result, actural_key); + if (!result) { + LOG(INFO) << "failed to gain the lock, wait for next time"; + sleep(1); + continue; + } else { + break; + } + } + + // // sync global cache object with vineyard + ObjectID global_kv_state_cache_id; + Status status = + client.GetName(llm_cache_object_name, global_kv_state_cache_id); + if (status.ok()) { + // if success, pull the cache object + std::shared_ptr global_kv_state_cache = + std::dynamic_pointer_cast( + client.GetObject(global_kv_state_cache_id)); + // TBD cache stragety + kv_state_cache_builder = + std::make_shared(client, global_kv_state_cache); + } else { + // if failed, create a new cache object + LOG(INFO) << "failed to get the cache object, create a new one"; + kv_state_cache_builder = std::make_shared( + client, dimension, cache_capacity); + } + + // // release the lock + client.TryReleaseLock(actural_key, result); + VINEYARD_ASSERT(result == true); + + sync_thread = new std::thread(threadFunc); + + signal(SIGINT, signalHandler); + // TBD + // use lease to prevent the deadlock if the client is down + } +} + +void updateInternal(const std::vector& token_list, int next_token, + const KV_STATE_WITH_LAYER& kv_state) { + kv_state_cache_builder->Update(client, token_list, next_token, kv_state); +} + +void update(const std::vector& token_list, int next_token, + const KV_STATE_WITH_LAYER& kv_state) { + LOG(INFO) << "update"; + if (pthread_mutex_trylock(&sync_mutex)) { + return; + } + + updateInternal(token_list, next_token, kv_state); + + pthread_mutex_unlock(&sync_mutex); +} + +void update(const std::vector& token_list, + const LIST_KV_STATE_WITH_LAYER& kv_state) { + if (pthread_mutex_trylock(&sync_mutex)) { + return; + } + std::vector token_list_copy; + for (size_t i = 0; i < token_list.size(); i++) { + updateInternal(token_list_copy, token_list[i], kv_state[i]); + token_list_copy.push_back(token_list[i]); + } + pthread_mutex_unlock(&sync_mutex); +} + +KV_STATE_WITH_LAYER queryInternal(const std::vector& token_list, + int token) { + return kv_state_cache_builder->Query(client, token_list, token); +} + +KV_STATE_WITH_LAYER query(const std::vector& token_list, int token) { + LOG(INFO) << "query"; + KV_STATE_WITH_LAYER result; + if (pthread_mutex_trylock(&sync_mutex)) { + return result; + } + + result = queryInternal(token_list, token); + LOG(INFO) << "unlock"; + pthread_mutex_unlock(&sync_mutex); + LOG(INFO) << "query end"; + + return result; +} + +LIST_KV_STATE_WITH_LAYER query(const std::vector& token_list) { + LIST_KV_STATE_WITH_LAYER list_kv_state; + if (pthread_mutex_trylock(&sync_mutex)) { + return list_kv_state; + } + + std::vector token_list_copy; + for (size_t i = 0; i < token_list.size(); i++) { + KV_STATE_WITH_LAYER kv_state = + queryInternal(token_list_copy, token_list[i]); + list_kv_state.push_back(kv_state); + token_list_copy.push_back(token_list[i]); + } + + pthread_mutex_unlock(&sync_mutex); + return list_kv_state; +} + +void sync() { + LOG(INFO) << "sync"; + + // 1. gain the lock + std::string actural_key; + bool result; + client.TryAcquireLock(llm_cache_sync_lock, result, actural_key); + if (!result) { + LOG(INFO) << "failed to gain the lock, wait for next time"; + return; + } + // 2. pull the cache object + ObjectID global_kv_state_cache_id; + std::vector delete_list; + + std::shared_ptr global_kv_state_cache = nullptr; + Status status = + client.GetName(llm_cache_object_name, global_kv_state_cache_id); + if (status.ok()) { + delete_list.push_back(global_kv_state_cache_id); + global_kv_state_cache = std::dynamic_pointer_cast( + client.GetObject(global_kv_state_cache_id)); + } + + // 3. merge the cache object + std::shared_ptr merged_kv_state_cache_builder = + kv_state_cache_builder->Merge(client, global_kv_state_cache); + if (merged_kv_state_cache_builder == nullptr) { + merged_kv_state_cache_builder = kv_state_cache_builder; + } + + // 4. push the cache object + std::shared_ptr kv_state_cache = + merged_kv_state_cache_builder->_Seal(client); + client.Persist(kv_state_cache->id()); + + // 5. put the name of the new cache object to the meta server + LOG(INFO) << "stage 5"; + client.DropName(llm_cache_object_name); + status = client.PutName(kv_state_cache->id(), llm_cache_object_name); + if (status.ok()) { + LOG(INFO) << "put name success"; + } else { + LOG(INFO) << "put name failed with status:" + status.ToString(); + } + + LOG(INFO) << "stage 6"; + // 6. delete old cache object + client.DelData(delete_list); + + LOG(INFO) << "stage 7"; + // 7. create a global cache object replica + // TBD cache stragety + std::dynamic_pointer_cast(kv_state_cache)->Resolve(); + kv_state_cache_builder = std::make_shared( + client, std::dynamic_pointer_cast(kv_state_cache)); + + LOG(INFO) << "stage 8"; + // 8. release the lock + client.TryReleaseLock(actural_key, result); + VINEYARD_ASSERT(result == true); + + // TBD + // use lease to prevent the deadlock if the client is down +} + +void threadFunc() { + while (1) { + sleep(SYNC_INTERVAL); + if (exit_flag) { + break; + } + LOG(INFO) << "Try sync"; + pthread_mutex_lock(&sync_mutex); + sync(); + pthread_mutex_unlock(&sync_mutex); + // break; + } +} + +/* + a. vineyardd with global cache object | sealed + b. client get the object replica + c. client update replica + d. client seal the local object and try to push object to server (modified + sealed object and global cache version) â…°. if success + 1. vineyardd modify global object meta + 2. client reconstruct the local object replica + 3. goto c + â…±. if failed + 1. client pull the global object + 2. merge the object with local cache (e.g. create a new child_cache_object + and merge) + 3. goto d +*/ +/* node with attr node cache node data node + addr: 0x7e6fec007be0 0x7e6fec0077e0 0x7e6fec001100 0 + addr: 0x7e6fec007be0 0x7e6fec0077e0 0x7e6fec001100 0x7e6fec009ea0 + + + + 0x5654cf368c60 0x5654cf368f40 0x7e6fec001100 0x7e6fec009ea0 +*/ \ No newline at end of file diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h new file mode 100644 index 00000000..be346312 --- /dev/null +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -0,0 +1,33 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "kv-state-cache/ds/kv_state_cache.h" + +#ifndef MODULES_KV_STATE_CACHE_UTILS_H_ +#define MODULES_KV_STATE_CACHE_UTILS_H_ + +void initKVStateCache(int dimension = 10, int cache_capacity = 10); + +void update(const std::vector& token_list, int next_token, + const KV_STATE_WITH_LAYER& kv_state); + +void update(const std::vector& token_list, + const LIST_KV_STATE_WITH_LAYER& kv_state); + +KV_STATE_WITH_LAYER query(const std::vector& token_list, int token); + +LIST_KV_STATE_WITH_LAYER query(const std::vector& token_list); + +#endif \ No newline at end of file diff --git a/src/client/client.cc b/src/client/client.cc index 67dcc258..b949e8f7 100644 --- a/src/client/client.cc +++ b/src/client/client.cc @@ -1156,6 +1156,33 @@ Status Client::IsSpilled(ObjectID const& id, bool& is_spilled) { return Status::OK(); } +Status Client::TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + ENSURE_CONNECTED(this); + + std::string message_out; + WriteTryAcquireLockRequest(key, message_out); + VINEYARD_CHECK_OK(doWrite(message_out)); + + json message_in; + VINEYARD_CHECK_OK(doRead(message_in)); + VINEYARD_CHECK_OK(ReadTryAcquireLockReply(message_in, result, actural_key)); + return Status::OK(); +} + +Status Client::TryReleaseLock(std::string key, bool& result) { + ENSURE_CONNECTED(this); + + std::string message_out; + WriteTryReleaseLockRequest(key, message_out); + VINEYARD_CHECK_OK(doWrite(message_out)); + + json message_in; + VINEYARD_CHECK_OK(doRead(message_in)); + VINEYARD_CHECK_OK(ReadTryReleaseLockReply(message_in, result)); + return Status::OK(); +} + PlasmaClient::~PlasmaClient() {} // dummy implementation diff --git a/src/client/client.h b/src/client/client.h index e6f13b72..1b9d8eba 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -831,6 +831,24 @@ class Client final : public BasicIPCClient, */ Status GetGPUBuffer(const ObjectID id, const bool unsafe, std::shared_ptr& buffer); + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key); + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + Status TryReleaseLock(std::string key, bool& result); protected: /** @@ -1023,6 +1041,15 @@ class PlasmaClient final */ Status Delete(PlasmaID const& id); + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + return Status::NotImplemented(); + } + + Status TryReleaseLock(std::string key, bool& result) { + return Status::NotImplemented(); + } + protected: /** * @brief Required by `UsageTracker`. When reference count reaches zero, send diff --git a/src/client/client_base.h b/src/client/client_base.h index a3c28eda..8e1748e3 100644 --- a/src/client/client_base.h +++ b/src/client/client_base.h @@ -602,6 +602,25 @@ class ClientBase { */ const std::string& Version() const { return server_version_; } + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + virtual Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) = 0; + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + virtual Status TryReleaseLock(std::string key, bool& result) = 0; + /** * @brief Issue a debug request. * diff --git a/src/client/rpc_client.h b/src/client/rpc_client.h index f9cc00eb..88b90ea3 100644 --- a/src/client/rpc_client.h +++ b/src/client/rpc_client.h @@ -359,6 +359,30 @@ class RPCClient final : public ClientBase { Status GetRemoteBlobs( std::set const& ids, const bool unsafe, std::map>& remote_blobs); + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + // TBD + return Status::NotImplemented("TryAcquireLock is not implemented yet."); + } + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + Status TryReleaseLock(std::string key, bool& result) { + // TBD + return Status::NotImplemented("TryAcquireLock is not implemented yet."); + } private: InstanceID remote_instance_id_; diff --git a/src/common/util/protocols.cc b/src/common/util/protocols.cc index 0d16c68a..5c692c0d 100644 --- a/src/common/util/protocols.cc +++ b/src/common/util/protocols.cc @@ -206,6 +206,12 @@ const std::string command_t::SHALLOW_COPY_REPLY = "shallow_copy_reply"; const std::string command_t::DEBUG_REQUEST = "debug_command"; const std::string command_t::DEBUG_REPLY = "debug_reply"; +// distributed lock +const std::string command_t::ACQUIRE_LOCK_REQUEST = "acquire_lock_request"; +const std::string command_t::ACQUIRE_LOCK_REPLY = "acquire_lock_reply"; +const std::string command_t::RELEASE_LOCK_REQUEST = "release_lock_request"; +const std::string command_t::RELEASE_LOCK_REPLY = "release_lock_reply"; + void WriteErrorReply(Status const& status, std::string& msg) { encode_msg(status.ToJSON(), msg); } @@ -2133,7 +2139,8 @@ void WriteInstanceStatusReply(const json& meta, std::string& msg) { Status ReadInstanceStatusReply(const json& root, json& meta) { CHECK_IPC_ERROR(root, command_t::INSTANCE_STATUS_REPLY); - meta = root["meta"]; + meta = root["meta"].get(); + ; return Status::OK(); } @@ -2258,4 +2265,62 @@ Status ReadDebugReply(const json& root, json& result) { return Status::OK(); } +void WriteTryAcquireLockRequest(const std::string& key, std::string& msg) { + json root; + root["type"] = command_t::ACQUIRE_LOCK_REQUEST; + root["key"] = key; + encode_msg(root, msg); +} + +Status ReadTryAcquireLockRequest(const json& root, std::string& key) { + CHECK_IPC_ERROR(root, command_t::ACQUIRE_LOCK_REQUEST); + key = root["key"].get(); + return Status::OK(); +} + +void WriteTryAcquireLockReply(const bool result, const std::string actual_key, + std::string& msg) { + json root; + root["type"] = command_t::ACQUIRE_LOCK_REPLY; + root["key"] = actual_key; + root["result"] = result; + encode_msg(root, msg); +} + +Status ReadTryAcquireLockReply(const json& root, bool& result, + std::string& key) { + CHECK_IPC_ERROR(root, command_t::ACQUIRE_LOCK_REPLY); + result = root["result"].get(); + key = root["key"].get(); + return Status::OK(); +} + +void WriteTryReleaseLockRequest(const std::string& key, std::string& msg) { + json root; + root["type"] = command_t::RELEASE_LOCK_REQUEST; + root["key"] = key; + encode_msg(root, msg); +} + +Status ReadTryReleaseLockRequest(const json& root, std::string& key) { + CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REQUEST); + key = root["key"].get(); + ; + return Status::OK(); +} + +void WriteTryReleaseLockReply(const bool result, std::string& msg) { + json root; + root["type"] = command_t::RELEASE_LOCK_REPLY; + root["result"] = result; + encode_msg(root, msg); +} + +Status ReadTryReleaseLockReply(const json& root, bool& result) { + CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REPLY); + result = root["result"].get(); + ; + return Status::OK(); +} + } // namespace vineyard diff --git a/src/common/util/protocols.h b/src/common/util/protocols.h index d8a89bea..37fc581f 100644 --- a/src/common/util/protocols.h +++ b/src/common/util/protocols.h @@ -171,6 +171,12 @@ struct command_t { static const std::string SHALLOW_COPY_REPLY; static const std::string DEBUG_REQUEST; static const std::string DEBUG_REPLY; + + // distributed lock + static const std::string ACQUIRE_LOCK_REQUEST; + static const std::string ACQUIRE_LOCK_REPLY; + static const std::string RELEASE_LOCK_REQUEST; + static const std::string RELEASE_LOCK_REPLY; }; enum class StoreType { @@ -820,6 +826,24 @@ void WriteDebugReply(const json& result, std::string& msg); Status ReadDebugReply(const json& root, json& result); +void WriteTryAcquireLockRequest(const std::string& key, std::string& msg); + +Status ReadTryAcquireLockRequest(const json& root, std::string& key); + +void WriteTryAcquireLockReply(const bool result, const std::string actual_key, + std::string& msg); + +Status ReadTryAcquireLockReply(const json& root, bool& result, + std::string& key); + +void WriteTryReleaseLockRequest(const std::string& key, std::string& msg); + +Status ReadTryReleaseLockRequest(const json& root, std::string& key); + +void WriteTryReleaseLockReply(const bool result, std::string& msg); + +Status ReadTryReleaseLockReply(const json& root, bool& result); + } // namespace vineyard #endif // SRC_COMMON_UTIL_PROTOCOLS_H_ diff --git a/src/server/async/socket_server.cc b/src/server/async/socket_server.cc index a0664a78..36b58ac1 100644 --- a/src/server/async/socket_server.cc +++ b/src/server/async/socket_server.cc @@ -359,6 +359,10 @@ bool SocketConnection::processMessage(const std::string& message_in) { return doShallowCopy(root); } else if (cmd == command_t::DEBUG_REQUEST) { return doDebug(root); + } else if (cmd == command_t::ACQUIRE_LOCK_REQUEST) { + return doAcquireLock(root); + } else if (cmd == command_t::RELEASE_LOCK_REQUEST) { + return doReleaseLock(root); } else { RESPONSE_ON_ERROR(Status::Invalid("Got unexpected command: " + cmd)); return false; @@ -1760,6 +1764,46 @@ bool SocketConnection::doDebug(const json& root) { return false; } +bool SocketConnection::doAcquireLock(const json& root) { + auto self(shared_from_this()); + std::string key; + TRY_READ_REQUEST(ReadTryAcquireLockRequest, root, key); + + RESPONSE_ON_ERROR(server_ptr_->TryAcquireLock( + key, [self](const Status& status, bool result, std::string actual_key) { + std::string message_out; + if (status.ok()) { + WriteTryAcquireLockReply(result, actual_key, message_out); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + } + self->doWrite(message_out); + return Status::OK(); + })); + return false; +} + +bool SocketConnection::doReleaseLock(const json& root) { + auto self(shared_from_this()); + std::string key; + TRY_READ_REQUEST(ReadTryReleaseLockRequest, root, key); + + RESPONSE_ON_ERROR(server_ptr_->TryReleaseLock( + key, [self](const Status& status, bool result) { + std::string message_out; + if (status.ok()) { + WriteTryReleaseLockReply(result, message_out); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + } + self->doWrite(message_out); + return Status::OK(); + })); + return false; +} + void SocketConnection::doWrite(const std::string& buf) { std::string to_send; size_t length = buf.size(); diff --git a/src/server/async/socket_server.h b/src/server/async/socket_server.h index d712df15..a5291582 100644 --- a/src/server/async/socket_server.h +++ b/src/server/async/socket_server.h @@ -142,6 +142,9 @@ class SocketConnection : public std::enable_shared_from_this { bool doDebug(json const& root); + bool doAcquireLock(json const& root); + bool doReleaseLock(json const& root); + protected: template Status MoveBuffers(std::map mapping, diff --git a/src/server/server/vineyard_server.cc b/src/server/server/vineyard_server.cc index 043c30b7..063f4ee2 100644 --- a/src/server/server/vineyard_server.cc +++ b/src/server/server/vineyard_server.cc @@ -1054,6 +1054,40 @@ Status VineyardServer::MigrateObject(const ObjectID object_id, return Status::OK(); } +Status VineyardServer::TryAcquireLock(std::string& key, + callback_t callback) { + ENSURE_VINEYARDD_READY(); + auto self(shared_from_this()); + meta_service_ptr_->TryAcquireLock( + key, [self, callback](const Status& status, bool result, + std::string actural_key) { + if (status.ok()) { + LOG(INFO) << "No error occurred. Gain lock:" << result; + return callback(status, result, actural_key); + } else { + return callback(status, result, actural_key); + } + }); + + return Status::OK(); +} + +Status VineyardServer::TryReleaseLock(std::string& key, + callback_t callback) { + ENSURE_VINEYARDD_READY(); + auto self(shared_from_this()); + meta_service_ptr_->TryReleaseLock( + key, [self, callback](const Status& status, bool result) { + if (status.ok()) { + LOG(INFO) << "No error occurred. Release lock:" << result; + return callback(status, result); + } else { + return status; + } + }); + return Status::OK(); +} + Status VineyardServer::LabelObjects(const ObjectID object_id, const std::vector& keys, const std::vector& values, diff --git a/src/server/server/vineyard_server.h b/src/server/server/vineyard_server.h index c9e93663..a51a8323 100644 --- a/src/server/server/vineyard_server.h +++ b/src/server/server/vineyard_server.h @@ -193,6 +193,11 @@ class VineyardServer : public std::enable_shared_from_this { Status Verify(const std::string& username, const std::string& password, callback_t<> callback); + Status TryAcquireLock(std::string& key, + callback_t callback); + + Status TryReleaseLock(std::string& key, callback_t callback); + inline SessionID session_id() const { return session_id_; } inline InstanceID instance_id() { return instance_id_; } inline std::string instance_name() { return instance_name_; } diff --git a/src/server/services/etcd_meta_service.cc b/src/server/services/etcd_meta_service.cc index 161ed9fc..76e71e20 100644 --- a/src/server/services/etcd_meta_service.cc +++ b/src/server/services/etcd_meta_service.cc @@ -151,6 +151,48 @@ void EtcdMetaService::Stop() { } } +void EtcdMetaService::TryAcquireLock( + std::string key, callback_t callback_after_try_lock) { + LOG(INFO) << "TryAcquireLock, key:" << key; + auto self(shared_from_base()); + + etcd_->lock(prefix_ + key) + .then([self, callback_after_try_lock]( + pplx::task const& resp_task) { + auto const& resp = resp_task.get(); + if (resp.is_ok()) { + LOG(INFO) << "lock success! key is :" + resp.lock_key(); + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_lock, Status::OK(), true, + resp.lock_key().substr(self->prefix_.size()))); + } else { + LOG(INFO) << "lock falied!"; + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_lock, Status::OK(), false, "")); + } + }); +} + +void EtcdMetaService::TryReleaseLock( + std::string key, callback_t callback_after_try_unlock) { + auto self(shared_from_base()); + + etcd_->unlock(prefix_ + key) + .then([self, callback_after_try_unlock]( + pplx::task const& resp_task) { + auto const& resp = resp_task.get(); + if (resp.is_ok()) { + LOG(INFO) << "unlock success!"; + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::OK(), true)); + } else { + LOG(INFO) << "unlock failed!"; + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::OK(), true)); + } + }); +} + void EtcdMetaService::requestLock( std::string lock_name, callback_t> callback_after_locked) { diff --git a/src/server/services/etcd_meta_service.h b/src/server/services/etcd_meta_service.h index 8fa1403d..4b4624b8 100644 --- a/src/server/services/etcd_meta_service.h +++ b/src/server/services/etcd_meta_service.h @@ -127,6 +127,10 @@ class EtcdMetaService : public IMetaService { ~EtcdMetaService() override {} + void TryAcquireLock(std::string key, callback_t); + + void TryReleaseLock(std::string key, callback_t); + protected: explicit EtcdMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr), diff --git a/src/server/services/local_meta_service.h b/src/server/services/local_meta_service.h index be6c2fd1..50136a64 100644 --- a/src/server/services/local_meta_service.h +++ b/src/server/services/local_meta_service.h @@ -50,6 +50,18 @@ class LocalMetaService : public IMetaService { ~LocalMetaService() override {} + void TryAcquireLock(std::string key, + callback_t callback_after_try_locked) { + // TBD + assert(false); + } + + void TryReleaseLock(std::string key, + callback_t callback_after_try_unlocked) { + // TBD + assert(false); + } + protected: explicit LocalMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr) {} diff --git a/src/server/services/meta_service.h b/src/server/services/meta_service.h index 294bcb35..419888e6 100644 --- a/src/server/services/meta_service.h +++ b/src/server/services/meta_service.h @@ -142,6 +142,11 @@ class IMetaService : public std::enable_shared_from_this { bool stopped() const { return this->stopped_.load(); } + virtual void TryAcquireLock(std::string key, + callback_t callback) = 0; + + virtual void TryReleaseLock(std::string key, callback_t callback) = 0; + private: void registerToEtcd(); diff --git a/src/server/services/redis_meta_service.h b/src/server/services/redis_meta_service.h index b3d3b65f..d1fa6749 100644 --- a/src/server/services/redis_meta_service.h +++ b/src/server/services/redis_meta_service.h @@ -177,6 +177,18 @@ class RedisMetaService : public IMetaService { inline void Stop() override; ~RedisMetaService() override {} + void TryAcquireLock(std::string key, + callback_t callback_after_try_locked) { + // TBD + assert(false); + } + + void TryReleaseLock(std::string key, + callback_t callback_after_try_unlocked) { + // TBD + assert(false); + } + protected: explicit RedisMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr), diff --git a/test/distributed_lock_test.cc b/test/distributed_lock_test.cc new file mode 100644 index 00000000..ef81bf2a --- /dev/null +++ b/test/distributed_lock_test.cc @@ -0,0 +1,66 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" + +using namespace vineyard; + +int numThreads = 5; + +static int count = 0; + +void test(int i) { + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); + Client client; + client.Connect(socket); + + bool result; + std::string actural_key_of_lock; + + LOG(INFO) << "Thread: " << i << " try to acquire lock: test"; + client.TryAcquireLock("test", result, actural_key_of_lock); + LOG(INFO) << "Thread: " << i + << " acquire Lock: " << (result == true ? "success" : "fail") + << ", key is :" + actural_key_of_lock; + + if (result) { + count++; + LOG(INFO) << "count: " << count; + + sleep(3); + + LOG(INFO) << "Thread: " << i << " try to release lock: test"; + client.TryReleaseLock(actural_key_of_lock, result); + LOG(INFO) << "Thread: " << i + << " release Lock: " << (result == true ? "success" : "fail"); + } +} + +int main() { + std::thread threads[numThreads]; + for (int i = 0; i < numThreads; i++) { + threads[i] = std::thread(test, i); + } + + for (int i = 0; i < numThreads; i++) { + threads[i].join(); + } + + return 0; +} \ No newline at end of file diff --git a/test/kv_state_cache_object_test.cc b/test/kv_state_cache_object_test.cc new file mode 100644 index 00000000..60bdd3e0 --- /dev/null +++ b/test/kv_state_cache_object_test.cc @@ -0,0 +1,174 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "basic/ds/tensor.h" +#include "common/util/logging.h" +#include "kv-state-cache/ds/kv_state_cache.h" + +using namespace vineyard; + +std::vector tokens; +RadixTree* radix_tree; +std::vector> k_state_list; +std::vector> v_state_list; +std::vector> nodes_with_tree_attri_list; + +#define DIMENSION 10 +#define TOKEN_NUM 10 +#define CACHE_CAPACITY 10 + +void prepareData(KVStateCacheBuilder* kv_state_cache_builder) { + radix_tree = new RadixTree(10); + radix_tree->SetCustomData(kv_state_cache_builder, + sizeof(KVStateCacheBuilder)); + + for (int i = 0; i < TOKEN_NUM; i++) { + tokens.push_back(i); + } + + LOG(INFO) << "stage 1"; + for (int i = 0; i < TOKEN_NUM; i++) { + std::vector key_state; + for (int j = 0; j < DIMENSION; ++j) { + key_state.push_back(((double) (j)) * 0.1 + (double) i); + } + k_state_list.push_back(key_state); + } + + LOG(INFO) << "stage 2"; + for (int i = 0; i < TOKEN_NUM; i++) { + std::vector value_state; + for (int j = 0; j < DIMENSION; ++j) { + value_state.push_back(((double) (j)) * 0.1 + (double) i); + } + v_state_list.push_back(value_state); + } +} + +void updateTest(Client& client, KVStateCacheBuilder* builder) { + std::vector prefix; + + for (size_t i = 0; i < tokens.size(); ++i) { + KV_STATE_WITH_LAYER kv_state; + kv_state.insert( + std::make_pair(1, std::make_pair(k_state_list[i], v_state_list[i]))); + LOG(INFO) << "update test"; + builder->Update(client, prefix, tokens[i], kv_state); + prefix.push_back(tokens[i]); + } +} + +void queryTest(Client& client, KVStateCacheBuilder* builder) { + std::vector prefix; + KV_STATE_WITH_LAYER kv_state; + + for (int i = 0; i < TOKEN_NUM; i++) { + kv_state = builder->Query(client, prefix, tokens[i]); + std::vector key_state = kv_state[1].first; + std::vector value_state = kv_state[1].second; + + VINEYARD_ASSERT( + key_state.size() == (size_t) DIMENSION, + "Expected key_state.size() == " + std::to_string(DIMENSION) + + ", but got + key_state.size() == " + + std::to_string(key_state.size())); + VINEYARD_ASSERT( + value_state.size() == (size_t) DIMENSION, + "Expected value_state.size() == " + std::to_string(DIMENSION) + + ", but got + value_state.size() == " + + std::to_string(value_state.size())); + for (int j = 0; j < DIMENSION; ++j) { + VINEYARD_ASSERT(key_state[j] == k_state_list[i][j], + "Expected key_state[" + std::to_string(j) + + "] == " + std::to_string(k_state_list[i][j]) + + ", but got + key_state[" + std::to_string(j) + + "] == " + std::to_string(key_state[j])); + VINEYARD_ASSERT(value_state[j] == v_state_list[i][j], + "Expected value_state[" + std::to_string(j) + + "] == " + std::to_string(v_state_list[i][j]) + + ", but got + value_state[" + std::to_string(j) + + "] == " + std::to_string(value_state[j])); + } + prefix.push_back(tokens[i]); + } +} + +void sealAndConstructTest(Client& client, KVStateCacheBuilder* builder) { + ObjectID id = builder->_Seal(client)->id(); + std::shared_ptr kv_state_cache = + std::dynamic_pointer_cast(client.GetObject(id)); + std::shared_ptr kv_state_cache_block = + kv_state_cache->GetKVStateCacheBlock(); + std::shared_ptr kv_state_cache_block_builder = + builder->GetKVStateCacheBlockBuilder(); + + // compare kv_state_cache_block and kv_state_cache_block_builder + VINEYARD_ASSERT(kv_state_cache_block->GetDimension() == + kv_state_cache_block_builder->GetDimension()); + + VINEYARD_ASSERT(kv_state_cache_block->GetBitmap() == + kv_state_cache_block_builder->GetBitmap()); + + LOG(INFO) << "Bitmap:"; + LOG(INFO) << kv_state_cache_block_builder->GetBitmapStr(); + LOG(INFO) << kv_state_cache_block->GetBitmapStr(); + + const std::shared_ptr> k_tensor_builder = + kv_state_cache_block_builder->getKBuilder(); + const std::shared_ptr> v_tensor_builder = + kv_state_cache_block_builder->getVBuilder(); + + std::shared_ptr> k_tensor = + kv_state_cache_block->GetKTensor(); + std::shared_ptr> v_tensor = + kv_state_cache_block->GetVTensor(); + + for (int i = 0; i < TOKEN_NUM; i++) { + for (int j = 0; j < DIMENSION; j++) { + VINEYARD_ASSERT(k_tensor->data()[i * DIMENSION + j] == + k_tensor_builder->data()[i * DIMENSION + j]); + VINEYARD_ASSERT(v_tensor->data()[i * DIMENSION + j] == + v_tensor_builder->data()[i * DIMENSION + j]); + } + } +} + +void splitTest(Client& client, KVStateCacheBuilder* builder) {} + +int main() { + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); + Client client; + client.Connect(socket); + + LOG(INFO) << "Build kv state cache"; + KVStateCacheBuilder* kv_state_cache_builder = + new KVStateCacheBuilder(client, DIMENSION, CACHE_CAPACITY); + + LOG(INFO) << "Prepare data"; + prepareData(kv_state_cache_builder); + + LOG(INFO) << "Test update"; + updateTest(client, kv_state_cache_builder); + + LOG(INFO) << "Test query"; + queryTest(client, kv_state_cache_builder); + + LOG(INFO) << "Test seal and construct"; + sealAndConstructTest(client, kv_state_cache_builder); + + return 0; +} \ No newline at end of file diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc new file mode 100644 index 00000000..d2b6790a --- /dev/null +++ b/test/kv_state_cache_test.cc @@ -0,0 +1,110 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "common/util/logging.h" +#include "kv-state-cache/utils/kv_state_cache_utils.h" + +using namespace vineyard; + +#define DEMENSION 10 + +void init() { initKVStateCache(DEMENSION); } + +void print_current_tokens(const std::vector& prefix, int next_token) { + std::string tokens_str = ""; + for (size_t i = 0; i < prefix.size(); ++i) { + tokens_str += std::to_string(prefix[i]); + } + tokens_str += std::to_string(next_token); + LOG(INFO) << "Current tokens: " + tokens_str; + LOG(INFO) << tokens_str; +} + +void print_kv_state( + const std::map, std::vector>>& + kv_state) { + LOG(INFO) << "kv_state: "; + for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { + std::string key_state_str = ""; + std::string value_state_str = ""; + for (int i = 0; i < DEMENSION; ++i) { + key_state_str += std::to_string(iter->second.first[i]) + " "; + value_state_str += std::to_string(iter->second.second[i]) + " "; + } + LOG(INFO) << "key_state: " << key_state_str; + LOG(INFO) << "value_state: " << value_state_str; + } +} + +// we do not consider the layer. +std::map, std::vector>> +generate_kv_state(int token) { + std::vector key_state; + std::vector value_state; + for (int i = 0; i < DEMENSION; ++i) { + key_state.push_back(((double) token) / DEMENSION * (i + 1)); + value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2); + } + + std::map, std::vector>> kv_state; + kv_state.insert(std::make_pair(1, std::make_pair(key_state, value_state))); + return kv_state; +} + +void inference(std::vector tokens, bool block = false) { + LOG(INFO) << "inference"; + std::vector inference_tokens; + std::map, std::vector>> kv_state; + + for (size_t i = 0; i < tokens.size(); ++i) { + kv_state = query(inference_tokens, tokens[i]); + if (kv_state.size() == 0) { + LOG(INFO) << "======================================"; + LOG(INFO) << "Can not find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + LOG(INFO) << "Generate the kv_state and update the cache."; + kv_state = generate_kv_state(tokens[i]); + update(inference_tokens, tokens[i], kv_state); + print_kv_state(kv_state); + LOG(INFO) << "======================================"; + } else { + LOG(INFO) << "--------------------------------------"; + LOG(INFO) << "Find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + print_kv_state(kv_state); + LOG(INFO) << "--------------------------------------"; + } + inference_tokens.push_back(tokens[i]); + } +} + +int main() { + init(); + std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + // std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10}; + inference(round_1_tokens); + inference(round_1_tokens); + sleep(5); + // inference(round_2_tokens); + // inference(round_2_tokens); + inference(round_1_tokens, true); + while (1) + ; + return 0; +} \ No newline at end of file From d2e6803c1a35208915693000e84a7c2e173c51c9 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Fri, 26 Jan 2024 14:37:47 +0800 Subject: [PATCH 02/20] Improve the split logic and implement the serialization and deserialization of radix subtree. (#1730) Signed-off-by: Ye Cao --- modules/kv-state-cache/ds/kv_state_cache.cc | 41 +- .../kv-state-cache/ds/kv_state_cache_block.h | 2 +- .../kv-state-cache/radix-tree/radix-tree.h | 371 +++++++++++++----- modules/kv-state-cache/radix-tree/radix.cc | 281 ++++++++++++- modules/kv-state-cache/radix-tree/radix.h | 202 +++++----- .../utils/kv_state_cache_utils.cc | 2 +- 6 files changed, 693 insertions(+), 206 deletions(-) diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index c4df7d00..af0e4522 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -20,6 +20,7 @@ limitations under the License. #include "common/util/logging.h" #include "common/util/status.h" #include "kv-state-cache/radix-tree/radix-tree.h" +#include "kv-state-cache/radix-tree/radix.h" #include "kv_state_cache.h" namespace vineyard { @@ -43,8 +44,11 @@ void KVStateCache::Resolve() { // 2. construct the radix tree this->root_tree = RadixTree::Deserialize( base64_decode(this->meta_.GetKeyValue("radix_tree"))); + LOG(INFO) << "Resolve RadixTree success" << std::endl; + raxShow(this->root_tree->GetTree()); // 3. construct the member field this->dimension = this->meta_.GetKeyValue("dimension"); + LOG(INFO) << "construct the member field success" << std::endl; } KVStateCache::~KVStateCache() { @@ -85,27 +89,39 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( KVStateCacheBlockBuilder* child_kv_state_cache_block_builder = new KVStateCacheBlockBuilder(client, this->dimension); for (size_t i = 0; i < node_with_tree_attri_list.size(); i++) { + LOG(INFO) << "transfer node:" << i; + LOG(INFO) << "node:" << node_with_tree_attri_list[i]->get_node(); + LOG(INFO) << "data:" << node_with_tree_attri_list[i]->get_node()->get_data(); offset_data* data = (offset_data*) node_with_tree_attri_list[i]->get_node()->get_data(); + if (data == nullptr) + continue; int index = data->offset; + LOG(INFO) << "stage 0"; // Transfer the data from this builder to the child builder. const std::shared_ptr> k_builder = kv_state_cache_block_builder->getKBuilder(); const std::shared_ptr> v_builder = kv_state_cache_block_builder->getVBuilder(); + LOG(INFO) << "stage 0.5"; offset_data* new_offset_data = new offset_data(); child_kv_state_cache_block_builder->Update( k_builder->data() + index * this->dimension, v_builder->data() + index * this->dimension, - this->dimension * sizeof(double), new_offset_data); + this->dimension, new_offset_data); + LOG(INFO) << "stage 1"; node_with_tree_attri_list[i]->get_node()->set_data(new_offset_data, sizeof(offset_data)); // Clear the bitmap. + LOG(INFO) << "stage 2, index:" << index; kv_state_cache_block_builder->DeleteKVCache(index); + LOG(INFO) << "bitmap:" << kv_state_cache_block_builder->GetBitmapStr(); } + LOG(INFO) << "stage 3"; kv_state_cache_block_builder->SetChildKVStateCacheBlockBuilder( child_kv_state_cache_block_builder); + LOG(INFO) << "stage 4"; return child_kv_state_cache_block_builder; } @@ -125,16 +141,23 @@ void KVStateCacheBuilder::Update(Client& client, LOG(INFO) << "insert failed"; return; } - std::shared_ptr sub_tree = node_with_tree_attri->get_tree(); + LOG(INFO) << "stage 1"; + std::shared_ptr tree = node_with_tree_attri->get_tree(); + LOG(INFO) << "stage 2"; KVStateCacheBlockBuilder* kv_state_cache_block_builder = - (KVStateCacheBlockBuilder*) sub_tree->GetCustomData(); + (KVStateCacheBlockBuilder*) tree->GetCustomData(); + LOG(INFO) << "stage 3"; if (evicted_node != nullptr) { + LOG(INFO) << "stage 4"; offset_data* data = (offset_data*) evicted_node->get_node()->get_data(); + LOG(INFO) << "stage 5"; KVStateCacheBlockBuilder* builder = (KVStateCacheBlockBuilder*) evicted_node->get_tree()->GetCustomData(); + LOG(INFO) << "stage 6"; builder->DeleteKVCache(data->offset); - delete (offset_data*) evicted_node->get_node()->get_data(); + if ((offset_data*) evicted_node->get_node()->get_data() != nullptr) + delete (offset_data*) evicted_node->get_node()->get_data(); } // TBD @@ -147,25 +170,33 @@ void KVStateCacheBuilder::Update(Client& client, * empty node from the radix tree and split the tree. Then, kv-state cache * split according to the new tree. */ + LOG(INFO) << "triggle splits"; std::shared_ptr evicted_node = nullptr; this->root_tree->Delete(token_list_copy, evicted_node); - std::shared_ptr new_tree = sub_tree->Split(token_list_copy); + std::shared_ptr new_tree = tree->Split(token_list_copy); + LOG(INFO) << "tree split success"; std::vector> node_with_tree_attri_list = RadixTree::TraverseTreeWithoutSubTree(new_tree); KVStateCacheBlockBuilder* new_kv_state_cache_block_builder = Split(client, kv_state_cache_block_builder, node_with_tree_attri_list); new_tree->SetCustomData(new_kv_state_cache_block_builder, sizeof(KVStateCacheBlockBuilder)); + LOG(INFO) << "block split success"; // kv_state_cache_builder->UnLock(); Update(client, token_list, next_token, kv_state); } else { // Update the kv-state cache. + LOG(INFO) << "update kv-state cache"; offset_data* data = new offset_data(); + LOG(INFO) << "stage 7"; kv_state_cache_block_builder->Update(kv_state, data); + LOG(INFO) << "stage 8"; std::shared_ptr node = node_with_tree_attri->get_node(); + LOG(INFO) << "stage 9"; node->set_data(data, sizeof(offset_data)); + LOG(INFO) << "stage 10"; } LOG(INFO) << "bitmap:" << kv_state_cache_block_builder->GetBitmapStr(); diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index 222c2ba4..7758ee19 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -163,4 +163,4 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { } // namespace vineyard -#endif \ No newline at end of file +#endif diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index ce46360e..f418cfb3 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -30,10 +30,15 @@ limitations under the License. using namespace vineyard; +typedef struct customData { + int data_length; + void* data; +} customData; + typedef struct nodeData { int data_length; void* data; - std::shared_ptr cache_node; + //std::shared_ptr cache_node; } nodeData; class Node { @@ -62,22 +67,22 @@ class Node { raxSetData(this->node, this->data); } - void set_cache_node(std::shared_ptr cache_node) { - if (this->node == NULL) { - LOG(INFO) << "set data failed, node is null"; - return; - } - this->data->cache_node = cache_node; - raxSetData(this->node, this->data); - } + //void set_cache_node(std::shared_ptr cache_node) { + // if (this->node == NULL) { + // LOG(INFO) << "set data failed, node is null"; + // return; + // } + // this->data->cache_node = cache_node; + // raxSetData(this->node, this->data); + //} void* get_data() { return this->data->data; } int get_data_length() { return this->data->data_length; } - std::shared_ptr get_cache_node() { - return this->data->cache_node; - } + //std::shared_ptr get_cache_node() { + // return this->data->cache_node; + //} }; class RadixTree; @@ -101,8 +106,6 @@ class NodeWithTreeAttri { class RadixTree : public std::enable_shared_from_this { private: - void* custom_data; - int custom_data_length; // the whole radix tree for prefix match rax* tree; // the sub tree for mapping a vineyard object @@ -113,24 +116,47 @@ class RadixTree : public std::enable_shared_from_this { RadixTree(int cache_capacity) { LOG(INFO) << "init radix tree"; this->tree = raxNew(); + if (!raxIsSubtreeAllocated(this->tree->head)) { + raxNode *new_root = raxReallocForSubtreeCustomData(this->tree->head); + this->tree->head = new_root; + raxSetSubtree(this->tree->head); + raxSetSubtreeAllocated(this->tree->head); + } + // this->sub_tree = this->tree; + lru_strategy = new LRUStrategy(cache_capacity); + } + + RadixTree(rax* rax_tree, int cache_capacity) { + LOG(INFO) << "init radix tree"; + this->tree = rax_tree; // this->sub_tree = this->tree; - this->custom_data = NULL; - this->custom_data_length = 0; lru_strategy = new LRUStrategy(cache_capacity); } RadixTree(void* custom_data, int custom_data_length, int cache_capacity) { LOG(INFO) << "init radix tree with custom data"; this->tree = raxNew(); + if (!raxIsSubtreeAllocated(this->tree->head)) { + raxNode *new_root = raxReallocForSubtreeCustomData(this->tree->head); + this->tree->head = new_root; + raxSetSubtree(this->tree->head); + raxSetSubtreeAllocated(this->tree->head); + } // this->sub_tree = this->tree; - this->custom_data = custom_data; - this->custom_data_length = custom_data_length; this->lru_strategy = new LRUStrategy(cache_capacity); } ~RadixTree() { - // TBD - // free all the node and the whole tree. + raxFreeWithCallback(this->tree, [](raxNode *n) { + if (n->iskey && !n->isnull) { + nodeData* nodedata = (nodeData*) raxGetData(n); + delete nodedata; + } + if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { + customData* customdata = (customData*) raxGetCustomData(n); + delete customdata; + } + }); } std::shared_ptr Insert( @@ -146,32 +172,37 @@ class RadixTree : public std::enable_shared_from_this { this->tree, insert_tokens_array, insert_tokens_array_len, dummy_data, (void**) &dataNode, (void**) &old_data); if (dataNode == NULL) { - LOG(INFO) << "insert failed"; + throw std::runtime_error("Insert token list failed"); return NULL; } LOG(INFO) << "insert success"; - if (retval == 0) { + //if (retval == 0) { // (retval == 0 ) means the token vector already exists in the radix tree // remove the token vector from the lru cache as it will be inserted again - std::shared_ptr node = std::make_shared(old_data); - std::shared_ptr cache_node = node->get_cache_node(); - lru_strategy->Remove(cache_node); - delete old_data; - } + // std::shared_ptr node = std::make_shared(old_data); + // LOG(INFO) << "delete cache_node"; + // std::shared_ptr cache_node = node->get_cache_node(); + // LOG(INFO) << "delete cache_node 1"; + // lru_strategy->Remove(cache_node); + // LOG(INFO) << "delete cache_node 2"; + // delete old_data; + //} + //LOG(INFO) << "delete cache_node success"; // refresh the lru cache - std::vector evicted_tokens; - std::shared_ptr cache_node = - lru_strategy->InsertToHeader(tokens, evicted_tokens); - if (cache_node == nullptr) { - LOG(INFO) << "WTF?"; - } - dummy_data->cache_node = cache_node; + //std::vector evicted_tokens; + //std::shared_ptr cache_node = + // lru_strategy->InsertToHeader(tokens, evicted_tokens); + //if (cache_node == nullptr) { + // LOG(INFO) << "WTF?"; + //} + //dummy_data->cache_node = cache_node; raxSetData(dataNode, dummy_data); - if (evicted_tokens.size() > 0) { - this->Delete(evicted_tokens, evicted_node); - } + //if (evicted_tokens.size() > 0) { + // this->Delete(evicted_tokens, evicted_node); + //} + //LOG(INFO) << "refresh cache_node success"; return std::make_shared(std::make_shared(dataNode), shared_from_this()); @@ -218,27 +249,44 @@ class RadixTree : public std::enable_shared_from_this { // refresh the lru cache std::shared_ptr node = std::make_shared(dataNode); - std::shared_ptr cache_node = node->get_cache_node(); - lru_strategy->MoveToHead(cache_node); + //std::shared_ptr cache_node = node->get_cache_node(); + //lru_strategy->MoveToHead(cache_node); return std::make_shared(node, shared_from_this()); } std::string Serialize() { + LOG(INFO) << "Serialize......"; std::vector> token_list; std::vector data_list; - raxSerialize(this->tree, token_list, data_list); + std::vector timestamp_list; + std::vector> sub_tree_token_list; + std::vector sub_tree_data_list; + raxSerialize(this->tree, token_list, data_list, timestamp_list, &sub_tree_token_list, + &sub_tree_data_list); - std::map, bool> cache_node_map; - std::shared_ptr current_node = - this->lru_strategy->GetHeader(); + raxShow(this->tree); + //std::map, bool> cache_node_map; + //std::shared_ptr current_node = + // this->lru_strategy->GetHeader(); // the string format is: - // [token list] [data hex string]\n + // [token list]|[timestamp]|[data hex string]\n + // ... + // [token list]|[timestamp]|[data hex string]\n + // \t\n + // [subtree token list]|[timestamp]|[custom data string]\n + // ... + // [subtree token list]|[timestamp]|[custom data string]\n // E.g // tokens | data - // 1,2|0800000008000000xxxx + // 1|0000000001|0800000008000000xxxx\n + // 1,2|0000000002|0800000008000000xxxx\n + // 1,2,3|0000000002|0800000008000000xxxx\n + // \t\n + // 1,2|0000000003|0800000008000000xxxx\n std::string serialized_str; + /* while (current_node != nullptr) { cache_node_map[current_node] = true; auto it = std::lower_bound(token_list.begin(), token_list.end(), @@ -269,6 +317,56 @@ class RadixTree : public std::enable_shared_from_this { } current_node = current_node->next; } + */ + + if (token_list.size() != data_list.size()) { + throw std::runtime_error("The size of token list and data list is not equal"); + } + for (size_t index = 0; index < token_list.size(); index++) { + for (size_t j = 0; j < token_list[index].size(); j++) { + serialized_str += std::to_string(token_list[index][j]); + if (j < token_list[index].size() - 1) { + serialized_str += ","; + } + } + serialized_str += "|"; + + // convert timestamp(uint64) to hex string + uint64_t timestamp = timestamp_list[index]; + std::ostringstream timestamp_oss; + timestamp_oss << std::hex << timestamp; + + serialized_str += timestamp_oss.str() + "|"; + + // convert data to hex string + char* bytes = (char*) ((nodeData*) data_list[index])->data; + std::ostringstream data_oss; + + for (size_t i = 0; i < ((nodeData*)data_list[index])->data_length; i++) { + data_oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); + } + serialized_str += data_oss.str() + "\n"; + } + + serialized_str += "\t\n"; + + for (size_t index = 0; index < sub_tree_token_list.size(); index++) { + for (size_t j = 0; j < sub_tree_token_list[index].size(); j++) { + serialized_str += std::to_string(sub_tree_token_list[index][j]); + if (j < sub_tree_token_list[index].size() - 1) { + serialized_str += ","; + } + } + serialized_str += "|"; + // convert custom data to hex string + char* bytes = (char*) ((customData*) sub_tree_data_list[index])->data; + std::ostringstream data_oss; + + for (size_t i = 0; i < ((customData*)sub_tree_data_list[index])->data_length; ++i) { + data_oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); + } + serialized_str += data_oss.str() + "\n"; + } // use LZ4 to compress the serialized string const char* const src = serialized_str.c_str(); @@ -306,6 +404,7 @@ class RadixTree : public std::enable_shared_from_this { } static std::shared_ptr Deserialize(std::string data) { + LOG(INFO) << "Deserialize......"; // use LZ4 to decompress the serialized string int src_size = *(int*) data.c_str(); data.erase(0, sizeof(int)); @@ -332,17 +431,34 @@ class RadixTree : public std::enable_shared_from_this { std::vector> token_list; std::vector data_list; - std::vector data_size_list; + std::vector data_size_list; + std::vector timestamp_list; + std::vector> sub_tree_token_list; + std::vector sub_tree_data_list; + std::vector sub_tree_data_size_list; std::istringstream iss(data); std::string line; + bool isMainTree = true; while (std::getline(iss, line)) { + if (!line.empty() && line[0] == '\t') { + isMainTree = false; + line.pop_back(); + continue; + } + LOG(INFO) << "data line:" << line << std::endl; std::istringstream lineStream(line); - std::string tokenListPart, dataPart; + std::string tokenListPart, timestampPart, dataPart; if (!std::getline(lineStream, tokenListPart, '|')) { throw std::runtime_error( - "Invalid serialized string format in key part."); + "Invalid serialized string format in token list part."); + } + if (isMainTree) { + if (!std::getline(lineStream, timestampPart, '|')) { + throw std::runtime_error( + "Invalid serialized string format in timestamp part."); + } } if (!std::getline(lineStream, dataPart)) { throw std::runtime_error( @@ -356,42 +472,55 @@ class RadixTree : public std::enable_shared_from_this { keys.push_back(std::stoi(token)); } - // size_t dataSize = dataPart.length() / 2; - size_t dataSize = dataPart.length(); - data_size_list.push_back(dataSize); + uint64_t timestamp; + if (isMainTree) { + std::istringstream timestampStream(timestampPart); + if (!(timestampStream >> std::hex >> timestamp)) { + LOG(INFO) << "Invalid timestamp format."; + throw std::runtime_error("Invalid timestamp format."); + } + } + + size_t dataSize = dataPart.length() / 2; // Each byte is represented by two hex characters + if (isMainTree) { + data_size_list.push_back(dataSize); + } else { + sub_tree_data_size_list.push_back(dataSize); + } // This pointer will be freed by upper layer. Because this data // is created by upper layer. Here just recover it from serialized // string. char* data = new char[dataSize]; std::istringstream dataStream(dataPart); - // for (size_t i = 0; i < dataSize; ++i) { - // // Temporary buffer to store two hexadecimal chars + null - // terminator char hex[3] = {}; - // // Read two characters for one byte - // if (!dataStream.read(hex, 2)) { - // delete[] data; - // LOG(INFO) << "Invalid data format."; - // throw std::runtime_error("Invalid data format."); - // } - // // Convert the two hex characters to one byte - // unsigned int byte; - // std::istringstream hexStream(hex); - // if (!(hexStream >> std::hex >> byte)) { - // delete[] data; - // LOG(INFO) << "Invalid data format."; - // throw std::runtime_error("Invalid data format."); - // } - // reinterpret_cast(data)[i] = static_cast(byte); - // } - if (!dataStream.read(data, dataSize)) { - delete[] data; - LOG(INFO) << "Invalid data."; + for (size_t i = 0; i < dataSize; ++i) { + // Temporary buffer to store two hexadecimal chars + null + char hex[3] = {}; + // Read two characters for one byte + if (!dataStream.read(hex, 2)) { + delete[] data; + LOG(INFO) << "Invalid data format."; + throw std::runtime_error("Invalid data format."); + } + // Convert the two hex characters to one byte + unsigned int byte; + std::istringstream hexStream(hex); + if (!(hexStream >> std::hex >> byte)) { + delete[] data; + LOG(INFO) << "Invalid data format."; + throw std::runtime_error("Invalid data format."); + } + reinterpret_cast(data)[i] = static_cast(byte); + } + if (isMainTree) { + token_list.push_back(keys); + timestamp_list.push_back(timestamp); + data_list.push_back(data); + } else { + sub_tree_token_list.push_back(keys); + sub_tree_data_list.push_back(data); } - - token_list.push_back(keys); - data_list.push_back(data); } // This pointer will be freed by upper layer. Because this data @@ -419,12 +548,58 @@ class RadixTree : public std::enable_shared_from_this { // } // dummy_data->cache_node = cache_node; // } - for (int i = token_list.size() - 1; i >= 0; i--) { - std::shared_ptr evicted_node; - std::shared_ptr node = - radix_tree->Insert(token_list[i], evicted_node); + + + for (int i = 0; i < token_list.size(); i++) { + int* insert_tokens_array = token_list[i].data(); + size_t insert_tokens_array_len = token_list[i].size(); + nodeData* data = new nodeData(); + raxNode* dataNode = NULL; + int retval = raxInsertAndReturnDataNode( + radix_tree->tree, insert_tokens_array, insert_tokens_array_len, data, + (void**) &dataNode, NULL); + if (dataNode == NULL) { + throw std::runtime_error("Insert token list failed"); + } + dataNode->timestamp = timestamp_list[i]; + std::shared_ptr node = std::make_shared(std::make_shared(dataNode), + radix_tree); node->get_node()->set_data(data_list[i], data_size_list[i]); } + LOG(INFO) << "start to insert sub tree token list" << std::endl; + for (int i = 0; i < sub_tree_token_list.size(); i++) { + for (int j = 0; j < sub_tree_token_list[i].size(); j++) { + LOG(INFO) << sub_tree_token_list[i][j]; + } + raxNode *parentlink = NULL; + raxNode *node = NULL; + raxNode *newNode = NULL; + raxFindNodeWithParent(radix_tree->tree, sub_tree_token_list[i].data(), + sub_tree_token_list[i].size(), (void **)&node,(void **)&parentlink); + if (node == NULL) { + throw std::runtime_error("Unable to find the root node of the sub tree"); + return NULL; + } + if (parentlink == NULL) { + parentlink = radix_tree->tree->head; + } + if (!raxIsSubtree(node) || !raxIsSubtreeAllocated(node)) { + newNode = raxReallocForSubtreeCustomData(node); + raxSetSubtree(node); + raxSetSubtreeAllocated(node); + } else { + newNode = node; + } + memcpy(parentlink,&newNode,sizeof(newNode)); + customData* data = new customData(); + data->data = sub_tree_data_list[i]; + data->data_length = sub_tree_data_size_list[i]; + if (raxIsSubtreeAllocated(newNode)) { + raxSetCustomData(newNode, data); + raxSetSubtreeNotNull(newNode); + } + } + LOG(INFO) << "Deserialize success"; return radix_tree; } @@ -435,11 +610,10 @@ class RadixTree : public std::enable_shared_from_this { // TBD // if the sub_tree is null, delete this pointer. - std::shared_ptr sub_tree = - std::make_shared(this->lru_strategy->GetCapacity()); - sub_tree->tree = this->tree; rax* sub_rax = raxNew(); sub_rax->head = sub_tree_root_node; + std::shared_ptr sub_tree = + std::make_shared(sub_rax, this->lru_strategy->GetCapacity()); return sub_tree; } @@ -452,21 +626,40 @@ class RadixTree : public std::enable_shared_from_this { return nodes; } - std::vector> dataNodeList; + std::vector dataNodeList; raxNode* headNode = radix_tree->tree->head; raxTraverseSubTree(headNode, dataNodeList); for (size_t i = 0; i < dataNodeList.size(); i++) { nodes.push_back(std::make_shared( - std::make_shared(dataNodeList[i].get()), radix_tree)); + std::make_shared(dataNodeList[i]), radix_tree)); } return nodes; } - void* GetCustomData() { return custom_data; } + rax* GetTree() {return this->tree;} + void* GetCustomData() { + if (!raxIsSubtreeAllocated(this->tree->head) || raxIsSubtreeCustomDataNull(this->tree->head)) { + throw std::runtime_error("Subtree is not allocated or custom data is null"); + return NULL; + } + customData *custome_data = (customData *)raxGetCustomData(this->tree->head); + if (custome_data == NULL) { + throw std::runtime_error("Custom data is null"); + return NULL; + } + return (void *)custome_data->data; + } void SetCustomData(void* custom_data, int custom_data_length) { - this->custom_data = custom_data; - this->custom_data_length = custom_data_length; + customData* data = new customData(); + data->data = custom_data; + data->data_length = custom_data_length; + if (raxIsSubtreeAllocated(this->tree->head)) { + raxSetCustomData(this->tree->head, data); + raxSetSubtreeNotNull(this->tree->head); + return; + } + throw std::runtime_error("The custome data of subtree is not allocated"); } }; diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 2a6eaf5f..d3366454 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -44,6 +44,15 @@ #include RAX_MALLOC_INCLUDE +#include +#include "common/util/logging.h" +using namespace vineyard; +typedef struct nodeData1 { + int data_length; + void* data; + void* cache_node; +} nodeData1; + /* This is a special pointer that is guaranteed to never have the same value * of a radix tree node. It's used in order to report "not found" error without * requiring the function to have multiple return values. */ @@ -202,6 +211,10 @@ raxNode *raxNewNode(size_t children, int datafield) { node->iskey = 0; node->isnull = 0; node->iscompr = 0; + node->issubtree = 0; + node->iscustomnull = 1; + node->iscustomallocated = 0; + node->timestamp = 0; node->numnodes = 1; node->size = children; return node; @@ -253,6 +266,45 @@ void *raxGetData(raxNode *n) { return data; } +/* +* Reallocate the node to make room for custom data +*/ +raxNode *raxReallocForSubtreeCustomData(raxNode *n) { + size_t curlen = raxNodeCurrentLength(n); + raxNode *newNode = (raxNode *)rax_realloc(n,curlen+sizeof(void*)); + if (newNode == NULL) { + printf("can't realloc new memory\n"); + return NULL; + } + newNode->iscustomallocated = 1; + return newNode; +} + +/* +* Set the custom data for the root of sub-tree +*/ +void raxSetCustomData(raxNode *n, void *data) { + // wait for the custom data to be allocated + if (n->iscustomallocated==0) { + return; + } + void **ndata = (void**) + ((char*)n+raxNodeCurrentLength(n)); + memcpy(ndata,&data,sizeof(data)); + n->iscustomnull = 0; +} + +/* +* Get the custom data for the root of sub-tree +*/ +void *raxGetCustomData(raxNode *n) { + if (n->iscustomallocated==0 || n->iscustomnull==1) return NULL; + void **ndata =(void**)((char*)n+raxNodeCurrentLength(n)); + void *data; + memcpy(&data,ndata,sizeof(data)); + return data; +} + /* Add a new child to the node 'n' representing the token 'c' and return * its new pointer, as well as the child pointer by reference. Additionally * '***parentlink' is populated with the raxNode pointer-to-pointer of where @@ -271,6 +323,14 @@ raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlin n->size--; /* For now restore the orignal size. We'll update it only on success at the end. */ + // store the extra data pointer of subtree + void *customData; + bool isSubtree = false; + if (n->issubtree) { + isSubtree = true; + customData = raxGetCustomData(n); + } + /* Alloc the new child we will link to 'n'. */ raxNode *child = raxNewNode(0,0); if (child == NULL) return NULL; @@ -389,6 +449,18 @@ raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlin n->data[pos] = c; n->numnodes = parent_numnodes + 1; n->size++; + n->issubtree = 0; + n->iscustomnull = 1; + n->iscustomallocated = 0; + if (isSubtree) { + n->issubtree = 1; + size_t curlen = raxNodeCurrentLength(n); + raxNode *newNode = (raxNode *)rax_realloc(n,curlen+sizeof(void*)); + n = newNode; + n->iscustomnull = 1; + n->iscustomallocated = 1; + raxSetCustomData(n, customData); + } src = (char*) raxNodeFirstChildPtr(n); raxNode **childfield = (raxNode**)(src+sizeof(raxNode*)*pos); memcpy(childfield,&child,sizeof(child)); @@ -567,6 +639,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o debugf("### Insert: node representing key exists\n"); /* Make space for the value pointer if needed. */ if (!h->iskey || (h->isnull && overwrite)) { + printf("#############raxReallocForData1 ############\n"); h = raxReallocForData(h,data); if (h) memcpy(parentlink,&h,sizeof(h)); } @@ -961,11 +1034,24 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o raxStackAddNumNodes(&splitStack, insert_new_node); raxStackFree(&lowWalkStack); raxStackFree(&splitStack); + void *customData; + bool isSubtree = false; + if (h->issubtree && h->iscustomallocated && !h->iscustomnull) { + isSubtree = true; + customData = raxGetCustomData(h); + } raxNode *newh = raxReallocForData(h,data); + printf("#############raxReallocForData2 ############\n"); if (newh == NULL) { return handleOutOfMemory(rax, h, (int *)s, i, old); } h = newh; + if (isSubtree) { + raxNode *newNode = raxReallocForSubtreeCustomData(h); + newNode->issubtree = 1; + raxSetCustomData(newNode, customData); + h = newNode; + } if (!h->iskey) rax->numele++; raxSetData(h,data); memcpy(parentlink,&h,sizeof(h)); @@ -1009,7 +1095,8 @@ void *raxFind(rax *rax, int *s, size_t len) { return raxGetData(h); } -/* Find a key in the rax, returns the stack +/* +** Find a key in the rax, returns the stack */ raxStack raxFindWithStack(rax *rax, int *s, size_t len) { raxNode *h; @@ -1025,7 +1112,7 @@ raxStack raxFindWithStack(rax *rax, int *s, size_t len) { } /* -Find a key in the rax, returns the raxNode that contains the key. +** Find a key in the rax, returns the raxNode that contains the key. */ raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len) { raxNode *h; @@ -1038,6 +1125,22 @@ raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len) { return h; } +/* +** Find a key in the rax, returns the current node and its parent link. +*/ +int raxFindNodeWithParent(rax *rax, int *s, size_t len, void **node, void **parent) { + raxNode *h; + + int splitpos = 0; + raxNode **parentlink; + size_t i = raxLowWalk(rax,s,len,&h,&parentlink,&splitpos,NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return 0; + *parent = (void *)parentlink; + *node = (void *)h; + return 1; +} + /* Return the memory address where the 'parent' node stores the specified * 'child' pointer, so that the caller can update the pointer with another * one if needed. The function assumes it will find a match, otherwise the @@ -1074,6 +1177,19 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { return parent; } + /* + * + * 0. Before remove the child, we need to store the custom + * data if the current node is the root node of subtree + * + */ + void *customData; + bool isSubtree = false; + if (parent->issubtree) { + isSubtree = true; + customData = raxGetCustomData(parent); + } + /* Otherwise we need to scan for the child pointer and memmove() * accordingly. * @@ -1119,7 +1235,19 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { /* realloc the node according to the theoretical memory usage, to free * data if we are over-allocating right now. */ - raxNode *newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); + raxNode *newnode; + if (isSubtree) { + newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)+sizeof(void*)); + newnode->iscustomnull = 1; + newnode->iscustomallocated = 1; + newnode->issubtree = 1; + raxSetCustomData(newnode, customData); + } else { + newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); + newnode->iscustomnull = 1; + newnode->iscustomallocated = 0; + newnode->issubtree = 0; + } if (newnode) { debugnode("raxRemoveChild after", newnode); } @@ -1343,7 +1471,7 @@ int raxRemove(rax *rax, int *s, size_t len, void **old) { /* This is the core of raxFree(): performs a depth-first scan of the * tree and releases all the nodes found. */ -void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void*)) { +void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(raxNode *)) { debugnode("free traversing",n); int numchildren = n->iscompr ? 1 : n->size; raxNode **cp = raxNodeLastChildPtr(n); @@ -1354,15 +1482,17 @@ void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void*)) { cp--; } debugnode("free depth-first",n); - if (free_callback && n->iskey && !n->isnull) - free_callback(raxGetData(n)); + // if n is a key node, we need to free the data + // if n is a subtree, we need to free the custom data + if (free_callback && ((n->iskey && !n->isnull) || (n->issubtree && n->iscustomallocated && !n->iscustomnull))) + free_callback(n); rax_free(n); rax->numnodes--; } /* Free a whole radix tree, calling the specified callback in order to * free the auxiliary data. */ -void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)) { +void raxFreeWithCallback(rax *rax, void (*free_callback)(raxNode *)) { raxRecursiveFree(rax,rax->head,free_callback); assert(rax->numnodes == 0); rax_free(rax); @@ -1383,6 +1513,9 @@ void raxStart(raxIterator *it, rax *rt) { it->rt = rt; it->key_len = 0; it->key = it->key_static_tokens; + it->add_to_subtree_list = false; + it->subtree_list = NULL; + it->subtree_data_list = NULL; it->key_max = RAX_ITER_STATIC_LEN; it->data = NULL; it->node_cb = NULL; @@ -1461,6 +1594,20 @@ int raxIteratorNextStep(raxIterator *it, int noup) { raxNode **cp = raxNodeFirstChildPtr(it->node); if (!raxIteratorAddToken(it,it->node->data, it->node->iscompr ? it->node->size : 1)) return 0; + if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + it->subtree_data_list != NULL) { + std::cout << "first find subtree list is:" << std::endl; + std::vector token; + for (int i = 0; i < it->key_len; i++) { + token.push_back(it->key[i]); + } + (*it->subtree_list).push_back(token); + void *data = raxGetCustomData(it->node); + if (data == NULL) { + throw std::runtime_error("custom data is null"); + } + (*it->subtree_data_list).push_back(data); + } memcpy(&it->node,cp,sizeof(it->node)); /* Call the node callback if any, and replace the node pointer * if the callback returns true. */ @@ -1517,6 +1664,20 @@ int raxIteratorNextStep(raxIterator *it, int noup) { debugf("SCAN found a new node\n"); raxIteratorAddToken(it,it->node->data+i,1); if (!raxStackPush(&it->stack,it->node)) return 0; + if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + it->subtree_data_list != NULL) { + std::cout << "second find subtree list is:" << std::endl; + std::vector token; + for (int i = 0; i < it->key_len; i++) { + token.push_back(it->key[i]); + } + (*it->subtree_list).push_back(token); + void *data = raxGetCustomData(it->node); + if (data == NULL) { + throw std::runtime_error("custom data is null"); + } + (*it->subtree_data_list).push_back(data); + } memcpy(&it->node,cp,sizeof(it->node)); /* Call the node callback if any, and replace the node * pointer if the callback returns true. */ @@ -1987,6 +2148,10 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { } numchars += printf("%c %d ", e, n->numnodes); + if (n->issubtree) { + numchars += printf("# "); + printf(" %p ", n); + } if (n->iskey) { numchars += printf("=%p",raxGetData(n)); } @@ -2103,6 +2268,57 @@ void raxTraverse(raxNode *n, std::vector> &dataNodeList } } +/* +* Set a node as a subtree root node +*/ +void raxSetSubtree(raxNode *node) { + node->issubtree = 1; +} + +/* +* Set the subtree root node as allocated custom data +*/ +void raxSetSubtreeAllocated(raxNode *node) { + node->iscustomallocated = 1; +} + +/* +* Set the subtree root node as null custom data +*/ +void raxSetSubtreeNotNull(raxNode *node) { + node->iscustomnull = 0; +} + +/* +* Check if a node is a subtree root node +*/ +bool raxIsSubtree(raxNode *node) { + if (node->issubtree) { + return true; + } + return false; +} + +/* +* Check if the subtree has been allocated custom data +*/ +bool raxIsSubtreeAllocated(raxNode *node) { + if (node->iscustomallocated) { + return true; + } + return false; +} + +/* +* Check if the custom data of the subtree is null +*/ +bool raxIsSubtreeCustomDataNull(raxNode *node) { + if (node->iscustomnull) { + return true; + } + return false; +} + /* * Split the tree into two sub trees, and return the root node of the new sub tree * @@ -2112,18 +2328,28 @@ void raxTraverse(raxNode *n, std::vector> &dataNodeList * tree from the root node. * */ -raxNode *raxSplit(rax *rax, int *s, size_t len, void *data){ - int retval = raxInsert(rax, s, len, data, NULL); - if (retval == 0 && errno != 0) { - return NULL; - } +raxNode *raxSplit(rax *rax, int *s, size_t len, void *data) { raxNode *childNode = NULL; raxNode *splitNode = NULL; raxStack stack = raxFindWithStack(rax, s, len); int items = stack.items; + int subtreeNumNodes = 0; + // find the latest subtree root node + int index = stack.items - 1; + while (index >= 0) { + raxNode *node = (raxNode *)stack.stack[index]; + if (node->issubtree) { + subtreeNumNodes = node->numnodes; + break; + } + index--; + } + + + // find the node that has N/2 children while (items > 0) { raxNode *node = (raxNode *)raxStackPop(&stack); - if (node->numnodes >= (uint32_t)RAX_NODE_MAX_SIZE/2 || node->issubtree) { + if (node->numnodes >= (uint32_t)subtreeNumNodes/2 || node->issubtree) { splitNode = childNode; raxStackPush(&stack, node); break; @@ -2136,21 +2362,31 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, void *data){ return rax->head; } - raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); - raxStackFree(&stack); + raxNode *parent = (raxNode *)raxStackPeek(&stack); + raxNode **parentlink; + if (parent == NULL) { + parentlink = &rax->head; + } else { + parentlink = raxFindParentLink(parent,splitNode); + } + raxNode *newNode = raxReallocForSubtreeCustomData(splitNode); + raxSetSubtree(newNode); + memcpy(parentlink,&newNode,sizeof(newNode)); - splitNode->issubtree = 1; + raxStackAddNumNodes(&stack, -(int)(newNode->numnodes)); + raxStackFree(&stack); - return splitNode; + return newNode; } /* * Traverse the subtree and return all the nodes that contain data under the subtree * these nodes are stored in the dataNodeList */ -void raxTraverseSubTree(raxNode *n, std::vector> &dataNodeList) { + +void raxTraverseSubTree(raxNode *n, std::vector &dataNodeList) { if (n->iskey) { - dataNodeList.push_back(std::shared_ptr(n, [](raxNode*){})); + dataNodeList.push_back(n); } int numchildren = n->iscompr ? 1 : n->size; @@ -2165,9 +2401,13 @@ void raxTraverseSubTree(raxNode *n, std::vector> &dataN } } -void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList) { +void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList, std::vector ×tampList, + std::vector> *subtreeList, std::vector *subtreeDataList) { raxIterator iter; raxStart(&iter, root); + iter.add_to_subtree_list = 1; + iter.subtree_list = subtreeList; + iter.subtree_data_list = subtreeDataList; raxSeek(&iter, "^", NULL, 0); while (raxNext(&iter)) { std::vector token; @@ -2176,6 +2416,7 @@ void raxSerialize(rax *root, std::vector> &tokenList, std::vect } tokenList.push_back(token); dataList.push_back(iter.data); + timestampList.push_back(iter.node->timestamp); } raxStop(&iter); } diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h index 81682912..790ca134 100644 --- a/modules/kv-state-cache/radix-tree/radix.h +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -31,14 +31,14 @@ #ifndef RADIX_H #define RADIX_H -#include #include +#include #include #include /* Representation of a radix tree as implemented in this file, that contains - * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after - * the insertion of each token list. When the node represents a key inside + * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after + * the insertion of each token list. When the node represents a key inside * the radix tree, we write it between [], otherwise it is written between (). * * This is the vanilla representation: @@ -88,7 +88,7 @@ * [1,2,3,4] ([5,6]) ([7,8]) [1,2,3,6] * / \ * [1,2,3,4,5,6] [] [] [1,2,3,6,7,8] - * + * * * Similarly after deletion, if a new chain of nodes having a single child * is created (the chain must also not include nodes that represent keys), @@ -98,46 +98,50 @@ #define RAX_NODE_MAX_SIZE 1024 typedef struct raxNode { - uint32_t iskey:1; /* Does this node contain a key? */ - uint32_t isnull:1; /* Associated value is NULL (don't store it). */ - uint32_t iscompr:1; /* Node is compressed. */ - uint32_t issubtree:1; /* Node is the root node of a sub tree */ - uint32_t size:28; /* Number of children, or compressed string len. */ - uint32_t numnodes; /* Number of the child nodes */ - /* Data layout is as follows: - * - * If node is not compressed we have 'size' bytes, one for each children - * token, and 'size' raxNode pointers, point to each child node. - * Note how the character is not stored in the children but in the - * edge of the parents: - * - * [header iscompr=0][1,2,3][1-ptr][2-ptr][3-ptr](value-ptr?) - * - * if node is compressed (iscompr bit is 1) the node has 1 children. - * In that case the 'size' bytes of the string stored immediately at - * the start of the data section, represent a sequence of successive - * nodes linked one after the other, for which only the last one in - * the sequence is actually represented as a node, and pointed to by - * the current compressed node. - * - * [header iscompr=1][1,2,3][3-ptr](value-ptr?) - * - * Both compressed and not compressed nodes can represent a key - * with associated data in the radix tree at any level (not just terminal - * nodes). - * - * If the node has an associated key (iskey=1) and is not NULL - * (isnull=0), then after the raxNode pointers poiting to the - * children, an additional value pointer is present (as you can see - * in the representation above as "value-ptr" field). - */ - int data[]; + uint32_t iskey : 1; /* Does this node contain a key? */ + uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ + uint32_t iscustomnull : 1; /* Associated custome value is NULL */ + uint32_t iscustomallocated : 1; /* The memory to store custom value is allocated */ + uint32_t iscompr : 1; /* Node is compressed. */ + uint32_t issubtree : 1; /* Node is the root node of a sub tree */ + uint32_t size : 26; /* Number of children, or compressed string len. */ + uint32_t numnodes; /* Number of the child nodes */ + uint32_t numele; /* Number of elements inside this node. */ + uint64_t timestamp; /* Timestamps of the node */ + /* Data layout is as follows: + * + * If node is not compressed we have 'size' bytes, one for each children + * token, and 'size' raxNode pointers, point to each child node. + * Note how the character is not stored in the children but in the + * edge of the parents: + * + * [header iscompr=0][1,2,3][1-ptr][2-ptr][3-ptr](value-ptr?) + * + * if node is compressed (iscompr bit is 1) the node has 1 children. + * In that case the 'size' bytes of the string stored immediately at + * the start of the data section, represent a sequence of successive + * nodes linked one after the other, for which only the last one in + * the sequence is actually represented as a node, and pointed to by + * the current compressed node. + * + * [header iscompr=1][1,2,3][3-ptr](value-ptr?) + * + * Both compressed and not compressed nodes can represent a key + * with associated data in the radix tree at any level (not just terminal + * nodes). + * + * If the node has an associated key (iskey=1) and is not NULL + * (isnull=0), then after the raxNode pointers poiting to the + * children, an additional value pointer is present (as you can see + * in the representation above as "value-ptr" field). + */ + int data[]; } raxNode; typedef struct rax { - raxNode *head; - uint64_t numele; - uint64_t numnodes; + raxNode* head; + uint64_t numele; + uint64_t numnodes; } rax; /* Stack data structure used by raxLowWalk() in order to, optionally, return @@ -145,12 +149,12 @@ typedef struct rax { * field for space concerns, so we use the auxiliary stack when needed. */ #define RAX_STACK_STATIC_ITEMS 32 typedef struct raxStack { - void **stack; /* Points to static_items or an heap allocated array. */ - size_t items, maxitems; /* Number of items contained and total space. */ - /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap - * and use this static array of pointers instead. */ - void *static_items[RAX_STACK_STATIC_ITEMS]; - int oom; /* True if pushing into this stack failed for OOM at some point. */ + void** stack; /* Points to static_items or an heap allocated array. */ + size_t items, maxitems; /* Number of items contained and total space. */ + /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap + * and use this static array of pointers instead. */ + void* static_items[RAX_STACK_STATIC_ITEMS]; + int oom; /* True if pushing into this stack failed for OOM at some point. */ } raxStack; /* Optional callback used for iterators and be notified on each rax node, @@ -166,62 +170,80 @@ typedef struct raxStack { * Redis application for this callback). * * This is currently only supported in forward iterations (raxNext) */ -typedef int (*raxNodeCallback)(raxNode **noderef); +typedef int (*raxNodeCallback)(raxNode** noderef); /* Radix tree iterator state is encapsulated into this data structure. */ #define RAX_ITER_STATIC_LEN 128 -#define RAX_ITER_JUST_SEEKED (1<<0) /* Iterator was just seeked. Return current - element for the first iteration and - clear the flag. */ -#define RAX_ITER_EOF (1<<1) /* End of iteration reached. */ -#define RAX_ITER_SAFE (1<<2) /* Safe iterator, allows operations while - iterating. But it is slower. */ +#define RAX_ITER_JUST_SEEKED \ + (1 << 0) /* Iterator was just seeked. Return current \ + element for the first iteration and \ + clear the flag. */ +#define RAX_ITER_EOF (1 << 1) /* End of iteration reached. */ +#define RAX_ITER_SAFE \ + (1 << 2) /* Safe iterator, allows operations while \ + iterating. But it is slower. */ typedef struct raxIterator { - int flags; - rax *rt; /* Radix tree we are iterating. */ - int *key; /* The current string. */ - void *data; /* Data associated to this key. */ - size_t key_len; /* Current key length. */ - size_t key_max; /* Max key len the current key buffer can hold. */ - int key_static_tokens[RAX_ITER_STATIC_LEN]; - raxNode *node; /* Current node. Only for unsafe iteration. */ - raxStack stack; /* Stack used for unsafe iteration. */ - raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ + int flags; + rax* rt; /* Radix tree we are iterating. */ + int* key; /* The current string. */ + void* data; /* Data associated to this key. */ + size_t key_len; /* Current key length. */ + size_t key_max; /* Max key len the current key buffer can hold. */ + int key_static_tokens[RAX_ITER_STATIC_LEN]; + bool add_to_subtree_list; /* Whether to add the current node to the subtree list. */ + std::vector> *subtree_list; /* List of subtrees. */ + std::vector *subtree_data_list; /* List of subtrees' data. */ + raxNode* node; /* Current node. Only for unsafe iteration. */ + raxStack stack; /* Stack used for unsafe iteration. */ + raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ } raxIterator; /* A special pointer returned for not found items. */ -extern void *raxNotFound; +extern void* raxNotFound; /* Exported API. */ -rax *raxNew(void); -int raxInsert(rax *rax, int *s, size_t len, void *data, void **old); -int raxTryInsert(rax *rax, int *s, size_t len, void *data, void **old); -int raxInsertAndReturnDataNode(rax *rax, int *s, size_t len, void *data, void **node, void **old); -int raxRemove(rax *rax, int *s, size_t len, void **old); -void *raxFind(rax *rax, int *s, size_t len); -raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len); -void raxFree(rax *rax); -void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)); -void raxStart(raxIterator *it, rax *rt); -int raxSeek(raxIterator *it, const char *op, int *ele, size_t len); -int raxNext(raxIterator *it); -int raxPrev(raxIterator *it); -int raxRandomWalk(raxIterator *it, size_t steps); -int raxCompare(raxIterator *iter, const char *op, int *key, size_t key_len); -void raxStop(raxIterator *it); -int raxEOF(raxIterator *it); -void raxShow(rax *rax); -uint64_t raxSize(rax *rax); -unsigned long raxTouch(raxNode *n); +rax* raxNew(void); +int raxInsert(rax* rax, int* s, size_t len, void* data, void** old); +int raxTryInsert(rax* rax, int* s, size_t len, void* data, void** old); +int raxInsertAndReturnDataNode(rax* rax, int* s, size_t len, void* data, + void** node, void** old); +int raxRemove(rax* rax, int* s, size_t len, void** old); +void* raxFind(rax* rax, int* s, size_t len); +raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len); +void raxSetSubtree(raxNode *n); +void raxSetSubtreeAllocated(raxNode *node); +void raxSetSubtreeNotNull(raxNode *node); +bool raxIsSubtree(raxNode *node); +bool raxIsSubtreeAllocated(raxNode *node); +bool raxIsSubtreeCustomDataNull(raxNode *node); +int raxFindNodeWithParent(rax *rax, int *s, size_t len, void **node, void **parent); +void raxFree(rax* rax); +void raxFreeWithCallback(rax *rax, void (*free_callback)(raxNode *)); +void raxStart(raxIterator* it, rax* rt); +int raxSeek(raxIterator* it, const char* op, int* ele, size_t len); +int raxNext(raxIterator* it); +int raxPrev(raxIterator* it); +int raxRandomWalk(raxIterator* it, size_t steps); +int raxCompare(raxIterator* iter, const char* op, int* key, size_t key_len); +void raxStop(raxIterator* it); +int raxEOF(raxIterator* it); +void raxShow(rax* rax); +uint64_t raxSize(rax* rax); +raxNode *raxReallocForSubtreeCustomData(raxNode *n); +void raxSetCustomData(raxNode *n, void *data); +void *raxGetCustomData(raxNode *n); +unsigned long raxTouch(raxNode* n); void raxSetDebugMsg(int onoff); -void raxTraverse(raxNode *rax, std::vector> &dataNodeList); -void raxTraverseSubTree(raxNode *n, std::vector> &dataNodeList); +void raxTraverse(raxNode* rax, + std::vector>& dataNodeList); +void raxTraverseSubTree(raxNode* n, std::vector &dataNodeList); raxNode *raxSplit(rax *rax, int *s, size_t len, void *data); -void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList); +void raxSerialize(rax* root, std::vector>& tokenList, std::vector& dataList, std::vector ×tampsList, + std::vector> *subtreeList, std::vector *subtreeNodeList); /* Internal API. May be used by the node callback in order to access rax nodes * in a low level way, so this function is exported as well. */ -void raxSetData(raxNode *n, void *data); -void *raxGetData(raxNode *n); +void raxSetData(raxNode* n, void* data); +void* raxGetData(raxNode* n); #endif diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index ddefce59..c43fc0db 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -271,4 +271,4 @@ void threadFunc() { 0x5654cf368c60 0x5654cf368f40 0x7e6fec001100 0x7e6fec009ea0 -*/ \ No newline at end of file +*/ From 79f68de23e410d5a2e3297091aef63db6d3318aa Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Fri, 26 Jan 2024 18:55:10 +0800 Subject: [PATCH 03/20] Refactor code of radix split subtree and fix bug of radix split. (#1735) Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 6 +- .../kv-state-cache/ds/kv_state_cache_block.h | 2 +- .../kv-state-cache/radix-tree/radix-tree.h | 108 ++++++--------- modules/kv-state-cache/radix-tree/radix.cc | 128 ++++-------------- modules/kv-state-cache/radix-tree/radix.h | 8 +- test/kv_state_cache_test.cc | 1 - 6 files changed, 74 insertions(+), 179 deletions(-) diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index af0e4522..6d18fb10 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -64,7 +64,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, this->root_tree = std::make_shared(cache_capacity); this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), - sizeof(KVStateCacheBlockBuilder)); + sizeof(this->kv_state_cache_block_builder.get())); } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, @@ -78,7 +78,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, this->root_tree = cache->GetRootTree(); this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), - sizeof(KVStateCacheBlockBuilder)); + sizeof(this->kv_state_cache_block_builder.get())); } KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( @@ -181,7 +181,7 @@ void KVStateCacheBuilder::Update(Client& client, KVStateCacheBlockBuilder* new_kv_state_cache_block_builder = Split(client, kv_state_cache_block_builder, node_with_tree_attri_list); new_tree->SetCustomData(new_kv_state_cache_block_builder, - sizeof(KVStateCacheBlockBuilder)); + sizeof(new_kv_state_cache_block_builder)); LOG(INFO) << "block split success"; // kv_state_cache_builder->UnLock(); diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index 7758ee19..e6e172fd 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -44,7 +44,7 @@ struct offset_data { namespace vineyard { -#define LIST_SIZE 64 +#define LIST_SIZE 5 /** * @brief KVStateCacheBlock is a cache for kv-cache of LLM. When a new prompt diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index f418cfb3..0c60eb47 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -116,13 +116,7 @@ class RadixTree : public std::enable_shared_from_this { RadixTree(int cache_capacity) { LOG(INFO) << "init radix tree"; this->tree = raxNew(); - if (!raxIsSubtreeAllocated(this->tree->head)) { - raxNode *new_root = raxReallocForSubtreeCustomData(this->tree->head); - this->tree->head = new_root; - raxSetSubtree(this->tree->head); - raxSetSubtreeAllocated(this->tree->head); - } - // this->sub_tree = this->tree; + this->tree->head->issubtree = true; lru_strategy = new LRUStrategy(cache_capacity); } @@ -130,33 +124,32 @@ class RadixTree : public std::enable_shared_from_this { LOG(INFO) << "init radix tree"; this->tree = rax_tree; // this->sub_tree = this->tree; + this->tree->head->issubtree = true; lru_strategy = new LRUStrategy(cache_capacity); } RadixTree(void* custom_data, int custom_data_length, int cache_capacity) { LOG(INFO) << "init radix tree with custom data"; this->tree = raxNew(); - if (!raxIsSubtreeAllocated(this->tree->head)) { - raxNode *new_root = raxReallocForSubtreeCustomData(this->tree->head); - this->tree->head = new_root; - raxSetSubtree(this->tree->head); - raxSetSubtreeAllocated(this->tree->head); - } - // this->sub_tree = this->tree; + this->tree->head->issubtree = true; + customData* custom_data_struct = new customData(); + custom_data_struct->data = custom_data; + custom_data_struct->data_length = custom_data_length; + raxSetCustomData(this->tree->head, custom_data_struct); this->lru_strategy = new LRUStrategy(cache_capacity); } ~RadixTree() { - raxFreeWithCallback(this->tree, [](raxNode *n) { - if (n->iskey && !n->isnull) { - nodeData* nodedata = (nodeData*) raxGetData(n); - delete nodedata; - } - if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { - customData* customdata = (customData*) raxGetCustomData(n); - delete customdata; - } - }); + // raxFreeWithCallback(this->tree, [](raxNode *n) { + // if (n->iskey && !n->isnull) { + // nodeData* nodedata = (nodeData*) raxGetData(n); + // delete nodedata; + // } + // if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { + // customData* customdata = (customData*) raxGetCustomData(n); + // delete customdata; + // } + // }); } std::shared_ptr Insert( @@ -257,6 +250,7 @@ class RadixTree : public std::enable_shared_from_this { std::string Serialize() { LOG(INFO) << "Serialize......"; + raxShow(this->tree); std::vector> token_list; std::vector data_list; std::vector timestamp_list; @@ -350,6 +344,7 @@ class RadixTree : public std::enable_shared_from_this { serialized_str += "\t\n"; + LOG(INFO) << "sub tree token list size:" << sub_tree_token_list.size(); for (size_t index = 0; index < sub_tree_token_list.size(); index++) { for (size_t j = 0; j < sub_tree_token_list[index].size(); j++) { serialized_str += std::to_string(sub_tree_token_list[index][j]); @@ -362,11 +357,15 @@ class RadixTree : public std::enable_shared_from_this { char* bytes = (char*) ((customData*) sub_tree_data_list[index])->data; std::ostringstream data_oss; + LOG(INFO) << "data length:" << ((customData*)sub_tree_data_list[index])->data_length; for (size_t i = 0; i < ((customData*)sub_tree_data_list[index])->data_length; ++i) { data_oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); } + LOG(INFO) << "data:" << ((customData*)sub_tree_data_list[index])->data; + LOG(INFO) << "data oss:" << data_oss.str(); serialized_str += data_oss.str() + "\n"; } + LOG(INFO) << "serialized_str:" << serialized_str; // use LZ4 to compress the serialized string const char* const src = serialized_str.c_str(); @@ -492,6 +491,7 @@ class RadixTree : public std::enable_shared_from_this { // is created by upper layer. Here just recover it from serialized // string. char* data = new char[dataSize]; + LOG(INFO) << "data size:" << dataSize; std::istringstream dataStream(dataPart); for (size_t i = 0; i < dataSize; ++i) { // Temporary buffer to store two hexadecimal chars + null @@ -571,33 +571,21 @@ class RadixTree : public std::enable_shared_from_this { for (int j = 0; j < sub_tree_token_list[i].size(); j++) { LOG(INFO) << sub_tree_token_list[i][j]; } - raxNode *parentlink = NULL; - raxNode *node = NULL; - raxNode *newNode = NULL; - raxFindNodeWithParent(radix_tree->tree, sub_tree_token_list[i].data(), - sub_tree_token_list[i].size(), (void **)&node,(void **)&parentlink); - if (node == NULL) { - throw std::runtime_error("Unable to find the root node of the sub tree"); - return NULL; - } - if (parentlink == NULL) { - parentlink = radix_tree->tree->head; - } - if (!raxIsSubtree(node) || !raxIsSubtreeAllocated(node)) { - newNode = raxReallocForSubtreeCustomData(node); - raxSetSubtree(node); - raxSetSubtreeAllocated(node); - } else { - newNode = node; - } - memcpy(parentlink,&newNode,sizeof(newNode)); + + raxNode* node = nullptr; + LOG(INFO) << "stage 1"; + VINEYARD_ASSERT(radix_tree->tree != nullptr); + raxFindNode(radix_tree->tree, sub_tree_token_list[i].data(), + sub_tree_token_list[i].size(), (void **)&node); + VINEYARD_ASSERT(node != nullptr); + LOG(INFO) << "stage 2"; customData* data = new customData(); data->data = sub_tree_data_list[i]; data->data_length = sub_tree_data_size_list[i]; - if (raxIsSubtreeAllocated(newNode)) { - raxSetCustomData(newNode, data); - raxSetSubtreeNotNull(newNode); - } + + LOG(INFO) << "stage 3"; + node->issubtree = true; + raxSetCustomData(node, data); } LOG(INFO) << "Deserialize success"; return radix_tree; @@ -638,28 +626,20 @@ class RadixTree : public std::enable_shared_from_this { rax* GetTree() {return this->tree;} void* GetCustomData() { - if (!raxIsSubtreeAllocated(this->tree->head) || raxIsSubtreeCustomDataNull(this->tree->head)) { - throw std::runtime_error("Subtree is not allocated or custom data is null"); - return NULL; - } - customData *custome_data = (customData *)raxGetCustomData(this->tree->head); - if (custome_data == NULL) { - throw std::runtime_error("Custom data is null"); - return NULL; - } - return (void *)custome_data->data; + LOG(INFO) << "tree:" << this->tree << " tree node:" << this->tree->head; + VINEYARD_ASSERT(tree->head->custom_data != nullptr); + LOG(INFO) << "custom data:" << ((customData *)tree->head->custom_data)->data; + return ((customData *)tree->head->custom_data)->data; } void SetCustomData(void* custom_data, int custom_data_length) { customData* data = new customData(); data->data = custom_data; + LOG(INFO) << "custom data:" << data->data; data->data_length = custom_data_length; - if (raxIsSubtreeAllocated(this->tree->head)) { - raxSetCustomData(this->tree->head, data); - raxSetSubtreeNotNull(this->tree->head); - return; - } - throw std::runtime_error("The custome data of subtree is not allocated"); + LOG(INFO) << "custom data length:" << data->data_length; + LOG(INFO) << "tree:" << this->tree << " tree node:" << this->tree->head; + raxSetCustomData(this->tree->head, data); } }; diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index d3366454..87067fbc 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -212,8 +212,7 @@ raxNode *raxNewNode(size_t children, int datafield) { node->isnull = 0; node->iscompr = 0; node->issubtree = 0; - node->iscustomnull = 1; - node->iscustomallocated = 0; + node->custom_data = nullptr; node->timestamp = 0; node->numnodes = 1; node->size = children; @@ -266,43 +265,19 @@ void *raxGetData(raxNode *n) { return data; } -/* -* Reallocate the node to make room for custom data -*/ -raxNode *raxReallocForSubtreeCustomData(raxNode *n) { - size_t curlen = raxNodeCurrentLength(n); - raxNode *newNode = (raxNode *)rax_realloc(n,curlen+sizeof(void*)); - if (newNode == NULL) { - printf("can't realloc new memory\n"); - return NULL; - } - newNode->iscustomallocated = 1; - return newNode; -} /* * Set the custom data for the root of sub-tree */ void raxSetCustomData(raxNode *n, void *data) { - // wait for the custom data to be allocated - if (n->iscustomallocated==0) { - return; - } - void **ndata = (void**) - ((char*)n+raxNodeCurrentLength(n)); - memcpy(ndata,&data,sizeof(data)); - n->iscustomnull = 0; + n->custom_data = data; } /* * Get the custom data for the root of sub-tree */ void *raxGetCustomData(raxNode *n) { - if (n->iscustomallocated==0 || n->iscustomnull==1) return NULL; - void **ndata =(void**)((char*)n+raxNodeCurrentLength(n)); - void *data; - memcpy(&data,ndata,sizeof(data)); - return data; + return n->custom_data; } /* Add a new child to the node 'n' representing the token 'c' and return @@ -449,18 +424,6 @@ raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlin n->data[pos] = c; n->numnodes = parent_numnodes + 1; n->size++; - n->issubtree = 0; - n->iscustomnull = 1; - n->iscustomallocated = 0; - if (isSubtree) { - n->issubtree = 1; - size_t curlen = raxNodeCurrentLength(n); - raxNode *newNode = (raxNode *)rax_realloc(n,curlen+sizeof(void*)); - n = newNode; - n->iscustomnull = 1; - n->iscustomallocated = 1; - raxSetCustomData(n, customData); - } src = (char*) raxNodeFirstChildPtr(n); raxNode **childfield = (raxNode**)(src+sizeof(raxNode*)*pos); memcpy(childfield,&child,sizeof(child)); @@ -984,6 +947,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o return 1; /* Key inserted. */ } + LOG(INFO) << "custom2:" << h->custom_data; raxNode *prev_node = NULL; int insert_new_node = 0; /* We walked the radix tree as far as we could, but still there are left @@ -1034,24 +998,12 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o raxStackAddNumNodes(&splitStack, insert_new_node); raxStackFree(&lowWalkStack); raxStackFree(&splitStack); - void *customData; - bool isSubtree = false; - if (h->issubtree && h->iscustomallocated && !h->iscustomnull) { - isSubtree = true; - customData = raxGetCustomData(h); - } raxNode *newh = raxReallocForData(h,data); printf("#############raxReallocForData2 ############\n"); if (newh == NULL) { return handleOutOfMemory(rax, h, (int *)s, i, old); } h = newh; - if (isSubtree) { - raxNode *newNode = raxReallocForSubtreeCustomData(h); - newNode->issubtree = 1; - raxSetCustomData(newNode, customData); - h = newNode; - } if (!h->iskey) rax->numele++; raxSetData(h,data); memcpy(parentlink,&h,sizeof(h)); @@ -1125,6 +1077,20 @@ raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len) { return h; } +int raxFindNode(rax *rax, int *s, size_t len, void **node) { + raxNode *h; + raxStack ts; + raxStackInit(&ts); + + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len, &h, nullptr, nullptr, &ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return 0; + *node = raxStackPeek(&ts); + + return 1; +} + /* ** Find a key in the rax, returns the current node and its parent link. */ @@ -1236,18 +1202,7 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { /* realloc the node according to the theoretical memory usage, to free * data if we are over-allocating right now. */ raxNode *newnode; - if (isSubtree) { - newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)+sizeof(void*)); - newnode->iscustomnull = 1; - newnode->iscustomallocated = 1; - newnode->issubtree = 1; - raxSetCustomData(newnode, customData); - } else { - newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); - newnode->iscustomnull = 1; - newnode->iscustomallocated = 0; - newnode->issubtree = 0; - } + newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); if (newnode) { debugnode("raxRemoveChild after", newnode); } @@ -1484,7 +1439,7 @@ void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(raxNode *)) { debugnode("free depth-first",n); // if n is a key node, we need to free the data // if n is a subtree, we need to free the custom data - if (free_callback && ((n->iskey && !n->isnull) || (n->issubtree && n->iscustomallocated && !n->iscustomnull))) + if (free_callback && ((n->iskey && !n->isnull))) free_callback(n); rax_free(n); rax->numnodes--; @@ -2155,6 +2110,7 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { if (n->iskey) { numchars += printf("=%p",raxGetData(n)); } + numchars += printf(" time:%ld, data:%p, is_sub_tree:%d", n->timestamp, n->custom_data, n->issubtree); int numchildren = n->iscompr ? 1 : n->size; /* Note that 7 and 4 magic constants are the string length @@ -2275,20 +2231,6 @@ void raxSetSubtree(raxNode *node) { node->issubtree = 1; } -/* -* Set the subtree root node as allocated custom data -*/ -void raxSetSubtreeAllocated(raxNode *node) { - node->iscustomallocated = 1; -} - -/* -* Set the subtree root node as null custom data -*/ -void raxSetSubtreeNotNull(raxNode *node) { - node->iscustomnull = 0; -} - /* * Check if a node is a subtree root node */ @@ -2299,26 +2241,6 @@ bool raxIsSubtree(raxNode *node) { return false; } -/* -* Check if the subtree has been allocated custom data -*/ -bool raxIsSubtreeAllocated(raxNode *node) { - if (node->iscustomallocated) { - return true; - } - return false; -} - -/* -* Check if the custom data of the subtree is null -*/ -bool raxIsSubtreeCustomDataNull(raxNode *node) { - if (node->iscustomnull) { - return true; - } - return false; -} - /* * Split the tree into two sub trees, and return the root node of the new sub tree * @@ -2369,14 +2291,12 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, void *data) { } else { parentlink = raxFindParentLink(parent,splitNode); } - raxNode *newNode = raxReallocForSubtreeCustomData(splitNode); - raxSetSubtree(newNode); - memcpy(parentlink,&newNode,sizeof(newNode)); + raxSetSubtree(splitNode); - raxStackAddNumNodes(&stack, -(int)(newNode->numnodes)); + raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); raxStackFree(&stack); - return newNode; + return splitNode; } /* diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h index 790ca134..331936d2 100644 --- a/modules/kv-state-cache/radix-tree/radix.h +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -100,14 +100,13 @@ typedef struct raxNode { uint32_t iskey : 1; /* Does this node contain a key? */ uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ - uint32_t iscustomnull : 1; /* Associated custome value is NULL */ - uint32_t iscustomallocated : 1; /* The memory to store custom value is allocated */ uint32_t iscompr : 1; /* Node is compressed. */ uint32_t issubtree : 1; /* Node is the root node of a sub tree */ uint32_t size : 26; /* Number of children, or compressed string len. */ uint32_t numnodes; /* Number of the child nodes */ uint32_t numele; /* Number of elements inside this node. */ uint64_t timestamp; /* Timestamps of the node */ + void *custom_data; /* Data layout is as follows: * * If node is not compressed we have 'size' bytes, one for each children @@ -213,9 +212,6 @@ raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len); void raxSetSubtree(raxNode *n); void raxSetSubtreeAllocated(raxNode *node); void raxSetSubtreeNotNull(raxNode *node); -bool raxIsSubtree(raxNode *node); -bool raxIsSubtreeAllocated(raxNode *node); -bool raxIsSubtreeCustomDataNull(raxNode *node); int raxFindNodeWithParent(rax *rax, int *s, size_t len, void **node, void **parent); void raxFree(rax* rax); void raxFreeWithCallback(rax *rax, void (*free_callback)(raxNode *)); @@ -229,7 +225,6 @@ void raxStop(raxIterator* it); int raxEOF(raxIterator* it); void raxShow(rax* rax); uint64_t raxSize(rax* rax); -raxNode *raxReallocForSubtreeCustomData(raxNode *n); void raxSetCustomData(raxNode *n, void *data); void *raxGetCustomData(raxNode *n); unsigned long raxTouch(raxNode* n); @@ -245,5 +240,6 @@ void raxSerialize(rax* root, std::vector>& tokenList, std::vect * in a low level way, so this function is exported as well. */ void raxSetData(raxNode* n, void* data); void* raxGetData(raxNode* n); +int raxFindNode(rax *rax, int *s, size_t len, void **node); #endif diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index d2b6790a..68d5a2d1 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -99,7 +99,6 @@ int main() { std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; // std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10}; inference(round_1_tokens); - inference(round_1_tokens); sleep(5); // inference(round_2_tokens); // inference(round_2_tokens); From b5704662c40223ecfc2a777fd3d6bae5e9016760 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Thu, 1 Feb 2024 10:34:22 +0800 Subject: [PATCH 04/20] Support merge tree / merge object / bug fix. (#1738) - Add timestamp for node - Support merge two radix tree - Support merge cache object - Refactor the framework of cache object, cache block object and radix tree - Fix error of setting sub_tree flag, as well as some derived bugs caused by the error - Format the code to an unified style Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 371 ++++++---- modules/kv-state-cache/ds/kv_state_cache.h | 44 +- .../kv-state-cache/ds/kv_state_cache_block.cc | 174 ++--- .../kv-state-cache/ds/kv_state_cache_block.h | 45 +- .../kv-state-cache/radix-tree/radix-tree.cc | 536 +++++++++++++++ .../kv-state-cache/radix-tree/radix-tree.h | 626 ++--------------- modules/kv-state-cache/radix-tree/radix.cc | 649 ++++++++++++++++-- modules/kv-state-cache/radix-tree/radix.h | 23 +- .../kv-state-cache/strategy/LRU_strategy.cc | 41 +- .../kv-state-cache/strategy/LRU_strategy.h | 4 +- .../utils/kv_state_cache_utils.cc | 70 +- .../utils/kv_state_cache_utils.h | 14 +- test/kv_state_cache_object_test.cc | 278 ++++---- test/kv_state_cache_test.cc | 35 +- test/rax_diff_test.cc | 101 +++ 15 files changed, 1833 insertions(+), 1178 deletions(-) create mode 100644 modules/kv-state-cache/radix-tree/radix-tree.cc create mode 100644 test/rax_diff_test.cc diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index 6d18fb10..43beec92 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -38,14 +38,25 @@ void KVStateCache::Resolve() { "Expect typename '" + typeName + "', but got '" + this->meta_.GetTypeName() + "'"); - // 1. construct the kv_state_cache_block_builder - this->kv_state_cache_block = std::dynamic_pointer_cast( - this->meta_.GetMember("root_kv_state_cache_block")); - // 2. construct the radix tree - this->root_tree = RadixTree::Deserialize( + // 1. construct the radix tree + this->rootTree = RadixTree::Deserialize( base64_decode(this->meta_.GetKeyValue("radix_tree"))); LOG(INFO) << "Resolve RadixTree success" << std::endl; - raxShow(this->root_tree->GetTree()); + raxShow(this->rootTree->GetRootTree()); + + // 2. construct the kvStateCacheBlockBuilder list + size_t numBlocks = this->meta_.GetKeyValue("numBlocks"); + LOG(INFO) << "num blocks:" << numBlocks; + for (size_t i = 0; i < numBlocks; i++) { + std::shared_ptr kvStateCacheBlockObject = this->meta_.GetMember( + "kv_state_cache_block_builder_" + std::to_string(i)); + this->kvStateCacheBlockList.push_back( + std::dynamic_pointer_cast(kvStateCacheBlockObject)); + this->kvStateCacheBlockMap[kvStateCacheBlockObject->id()] = + std::dynamic_pointer_cast(kvStateCacheBlockObject); + LOG(INFO) << "kvStateCacheBlockObject:" << kvStateCacheBlockObject->id(); + } + // 3. construct the member field this->dimension = this->meta_.GetKeyValue("dimension"); LOG(INFO) << "construct the member field success" << std::endl; @@ -56,15 +67,26 @@ KVStateCache::~KVStateCache() { } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, - int cache_capacity) { + int cacheCapacity) { this->dimension = dimension; this->version = 0; - this->kv_state_cache_block_builder = - std::make_shared(client, this->dimension); + KVStateCacheBlockBuilder* builder = + new KVStateCacheBlockBuilder(client, this->dimension); + + this->rootTree = std::make_shared(cacheCapacity); + + TreeData* treeData = new TreeData(); + treeData->kvStateCacheBlockBuilder = builder; + treeData->isPtr = true; - this->root_tree = std::make_shared(cache_capacity); - this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), - sizeof(this->kv_state_cache_block_builder.get())); + std::shared_ptr rootTreeHeader = this->rootTree->GetRootNode(); + rootTreeHeader->treeData->data = treeData; + rootTreeHeader->treeData->dataLength = sizeof(TreeData); + this->rootTree->SetSubtreeData(treeData, sizeof(TreeData)); + LOG(INFO) << "set builder:" << builder + << " to tree:" << this->rootTree->GetRootTree()->head; + LOG(INFO) << "data:" << treeData + << " custom data:" << rootTreeHeader->treeData; } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, @@ -72,173 +94,201 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, // TBD this->dimension = cache->GetDemension(); this->version = cache->GetVersion(); - this->kv_state_cache_block_builder = - std::make_shared(client, - cache->GetKVStateCacheBlock()); - - this->root_tree = cache->GetRootTree(); - this->root_tree->SetCustomData(this->kv_state_cache_block_builder.get(), - sizeof(this->kv_state_cache_block_builder.get())); + // 1. create block builder from block + std::map> kvStateCacheBlockMap = + cache->kvStateCacheBlockMap; + this->rootTree = cache->GetRootTree(); + std::set subTreeData = cache->rootTree->GetSubTreeDataSet(); + + for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) { + TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data; + LOG(INFO) << "tree data:" << treeData; + VINEYARD_ASSERT(treeData->isPtr == false); + LOG(INFO) << "id:" << treeData->builderObjectID; + std::shared_ptr kvStateCacheBlock = + kvStateCacheBlockMap[treeData->builderObjectID]; + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + new KVStateCacheBlockBuilder(client, kvStateCacheBlock); + + treeData->kvStateCacheBlockBuilder = kvStateCacheBlockBuilder; + treeData->isPtr = true; + } } KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( - Client& client, KVStateCacheBlockBuilder* kv_state_cache_block_builder, - std::vector> node_with_tree_attri_list) { - // Split the tree if the list of kv_state is full. - VINEYARD_ASSERT(node_with_tree_attri_list.size() > 0); - KVStateCacheBlockBuilder* child_kv_state_cache_block_builder = + Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + std::vector> nodeDataList) { + LOG(INFO) << "split"; + // Split the tree if the list of kvState is full. + VINEYARD_ASSERT(nodeDataList.size() > 0); + KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder = new KVStateCacheBlockBuilder(client, this->dimension); - for (size_t i = 0; i < node_with_tree_attri_list.size(); i++) { - LOG(INFO) << "transfer node:" << i; - LOG(INFO) << "node:" << node_with_tree_attri_list[i]->get_node(); - LOG(INFO) << "data:" << node_with_tree_attri_list[i]->get_node()->get_data(); - offset_data* data = - (offset_data*) node_with_tree_attri_list[i]->get_node()->get_data(); + for (size_t i = 0; i < nodeDataList.size(); i++) { + OffsetData* data = (OffsetData*) nodeDataList[i]->nodeData->data; if (data == nullptr) continue; int index = data->offset; - LOG(INFO) << "stage 0"; // Transfer the data from this builder to the child builder. - const std::shared_ptr> k_builder = - kv_state_cache_block_builder->getKBuilder(); - const std::shared_ptr> v_builder = - kv_state_cache_block_builder->getVBuilder(); - LOG(INFO) << "stage 0.5"; - offset_data* new_offset_data = new offset_data(); - child_kv_state_cache_block_builder->Update( - k_builder->data() + index * this->dimension, - v_builder->data() + index * this->dimension, + const std::shared_ptr> keyStateTensorBuilder = + kvStateCacheBlockBuilder->GetKeyStateBuilder(); + const std::shared_ptr> valueStateTensorBuilder = + kvStateCacheBlockBuilder->GetValueStateBuilder(); + OffsetData* new_offset_data = new OffsetData(); + childKVStateCacheBlockBuilder->Update( + keyStateTensorBuilder->data() + index * this->dimension, + valueStateTensorBuilder->data() + index * this->dimension, this->dimension, new_offset_data); - LOG(INFO) << "stage 1"; - node_with_tree_attri_list[i]->get_node()->set_data(new_offset_data, - sizeof(offset_data)); + nodeDataList[i]->nodeData->data = new_offset_data; + nodeDataList[i]->nodeData->dataLength = sizeof(OffsetData); // Clear the bitmap. - LOG(INFO) << "stage 2, index:" << index; - kv_state_cache_block_builder->DeleteKVCache(index); - LOG(INFO) << "bitmap:" << kv_state_cache_block_builder->GetBitmapStr(); + kvStateCacheBlockBuilder->DeleteKVCache(index); } - LOG(INFO) << "stage 3"; - kv_state_cache_block_builder->SetChildKVStateCacheBlockBuilder( - child_kv_state_cache_block_builder); - LOG(INFO) << "stage 4"; - return child_kv_state_cache_block_builder; + LOG(INFO) << "builder:" << kvStateCacheBlockBuilder + << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); + LOG(INFO) << "child_builder:" << childKVStateCacheBlockBuilder + << " bitmap:" << childKVStateCacheBlockBuilder->GetBitmapStr(); + return childKVStateCacheBlockBuilder; } void KVStateCacheBuilder::Update(Client& client, - const std::vector& token_list, - int next_token, - const KV_STATE_WITH_LAYER& kv_state) { + const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { LOG(INFO) << "update"; - std::vector token_list_copy = token_list; - token_list_copy.push_back(next_token); + std::vector tokenListCopy = tokenList; + tokenListCopy.push_back(nextToken); // Create a empty node of tokens from radix tree. - std::shared_ptr evicted_node = nullptr; - std::shared_ptr node_with_tree_attri = - this->root_tree->Insert(token_list_copy, evicted_node); - if (node_with_tree_attri == nullptr) { + std::shared_ptr evictedNodeData = nullptr; + std::shared_ptr nodeData = + this->rootTree->Insert(tokenListCopy, evictedNodeData); + if (nodeData == nullptr) { LOG(INFO) << "insert failed"; return; } - LOG(INFO) << "stage 1"; - std::shared_ptr tree = node_with_tree_attri->get_tree(); - LOG(INFO) << "stage 2"; - KVStateCacheBlockBuilder* kv_state_cache_block_builder = - (KVStateCacheBlockBuilder*) tree->GetCustomData(); - LOG(INFO) << "stage 3"; - if (evicted_node != nullptr) { - LOG(INFO) << "stage 4"; - offset_data* data = (offset_data*) evicted_node->get_node()->get_data(); - LOG(INFO) << "stage 5"; - KVStateCacheBlockBuilder* builder = - (KVStateCacheBlockBuilder*) evicted_node->get_tree()->GetCustomData(); - LOG(INFO) << "stage 6"; - builder->DeleteKVCache(data->offset); - - if ((offset_data*) evicted_node->get_node()->get_data() != nullptr) - delete (offset_data*) evicted_node->get_node()->get_data(); + LOG(INFO) << "insert end"; + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + (KVStateCacheBlockBuilder*) ((TreeData*) nodeData->treeData->data) + ->kvStateCacheBlockBuilder; + LOG(INFO) << "try to delete"; + if (evictedNodeData != nullptr) { + Delete(evictedNodeData); } // TBD // Use lock to protect the kv_state_cache_builder - // kv_state_cache_builder->Lock(); - - if (kv_state_cache_block_builder->IsFull()) { + LOG(INFO) << "data:" << nodeData->treeData->data + << " custom data:" << nodeData->treeData; + LOG(INFO) << "kvStateCacheBlockBuilder:" << kvStateCacheBlockBuilder; + if (kvStateCacheBlockBuilder->IsFull()) { /** * If the kv-state cache of the tree is full, triggle split. Delete the * empty node from the radix tree and split the tree. Then, kv-state cache * split according to the new tree. */ LOG(INFO) << "triggle splits"; - std::shared_ptr evicted_node = nullptr; - this->root_tree->Delete(token_list_copy, evicted_node); - std::shared_ptr new_tree = tree->Split(token_list_copy); - - LOG(INFO) << "tree split success"; - std::vector> node_with_tree_attri_list = - RadixTree::TraverseTreeWithoutSubTree(new_tree); - KVStateCacheBlockBuilder* new_kv_state_cache_block_builder = - Split(client, kv_state_cache_block_builder, node_with_tree_attri_list); - new_tree->SetCustomData(new_kv_state_cache_block_builder, - sizeof(new_kv_state_cache_block_builder)); + std::shared_ptr evictedNodeData = nullptr; + this->rootTree->Delete(tokenListCopy, evictedNodeData); + + std::shared_ptr subTreeHeader; + std::vector> nodeDataList = + rootTree->Split(tokenListCopy, subTreeHeader); + KVStateCacheBlockBuilder* newKVStateCacheBlockBuilder = + Split(client, kvStateCacheBlockBuilder, nodeDataList); + + TreeData* newTreeData = new TreeData(); + newTreeData->kvStateCacheBlockBuilder = newKVStateCacheBlockBuilder; + newTreeData->isPtr = true; + + subTreeHeader->treeData->data = newTreeData; + subTreeHeader->treeData->dataLength = sizeof(TreeData); + rootTree->SetSubtreeData(newTreeData, sizeof(TreeData)); LOG(INFO) << "block split success"; // kv_state_cache_builder->UnLock(); - Update(client, token_list, next_token, kv_state); + Update(client, tokenList, nextToken, kvState); } else { // Update the kv-state cache. - LOG(INFO) << "update kv-state cache"; - offset_data* data = new offset_data(); - LOG(INFO) << "stage 7"; - kv_state_cache_block_builder->Update(kv_state, data); - LOG(INFO) << "stage 8"; - std::shared_ptr node = node_with_tree_attri->get_node(); - LOG(INFO) << "stage 9"; - node->set_data(data, sizeof(offset_data)); - LOG(INFO) << "stage 10"; + OffsetData* data = new OffsetData(); + kvStateCacheBlockBuilder->Update(kvState, data); + nodeData->nodeData->data = data; + nodeData->nodeData->dataLength = sizeof(OffsetData); } - LOG(INFO) << "bitmap:" << kv_state_cache_block_builder->GetBitmapStr(); + LOG(INFO) << "builder:" << kvStateCacheBlockBuilder + << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); } -static std::shared_ptr node; +static std::shared_ptr node; KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( - Client& client, const std::vector& token_list, int token) { - std::vector token_list_copy = token_list; - token_list_copy.push_back(token); - - KV_STATE_WITH_LAYER kv_state; - std::shared_ptr node_with_tree_attri = - this->root_tree->Query(token_list_copy); - /**/ - if (node_with_tree_attri != nullptr) { - offset_data* data = - (offset_data*) node_with_tree_attri->get_node()->get_data(); + Client& client, const std::vector& tokenList, int token) { + std::vector tokenListCopy = tokenList; + tokenListCopy.push_back(token); + + KV_STATE_WITH_LAYER kvState; + std::shared_ptr nodeData = this->rootTree->Query(tokenListCopy); + + if (nodeData != nullptr) { + OffsetData* data = (OffsetData*) nodeData->nodeData->data; int offset = data->offset; - KVStateCacheBlockBuilder* kv_state_cache_block_builder = - (KVStateCacheBlockBuilder*) node_with_tree_attri->get_tree() - ->GetCustomData(); - // kv_state_cache_builder->Lock(); + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + (KVStateCacheBlockBuilder*) ((TreeData*) nodeData->treeData->data) + ->kvStateCacheBlockBuilder; - kv_state_cache_block_builder->Query(client, offset, kv_state); - // kv_state_cache_builder->UnLock(); - node = node_with_tree_attri->get_node(); + LOG(INFO) << "offset:" << offset; + LOG(INFO) << "kvStateCacheBlockBuilder:" << kvStateCacheBlockBuilder; + kvStateCacheBlockBuilder->Query(client, offset, kvState); } - LOG(INFO) << "query success"; - return kv_state; + return kvState; +} + +void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { + LOG(INFO) << "stage1"; + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + (KVStateCacheBlockBuilder*) ((TreeData*) evictedNodeData->treeData->data) + ->kvStateCacheBlockBuilder; + LOG(INFO) << "stage2, builder:" << kvStateCacheBlockBuilder; + OffsetData* data = (OffsetData*) evictedNodeData->nodeData->data; + LOG(INFO) << "stage3"; + kvStateCacheBlockBuilder->DeleteKVCache(data->offset); + LOG(INFO) << "stage4"; + delete data; } -std::shared_ptr KVStateCacheBuilder::Merge( - Client& client, std::shared_ptr kv_state_cache) { +void KVStateCacheBuilder::Merge(Client& client, + std::shared_ptr kvStateCache) { // TBD - if (kv_state_cache == nullptr) { - return nullptr; + if (kvStateCache == nullptr) { + return; + } + std::shared_ptr globalCacheBuilder = + std::make_shared(client, kvStateCache); + std::shared_ptr globalCacheTree = kvStateCache->GetRootTree(); + + std::set> insertTokenList; + std::vector> evicted_token_list; + mergeTree(this->rootTree->GetRootTree(), globalCacheTree->GetRootTree(), + evicted_token_list, insertTokenList, + this->rootTree->GetCacheCapacity()); + + for (size_t i = 0; i < evicted_token_list.size(); i++) { + std::vector tokenList = evicted_token_list[i]; + std::shared_ptr evictedNodeData; + this->rootTree->Delete(tokenList, evictedNodeData); + Delete(evictedNodeData); } - // VINEYARD_ASSERT(false); - return nullptr; + + for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) { + std::vector tokenList = *it; + KV_STATE_WITH_LAYER kvState = globalCacheBuilder->Query( + client, std::vector(tokenList.begin(), tokenList.end() - 1), + tokenList.back()); + this->Update(client, tokenList, tokenList[tokenList.size() - 1], kvState); + } + return; } Status KVStateCacheBuilder::Build(Client& client) { @@ -247,40 +297,61 @@ Status KVStateCacheBuilder::Build(Client& client) { } std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { + LOG(INFO) << "cache seal"; this->Build(client); - std::shared_ptr kv_state_cache = - std::make_shared(); + std::shared_ptr kvStateCache = std::make_shared(); // 1. store the member variables to cache object meta - kv_state_cache->meta_.AddKeyValue("dimension", this->dimension); + kvStateCache->meta_.AddKeyValue("dimension", this->dimension); + + // 2. seal all the block and put object id to cache object and + // change the tree data from pointer to object id + + int count = 0; + LOG(INFO) << "count:" << count; + std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); + for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); + ++iter) { + TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data; + VINEYARD_ASSERT(treeData != nullptr); + VINEYARD_ASSERT(treeData->isPtr == true); + + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder; + LOG(INFO) << "builder:" << kvStateCacheBlockBuilder; + std::shared_ptr kvStateCacheBlock = + kvStateCacheBlockBuilder->_Seal(client); + kvStateCache->meta_.AddMember( + "kv_state_cache_block_builder_" + std::to_string(count), + kvStateCacheBlock); + treeData->builderObjectID = kvStateCacheBlock->id(); + treeData->isPtr = false; + count++; + } - // 2. seal all the kv_state_cache_block - // 3. put cache_block_object_id to cache object meta - kv_state_cache->meta_.AddMember( - "root_kv_state_cache_block", - this->kv_state_cache_block_builder->_Seal(client)); + kvStateCache->meta_.AddKeyValue("numBlocks", count); - // 4. put the serialized sequence radix tree to cache object meta - kv_state_cache->meta_.AddKeyValue( - "radix_tree", base64_encode(this->root_tree->Serialize())); + // 3. put the serialized sequence radix tree to cache object meta + kvStateCache->meta_.AddKeyValue("radix_tree", + base64_encode(this->rootTree->Serialize())); - // 5. put the object type to the meta - kv_state_cache->meta_.SetTypeName(type_name()); + // 4. put the object type to the meta + kvStateCache->meta_.SetTypeName(type_name()); VINEYARD_CHECK_OK( - client.CreateMetaData(kv_state_cache->meta_, kv_state_cache->id_)); - LOG(INFO) << "KVStateCacheBuilder::_Seal: " << kv_state_cache->id_; - return kv_state_cache; + client.CreateMetaData(kvStateCache->meta_, kvStateCache->id_)); + LOG(INFO) << "KVStateCacheBuilder::_Seal: " << kvStateCache->id_; + return kvStateCache; } KVStateCacheBuilder::~KVStateCacheBuilder() { // TBD - std::vector> node_with_tree_attri_list = - RadixTree::TraverseTreeWithoutSubTree(this->root_tree); - for (size_t i = 0; i < node_with_tree_attri_list.size(); i++) { - delete (offset_data*) node_with_tree_attri_list[i]->get_node()->get_data(); - } + // std::vector> nodeDataList = + // RadixTree::TraverseTreeWithoutSubTree(this->rootTree); + // for (size_t i = 0; i < nodeDataList.size(); i++) { + // delete (OffsetData*) nodeDataList[i]->get_node()->get_data(); + // } } } // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h index f1e58e12..d4498281 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -28,12 +28,21 @@ limitations under the License. namespace vineyard { +struct TreeData { + union { + void* kvStateCacheBlockBuilder; + uint64_t builderObjectID; + }; + bool isPtr = true; +}; + class KVStateCache : public vineyard::Registered { private: - std::shared_ptr kv_state_cache_block; - std::shared_ptr root_tree; + std::vector> kvStateCacheBlockList; + std::map> kvStateCacheBlockMap; + std::shared_ptr rootTree; int dimension; - int cache_capacity; + int cacheCapacity; uint64_t version; public: @@ -47,17 +56,17 @@ class KVStateCache : public vineyard::Registered { void Resolve(); // for test - std::shared_ptr GetKVStateCacheBlock() { - return this->kv_state_cache_block; + std::vector> GetKVStateCacheBlockList() { + return this->kvStateCacheBlockList; } int GetDemension() { return this->dimension; } - int GetCacheCapacity() { return this->cache_capacity; } + int GetCacheCapacity() { return this->cacheCapacity; } uint64_t GetVersion() { return this->version; } - std::shared_ptr GetRootTree() { return this->root_tree; } + std::shared_ptr GetRootTree() { return this->rootTree; } ~KVStateCache(); @@ -65,20 +74,18 @@ class KVStateCache : public vineyard::Registered { }; class KVStateCacheBuilder : public vineyard::ObjectBuilder { - std::shared_ptr kv_state_cache_block_builder; - std::shared_ptr root_tree; + std::shared_ptr rootTree; int dimension; uint64_t version; public: - KVStateCacheBuilder(Client& client, int dimension, int cache_capacity); + KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity); KVStateCacheBuilder(Client& client, std::shared_ptr cache); KVStateCacheBlockBuilder* Split( - Client& client, KVStateCacheBlockBuilder* kv_state_cache_block_builder, - std::vector> - node_with_tree_attri_list); + Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + std::vector> nodeDataList); void Update(Client& client, const std::vector& token_list, int next_token, const KV_STATE_WITH_LAYER& kv_state); @@ -86,8 +93,9 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { KV_STATE_WITH_LAYER Query(Client& client, const std::vector& token_list, int token); - std::shared_ptr Merge( - Client& client, std::shared_ptr kv_state_cache); + void Delete(std::shared_ptr evicted_node); + + void Merge(Client& client, std::shared_ptr kv_state_cache); uint64_t GetVersion() { return this->version; } @@ -95,13 +103,9 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr _Seal(Client& client) override; - std::shared_ptr GetKVStateCacheBlockBuilder() { - return this->kv_state_cache_block_builder; - } - uint64_t GetDemension() { return this->dimension; } - std::shared_ptr GetRootTree() { return this->root_tree; } + std::shared_ptr GetRootTree() { return this->rootTree; } ~KVStateCacheBuilder(); }; diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/kv-state-cache/ds/kv_state_cache_block.cc index 64079ded..2b49af15 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.cc +++ b/modules/kv-state-cache/ds/kv_state_cache_block.cc @@ -48,72 +48,64 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) { meta.GetTypeName() + "'"); // TBD - // 1. construct the k_builder and v_builder - this->k_tensor = std::dynamic_pointer_cast>( - this->meta_.GetMember("k_builder")); - this->v_tensor = std::dynamic_pointer_cast>( - this->meta_.GetMember("v_builder")); - // 2. construct the child kv_state_cache_block_builder - int child_num = this->meta_.GetKeyValue("child_num"); - for (int i = 0; i < child_num; ++i) { - std::shared_ptr child_kv_state_cache_block_builder = - std::dynamic_pointer_cast(this->meta_.GetMember( - "child_kv_state_cache_block_" + std::to_string(i))); - this->child_kv_state_cache_block_list.push_back( - child_kv_state_cache_block_builder); - } - // 3. construct the member field + // 1. construct the keyStateTensorBuilder and valueStateTensorBuilder + this->keyStateTensor = std::dynamic_pointer_cast>( + this->meta_.GetMember("keyStateTensorBuilder")); + this->valueStateTensor = std::dynamic_pointer_cast>( + this->meta_.GetMember("valueStateTensorBuilder")); + // 2. construct the member field this->bitmap = this->meta_.GetKeyValue("bitmap"); this->dimension = this->meta_.GetKeyValue("dimension"); } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, int dimension) { - pthread_spin_init(&(this->spin_lock), 0); this->bitmap = UINT64_MAX; std::vector shape = {LIST_SIZE, dimension}; - this->k_builder = std::make_shared>(client, shape); - this->v_builder = std::make_shared>(client, shape); + this->keyStateTensorBuilder = + std::make_shared>(client, shape); + this->valueStateTensorBuilder = + std::make_shared>(client, shape); this->dimension = dimension; } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( - Client& client, std::shared_ptr kv_state_cache_block) { - pthread_spin_init(&(this->spin_lock), 0); - this->bitmap = kv_state_cache_block->bitmap; - this->dimension = kv_state_cache_block->dimension; + Client& client, std::shared_ptr kvStateCacheBlock) { + this->bitmap = kvStateCacheBlock->bitmap; + this->dimension = kvStateCacheBlock->dimension; std::vector shape = {LIST_SIZE, dimension}; - this->k_builder = std::make_shared>(client, shape); - this->v_builder = std::make_shared>(client, shape); + this->keyStateTensorBuilder = + std::make_shared>(client, shape); + this->valueStateTensorBuilder = + std::make_shared>(client, shape); // transfer the data from kv_state_cache to this builder - memcpy(this->k_builder->data(), kv_state_cache_block->k_tensor->data(), + memcpy(this->keyStateTensorBuilder->data(), + kvStateCacheBlock->keyStateTensor->data(), LIST_SIZE * this->dimension * sizeof(double)); - memcpy(this->v_builder->data(), kv_state_cache_block->v_tensor->data(), + memcpy(this->valueStateTensorBuilder->data(), + kvStateCacheBlock->valueStateTensor->data(), LIST_SIZE * this->dimension * sizeof(double)); - for (size_t i = 0; - i < kv_state_cache_block->child_kv_state_cache_block_list.size(); ++i) { - this->child_kv_state_cache_builder_list.push_back( - new KVStateCacheBlockBuilder( - client, kv_state_cache_block->child_kv_state_cache_block_list[i])); - } } // current we do not consider the layer. Status KVStateCacheBlockBuilder::Query(Client& client, int index, - KV_STATE_WITH_LAYER& kv_state) { - std::vector k_state; - std::vector v_state; + KV_STATE_WITH_LAYER& kvState) { + std::vector keyStateVector; + std::vector valueStateVector; for (int i = 0; i < this->dimension; ++i) { - k_state.push_back(((double*) k_builder->data())[index * dimension + i]); + keyStateVector.push_back( + ((double*) keyStateTensorBuilder->data())[index * dimension + i]); } for (int i = 0; i < this->dimension; ++i) { - v_state.push_back(((double*) v_builder->data())[index * dimension + i]); + valueStateVector.push_back( + ((double*) valueStateTensorBuilder->data())[index * dimension + i]); } - kv_state.insert(std::make_pair(1, std::make_pair(k_state, v_state))); + kvState.insert( + std::make_pair(1, std::make_pair(keyStateVector, valueStateVector))); return Status::OK(); } @@ -128,100 +120,70 @@ bool KVStateCacheBlockBuilder::IsFull() { return index < 0 || index >= LIST_SIZE; } -void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kv_state, - offset_data* data) { +void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, + OffsetData* data) { int index = this->FindEmptySlot(); LOG(INFO) << "index:" << index; - std::vector k_state = (kv_state.find(1)->second).first; - std::vector v_state = (kv_state.find(1)->second).second; - VINEYARD_ASSERT(k_state.size() == (size_t) this->dimension); - VINEYARD_ASSERT(v_state.size() == (size_t) this->dimension); + std::vector keyStateVector = (kvState.find(1)->second).first; + std::vector valueStateVector = (kvState.find(1)->second).second; + VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); + VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); - double* key_data = (double*) k_builder->data(); - double* value_data = (double*) v_builder->data(); + double* keyData = (double*) keyStateTensorBuilder->data(); + double* valueData = (double*) valueStateTensorBuilder->data(); for (int i = 0; i < this->dimension; ++i) { - key_data[index * this->dimension + i] = k_state[i]; + keyData[index * this->dimension + i] = keyStateVector[i]; } for (int i = 0; i < this->dimension; ++i) { - value_data[index * this->dimension + i] = v_state[i]; + valueData[index * this->dimension + i] = valueStateVector[i]; } data->offset = index; - LOG(INFO) << "before:" << this->bitmap; ACQUIRE_BIT_RESOURCE(this->bitmap, index); - LOG(INFO) << "after:" << this->bitmap; } -void KVStateCacheBlockBuilder::Update(double* k_data, double* v_data, - unsigned long data_length, - offset_data* data) { +void KVStateCacheBlockBuilder::Update(double* keyState, double* valueState, + unsigned long dataLength, + OffsetData* data) { int index = FindEmptySlot(); - double* key_data = (double*) k_builder->data(); - double* value_data = (double*) v_builder->data(); - VINEYARD_ASSERT((unsigned long) this->dimension == data_length); - for (unsigned long i = 0; i < data_length; ++i) { - key_data[index * this->dimension + i] = k_data[i]; + double* keyData = (double*) keyStateTensorBuilder->data(); + double* valueData = (double*) valueStateTensorBuilder->data(); + VINEYARD_ASSERT((unsigned long) this->dimension == dataLength); + for (unsigned long i = 0; i < dataLength; ++i) { + keyData[index * this->dimension + i] = keyState[i]; } - for (unsigned long i = 0; i < data_length; ++i) { - value_data[index * this->dimension + i] = v_data[i]; + for (unsigned long i = 0; i < dataLength; ++i) { + valueData[index * this->dimension + i] = valueState[i]; } data->offset = index; ACQUIRE_BIT_RESOURCE(this->bitmap, index); - LOG(INFO) << "bitmap:" << this->GetBitmapStr(); -} - -void KVStateCacheBlockBuilder::SetChildKVStateCacheBlockBuilder( - KVStateCacheBlockBuilder* child_kv_state_cache_builder) { - this->child_kv_state_cache_builder_list.push_back( - child_kv_state_cache_builder); } -Status KVStateCacheBlockBuilder::Build(Client& client) { - // TBD craete vineyard object - // pthread_spin_lock(&(this->spin_lock)); - // ObjectMeta meta; - // meta.SetTypeName(type_name()); - // meta.AddKeyValue("bitmap", this->bitmap); - // for (int i = 0; i < LIST_SIZE; ++i) { - // // TBD - // // create tensor meta - // } - // // TBD check the status - // client.CreateMetaData(meta, id); - // pthread_spin_unlock(&(this->spin_lock)); - return Status::OK(); -} +Status KVStateCacheBlockBuilder::Build(Client& client) { return Status::OK(); } std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { + LOG(INFO) << "block seal:" << this; this->Build(client); - // pthread_spin_lock(&(this->spin_lock)); - // pthread_spin_unlock(&(this->spin_lock)); - std::shared_ptr kv_state_cache_block = + std::shared_ptr kvStateCacheBlock = std::make_shared(); - // TBD - // 1. seal k_builder and v_builder - kv_state_cache_block->meta_.AddMember("k_builder", k_builder->Seal(client)); - kv_state_cache_block->meta_.AddMember("v_builder", v_builder->Seal(client)); - // 2. seal child kv_state_cache_block_builder - for (size_t i = 0; i < this->child_kv_state_cache_builder_list.size(); ++i) { - kv_state_cache_block->meta_.AddMember( - "child_kv_state_cache_block_" + std::to_string(i), - this->child_kv_state_cache_builder_list[i]->_Seal(client)); - } - kv_state_cache_block->meta_.AddKeyValue( - "child_num", this->child_kv_state_cache_builder_list.size()); - // 3. store the member field to meta - kv_state_cache_block->meta_.AddKeyValue("bitmap", this->bitmap); - kv_state_cache_block->meta_.AddKeyValue("dimension", this->dimension); - // 4. set the object type to meta - kv_state_cache_block->meta_.SetTypeName(type_name()); - - VINEYARD_CHECK_OK(client.CreateMetaData(kv_state_cache_block->meta_, - kv_state_cache_block->id_)); - return kv_state_cache_block; + // 1. seal keyStateTensorBuilder and valueStateTensorBuilder + kvStateCacheBlock->meta_.AddMember("keyStateTensorBuilder", + keyStateTensorBuilder->Seal(client)); + kvStateCacheBlock->meta_.AddMember("valueStateTensorBuilder", + valueStateTensorBuilder->Seal(client)); + + // 2. store the member field to meta + kvStateCacheBlock->meta_.AddKeyValue("bitmap", this->bitmap); + kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension); + // 3. set the object type to meta + kvStateCacheBlock->meta_.SetTypeName(type_name()); + + VINEYARD_CHECK_OK( + client.CreateMetaData(kvStateCacheBlock->meta_, kvStateCacheBlock->id_)); + return kvStateCacheBlock; } } // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index e6e172fd..667cd467 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -24,6 +24,7 @@ limitations under the License. #include "basic/ds/tensor.h" #include "client/ds/blob.h" #include "client/ds/i_object.h" +#include "kv-state-cache/radix-tree/radix-tree.h" typedef std::map, std::vector>> KV_STATE_WITH_LAYER; @@ -38,10 +39,9 @@ typedef std::vector< #define ACQUIRE_BIT_RESOURCE(value, bit) \ ((value) &= (~(((uint64_t) 1) << (bit)))) -struct offset_data { +struct OffsetData { short offset; }; - namespace vineyard { #define LIST_SIZE 5 @@ -61,10 +61,8 @@ namespace vineyard { class KVStateCacheBlock : public vineyard::Registered { private: - std::shared_ptr> k_tensor; - std::shared_ptr> v_tensor; - std::vector> - child_kv_state_cache_block_list; + std::shared_ptr> keyStateTensor; + std::shared_ptr> valueStateTensor; uint64_t bitmap; ObjectID id; int dimension; @@ -83,22 +81,24 @@ class KVStateCacheBlock : public vineyard::Registered { uint64_t GetBitmap() { return this->bitmap; } - std::shared_ptr> GetKTensor() { return this->k_tensor; } + std::shared_ptr> GetKeyTensor() { + return this->keyStateTensor; + } - std::shared_ptr> GetVTensor() { return this->v_tensor; } + std::shared_ptr> GetValueTensor() { + return this->valueStateTensor; + } friend class KVStateCacheBlockBuilder; }; class KVStateCacheBlockBuilder : public ObjectBuilder { private: - std::shared_ptr> k_builder; - std::shared_ptr> v_builder; - std::vector child_kv_state_cache_builder_list; + std::shared_ptr> keyStateTensorBuilder; + std::shared_ptr> valueStateTensorBuilder; // TBD // support more than 64 kv-state cache slots uint64_t bitmap; - pthread_spinlock_t spin_lock; int dimension; int FindEmptySlot(); @@ -116,10 +116,10 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { * @param kv_state The kv-state of the prompt. A LLM inference can contain * multiple kv-states for each layer. */ - void Update(const KV_STATE_WITH_LAYER& kv_state, offset_data* data); + void Update(const KV_STATE_WITH_LAYER& kv_state, OffsetData* data); - void Update(double* k_data, double* v_data, unsigned long data_length, - offset_data* data); + void Update(double* keyState, double* valueState, unsigned long dataLength, + OffsetData* data); /** * @brief Query the kv-state using the whole token list. @@ -137,23 +137,16 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { std::shared_ptr _Seal(Client& client) override; - void Lock() { pthread_spin_lock(&(this->spin_lock)); } - - void UnLock() { pthread_spin_unlock(&(this->spin_lock)); } - - const std::shared_ptr> getKBuilder() { - return k_builder; + const std::shared_ptr> GetKeyStateBuilder() { + return keyStateTensorBuilder; } - const std::shared_ptr> getVBuilder() { - return v_builder; + const std::shared_ptr> GetValueStateBuilder() { + return valueStateTensorBuilder; } void DeleteKVCache(int bit) { FREE_BIT_RESOURCE(this->bitmap, bit); } - void SetChildKVStateCacheBlockBuilder( - KVStateCacheBlockBuilder* child_kv_state_cache_builder); - std::string GetBitmapStr(); uint64_t GetBitmap() { return this->bitmap; } diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc new file mode 100644 index 00000000..3442febe --- /dev/null +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -0,0 +1,536 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "radix-tree.h" +#include "common/util/base64.h" +#include "common/util/logging.h" +#include "common/util/status.h" + +using namespace vineyard; + +RadixTree::RadixTree(int cacheCapacity) { + this->tree = raxNew(); + this->cacheCapacity = cacheCapacity; + this->nodeCount = 0; + + // prepare root node + std::vector rootToken = {INT32_MAX}; + std::shared_ptr evictedNode; + this->InsertInternal(rootToken, evictedNode); + raxShow(this->tree); + raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL, false); + DataWrapper* data = new DataWrapper(); + data->data = nullptr; + data->dataLength = 0; + dataNode->custom_data = data; + dataNode->issubtree = true; + this->rootToken = rootToken; +} + +RadixTree::~RadixTree() { + // TBD + // raxFreeWithCallback(this->tree, [](raxNode *n) { + // if (n->iskey && !n->isnull) { + // nodeData* nodedata = (nodeData*) raxGetData(n); + // delete nodedata; + // } + // if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { + // customData* customdata = (customData*) raxGetCustomData(n); + // delete customdata; + // } + // }); +} + +std::shared_ptr RadixTree::Insert( + std::vector tokens, std::shared_ptr& evictedNode) { + tokens.insert(tokens.begin(), INT32_MAX); + return InsertInternal(tokens, evictedNode); +} + +void RadixTree::Delete(std::vector tokens, + std::shared_ptr& evictedNode) { + tokens.insert(tokens.begin(), INT32_MAX); + DeleteInternal(tokens, evictedNode); +} + +std::shared_ptr RadixTree::Query(std::vector key) { + key.insert(key.begin(), INT32_MAX); + return QueryInternal(key); +} + +std::vector> RadixTree::Split( + std::vector tokens, std::shared_ptr& header) { + tokens.insert(tokens.begin(), INT32_MAX); + return SplitInternal(tokens, header); +} + +std::shared_ptr RadixTree::InsertInternal( + std::vector tokens, std::shared_ptr& evictedNode) { + // get the sub vector of the tokens + std::vector rootToken = + std::vector(tokens.begin(), tokens.end() - 1); + if (rootToken.size() > 0 && QueryInternal(rootToken) == nullptr) { + return nullptr; + } + + // insert the token vector to the radix tree + int* insertTokensArray = tokens.data(); + size_t insertTokensArrayLen = tokens.size(); + DataWrapper* dummyData = new DataWrapper(); + DataWrapper* oldData; + raxNode* dataNode = NULL; + int retval = raxInsertAndReturnDataNode( + this->tree, insertTokensArray, insertTokensArrayLen, dummyData, + (void**) &dataNode, (void**) &oldData); + if (dataNode == NULL) { + throw std::runtime_error("Insert token list failed"); + return NULL; + } + if (retval == 1) { + LOG(INFO) << "node count++:" << this->nodeCount; + nodeCount++; + } + + raxShow(this->tree); + if (this->nodeCount > this->cacheCapacity) { + LOG(INFO) << "cache capacity is full, evict the last recent node"; + LOG(INFO) << "cache capacity:" << this->cacheCapacity + << " node count:" << this->nodeCount; + // evict the last recent node (the node with the largest lru index- + std::vector evictedTokensVector; + raxFindLastRecentNode(this->tree->head, evictedTokensVector); + std::string evicted_str = ""; + for (size_t i = 0; i < evictedTokensVector.size(); i++) { + evicted_str += std::to_string(evictedTokensVector[i]); + } + this->DeleteInternal(evictedTokensVector, evictedNode); + } + + raxNode* subTreeNode = nullptr; + dataNode = raxFindAndReturnDataNode( + this->tree, insertTokensArray, insertTokensArrayLen, &subTreeNode, false); + LOG(INFO) << "sub tree node:" << subTreeNode << " data node:" << dataNode; + /** + * if the data node is null, it means the evicted node is the same node as + * the inserted node. + */ + if (dataNode == NULL) { + LOG(INFO) << "get failed"; + return NULL; + } + + if (subTreeNode == nullptr) { + return std::make_shared(dummyData, nullptr); + } + return std::make_shared(dummyData, + (DataWrapper*) subTreeNode->custom_data); +} + +void RadixTree::DeleteInternal(std::vector tokens, + std::shared_ptr& evictedNode) { + // remove the token vector from the radix tree + // TBD + // If the evicted node is the root node of sub tree, recycle the tree. + int* deleteTokensArray = tokens.data(); + size_t deleteTokensArrayLen = tokens.size(); + + DataWrapper* oldData; + raxNode* subTreeNode; + std::vector pre; + raxFindAndReturnDataNode(this->tree, deleteTokensArray, deleteTokensArrayLen, + &subTreeNode, false); + int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, + (void**) &oldData, &subTreeNode); + if (retval == 1) { + evictedNode = std::make_shared( + oldData, (DataWrapper*) subTreeNode->custom_data); + nodeCount--; + } else { + LOG(INFO) << "remove failed"; + } +} + +std::shared_ptr RadixTree::QueryInternal(std::vector key) { + LOG(INFO) << "Query"; + int* tokens = key.data(); + size_t tokensLen = key.size(); + + if (this->tree == nullptr) { + return NULL; + } + + raxNode* subTreeNode; + raxNode* dataNode = + raxFindAndReturnDataNode(this->tree, tokens, tokensLen, &subTreeNode); + LOG(INFO) << "query subtree node:" << subTreeNode; + if (dataNode == NULL) { + LOG(INFO) << "get failed"; + return NULL; + } + + return std::make_shared((DataWrapper*) raxGetData(dataNode), + (DataWrapper*) subTreeNode->custom_data); +} + +std::string RadixTree::Serialize() { + LOG(INFO) << "Serialize......"; + raxShow(this->tree); + std::vector> tokenList; + std::vector dataList; + std::vector timestampList; + std::vector> subTreeTokenList; + std::vector subTreeDataList; + raxSerialize(this->tree, tokenList, dataList, timestampList, + &subTreeTokenList, &subTreeDataList); + + raxShow(this->tree); + std::string serializedStr; + + if (tokenList.size() != dataList.size()) { + throw std::runtime_error( + "The size of token list and data list is not equal"); + } + for (size_t index = 0; index < tokenList.size(); index++) { + for (size_t j = 0; j < tokenList[index].size(); j++) { + serializedStr += std::to_string(tokenList[index][j]); + if (j < tokenList[index].size() - 1) { + serializedStr += ","; + } + } + serializedStr += "|"; + + // convert timestamp(uint64) to hex string + uint64_t timestamp = timestampList[index]; + std::ostringstream timestampOSS; + timestampOSS << std::hex << timestamp; + + serializedStr += timestampOSS.str() + "|"; + + // convert data to hex string + char* bytes = (char*) ((DataWrapper*) dataList[index])->data; + std::ostringstream dataOSS; + + for (int i = 0; i < ((DataWrapper*) dataList[index])->dataLength; i++) { + dataOSS << std::hex << std::setw(2) << std::setfill('0') + << static_cast(static_cast(bytes[i])); + } + serializedStr += dataOSS.str() + "\n"; + } + + serializedStr += "\t\n"; + + LOG(INFO) << "sub tree token list size:" << subTreeTokenList.size(); + for (size_t index = 0; index < subTreeTokenList.size(); index++) { + for (size_t j = 0; j < subTreeTokenList[index].size(); j++) { + serializedStr += std::to_string(subTreeTokenList[index][j]); + if (j < subTreeTokenList[index].size() - 1) { + serializedStr += ","; + } + } + serializedStr += "|"; + + // convert custom data to hex string + char* bytes = (char*) ((DataWrapper*) subTreeDataList[index])->data; + std::ostringstream dataOSS; + + LOG(INFO) << "data lengtÏ€h:" + << ((DataWrapper*) subTreeDataList[index])->dataLength; + for (int i = 0; i < ((DataWrapper*) subTreeDataList[index])->dataLength; + ++i) { + dataOSS << std::hex << std::setw(2) << std::setfill('0') + << static_cast(static_cast(bytes[i])); + } + LOG(INFO) << "data:" << ((DataWrapper*) subTreeDataList[index])->data; + LOG(INFO) << "data oss:" << dataOSS.str(); + serializedStr += dataOSS.str() + "\n"; + } + LOG(INFO) << "serializedStr:" << serializedStr; + + // use LZ4 to compress the serialized string + const char* const src = serializedStr.c_str(); + const int srcSize = serializedStr.size(); + const int maxDstSize = LZ4_compressBound(srcSize); + char* compressedData = new char[maxDstSize]; + if (compressedData == NULL) { + LOG(INFO) << "Failed to allocate memory for *compressedData."; + } + + const int compressedDataSize = + LZ4_compress_default(src, compressedData, srcSize, maxDstSize); + if (compressedDataSize <= 0) { + LOG(INFO) << "A 0 or negative result from LZ4_compress_default() " + "indicates a failure trying to compress the data. "; + } + + if (compressedDataSize > 0) { + LOG(INFO) << "We successfully compressed some data! Ratio: " + << ((float) compressedDataSize / srcSize); + } + + if (compressedData == NULL) { + LOG(INFO) << "Failed to re-alloc memory for compressedData. Sad :("; + } + + std::string compressedStr = std::string(compressedData, compressedDataSize); + std::string result = std::string((char*) &srcSize, sizeof(int)) + + std::string((char*) &cacheCapacity, sizeof(int)) + + compressedStr; + delete[] compressedData; + return result; +} + +std::shared_ptr RadixTree::Deserialize(std::string data) { + LOG(INFO) << "Deserialize......"; + // use LZ4 to decompress the serialized string + int srcSize = *(int*) data.c_str(); + data.erase(0, sizeof(int)); + int cacheCapacity = *(int*) data.c_str(); + data.erase(0, sizeof(int)); + char* const decompressBuffer = new char[srcSize]; + if (decompressBuffer == NULL) { + LOG(INFO) << "Failed to allocate memory for *decompressBuffer."; + } + + const int decompressedSize = + LZ4_decompress_safe(data.c_str(), decompressBuffer, data.size(), srcSize); + if (decompressedSize < 0) { + LOG(INFO) << "A negative result from LZ4_decompress_safe indicates a " + "failure trying to decompress the data. See exit code " + "(echo $?) for value returned."; + } + if (decompressedSize >= 0) { + LOG(INFO) << "We successfully decompressed some data!"; + } + // if (decompressedSize != data.size()) { + // LOG(INFO) << "Decompressed data is different from original! \n"; + // } + data = std::string(decompressBuffer, decompressedSize); + delete[] decompressBuffer; + + std::vector> tokenList; + std::vector dataList; + std::vector dataSizeList; + std::vector timestampList; + std::vector> subTreeTokenList; + std::vector subTreeDataList; + std::vector subTreeDataSizeList; + std::istringstream iss(data); + std::string line; + bool isMainTree = true; + + while (std::getline(iss, line)) { + if (!line.empty() && line[0] == '\t') { + isMainTree = false; + line.pop_back(); + continue; + } + LOG(INFO) << "data line:" << line << std::endl; + std::istringstream lineStream(line); + std::string tokenListPart, timestampPart, dataPart; + + if (!std::getline(lineStream, tokenListPart, '|')) { + throw std::runtime_error( + "Invalid serialized string format in token list part."); + } + if (isMainTree) { + if (!std::getline(lineStream, timestampPart, '|')) { + throw std::runtime_error( + "Invalid serialized string format in timestamp part."); + } + } + if (!std::getline(lineStream, dataPart)) { + LOG(INFO) << "data length is 0"; + } + + std::istringstream keyStream(tokenListPart); + std::string token; + std::vector keys; + while (std::getline(keyStream, token, ',')) { + keys.push_back(std::stoi(token)); + } + + uint64_t timestamp; + if (isMainTree) { + std::istringstream timestampStream(timestampPart); + if (!(timestampStream >> std::hex >> timestamp)) { + LOG(INFO) << "Invalid timestamp format."; + throw std::runtime_error("Invalid timestamp format."); + } + } + + size_t dataSize = dataPart.length() / + 2; // Each byte is represented by two hex characters + if (isMainTree) { + dataSizeList.push_back(dataSize); + } else { + subTreeDataSizeList.push_back(dataSize); + } + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + char* data = nullptr; + LOG(INFO) << "data size:" << dataSize; + if (dataSize != 0) { + data = new char[dataSize]; + std::istringstream dataStream(dataPart); + for (size_t i = 0; i < dataSize; ++i) { + // Temporary buffer to store two hexadecimal chars + null + char hex[3] = {}; + // Read two characters for one byte + if (!dataStream.read(hex, 2)) { + delete[] data; + LOG(INFO) << "Invalid data format."; + throw std::runtime_error("Invalid data format."); + } + // Convert the two hex characters to one byte + unsigned int byte; + std::istringstream hexStream(hex); + if (!(hexStream >> std::hex >> byte)) { + delete[] data; + LOG(INFO) << "Invalid data format."; + throw std::runtime_error("Invalid data format."); + } + reinterpret_cast(data)[i] = + static_cast(byte); + } + } + if (isMainTree) { + tokenList.push_back(keys); + timestampList.push_back(timestamp); + dataList.push_back(data); + } else { + subTreeTokenList.push_back(keys); + subTreeDataList.push_back(data); + } + } + + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + std::shared_ptr radixTree = + std::make_shared(cacheCapacity); + radixTree->nodeCount = tokenList.size(); + + raxShow(radixTree->tree); + for (size_t i = 0; i < tokenList.size(); i++) { + std::string token_str = ""; + for (size_t j = 0; j < tokenList[i].size(); j++) { + token_str += std::to_string(tokenList[i][j]); + } + LOG(INFO) << "token:" << token_str; + int* insertTokensArray = tokenList[i].data(); + size_t insertTokensArrayLen = tokenList[i].size(); + DataWrapper* data = new DataWrapper(); + data->data = dataList[i]; + data->dataLength = dataSizeList[i]; + raxNode* dataNode = NULL; + + // TBD + // check retval + raxInsertAndReturnDataNode(radixTree->tree, insertTokensArray, + insertTokensArrayLen, data, (void**) &dataNode, + NULL); + + if (dataNode == NULL) { + throw std::runtime_error("Insert token list failed"); + } + dataNode->timestamp = timestampList[i]; + } + LOG(INFO) << "start to insert sub tree token list" << std::endl; + for (size_t i = 0; i < subTreeTokenList.size(); i++) { + for (size_t j = 0; j < subTreeTokenList[i].size(); j++) { + LOG(INFO) << subTreeTokenList[i][j]; + } + + raxNode* node = nullptr; + LOG(INFO) << "stage 1"; + VINEYARD_ASSERT(radixTree->tree != nullptr); + + // TBD refator this code. + node = raxFindAndReturnDataNode(radixTree->tree, subTreeTokenList[i].data(), + subTreeTokenList[i].size(), NULL, false); + VINEYARD_ASSERT(node != nullptr); + LOG(INFO) << "stage 2"; + DataWrapper* data = new DataWrapper(); + data->data = subTreeDataList[i]; + LOG(INFO) << subTreeDataList[i]; + data->dataLength = subTreeDataSizeList[i]; + + LOG(INFO) << "stage 3"; + node->issubtree = true; + raxSetCustomData(node, data); + + // TBD + // refactor this code. + radixTree->subTreeDataSet.insert(data); + } + LOG(INFO) << "Deserialize success"; + return radixTree; +} + +std::vector> RadixTree::SplitInternal( + std::vector tokens, std::shared_ptr& header) { + std::vector rootToken; + DataWrapper* dummyData = new DataWrapper(); + raxNode* subTreeRootNode = + raxSplit(this->tree, tokens.data(), tokens.size(), dummyData, rootToken); + + raxShow(this->tree); + subTreeRootNode->issubtree = true; + DataWrapper* treeData = new DataWrapper(); + treeData->data = nullptr; + treeData->dataLength = 0; + subTreeRootNode->custom_data = treeData; + header = std::make_shared( + (DataWrapper*) raxGetData(subTreeRootNode), treeData); + return TraverseTreeWithoutSubTree(subTreeRootNode); +} + +// Get child node list from this tree. +std::vector> RadixTree::TraverseTreeWithoutSubTree( + raxNode* headNode) { + std::vector> nodes; + if (headNode == NULL) { + LOG(INFO) << "traverse failed"; + return nodes; + } + + std::vector dataNodeList; + std::vector pre_tmp; + raxTraverseSubTree(headNode, dataNodeList); + LOG(INFO) << "data node list:" << dataNodeList.size(); + for (size_t i = 0; i < dataNodeList.size(); i++) { + nodes.push_back(std::make_shared( + (DataWrapper*) raxGetData(dataNodeList[i]), + (DataWrapper*) dataNodeList[i]->custom_data)); + } + return nodes; +} + +void RadixTree::SetSubtreeData(void* data, int dataLength) { + LOG(INFO) << "set subtree data"; + DataWrapper* dataWrapper = new DataWrapper(); + dataWrapper->data = data; + dataWrapper->dataLength = dataLength; + subTreeDataSet.insert(dataWrapper); +} + +std::shared_ptr RadixTree::GetRootNode() { + raxNode* node = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL); + return std::make_shared((DataWrapper*) raxGetData(node), + (DataWrapper*) node->custom_data); +} \ No newline at end of file diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index 0c60eb47..bcea8e67 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -20,627 +20,83 @@ limitations under the License. #include "common/util/base64.h" #include "common/util/logging.h" -#include "kv-state-cache/strategy/LRU_strategy.h" #include "lz4.h" #include #include #include +#include #include using namespace vineyard; -typedef struct customData { - int data_length; +struct DataWrapper { void* data; -} customData; - -typedef struct nodeData { - int data_length; - void* data; - //std::shared_ptr cache_node; -} nodeData; - -class Node { - private: - nodeData* data; - raxNode* node; - - public: - Node(raxNode* node) { - this->data = (nodeData*) raxGetData(node); - this->node = node; - } - - Node(nodeData* data) { - this->data = data; - this->node = NULL; - } - - void set_data(void* data, int data_length) { - if (this->node == NULL) { - LOG(INFO) << "set data failed, node is null"; - return; - } - this->data->data = data; - this->data->data_length = data_length; - raxSetData(this->node, this->data); - } - - //void set_cache_node(std::shared_ptr cache_node) { - // if (this->node == NULL) { - // LOG(INFO) << "set data failed, node is null"; - // return; - // } - // this->data->cache_node = cache_node; - // raxSetData(this->node, this->data); - //} - - void* get_data() { return this->data->data; } - - int get_data_length() { return this->data->data_length; } - - //std::shared_ptr get_cache_node() { - // return this->data->cache_node; - //} + int dataLength; }; -class RadixTree; +struct NodeData { + DataWrapper* nodeData; + DataWrapper* treeData; -class NodeWithTreeAttri { - private: - std::shared_ptr node; - std::shared_ptr belong_to; - - public: - NodeWithTreeAttri(std::shared_ptr node, - std::shared_ptr belong_to) { - this->node = node; - this->belong_to = belong_to; + NodeData(DataWrapper* nodeData, DataWrapper* treeData) { + this->nodeData = nodeData; + this->treeData = treeData; } - - std::shared_ptr get_node() { return node; } - - std::shared_ptr get_tree() { return belong_to; } }; class RadixTree : public std::enable_shared_from_this { - private: - // the whole radix tree for prefix match - rax* tree; - // the sub tree for mapping a vineyard object - // rax* sub_tree; - LRUStrategy* lru_strategy; - public: - RadixTree(int cache_capacity) { - LOG(INFO) << "init radix tree"; - this->tree = raxNew(); - this->tree->head->issubtree = true; - lru_strategy = new LRUStrategy(cache_capacity); - } - - RadixTree(rax* rax_tree, int cache_capacity) { - LOG(INFO) << "init radix tree"; - this->tree = rax_tree; - // this->sub_tree = this->tree; - this->tree->head->issubtree = true; - lru_strategy = new LRUStrategy(cache_capacity); - } - - RadixTree(void* custom_data, int custom_data_length, int cache_capacity) { - LOG(INFO) << "init radix tree with custom data"; - this->tree = raxNew(); - this->tree->head->issubtree = true; - customData* custom_data_struct = new customData(); - custom_data_struct->data = custom_data; - custom_data_struct->data_length = custom_data_length; - raxSetCustomData(this->tree->head, custom_data_struct); - this->lru_strategy = new LRUStrategy(cache_capacity); - } - - ~RadixTree() { - // raxFreeWithCallback(this->tree, [](raxNode *n) { - // if (n->iskey && !n->isnull) { - // nodeData* nodedata = (nodeData*) raxGetData(n); - // delete nodedata; - // } - // if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { - // customData* customdata = (customData*) raxGetCustomData(n); - // delete customdata; - // } - // }); - } - - std::shared_ptr Insert( - std::vector tokens, - std::shared_ptr evicted_node) { - // insert the token vector to the radix tree - int* insert_tokens_array = tokens.data(); - size_t insert_tokens_array_len = tokens.size(); - nodeData* dummy_data = new nodeData(); - nodeData* old_data; - raxNode* dataNode = NULL; - int retval = raxInsertAndReturnDataNode( - this->tree, insert_tokens_array, insert_tokens_array_len, dummy_data, - (void**) &dataNode, (void**) &old_data); - if (dataNode == NULL) { - throw std::runtime_error("Insert token list failed"); - return NULL; - } - LOG(INFO) << "insert success"; - - //if (retval == 0) { - // (retval == 0 ) means the token vector already exists in the radix tree - // remove the token vector from the lru cache as it will be inserted again - // std::shared_ptr node = std::make_shared(old_data); - // LOG(INFO) << "delete cache_node"; - // std::shared_ptr cache_node = node->get_cache_node(); - // LOG(INFO) << "delete cache_node 1"; - // lru_strategy->Remove(cache_node); - // LOG(INFO) << "delete cache_node 2"; - // delete old_data; - //} - //LOG(INFO) << "delete cache_node success"; - - // refresh the lru cache - //std::vector evicted_tokens; - //std::shared_ptr cache_node = - // lru_strategy->InsertToHeader(tokens, evicted_tokens); - //if (cache_node == nullptr) { - // LOG(INFO) << "WTF?"; - //} - //dummy_data->cache_node = cache_node; - raxSetData(dataNode, dummy_data); - //if (evicted_tokens.size() > 0) { - // this->Delete(evicted_tokens, evicted_node); - //} - //LOG(INFO) << "refresh cache_node success"; - - return std::make_shared(std::make_shared(dataNode), - shared_from_this()); - } - - void Delete(std::vector tokens, - std::shared_ptr& evicted_node) { - // remove the token vector from the radix tree - int* delete_tokens_array = tokens.data(); - size_t delete_tokens_array_len = tokens.size(); - - nodeData* old_data; - int retval = raxRemove(this->tree, delete_tokens_array, - delete_tokens_array_len, (void**) &old_data); - if (retval == 1) { - LOG(INFO) << "remove success"; - std::shared_ptr node = std::make_shared(old_data); - evicted_node = - std::make_shared(node, shared_from_this()); - delete old_data; - } else { - LOG(INFO) << "remove failed"; - } - } - - std::shared_ptr Query(std::vector key) { - LOG(INFO) << "Query"; - int* tokens = key.data(); - size_t tokens_len = key.size(); - - LOG(INFO) << "Query with tokens_len:" << tokens_len; - if (this->tree == nullptr) { - LOG(INFO) << "WTF!"; - return NULL; - } - - raxNode* dataNode = - raxFindAndReturnDataNode(this->tree, tokens, tokens_len); - if (dataNode == NULL) { - LOG(INFO) << "get failed"; - return NULL; - } - LOG(INFO) << "get success"; - - // refresh the lru cache - std::shared_ptr node = std::make_shared(dataNode); - //std::shared_ptr cache_node = node->get_cache_node(); - //lru_strategy->MoveToHead(cache_node); - - return std::make_shared(node, shared_from_this()); - } - - std::string Serialize() { - LOG(INFO) << "Serialize......"; - raxShow(this->tree); - std::vector> token_list; - std::vector data_list; - std::vector timestamp_list; - std::vector> sub_tree_token_list; - std::vector sub_tree_data_list; - raxSerialize(this->tree, token_list, data_list, timestamp_list, &sub_tree_token_list, - &sub_tree_data_list); - - raxShow(this->tree); - //std::map, bool> cache_node_map; - //std::shared_ptr current_node = - // this->lru_strategy->GetHeader(); - - // the string format is: - // [token list]|[timestamp]|[data hex string]\n - // ... - // [token list]|[timestamp]|[data hex string]\n - // \t\n - // [subtree token list]|[timestamp]|[custom data string]\n - // ... - // [subtree token list]|[timestamp]|[custom data string]\n - // E.g - // tokens | data - // 1|0000000001|0800000008000000xxxx\n - // 1,2|0000000002|0800000008000000xxxx\n - // 1,2,3|0000000002|0800000008000000xxxx\n - // \t\n - // 1,2|0000000003|0800000008000000xxxx\n - std::string serialized_str; - /* - while (current_node != nullptr) { - cache_node_map[current_node] = true; - auto it = std::lower_bound(token_list.begin(), token_list.end(), - current_node->tokens); - - if (it != token_list.end() && *it == current_node->tokens) { - // get the index of the token vector via binary search - int index = std::distance(token_list.begin(), it); - for (size_t i = 0; i < (*it).size(); i++) { - serialized_str += std::to_string((*it)[i]); - if (i < (*it).size() - 1) { - serialized_str += ","; - } - } - // serialized_str += "|" + std::to_string(index) + "|"; - serialized_str += "|"; - - // convert data to hex string - char* bytes = (char*) ((nodeData*) data_list[index])->data; - std::ostringstream oss; - - for (int i = 0; i < ((nodeData*) data_list[index])->data_length; ++i) { - oss << bytes[i]; - } - serialized_str += oss.str() + "\n"; - } else { - throw std::runtime_error("The token vector is not in the radix tree"); - } - current_node = current_node->next; - } - */ - - if (token_list.size() != data_list.size()) { - throw std::runtime_error("The size of token list and data list is not equal"); - } - for (size_t index = 0; index < token_list.size(); index++) { - for (size_t j = 0; j < token_list[index].size(); j++) { - serialized_str += std::to_string(token_list[index][j]); - if (j < token_list[index].size() - 1) { - serialized_str += ","; - } - } - serialized_str += "|"; - - // convert timestamp(uint64) to hex string - uint64_t timestamp = timestamp_list[index]; - std::ostringstream timestamp_oss; - timestamp_oss << std::hex << timestamp; - - serialized_str += timestamp_oss.str() + "|"; - - // convert data to hex string - char* bytes = (char*) ((nodeData*) data_list[index])->data; - std::ostringstream data_oss; - - for (size_t i = 0; i < ((nodeData*)data_list[index])->data_length; i++) { - data_oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); - } - serialized_str += data_oss.str() + "\n"; - } - - serialized_str += "\t\n"; - - LOG(INFO) << "sub tree token list size:" << sub_tree_token_list.size(); - for (size_t index = 0; index < sub_tree_token_list.size(); index++) { - for (size_t j = 0; j < sub_tree_token_list[index].size(); j++) { - serialized_str += std::to_string(sub_tree_token_list[index][j]); - if (j < sub_tree_token_list[index].size() - 1) { - serialized_str += ","; - } - } - serialized_str += "|"; - // convert custom data to hex string - char* bytes = (char*) ((customData*) sub_tree_data_list[index])->data; - std::ostringstream data_oss; - - LOG(INFO) << "data length:" << ((customData*)sub_tree_data_list[index])->data_length; - for (size_t i = 0; i < ((customData*)sub_tree_data_list[index])->data_length; ++i) { - data_oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); - } - LOG(INFO) << "data:" << ((customData*)sub_tree_data_list[index])->data; - LOG(INFO) << "data oss:" << data_oss.str(); - serialized_str += data_oss.str() + "\n"; - } - LOG(INFO) << "serialized_str:" << serialized_str; - - // use LZ4 to compress the serialized string - const char* const src = serialized_str.c_str(); - const int src_size = serialized_str.size(); - const int max_dst_size = LZ4_compressBound(src_size); - char* compressed_data = new char[max_dst_size]; - if (compressed_data == NULL) { - LOG(INFO) << "Failed to allocate memory for *compressed_data."; - } - - const int compressed_data_size = - LZ4_compress_default(src, compressed_data, src_size, max_dst_size); - if (compressed_data_size <= 0) { - LOG(INFO) << "A 0 or negative result from LZ4_compress_default() " - "indicates a failure trying to compress the data. "; - } - - if (compressed_data_size > 0) { - LOG(INFO) << "We successfully compressed some data! Ratio: " - << ((float) compressed_data_size / src_size); - } - - // compressed_data = - // (char*) realloc(compressed_data, (size_t) compressed_data_size); - if (compressed_data == NULL) { - LOG(INFO) << "Failed to re-alloc memory for compressed_data. Sad :("; - } - - std::string compressed_str = - std::string(compressed_data, compressed_data_size); - std::string result = - std::string((char*) &src_size, sizeof(int)) + compressed_str; - delete[] compressed_data; - return result; - } - - static std::shared_ptr Deserialize(std::string data) { - LOG(INFO) << "Deserialize......"; - // use LZ4 to decompress the serialized string - int src_size = *(int*) data.c_str(); - data.erase(0, sizeof(int)); - char* const decompress_buffer = new char[src_size]; - if (decompress_buffer == NULL) { - LOG(INFO) << "Failed to allocate memory for *decompress_buffer."; - } - - const int decompressed_size = LZ4_decompress_safe( - data.c_str(), decompress_buffer, data.size(), src_size); - if (decompressed_size < 0) { - LOG(INFO) << "A negative result from LZ4_decompress_safe indicates a " - "failure trying to decompress the data. See exit code " - "(echo $?) for value returned."; - } - if (decompressed_size >= 0) { - LOG(INFO) << "We successfully decompressed some data!"; - } - // if (decompressed_size != data.size()) { - // LOG(INFO) << "Decompressed data is different from original! \n"; - // } - data = std::string(decompress_buffer, decompressed_size); - delete[] decompress_buffer; + rax* tree; + int cacheCapacity; + int nodeCount; + std::set subTreeDataSet; + std::vector rootToken; - std::vector> token_list; - std::vector data_list; - std::vector data_size_list; - std::vector timestamp_list; - std::vector> sub_tree_token_list; - std::vector sub_tree_data_list; - std::vector sub_tree_data_size_list; - std::istringstream iss(data); - std::string line; - bool isMainTree = true; + private: + std::shared_ptr InsertInternal( + std::vector tokens, std::shared_ptr& evictedNode); - while (std::getline(iss, line)) { - if (!line.empty() && line[0] == '\t') { - isMainTree = false; - line.pop_back(); - continue; - } - LOG(INFO) << "data line:" << line << std::endl; - std::istringstream lineStream(line); - std::string tokenListPart, timestampPart, dataPart; + void DeleteInternal(std::vector tokens, + std::shared_ptr& evictedNode); - if (!std::getline(lineStream, tokenListPart, '|')) { - throw std::runtime_error( - "Invalid serialized string format in token list part."); - } - if (isMainTree) { - if (!std::getline(lineStream, timestampPart, '|')) { - throw std::runtime_error( - "Invalid serialized string format in timestamp part."); - } - } - if (!std::getline(lineStream, dataPart)) { - throw std::runtime_error( - "Invalid serialized string format in data part."); - } + std::shared_ptr QueryInternal(std::vector key); - std::istringstream keyStream(tokenListPart); - std::string token; - std::vector keys; - while (std::getline(keyStream, token, ',')) { - keys.push_back(std::stoi(token)); - } + std::vector> SplitInternal( + std::vector tokens, std::shared_ptr& header); - uint64_t timestamp; - if (isMainTree) { - std::istringstream timestampStream(timestampPart); - if (!(timestampStream >> std::hex >> timestamp)) { - LOG(INFO) << "Invalid timestamp format."; - throw std::runtime_error("Invalid timestamp format."); - } - } + public: + RadixTree(int cacheCapacity); + ~RadixTree(); - size_t dataSize = dataPart.length() / 2; // Each byte is represented by two hex characters - if (isMainTree) { - data_size_list.push_back(dataSize); - } else { - sub_tree_data_size_list.push_back(dataSize); - } - // This pointer will be freed by upper layer. Because this data - // is created by upper layer. Here just recover it from serialized - // string. - char* data = new char[dataSize]; - LOG(INFO) << "data size:" << dataSize; - std::istringstream dataStream(dataPart); - for (size_t i = 0; i < dataSize; ++i) { - // Temporary buffer to store two hexadecimal chars + null - char hex[3] = {}; - // Read two characters for one byte - if (!dataStream.read(hex, 2)) { - delete[] data; - LOG(INFO) << "Invalid data format."; - throw std::runtime_error("Invalid data format."); - } - // Convert the two hex characters to one byte - unsigned int byte; - std::istringstream hexStream(hex); - if (!(hexStream >> std::hex >> byte)) { - delete[] data; - LOG(INFO) << "Invalid data format."; - throw std::runtime_error("Invalid data format."); - } - reinterpret_cast(data)[i] = static_cast(byte); - } - if (isMainTree) { - token_list.push_back(keys); - timestamp_list.push_back(timestamp); - data_list.push_back(data); - } else { - sub_tree_token_list.push_back(keys); - sub_tree_data_list.push_back(data); - } - } + std::shared_ptr Insert(std::vector tokens, + std::shared_ptr& evictedNode); - // This pointer will be freed by upper layer. Because this data - // is created by upper layer. Here just recover it from serialized - // string. - std::shared_ptr radix_tree = std::make_shared(10); - // nodeData* dummy_data = new nodeData(); - // rax *root = raxNew(); - // for (int i = token_list.size()-1; i >= 0; i--) { - // LOG(INFO) << "insert token list:"; - // for (int j = 0; j < token_list[i].size(); j++) { - // LOG(INFO) << token_list[i][j]; - // } - // if (raxInsert(root, token_list[i].data(), token_list[i].size(), - // dummy_data, NULL) != 1) { - // LOG(INFO) << "Insert failed"; - // return NULL; - // } - // std::vector evicted_tokens; - // std::shared_ptr cache_node = - // radix_tree->lru_strategy->InsertToHeader(token_list[i], - // evicted_tokens); - // if (cache_node == nullptr) { - // LOG(INFO) << "WTF?"; - // } - // dummy_data->cache_node = cache_node; - // } + void Delete(std::vector tokens, std::shared_ptr& evictedNode); + std::shared_ptr Query(std::vector key); - for (int i = 0; i < token_list.size(); i++) { - int* insert_tokens_array = token_list[i].data(); - size_t insert_tokens_array_len = token_list[i].size(); - nodeData* data = new nodeData(); - raxNode* dataNode = NULL; - int retval = raxInsertAndReturnDataNode( - radix_tree->tree, insert_tokens_array, insert_tokens_array_len, data, - (void**) &dataNode, NULL); - if (dataNode == NULL) { - throw std::runtime_error("Insert token list failed"); - } - dataNode->timestamp = timestamp_list[i]; - std::shared_ptr node = std::make_shared(std::make_shared(dataNode), - radix_tree); - node->get_node()->set_data(data_list[i], data_size_list[i]); - } - LOG(INFO) << "start to insert sub tree token list" << std::endl; - for (int i = 0; i < sub_tree_token_list.size(); i++) { - for (int j = 0; j < sub_tree_token_list[i].size(); j++) { - LOG(INFO) << sub_tree_token_list[i][j]; - } + std::vector> Split( + std::vector tokens, std::shared_ptr& header); - raxNode* node = nullptr; - LOG(INFO) << "stage 1"; - VINEYARD_ASSERT(radix_tree->tree != nullptr); - raxFindNode(radix_tree->tree, sub_tree_token_list[i].data(), - sub_tree_token_list[i].size(), (void **)&node); - VINEYARD_ASSERT(node != nullptr); - LOG(INFO) << "stage 2"; - customData* data = new customData(); - data->data = sub_tree_data_list[i]; - data->data_length = sub_tree_data_size_list[i]; + std::string Serialize(); - LOG(INFO) << "stage 3"; - node->issubtree = true; - raxSetCustomData(node, data); - } - LOG(INFO) << "Deserialize success"; - return radix_tree; - } + static std::shared_ptr Deserialize(std::string data); - std::shared_ptr Split(std::vector tokens) { - nodeData* dummy_data = new nodeData(); - raxNode* sub_tree_root_node = - raxSplit(this->tree, tokens.data(), tokens.size(), dummy_data); + // Get child node list from this tree. + static std::vector> TraverseTreeWithoutSubTree( + raxNode* headNode); - // TBD - // if the sub_tree is null, delete this pointer. - rax* sub_rax = raxNew(); - sub_rax->head = sub_tree_root_node; - std::shared_ptr sub_tree = - std::make_shared(sub_rax, this->lru_strategy->GetCapacity()); - return sub_tree; - } + void SetSubtreeData(void* data, int dataLength); - // Get child node list from this tree. - static std::vector> - TraverseTreeWithoutSubTree(std::shared_ptr radix_tree) { - std::vector> nodes; - if (radix_tree == NULL) { - LOG(INFO) << "traverse failed"; - return nodes; - } + rax* GetRootTree() { return this->tree; } - std::vector dataNodeList; - raxNode* headNode = radix_tree->tree->head; - raxTraverseSubTree(headNode, dataNodeList); - for (size_t i = 0; i < dataNodeList.size(); i++) { - nodes.push_back(std::make_shared( - std::make_shared(dataNodeList[i]), radix_tree)); - } - return nodes; - } + int GetCacheCapacity() { return cacheCapacity; } - rax* GetTree() {return this->tree;} - void* GetCustomData() { - LOG(INFO) << "tree:" << this->tree << " tree node:" << this->tree->head; - VINEYARD_ASSERT(tree->head->custom_data != nullptr); - LOG(INFO) << "custom data:" << ((customData *)tree->head->custom_data)->data; - return ((customData *)tree->head->custom_data)->data; - } + std::set GetSubTreeDataSet() { return subTreeDataSet; } - void SetCustomData(void* custom_data, int custom_data_length) { - customData* data = new customData(); - data->data = custom_data; - LOG(INFO) << "custom data:" << data->data; - data->data_length = custom_data_length; - LOG(INFO) << "custom data length:" << data->data_length; - LOG(INFO) << "tree:" << this->tree << " tree node:" << this->tree->head; - raxSetCustomData(this->tree->head, data); - } + std::shared_ptr GetRootNode(); }; #endif diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 87067fbc..428fa819 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -36,6 +36,8 @@ #include #include #include +#include + #include "radix.h" #ifndef RAX_MALLOC_INCLUDE @@ -47,11 +49,6 @@ #include #include "common/util/logging.h" using namespace vineyard; -typedef struct nodeData1 { - int data_length; - void* data; - void* cache_node; -} nodeData1; /* This is a special pointer that is guaranteed to never have the same value * of a radix tree node. It's used in order to report "not found" error without @@ -157,7 +154,7 @@ static inline void raxStackFree(raxStack *ts) { /* Add the number of nodes in the stack to each node. */ void raxStackAddNumNodes(raxStack *stack, int num) { - for (int i=0; iitems; i++) { + for (size_t i=0; iitems; i++) { raxNode *node = (raxNode *)stack->stack[i]; node->numnodes+=(num); } @@ -216,6 +213,7 @@ raxNode *raxNewNode(size_t children, int datafield) { node->timestamp = 0; node->numnodes = 1; node->size = children; + node->timestamp = 0; return node; } @@ -298,16 +296,9 @@ raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlin n->size--; /* For now restore the orignal size. We'll update it only on success at the end. */ - // store the extra data pointer of subtree - void *customData; - bool isSubtree = false; - if (n->issubtree) { - isSubtree = true; - customData = raxGetCustomData(n); - } - /* Alloc the new child we will link to 'n'. */ raxNode *child = raxNewNode(0,0); + child->timestamp = n->timestamp; if (child == NULL) return NULL; int parent_numnodes = n->numnodes; @@ -449,6 +440,7 @@ raxNode *raxCompressNode(raxNode *n, int *s, size_t len, raxNode **child) { /* Allocate the child to link to this node. */ *child = raxNewNode(0,0); + (*child)->timestamp = n->timestamp; if (*child == NULL) return NULL; /* Make space in the parent node. */ @@ -504,12 +496,21 @@ raxNode *raxCompressNode(raxNode *n, int *s, size_t len, raxNode **child) { * means that the current node represents the key (that is, none of the * compressed node tokens list are needed to represent the key, just all * its parents nodes). */ -static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts) { +static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts, bool set_timestamp = true) { raxNode *h = rax->head; raxNode **parentlink = &rax->head; size_t i = 0; /* Position in the string. */ size_t j = 0; /* Position in the node children (or bytes if compressed).*/ + + // std::chrono::milliseconds ms = std::chrono::duration_cast< std::chrono::milliseconds >( + // std::chrono::system_clock::now().time_since_epoch()); + // int64_t timestamp = ms.count(); + auto now = std::chrono::high_resolution_clock::now(); + + auto micros = std::chrono::time_point_cast(now).time_since_epoch().count(); + int64_t timestamp = micros; + while(h->size && i < len) { debugnode("Lookup current node",h); int *v = h->data; @@ -538,6 +539,9 @@ static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **st i++; } + /* Save timestamp. */ + h->timestamp = timestamp; + if (ts) raxStackPush(ts,h); /* Save stack of parent nodes. */ raxNode **children = raxNodeFirstChildPtr(h); if (h->iscompr) j = 0; /* Compressed node only child is at index 0. */ @@ -548,6 +552,9 @@ static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **st position to 0 to signal this node represents the searched key. */ } + if (set_timestamp) { + h->timestamp = timestamp; + } debugnode("Lookup stop node is",h); if (stopnode) *stopnode = h; if (plink) *plink = parentlink; @@ -578,7 +585,7 @@ int handleOutOfMemory(rax *rax, raxNode *h, int *s, size_t len, void **old){ * function returns 0 as well but sets errno to ENOMEM, otherwise errno will * be set to 0. */ -int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int overwrite, void **dataNode) { +int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int overwrite, void **dataNode, bool set_timestamp = true) { size_t i; int j = 0; /* Split position. If raxLowWalk() stops in a compressed node, the index 'j' represents the char we stopped within the @@ -591,7 +598,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o raxStackInit(&splitStack); int all_added_node = 0; - i = raxLowWalk(rax,s,len,&h,&parentlink,&j,&lowWalkStack); + i = raxLowWalk(rax,s,len,&h,&parentlink,&j,&lowWalkStack, set_timestamp); debugf("######## after raxLowWalk##########"); /* If i == len we walked following the whole string. If we are not * in the middle of a compressed node, the string is either already @@ -779,6 +786,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o /* 2: Create the split node. Also allocate the other nodes we'll need * ASAP, so that it will be simpler to handle OOM. */ raxNode *splitnode = raxNewNode(1, split_node_is_key); + splitnode->timestamp = h->timestamp; raxNode *trimmed = NULL; raxNode *postfix = NULL; @@ -825,6 +833,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o trimmed->numnodes = h->numnodes; trimmed->iskey = h->iskey; trimmed->isnull = h->isnull; + trimmed->timestamp = h->timestamp; if (h->iskey && !h->isnull) { void *ndata = raxGetData(h); raxSetData(trimmed,ndata); @@ -912,6 +921,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o postfix->numnodes = h->numnodes; postfix->iskey = 1; postfix->isnull = 0; + postfix->timestamp = h->timestamp; memcpy(postfix->data,h->data+j,postfixlen*sizeof(int)); raxSetData(postfix,data); *dataNode = postfix; @@ -926,6 +936,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o trimmed->numnodes = h->numnodes+1; trimmed->iskey = 0; trimmed->isnull = 0; + trimmed->timestamp = h->timestamp; memcpy(trimmed->data,h->data,j*sizeof(int)); memcpy(parentlink,&trimmed,sizeof(trimmed)); if (h->iskey) { @@ -947,7 +958,6 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o return 1; /* Key inserted. */ } - LOG(INFO) << "custom2:" << h->custom_data; raxNode *prev_node = NULL; int insert_new_node = 0; /* We walked the radix tree as far as we could, but still there are left @@ -999,7 +1009,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o raxStackFree(&lowWalkStack); raxStackFree(&splitStack); raxNode *newh = raxReallocForData(h,data); - printf("#############raxReallocForData2 ############\n"); + // printf("#############raxReallocForData2 ############\n"); if (newh == NULL) { return handleOutOfMemory(rax, h, (int *)s, i, old); } @@ -1013,9 +1023,9 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o /* Overwriting insert. Just a wrapper for raxGenericInsert() that will * update the element if there is already one for the same key. */ -int raxInsert(rax *rax, int *s, size_t len, void *data, void **old) { +int raxInsert(rax *rax, int *s, size_t len, void *data, void **old, bool set_timestamp) { void *dataNode = NULL; - return raxGenericInsert(rax,s,len,data,old,1,&dataNode); + return raxGenericInsert(rax,s,len,data,old,1,&dataNode, set_timestamp); } /* Non overwriting insert function: this if an element with the same key @@ -1066,14 +1076,24 @@ raxStack raxFindWithStack(rax *rax, int *s, size_t len) { /* ** Find a key in the rax, returns the raxNode that contains the key. */ -raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len) { +raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len, raxNode** sub_tree_node, bool set_timestamp) { raxNode *h; + raxStack ts; + raxStackInit(&ts); //debugf("### Lookup: %.*s\n", (int)len, s); int splitpos = 0; - size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL); + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts,set_timestamp); if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) return NULL; + raxNode *tmp = h; + while(tmp != nullptr && tmp->issubtree == false) { + tmp = (raxNode *)raxStackPop(&ts); + } + if (tmp != nullptr && sub_tree_node != nullptr) { + *sub_tree_node = tmp; + } + return h; } @@ -1149,12 +1169,6 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { * data if the current node is the root node of subtree * */ - void *customData; - bool isSubtree = false; - if (parent->issubtree) { - isSubtree = true; - customData = raxGetCustomData(parent); - } /* Otherwise we need to scan for the child pointer and memmove() * accordingly. @@ -1213,7 +1227,7 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { /* Remove the specified item. Returns 1 if the item was found and * deleted, 0 otherwise. */ -int raxRemove(rax *rax, int *s, size_t len, void **old) { +int raxRemove(rax *rax, int *s, size_t len, void **old, raxNode** sub_tree_node, bool set_timestamp) { raxNode *h; raxStack ts; @@ -1221,11 +1235,21 @@ int raxRemove(rax *rax, int *s, size_t len, void **old) { raxStackInit(&ts); int splitpos = 0; int all_added_node = 0; - size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts); + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts, set_timestamp); if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { raxStackFree(&ts); return 0; } + if (sub_tree_node != NULL) { + for (int i = ts.items - 1; i >= 0; i--) { + if (((raxNode *)ts.stack[i])->issubtree == true) { + *sub_tree_node = (raxNode *)ts.stack[i]; + break; + } + } + } + + if (old) *old = raxGetData(h); h->iskey = 0; rax->numele--; @@ -1380,6 +1404,7 @@ int raxRemove(rax *rax, int *s, size_t len, void **old) { newNode->iscompr = 1; newNode->size = comprsize; newNode->numnodes = h->numnodes+1; + newNode->timestamp = h->timestamp; all_added_node++; rax->numnodes++; @@ -1553,9 +1578,12 @@ int raxIteratorNextStep(raxIterator *it, int noup) { it->subtree_data_list != NULL) { std::cout << "first find subtree list is:" << std::endl; std::vector token; - for (int i = 0; i < it->key_len; i++) { + std::string token_str; + for (size_t i = 0; i < it->key_len - 1; i++) { token.push_back(it->key[i]); + token_str += std::to_string(it->key[i]) + " "; } + LOG(INFO) << "list is:" << token_str; (*it->subtree_list).push_back(token); void *data = raxGetCustomData(it->node); if (data == NULL) { @@ -1591,6 +1619,23 @@ int raxIteratorNextStep(raxIterator *it, int noup) { it->node = orig_node; return 1; } + if (it->node->iskey && it->node->size == 0 && it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + it->subtree_data_list != NULL) { + // data node is sub tree + std::vector token; + std::string token_str; + for (size_t i = 0; i < it->key_len - 1; i++) { + token.push_back(it->key[i]); + token_str += std::to_string(it->key[i]) + " "; + } + LOG(INFO) << "sub tree is:" << token_str; + (*it->subtree_list).push_back(token); + void *data = raxGetCustomData(it->node); + if (data == NULL) { + throw std::runtime_error("custom data is null"); + } + (*it->subtree_data_list).push_back(data); + } /* If there are no children at the current node, try parent's * next child. */ int prevchild = it->key[it->key_len-1]; @@ -1619,20 +1664,20 @@ int raxIteratorNextStep(raxIterator *it, int noup) { debugf("SCAN found a new node\n"); raxIteratorAddToken(it,it->node->data+i,1); if (!raxStackPush(&it->stack,it->node)) return 0; - if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && - it->subtree_data_list != NULL) { - std::cout << "second find subtree list is:" << std::endl; - std::vector token; - for (int i = 0; i < it->key_len; i++) { - token.push_back(it->key[i]); - } - (*it->subtree_list).push_back(token); - void *data = raxGetCustomData(it->node); - if (data == NULL) { - throw std::runtime_error("custom data is null"); - } - (*it->subtree_data_list).push_back(data); - } + // if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // std::cout << "second find subtree list is:" << std::endl; + // std::vector token; + // for (size_t i = 0; i < it->key_len; i++) { + // token.push_back(it->key[i]); + // } + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } memcpy(&it->node,cp,sizeof(it->node)); /* Call the node callback if any, and replace the node * pointer if the callback returns true. */ @@ -2110,7 +2155,7 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { if (n->iskey) { numchars += printf("=%p",raxGetData(n)); } - numchars += printf(" time:%ld, data:%p, is_sub_tree:%d", n->timestamp, n->custom_data, n->issubtree); + numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d", n, n->timestamp, n->custom_data, n->issubtree); int numchildren = n->iscompr ? 1 : n->size; /* Note that 7 and 4 magic constants are the string length @@ -2121,7 +2166,7 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { } raxNode **cp = raxNodeFirstChildPtr(n); for (int i = 0; i < numchildren; i++) { - char *branch = " `-(%d) "; + const char *branch = " `-(%d) "; if (numchildren > 1) { printf("\n"); for (int j = 0; j < lpad; j++) putchar(' '); @@ -2250,7 +2295,7 @@ bool raxIsSubtree(raxNode *node) { * tree from the root node. * */ -raxNode *raxSplit(rax *rax, int *s, size_t len, void *data) { +raxNode *raxSplit(rax *rax, int *s, size_t len, void *data, std::vector& token) { raxNode *childNode = NULL; raxNode *splitNode = NULL; raxStack stack = raxFindWithStack(rax, s, len); @@ -2271,26 +2316,37 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, void *data) { // find the node that has N/2 children while (items > 0) { raxNode *node = (raxNode *)raxStackPop(&stack); - if (node->numnodes >= (uint32_t)subtreeNumNodes/2 || node->issubtree) { + if (node->numnodes > (uint32_t)subtreeNumNodes/2 || node->issubtree) { splitNode = childNode; raxStackPush(&stack, node); break; } - childNode = node; + childNode = node; items--; } + + raxIterator iter; + raxStart(&iter, rax); + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + if (iter.node == splitNode) { + for (size_t i = 0; i < iter.key_len; i++) { + token.push_back(iter.key[i]); + } + } + } + std::string token_str; + for (size_t i = 0; i < token.size(); i++) { + token_str += std::to_string(token[i]); + token_str += " "; + } + LOG(INFO) << "split token: " << token_str; + // if the splitNode is NULL, it means that the tree only has one node if (splitNode == NULL) { return rax->head; } - raxNode *parent = (raxNode *)raxStackPeek(&stack); - raxNode **parentlink; - if (parent == NULL) { - parentlink = &rax->head; - } else { - parentlink = raxFindParentLink(parent,splitNode); - } raxSetSubtree(splitNode); raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); @@ -2331,7 +2387,7 @@ void raxSerialize(rax *root, std::vector> &tokenList, std::vect raxSeek(&iter, "^", NULL, 0); while (raxNext(&iter)) { std::vector token; - for (int i = 0; i < iter.key_len; i++) { + for (size_t i = 0; i < iter.key_len; i++) { token.push_back(iter.key[i]); } tokenList.push_back(token); @@ -2340,3 +2396,478 @@ void raxSerialize(rax *root, std::vector> &tokenList, std::vect } raxStop(&iter); } + +void raxFindLastRecentNode(raxNode *node, std::vector& key) { + raxNode** childList = raxNodeFirstChildPtr(node); + + // node must have a key. + // assert(node->iskey == 1); + int numChildren = node->iscompr ? 1 : node->size; + if (numChildren == 0) { + // has no children, return + return; + } + + raxNode *chossenChild = childList[0]; + int choosenChildIndex = 0; + for (int i = 1; i < numChildren; i++) { + if (childList[i]->timestamp != 0 && childList[i]->timestamp <= chossenChild->timestamp) { + if (childList[i]->timestamp == chossenChild->timestamp && childList[i]->numnodes > chossenChild->numnodes) { + chossenChild = childList[i]; + choosenChildIndex = i; + } + // chossenChild = childList[i]; + // choosenChildIndex = i; + } + } + + if (node->iscompr) { + for (int i = 0; i < node->size; i++) { + key.push_back(node->data[i]); + } + } else { + key.push_back(node->data[choosenChildIndex]); + } + + raxFindLastRecentNode(chossenChild, key); +} + +bool compareKey(int *first_key, int *second_key, int first_key_len, int second_key_len) { + if (first_key_len != second_key_len) { + printf("length not equal, %d : %d\n", first_key_len, second_key_len); + return false; + } + for (int i = 0; i < first_key_len; i++) { + if (first_key[i] != second_key[i]) { + printf("key not equal, %d : %d\n", first_key[i], second_key[i]); + return false; + } + } + return true; +} + +// bool compare(raxNode *a, raxNode *b) { +// return a->timestamp > b->timestamp; +// } + +// void sortNode(raxNode **node, int size) { +// std::sort(node, node + size, compare); +// } + +// void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::map, void*>& insert_tokens, int max_node) { +// raxNode* first_tree_node = first_tree->head; +// raxNode* second_tree_node = second_tree->head; + +// std::queue first_tree_queue; +// std::queue second_tree_queue; + +// first_tree_queue.push(first_tree_node); +// second_tree_queue.push(second_tree_node); + +// rax* tree = raxNew(); + +// int nodeCount = 0; + +// while((!first_tree_queue.empty()) && (!second_tree_queue.empty())) { +// int first_tree_rax_node_list_size = first_tree_queue.size(); +// int second_tree_rax_node_list_size = second_tree_queue.size(); +// raxNode** first_tree_rax_node_list = (raxNode**)malloc(sizeof(raxNode*) * first_tree_rax_node_list_size); +// raxNode** second_tree_rax_node_list = (raxNode**)malloc(sizeof(raxNode*) * second_tree_rax_node_list_size); + +// for (int i = 0; i < first_tree_queue.size(); i++) { +// first_tree_rax_node_list[i] = first_tree_queue.front(); +// first_tree_queue.pop(); +// } + +// for (int i = 0; i < second_tree_queue.size(); i++) { +// second_tree_rax_node_list[i] = second_tree_queue.front(); +// second_tree_queue.pop(); +// } + +// sortNode(first_tree_rax_node_list, first_tree_queue.size()); +// sortNode(second_tree_rax_node_list, second_tree_queue.size()); + +// int first_tree_index = 0; +// int second_tree_index = 0; + +// while(first_tree_index < first_tree_rax_node_list_size && second_tree_index < second_tree_rax_node_list_size && nodeCount < max_node) { +// if (first_tree_rax_node_list[first_tree_index]->timestamp > second_tree_rax_node_list[second_tree_index]->timestamp) { +// // choose first_tree_rax_node_list[first_tree_index] +// if (raxFind(tree, first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->size) == NULL) { +// raxInsert(tree, first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->size, first_tree_rax_node_list[first_tree_index]->data, NULL); +// nodeCount++; +// } else { +// std::vector token = std::vector(first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->data + first_tree_rax_node_list[first_tree_index]->size); +// insert_tokens.erase(token); +// } +// first_tree_index++; +// } +// } +// } +// } + +bool compare(raxIterator a, raxIterator b) { + if (a.key_len == b.key_len) { + return a.node->timestamp > b.node->timestamp; + } + return a.key_len < b.key_len; +} + +void sortNode(std::vector &ite_list) { + std::sort(ite_list.begin(), ite_list.end(), compare); +} + +void printVector(int* v, int size) { + printf("token:\n"); + for (int i = 0; i < size; i++) { + printf("%d " , v[i]); + } + printf("\n"); +} + +void freeVector(std::vector &ite_list) { + for (size_t i = 0; i < ite_list.size(); i++) { + free(ite_list[i].key); + } +} + +void mergeTree(rax* first_tree, rax* second_tree, + std::vector>& evicted_tokens, + std::set>& insert_tokens, int max_node) { + printf("merge tree!\n"); + LOG(INFO) << "==============tree 1===================="; + raxShow(first_tree); + LOG(INFO) << "==============tree 2===================="; + raxShow(second_tree); + raxIterator first_tree_iter; + raxIterator second_tree_iter; + rax* tmp = raxNew(); + + raxStart(&first_tree_iter, first_tree); + raxStart(&second_tree_iter, second_tree); + raxSeek(&first_tree_iter, "^", NULL, 0); + raxSeek(&second_tree_iter, "^", NULL, 0); + + std::vector first_tree_iter_list; + std::vector second_tree_iter_list; + while (raxNext(&first_tree_iter)) { + raxIterator tmp_iter = first_tree_iter; + tmp_iter.key = (int*) malloc(sizeof(int) * first_tree_iter.key_len); + memcpy(tmp_iter.key, first_tree_iter.key, + sizeof(int) * first_tree_iter.key_len); + first_tree_iter_list.push_back(tmp_iter); + } + + while (raxNext(&second_tree_iter)) { + raxIterator tmp_iter = second_tree_iter; + tmp_iter.key = (int*) malloc(sizeof(int) * second_tree_iter.key_len); + memcpy(tmp_iter.key, second_tree_iter.key, + sizeof(int) * second_tree_iter.key_len); + second_tree_iter_list.push_back(tmp_iter); + } + + for (size_t i = 0; i < first_tree_iter_list.size(); i++) { + printVector(first_tree_iter_list[i].key, first_tree_iter_list[i].key_len); + } + + for (size_t i = 0; i < second_tree_iter_list.size(); i++) { + printVector(second_tree_iter_list[i].key, second_tree_iter_list[i].key_len); + } + + // Sort by the length of the key, or timestamp if the keys have the same length. + sortNode(first_tree_iter_list); + sortNode(second_tree_iter_list); + + size_t first_tree_index = 0; + size_t second_tree_index = 0; + int nodeCount = 0; + + /** + * We use two structures to store the nodes choosen from the second tree + * and the nodes evicted from the first tree. + */ + while (nodeCount < max_node) { + if (first_tree_index == first_tree_iter_list.size() || + second_tree_index == second_tree_iter_list.size()) { + break; + } + printf("nodeCount: %d\n", nodeCount); + + /** + * If the key is the same, use the larger timestamp to refresh + * the timestamp of the key in the first tree. If the key is not + * the same, choose the key with the larger timestamp. + */ + if (compareKey(first_tree_iter_list[first_tree_index].key, + second_tree_iter_list[second_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + second_tree_iter_list[second_tree_index].key_len)) { + // same key + printf("same key\n"); + first_tree_iter_list[first_tree_index].node->timestamp = + std::max(first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + first_tree_iter_list[first_tree_index].node->timestamp = + std::max(first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + first_tree_index++; + second_tree_index++; + nodeCount++; + } else if (first_tree_iter_list[first_tree_index].node->timestamp > + second_tree_iter_list[second_tree_index].node->timestamp) { + /** + * Choose first tree node. + * If the key is in the record tree, it means that there exist a same key in + * the second tree and has been choosen in the past. So we just need to remove + * the key from the insert_tokens and update the timestamp of the key in the + * first tree. + * If the key is not in the record tree, it means that the key has not been + * choosen in the past. So we need to insert the key into the record tree. + */ + printf("chosse first key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == raxNotFound) { + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + nodeCount++; + } else { + std::vector token = std::vector(first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, NULL, false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } else if (first_tree_iter_list[first_tree_index].node->timestamp < + second_tree_iter_list[second_tree_index].node->timestamp) { + /** + * Choose second tree node. + * If the key is in the record tree, it means that there exist a same key in + * the first tree and has been choosen in the past. So we need do nothing. + * If the key is not in the record tree, it means that the key has not been + * choosen in the past. So we need to insert the key into the record tree. + * and insert the key into the insert_tokens. + */ + printf("chosse second key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose second key + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + nodeCount++; + } + second_tree_index++; + } else { + /** + * If the key is not same and the timestamp is the same, we choose the key + * with the smaller children number. + */ + if (first_tree_iter_list[first_tree_index].node->numnodes <= + second_tree_iter_list[second_tree_index].node->numnodes) { + printf("chosse first key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose first key + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == raxNotFound) { + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + nodeCount++; + } else { + std::vector token = std::vector(first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } else { + printf("chosse second key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose second key + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + nodeCount++; + } + second_tree_index++; + } + } + } + + if (nodeCount == max_node) { + printf("insert evicted tokens\n"); + int evicted_node_count = 0; + while (first_tree_index < first_tree_iter_list.size()) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + std::vector evicted_token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + evicted_tokens.push_back(evicted_token); + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + evicted_node_count++; + first_tree_index++; + } + printf("evicted_node_count: %d\n", evicted_node_count); + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); + return; + } + + // first_ret and second_ret both are not 0 is the case that nodeCount == + // max_node + if (first_tree_index >= first_tree_iter_list.size() && + second_tree_index >= second_tree_iter_list.size()) { + // both tree are empty + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); + return; + } else if (first_tree_index >= first_tree_iter_list.size()) { + // first tree is empty + while (second_tree_index < second_tree_iter_list.size() && + nodeCount < max_node) { + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + + nodeCount++; + } + second_tree_index++; + } + } else if (second_tree_index >= second_tree_iter_list.size()) { + // second tree is empty + raxShow(tmp); + printf("nodeCount:%d\n", nodeCount); + printf("first_tree_index:%ld\n", first_tree_index); + while (first_tree_index < first_tree_iter_list.size() && + nodeCount < max_node) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + nodeCount++; + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } + while (first_tree_index < first_tree_iter_list.size()) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + std::vector evicted_token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + evicted_tokens.push_back(evicted_token); + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } + } + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); +} + +void testIteRax(rax *tree) { + raxIterator iter; + raxStart(&iter, tree); + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + printf("key: "); + for (size_t i = 0; i < iter.key_len; i++) { + printf("%d ", iter.key[i]); + } + printf("\n"); + // printf("data: %p\n", iter.data); + } + raxStop(&iter); +} + +raxNode* raxGetFirstChildPtr(raxNode* node) { + return raxGetFirstChildPtr(node); +} + +// 1 2 3 +// query subtree node:0x55f87076f760 +// I0129 16:44:25.626318 280948 kv_state_cache.cc:223] offset:0 +// I0129 16:44:25.626322 280948 kv_state_cache.cc:224] kvStateCacheBlockBuilder:0x55f870767bc0 \ No newline at end of file diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h index 331936d2..25321aa7 100644 --- a/modules/kv-state-cache/radix-tree/radix.h +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -35,6 +35,11 @@ #include #include #include +#include +#include +#include +#include +#include /* Representation of a radix tree as implemented in this file, that contains * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after @@ -139,6 +144,7 @@ typedef struct raxNode { typedef struct rax { raxNode* head; + raxNode* headDataNode; uint64_t numele; uint64_t numnodes; } rax; @@ -202,13 +208,13 @@ extern void* raxNotFound; /* Exported API. */ rax* raxNew(void); -int raxInsert(rax* rax, int* s, size_t len, void* data, void** old); +int raxInsert(rax *rax, int *s, size_t len, void *data, void **old, bool set_timestamp = true); int raxTryInsert(rax* rax, int* s, size_t len, void* data, void** old); int raxInsertAndReturnDataNode(rax* rax, int* s, size_t len, void* data, void** node, void** old); -int raxRemove(rax* rax, int* s, size_t len, void** old); +int raxRemove(rax* rax, int* s, size_t len, void** old, raxNode** sub_tree_node = NULL, bool set_timestamp = true); void* raxFind(rax* rax, int* s, size_t len); -raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len); +raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len, raxNode** sub_tree_node = NULL, bool set_timestamp = true); void raxSetSubtree(raxNode *n); void raxSetSubtreeAllocated(raxNode *node); void raxSetSubtreeNotNull(raxNode *node); @@ -232,14 +238,17 @@ void raxSetDebugMsg(int onoff); void raxTraverse(raxNode* rax, std::vector>& dataNodeList); void raxTraverseSubTree(raxNode* n, std::vector &dataNodeList); -raxNode *raxSplit(rax *rax, int *s, size_t len, void *data); +raxNode *raxSplit(rax *rax, int *s, size_t len, void *data, std::vector& key); void raxSerialize(rax* root, std::vector>& tokenList, std::vector& dataList, std::vector ×tampsList, std::vector> *subtreeList, std::vector *subtreeNodeList); /* Internal API. May be used by the node callback in order to access rax nodes * in a low level way, so this function is exported as well. */ -void raxSetData(raxNode* n, void* data); -void* raxGetData(raxNode* n); +void raxSetData(raxNode *n, void *data); +void *raxGetData(raxNode *n); int raxFindNode(rax *rax, int *s, size_t len, void **node); - +void raxFindLastRecentNode(raxNode *node, std::vector& key); +void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::set>& insert_tokens, int max_node); +void testIteRax(rax *tree); +raxNode* raxGetFirstChildPtr(raxNode* node); #endif diff --git a/modules/kv-state-cache/strategy/LRU_strategy.cc b/modules/kv-state-cache/strategy/LRU_strategy.cc index bf6b646e..87d4cdee 100644 --- a/modules/kv-state-cache/strategy/LRU_strategy.cc +++ b/modules/kv-state-cache/strategy/LRU_strategy.cc @@ -26,6 +26,16 @@ void PrintTokenList(std::vector& vector) { LOG(INFO) << tokens_str; } +void LRUStrategy::PrintLRUList() { + LOG(INFO) << "List:"; + std::shared_ptr node = header; + while (node != nullptr) { + PrintTokenList(node->tokens); + LOG(INFO) << "->"; + node = node->next; + } +} + LRUStrategy::LRUStrategy(int capacity) { this->capacity = capacity; this->header = this->tail = nullptr; @@ -40,7 +50,7 @@ LRUStrategy::LRUStrategy(const std::vector>& cache_list, std::shared_ptr LRUStrategy::InsertToHeader( const std::vector& tokens, std::vector& evicted_tokens) { if (current_size == capacity) { - std::shared_ptr remove_node = Remove(); + std::shared_ptr remove_node = tail; // Remove(); evicted_tokens = remove_node->tokens; } @@ -111,33 +121,4 @@ void LRUStrategy::Remove(std::shared_ptr cache_node) { std::shared_ptr LRUStrategy::GetHeader() { return header; } -// void LRUStrategy::Remove(const std::vector& prefix, int token) { -// std::vector tokens = prefix; -// tokens.push_back(token); - -// std::shared_ptr node_with_tree_attri = -// radix_tree->Query(tokens); -// if (node_with_tree_attri == nullptr) { -// return; -// } - -// std::shared_ptr cache_node = -// std::static_pointer_cast( -// node_with_tree_attri->get_node()->get_data()); -// if (cache_node == header) { -// header = header->next; -// header->prev = nullptr; -// } else if (cache_node == tail) { -// tail = tail->prev; -// tail->next = nullptr; -// } else { -// cache_node->prev->next = cache_node->next; -// cache_node->next->prev = cache_node->prev; -// } -// current_size--; -// radix_tree->Delete(tokens); -// } - -// LRUStrategy::~LRUStrategy() { delete radix_tree; } - } // namespace vineyard diff --git a/modules/kv-state-cache/strategy/LRU_strategy.h b/modules/kv-state-cache/strategy/LRU_strategy.h index 58c13b05..fcfc519e 100644 --- a/modules/kv-state-cache/strategy/LRU_strategy.h +++ b/modules/kv-state-cache/strategy/LRU_strategy.h @@ -58,8 +58,8 @@ class LRUStrategy : public CacheStrategy { std::shared_ptr GetHeader(); int GetCapacity() { return capacity; } - // for distributed sync - // void Remove(const std::vector& prefix, int token); + + void PrintLRUList(); }; } // namespace vineyard diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index c43fc0db..47638e55 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -47,7 +47,7 @@ void signalHandler(int signum) { exit(signum); } -void initKVStateCache(int dimension = 10, int cache_capacity = 10) { +void InitKVStateCache(int dimension = 10, int cacheCapacity = 10) { if (kv_state_cache_builder == nullptr) { std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); LOG(INFO) << "socket:" << socket; @@ -86,7 +86,7 @@ void initKVStateCache(int dimension = 10, int cache_capacity = 10) { // if failed, create a new cache object LOG(INFO) << "failed to get the cache object, create a new one"; kv_state_cache_builder = std::make_shared( - client, dimension, cache_capacity); + client, dimension, cacheCapacity); } // // release the lock @@ -101,68 +101,65 @@ void initKVStateCache(int dimension = 10, int cache_capacity = 10) { } } -void updateInternal(const std::vector& token_list, int next_token, - const KV_STATE_WITH_LAYER& kv_state) { - kv_state_cache_builder->Update(client, token_list, next_token, kv_state); +void updateInternal(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + kv_state_cache_builder->Update(client, tokenList, nextToken, kvState); } -void update(const std::vector& token_list, int next_token, - const KV_STATE_WITH_LAYER& kv_state) { - LOG(INFO) << "update"; +void Update(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + LOG(INFO) << "Update"; if (pthread_mutex_trylock(&sync_mutex)) { return; } - updateInternal(token_list, next_token, kv_state); + updateInternal(tokenList, nextToken, kvState); pthread_mutex_unlock(&sync_mutex); } -void update(const std::vector& token_list, - const LIST_KV_STATE_WITH_LAYER& kv_state) { +void Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState) { if (pthread_mutex_trylock(&sync_mutex)) { return; } std::vector token_list_copy; - for (size_t i = 0; i < token_list.size(); i++) { - updateInternal(token_list_copy, token_list[i], kv_state[i]); - token_list_copy.push_back(token_list[i]); + for (size_t i = 0; i < tokenList.size(); i++) { + updateInternal(token_list_copy, tokenList[i], kvState[i]); + token_list_copy.push_back(tokenList[i]); } pthread_mutex_unlock(&sync_mutex); } -KV_STATE_WITH_LAYER queryInternal(const std::vector& token_list, +KV_STATE_WITH_LAYER queryInternal(const std::vector& tokenList, int token) { - return kv_state_cache_builder->Query(client, token_list, token); + return kv_state_cache_builder->Query(client, tokenList, token); } -KV_STATE_WITH_LAYER query(const std::vector& token_list, int token) { - LOG(INFO) << "query"; +KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token) { + LOG(INFO) << "Query"; KV_STATE_WITH_LAYER result; if (pthread_mutex_trylock(&sync_mutex)) { return result; } - result = queryInternal(token_list, token); - LOG(INFO) << "unlock"; + result = queryInternal(tokenList, token); pthread_mutex_unlock(&sync_mutex); - LOG(INFO) << "query end"; return result; } -LIST_KV_STATE_WITH_LAYER query(const std::vector& token_list) { +LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList) { LIST_KV_STATE_WITH_LAYER list_kv_state; if (pthread_mutex_trylock(&sync_mutex)) { return list_kv_state; } std::vector token_list_copy; - for (size_t i = 0; i < token_list.size(); i++) { - KV_STATE_WITH_LAYER kv_state = - queryInternal(token_list_copy, token_list[i]); - list_kv_state.push_back(kv_state); - token_list_copy.push_back(token_list[i]); + for (size_t i = 0; i < tokenList.size(); i++) { + KV_STATE_WITH_LAYER kvState = queryInternal(token_list_copy, tokenList[i]); + list_kv_state.push_back(kvState); + token_list_copy.push_back(tokenList[i]); } pthread_mutex_unlock(&sync_mutex); @@ -194,15 +191,11 @@ void sync() { } // 3. merge the cache object - std::shared_ptr merged_kv_state_cache_builder = - kv_state_cache_builder->Merge(client, global_kv_state_cache); - if (merged_kv_state_cache_builder == nullptr) { - merged_kv_state_cache_builder = kv_state_cache_builder; - } + kv_state_cache_builder->Merge(client, global_kv_state_cache); // 4. push the cache object std::shared_ptr kv_state_cache = - merged_kv_state_cache_builder->_Seal(client); + kv_state_cache_builder->_Seal(client); client.Persist(kv_state_cache->id()); // 5. put the name of the new cache object to the meta server @@ -245,14 +238,13 @@ void threadFunc() { pthread_mutex_lock(&sync_mutex); sync(); pthread_mutex_unlock(&sync_mutex); - // break; } } /* a. vineyardd with global cache object | sealed b. client get the object replica - c. client update replica + c. client Update replica d. client seal the local object and try to push object to server (modified sealed object and global cache version) â…°. if success 1. vineyardd modify global object meta @@ -264,11 +256,3 @@ void threadFunc() { and merge) 3. goto d */ -/* node with attr node cache node data node - addr: 0x7e6fec007be0 0x7e6fec0077e0 0x7e6fec001100 0 - addr: 0x7e6fec007be0 0x7e6fec0077e0 0x7e6fec001100 0x7e6fec009ea0 - - - - 0x5654cf368c60 0x5654cf368f40 0x7e6fec001100 0x7e6fec009ea0 -*/ diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h index be346312..fe930cb1 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -18,16 +18,16 @@ limitations under the License. #ifndef MODULES_KV_STATE_CACHE_UTILS_H_ #define MODULES_KV_STATE_CACHE_UTILS_H_ -void initKVStateCache(int dimension = 10, int cache_capacity = 10); +void InitKVStateCache(int dimension = 10, int cacheCapacity = 10); -void update(const std::vector& token_list, int next_token, - const KV_STATE_WITH_LAYER& kv_state); +void Update(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState); -void update(const std::vector& token_list, - const LIST_KV_STATE_WITH_LAYER& kv_state); +void Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState); -KV_STATE_WITH_LAYER query(const std::vector& token_list, int token); +KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token); -LIST_KV_STATE_WITH_LAYER query(const std::vector& token_list); +LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); #endif \ No newline at end of file diff --git a/test/kv_state_cache_object_test.cc b/test/kv_state_cache_object_test.cc index 60bdd3e0..433fe5e0 100644 --- a/test/kv_state_cache_object_test.cc +++ b/test/kv_state_cache_object_test.cc @@ -21,154 +21,160 @@ limitations under the License. using namespace vineyard; -std::vector tokens; -RadixTree* radix_tree; -std::vector> k_state_list; -std::vector> v_state_list; -std::vector> nodes_with_tree_attri_list; - -#define DIMENSION 10 -#define TOKEN_NUM 10 -#define CACHE_CAPACITY 10 - -void prepareData(KVStateCacheBuilder* kv_state_cache_builder) { - radix_tree = new RadixTree(10); - radix_tree->SetCustomData(kv_state_cache_builder, - sizeof(KVStateCacheBuilder)); - - for (int i = 0; i < TOKEN_NUM; i++) { - tokens.push_back(i); - } - - LOG(INFO) << "stage 1"; - for (int i = 0; i < TOKEN_NUM; i++) { - std::vector key_state; - for (int j = 0; j < DIMENSION; ++j) { - key_state.push_back(((double) (j)) * 0.1 + (double) i); - } - k_state_list.push_back(key_state); - } - - LOG(INFO) << "stage 2"; - for (int i = 0; i < TOKEN_NUM; i++) { - std::vector value_state; - for (int j = 0; j < DIMENSION; ++j) { - value_state.push_back(((double) (j)) * 0.1 + (double) i); - } - v_state_list.push_back(value_state); - } -} - -void updateTest(Client& client, KVStateCacheBuilder* builder) { - std::vector prefix; - - for (size_t i = 0; i < tokens.size(); ++i) { - KV_STATE_WITH_LAYER kv_state; - kv_state.insert( - std::make_pair(1, std::make_pair(k_state_list[i], v_state_list[i]))); - LOG(INFO) << "update test"; - builder->Update(client, prefix, tokens[i], kv_state); - prefix.push_back(tokens[i]); - } -} - -void queryTest(Client& client, KVStateCacheBuilder* builder) { - std::vector prefix; - KV_STATE_WITH_LAYER kv_state; - - for (int i = 0; i < TOKEN_NUM; i++) { - kv_state = builder->Query(client, prefix, tokens[i]); - std::vector key_state = kv_state[1].first; - std::vector value_state = kv_state[1].second; - - VINEYARD_ASSERT( - key_state.size() == (size_t) DIMENSION, - "Expected key_state.size() == " + std::to_string(DIMENSION) + - ", but got + key_state.size() == " + - std::to_string(key_state.size())); - VINEYARD_ASSERT( - value_state.size() == (size_t) DIMENSION, - "Expected value_state.size() == " + std::to_string(DIMENSION) + - ", but got + value_state.size() == " + - std::to_string(value_state.size())); - for (int j = 0; j < DIMENSION; ++j) { - VINEYARD_ASSERT(key_state[j] == k_state_list[i][j], - "Expected key_state[" + std::to_string(j) + - "] == " + std::to_string(k_state_list[i][j]) + - ", but got + key_state[" + std::to_string(j) + - "] == " + std::to_string(key_state[j])); - VINEYARD_ASSERT(value_state[j] == v_state_list[i][j], - "Expected value_state[" + std::to_string(j) + - "] == " + std::to_string(v_state_list[i][j]) + - ", but got + value_state[" + std::to_string(j) + - "] == " + std::to_string(value_state[j])); - } - prefix.push_back(tokens[i]); - } -} +// std::vector tokens; +// RadixTree* radix_tree; +// std::vector> k_state_list; +// std::vector> v_state_list; +// std::vector> nodes_with_tree_attri_list; + +// #define DIMENSION 10 +// #define TOKEN_NUM 10 +// #define CACHE_CAPACITY 10 + +// void prepareData(KVStateCacheBuilder* kv_state_cache_builder) { +// radix_tree = new RadixTree(10); +// radix_tree->SetCustomData(kv_state_cache_builder, +// sizeof(KVStateCacheBuilder)); + +// for (int i = 0; i < TOKEN_NUM; i++) { +// tokens.push_back(i); +// } + +// LOG(INFO) << "stage 1"; +// for (int i = 0; i < TOKEN_NUM; i++) { +// std::vector key_state; +// for (int j = 0; j < DIMENSION; ++j) { +// key_state.push_back(((double) (j)) * 0.1 + (double) i); +// } +// k_state_list.push_back(key_state); +// } + +// LOG(INFO) << "stage 2"; +// for (int i = 0; i < TOKEN_NUM; i++) { +// std::vector value_state; +// for (int j = 0; j < DIMENSION; ++j) { +// value_state.push_back(((double) (j)) * 0.1 + (double) i); +// } +// v_state_list.push_back(value_state); +// } +// } + +// void updateTest(Client& client, KVStateCacheBuilder* builder) { +// std::vector prefix; + +// for (size_t i = 0; i < tokens.size(); ++i) { +// KV_STATE_WITH_LAYER kv_state; +// kv_state.insert( +// std::make_pair(1, std::make_pair(k_state_list[i], v_state_list[i]))); +// LOG(INFO) << "update test"; +// builder->Update(client, prefix, tokens[i], kv_state); +// prefix.push_back(tokens[i]); +// } +// } + +// void queryTest(Client& client, KVStateCacheBuilder* builder) { +// std::vector prefix; +// KV_STATE_WITH_LAYER kv_state; + +// for (int i = 0; i < TOKEN_NUM; i++) { +// kv_state = builder->Query(client, prefix, tokens[i]); +// std::vector key_state = kv_state[1].first; +// std::vector value_state = kv_state[1].second; + +// VINEYARD_ASSERT( +// key_state.size() == (size_t) DIMENSION, +// "Expected key_state.size() == " + std::to_string(DIMENSION) + +// ", but got + key_state.size() == " + +// std::to_string(key_state.size())); +// VINEYARD_ASSERT( +// value_state.size() == (size_t) DIMENSION, +// "Expected value_state.size() == " + std::to_string(DIMENSION) + +// ", but got + value_state.size() == " + +// std::to_string(value_state.size())); +// for (int j = 0; j < DIMENSION; ++j) { +// VINEYARD_ASSERT(key_state[j] == k_state_list[i][j], +// "Expected key_state[" + std::to_string(j) + +// "] == " + std::to_string(k_state_list[i][j]) + +// ", but got + key_state[" + std::to_string(j) + +// "] == " + std::to_string(key_state[j])); +// VINEYARD_ASSERT(value_state[j] == v_state_list[i][j], +// "Expected value_state[" + std::to_string(j) + +// "] == " + std::to_string(v_state_list[i][j]) + +// ", but got + value_state[" + std::to_string(j) + +// "] == " + std::to_string(value_state[j])); +// } +// prefix.push_back(tokens[i]); +// } +// } void sealAndConstructTest(Client& client, KVStateCacheBuilder* builder) { - ObjectID id = builder->_Seal(client)->id(); - std::shared_ptr kv_state_cache = - std::dynamic_pointer_cast(client.GetObject(id)); - std::shared_ptr kv_state_cache_block = - kv_state_cache->GetKVStateCacheBlock(); - std::shared_ptr kv_state_cache_block_builder = - builder->GetKVStateCacheBlockBuilder(); - - // compare kv_state_cache_block and kv_state_cache_block_builder - VINEYARD_ASSERT(kv_state_cache_block->GetDimension() == - kv_state_cache_block_builder->GetDimension()); - - VINEYARD_ASSERT(kv_state_cache_block->GetBitmap() == - kv_state_cache_block_builder->GetBitmap()); - - LOG(INFO) << "Bitmap:"; - LOG(INFO) << kv_state_cache_block_builder->GetBitmapStr(); - LOG(INFO) << kv_state_cache_block->GetBitmapStr(); - - const std::shared_ptr> k_tensor_builder = - kv_state_cache_block_builder->getKBuilder(); - const std::shared_ptr> v_tensor_builder = - kv_state_cache_block_builder->getVBuilder(); - - std::shared_ptr> k_tensor = - kv_state_cache_block->GetKTensor(); - std::shared_ptr> v_tensor = - kv_state_cache_block->GetVTensor(); - - for (int i = 0; i < TOKEN_NUM; i++) { - for (int j = 0; j < DIMENSION; j++) { - VINEYARD_ASSERT(k_tensor->data()[i * DIMENSION + j] == - k_tensor_builder->data()[i * DIMENSION + j]); - VINEYARD_ASSERT(v_tensor->data()[i * DIMENSION + j] == - v_tensor_builder->data()[i * DIMENSION + j]); - } - } + // ObjectID id = builder->_Seal(client)->id(); + // std::shared_ptr kv_state_cache = + // std::dynamic_pointer_cast(client.GetObject(id)); + // std::vector> kv_state_cache_block_list = + // kv_state_cache->GetKVStateCacheBlockList(); + // std::vector kv_state_cache_block_builder_list = + // builder->GetKVStateCacheBlockBuilderList(); + // for (int i = 0; i < kv_state_cache_block_list.size(); i++) { + // std::shared_ptr kv_state_cache_block = + // kv_state_cache_block_list[i]; + // KVStateCacheBlockBuilder* kv_state_cache_block_builder = + // kv_state_cache_block_builder_list[i]; + + // // compare kv_state_cache_block and kv_state_cache_block_builder + // VINEYARD_ASSERT(kv_state_cache_block->GetDimension() == + // kv_state_cache_block_builder->GetDimension()); + + // VINEYARD_ASSERT(kv_state_cache_block->GetBitmap() == + // kv_state_cache_block_builder->GetBitmap()); + + // LOG(INFO) << "Bitmap:"; + // LOG(INFO) << kv_state_cache_block_builder->GetBitmapStr(); + // LOG(INFO) << kv_state_cache_block->GetBitmapStr(); + + // const std::shared_ptr> k_tensor_builder = + // kv_state_cache_block_builder->getKBuilder(); + // const std::shared_ptr> v_tensor_builder = + // kv_state_cache_block_builder->getVBuilder(); + + // std::shared_ptr> k_tensor = + // kv_state_cache_block->GetKTensor(); + // std::shared_ptr> v_tensor = + // kv_state_cache_block->GetVTensor(); + + // for (int i = 0; i < TOKEN_NUM; i++) { + // for (int j = 0; j < DIMENSION; j++) { + // VINEYARD_ASSERT(k_tensor->data()[i * DIMENSION + j] == + // k_tensor_builder->data()[i * DIMENSION + j]); + // VINEYARD_ASSERT(v_tensor->data()[i * DIMENSION + j] == + // v_tensor_builder->data()[i * DIMENSION + j]); + // } + // } + // } } void splitTest(Client& client, KVStateCacheBuilder* builder) {} int main() { - std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); - Client client; - client.Connect(socket); + // std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); + // Client client; + // client.Connect(socket); - LOG(INFO) << "Build kv state cache"; - KVStateCacheBuilder* kv_state_cache_builder = - new KVStateCacheBuilder(client, DIMENSION, CACHE_CAPACITY); + // LOG(INFO) << "Build kv state cache"; + // KVStateCacheBuilder* kv_state_cache_builder = + // new KVStateCacheBuilder(client, DIMENSION, CACHE_CAPACITY); - LOG(INFO) << "Prepare data"; - prepareData(kv_state_cache_builder); + // LOG(INFO) << "Prepare data"; + // prepareData(kv_state_cache_builder); - LOG(INFO) << "Test update"; - updateTest(client, kv_state_cache_builder); + // LOG(INFO) << "Test update"; + // updateTest(client, kv_state_cache_builder); - LOG(INFO) << "Test query"; - queryTest(client, kv_state_cache_builder); + // LOG(INFO) << "Test query"; + // queryTest(client, kv_state_cache_builder); - LOG(INFO) << "Test seal and construct"; - sealAndConstructTest(client, kv_state_cache_builder); + // LOG(INFO) << "Test seal and construct"; + // sealAndConstructTest(client, kv_state_cache_builder); return 0; } \ No newline at end of file diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index 68d5a2d1..988f6417 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include +#include "kv-state-cache/radix-tree/radix.h" #include "common/util/logging.h" #include "kv-state-cache/utils/kv_state_cache_utils.h" @@ -23,8 +25,9 @@ limitations under the License. using namespace vineyard; #define DEMENSION 10 +#define CAPACITY 100 -void init() { initKVStateCache(DEMENSION); } +void init() { InitKVStateCache(DEMENSION, CAPACITY); } void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; @@ -33,7 +36,6 @@ void print_current_tokens(const std::vector& prefix, int next_token) { } tokens_str += std::to_string(next_token); LOG(INFO) << "Current tokens: " + tokens_str; - LOG(INFO) << tokens_str; } void print_kv_state( @@ -73,15 +75,15 @@ void inference(std::vector tokens, bool block = false) { std::map, std::vector>> kv_state; for (size_t i = 0; i < tokens.size(); ++i) { - kv_state = query(inference_tokens, tokens[i]); + kv_state = Query(inference_tokens, tokens[i]); if (kv_state.size() == 0) { LOG(INFO) << "======================================"; LOG(INFO) << "Can not find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); LOG(INFO) << "Generate the kv_state and update the cache."; kv_state = generate_kv_state(tokens[i]); - update(inference_tokens, tokens[i], kv_state); print_kv_state(kv_state); + Update(inference_tokens, tokens[i], kv_state); LOG(INFO) << "======================================"; } else { LOG(INFO) << "--------------------------------------"; @@ -97,12 +99,31 @@ void inference(std::vector tokens, bool block = false) { int main() { init(); std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - // std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10}; + std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, + 9, 10, 11, 12, 13, 14}; + std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; + // std::vector round_3_tokens = {1, 2, 3, 4, 5, 6, 7}; + // std::vector round_1_tokens = {1, 2}; + // std::vector round_2_tokens = {1, 3}; + // std::vector round_3_tokens = {1, 3, 4}; + // std::vector round_4_tokens = {1, 3, 5}; + // std::vector round_5_tokens = {1, 1}; inference(round_1_tokens); + // inference(round_1_tokens); + inference(round_2_tokens); sleep(5); + inference(round_3_tokens); + + inference(round_1_tokens); + inference(round_2_tokens); + inference(round_3_tokens); + // inference(round_3_tokens); + // inference(round_3_tokens); + // inference(round_4_tokens); + // inference(round_5_tokens); + // sleep(5); // inference(round_2_tokens); - // inference(round_2_tokens); - inference(round_1_tokens, true); + // inference(round_1_tokens, true); while (1) ; return 0; diff --git a/test/rax_diff_test.cc b/test/rax_diff_test.cc new file mode 100644 index 00000000..e86731d1 --- /dev/null +++ b/test/rax_diff_test.cc @@ -0,0 +1,101 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include "kv-state-cache/radix-tree/radix.h" + +int key_1[] = {1, 2}; +int key_2[] = {1, 3}; +int key_3[] = {1, 4}; +int key_4[] = {1, 3, 1}; +int key_5[] = {1, 3, 2}; + +void insert(rax* rt, int* key, int len) { + for (int i = 1; i <= len; i++) { + raxInsert(rt, key, i, NULL, NULL); + } +} + +int main(int argc, char** argv) { + rax* rt_1 = raxNew(); + rax* rt_2 = raxNew(); + + int max_node = argc > 1 ? atoi(argv[1]) : 3; + + // raxInsert(rt_2, key_1, 2, NULL, NULL); + // raxInsert(rt_2, key_2, 2, NULL, NULL); + + // raxInsert(rt_1, key_3, 2, NULL, NULL); + // raxInsert(rt_1, key_4, 3, NULL, NULL); + // raxInsert(rt_1, key_5, 3, NULL, NULL); + + insert(rt_1, key_3, 2); + insert(rt_1, key_4, 3); + + sleep(1); + + insert(rt_2, key_1, 2); + insert(rt_2, key_2, 2); + + sleep(1); + + insert(rt_1, key_5, 3); + + raxShow(rt_1); + printf("==============================\n"); + raxShow(rt_2); + printf("==============================\n"); + + testIteRax(rt_1); + printf("==============================\n"); + testIteRax(rt_2); + printf("==============================\n"); + + std::vector> evicted_tokens; + std::set> insert_tokens; + mergeTree(rt_1, rt_2, evicted_tokens, insert_tokens, max_node); + + printf("evicted_tokens:\n"); + for (size_t i = 0; i < evicted_tokens.size(); i++) { + for (size_t j = 0; j < evicted_tokens[i].size(); j++) { + printf("%d ", evicted_tokens[i][j]); + } + printf("\n"); + } + for (size_t i = 0; i < evicted_tokens.size(); i++) { + // void* tree_data; + raxRemove(rt_1, evicted_tokens[i].data(), evicted_tokens[i].size(), NULL, + NULL, false); + } + + for (auto it = insert_tokens.begin(); it != insert_tokens.end(); it++) { + raxInsert(rt_1, const_cast(it->data()), it->size(), NULL, NULL, + false); + } + + raxShow(rt_1); + printf("==============================\n"); + raxShow(rt_2); + printf("==============================\n"); + + testIteRax(rt_1); + printf("==============================\n"); + testIteRax(rt_2); + printf("==============================\n"); + + return 0; +} \ No newline at end of file From 32ad9f1658de991da44e0b2c0cd9f84ace4d4311 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:11:05 +0800 Subject: [PATCH 05/20] Support merge cache object with radix tree. (#1741) - Support merge cache object with radix tree. - Use cache version to control merge process. Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 31 ++-- modules/kv-state-cache/ds/kv_state_cache.h | 2 + .../kv-state-cache/radix-tree/radix-tree.cc | 31 +++- .../kv-state-cache/radix-tree/radix-tree.h | 7 +- modules/kv-state-cache/radix-tree/radix.cc | 60 ------- .../utils/kv_state_cache_utils.cc | 167 +++++++++--------- .../utils/kv_state_cache_utils.h | 2 + test/kv_state_cache_test.cc | 8 +- test/kv_state_cache_test_2.cc | 129 ++++++++++++++ 9 files changed, 276 insertions(+), 161 deletions(-) create mode 100644 test/kv_state_cache_test_2.cc diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index 43beec92..3437b223 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -59,6 +59,7 @@ void KVStateCache::Resolve() { // 3. construct the member field this->dimension = this->meta_.GetKeyValue("dimension"); + this->version = this->meta_.GetKeyValue("version"); LOG(INFO) << "construct the member field success" << std::endl; } @@ -260,34 +261,43 @@ void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { void KVStateCacheBuilder::Merge(Client& client, std::shared_ptr kvStateCache) { - // TBD if (kvStateCache == nullptr) { return; } + std::shared_ptr globalCacheBuilder = std::make_shared(client, kvStateCache); std::shared_ptr globalCacheTree = kvStateCache->GetRootTree(); std::set> insertTokenList; std::vector> evicted_token_list; - mergeTree(this->rootTree->GetRootTree(), globalCacheTree->GetRootTree(), - evicted_token_list, insertTokenList, - this->rootTree->GetCacheCapacity()); + RadixTree::MergeTree(this->rootTree, globalCacheTree, evicted_token_list, + insertTokenList); + LOG(INFO) << "insert token list size:" << insertTokenList.size() + << " evicted token list size:" << evicted_token_list.size(); for (size_t i = 0; i < evicted_token_list.size(); i++) { - std::vector tokenList = evicted_token_list[i]; + std::vector tokenList = + evicted_token_list[evicted_token_list.size() - i - 1]; std::shared_ptr evictedNodeData; this->rootTree->Delete(tokenList, evictedNodeData); Delete(evictedNodeData); } + /** + * Set use lexicographical order to insert the token list, so the insert token + * list is sorted and will not cause insert failed.(Radix tree will reject a + * insert operation if the prefix of the insert token list is not in the + * tree.) + */ for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) { - std::vector tokenList = *it; - KV_STATE_WITH_LAYER kvState = globalCacheBuilder->Query( - client, std::vector(tokenList.begin(), tokenList.end() - 1), - tokenList.back()); - this->Update(client, tokenList, tokenList[tokenList.size() - 1], kvState); + std::vector tokenList = + std::vector((*it).begin(), (*it).end() - 1); + KV_STATE_WITH_LAYER kvState = + globalCacheBuilder->Query(client, tokenList, (*it).back()); + this->Update(client, tokenList, (*it).back(), kvState); } + this->version = globalCacheBuilder->GetVersion(); return; } @@ -304,6 +314,7 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { // 1. store the member variables to cache object meta kvStateCache->meta_.AddKeyValue("dimension", this->dimension); + kvStateCache->meta_.AddKeyValue("version", this->version); // 2. seal all the block and put object id to cache object and // change the tree data from pointer to object id diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h index d4498281..ca4f4b70 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -99,6 +99,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { uint64_t GetVersion() { return this->version; } + void UpdateVersion() { this->version++; } + Status Build(Client& client) override; std::shared_ptr _Seal(Client& client) override; diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index 3442febe..215b9fcf 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -22,7 +22,8 @@ using namespace vineyard; RadixTree::RadixTree(int cacheCapacity) { this->tree = raxNew(); - this->cacheCapacity = cacheCapacity; + // add one to the cache capacity because we insert a root node to the tree. + this->cacheCapacity = cacheCapacity + 1; this->nodeCount = 0; // prepare root node @@ -150,8 +151,9 @@ void RadixTree::DeleteInternal(std::vector tokens, DataWrapper* oldData; raxNode* subTreeNode; std::vector pre; - raxFindAndReturnDataNode(this->tree, deleteTokensArray, deleteTokensArrayLen, - &subTreeNode, false); + // raxFindAndReturnDataNode(this->tree, deleteTokensArray, + // deleteTokensArrayLen, + // &subTreeNode, false); int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, (void**) &oldData, &subTreeNode); if (retval == 1) { @@ -285,6 +287,7 @@ std::string RadixTree::Serialize() { } std::string compressedStr = std::string(compressedData, compressedDataSize); + int cacheCapacity = this->cacheCapacity - 1; std::string result = std::string((char*) &srcSize, sizeof(int)) + std::string((char*) &cacheCapacity, sizeof(int)) + compressedStr; @@ -533,4 +536,26 @@ std::shared_ptr RadixTree::GetRootNode() { rootToken.size(), NULL); return std::make_shared((DataWrapper*) raxGetData(node), (DataWrapper*) node->custom_data); +} + +void RadixTree::MergeTree(std::shared_ptr tree_1, + std::shared_ptr tree_2, + std::vector>& evicted_tokens, + std::set>& insert_tokens) { + std::set> insert_tokens_set; + std::vector> evicted_tokens_vec; + mergeTree(tree_1->tree, tree_2->tree, evicted_tokens_vec, insert_tokens_set, + tree_1->cacheCapacity); + for (size_t i = 0; i < evicted_tokens_vec.size(); i++) { + VINEYARD_ASSERT(evicted_tokens_vec[i][0] == INT32_MAX); + std::vector tmp(evicted_tokens_vec[i].begin() + 1, + evicted_tokens_vec[i].end()); + evicted_tokens.push_back(tmp); + } + + for (auto& vec : insert_tokens_set) { + VINEYARD_ASSERT(vec[0] == INT32_MAX); + std::vector tmp(vec.begin() + 1, vec.end()); + insert_tokens.insert(tmp); + } } \ No newline at end of file diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index bcea8e67..78a7d562 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -92,11 +92,16 @@ class RadixTree : public std::enable_shared_from_this { rax* GetRootTree() { return this->tree; } - int GetCacheCapacity() { return cacheCapacity; } + int GetCacheCapacity() { return cacheCapacity - 1; } std::set GetSubTreeDataSet() { return subTreeDataSet; } std::shared_ptr GetRootNode(); + + static void MergeTree(std::shared_ptr tree_1, + std::shared_ptr tree_2, + std::vector>& evicted_tokens, + std::set>& insert_tokens); }; #endif diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 428fa819..590d7513 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -2446,66 +2446,6 @@ bool compareKey(int *first_key, int *second_key, int first_key_len, int second_k return true; } -// bool compare(raxNode *a, raxNode *b) { -// return a->timestamp > b->timestamp; -// } - -// void sortNode(raxNode **node, int size) { -// std::sort(node, node + size, compare); -// } - -// void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::map, void*>& insert_tokens, int max_node) { -// raxNode* first_tree_node = first_tree->head; -// raxNode* second_tree_node = second_tree->head; - -// std::queue first_tree_queue; -// std::queue second_tree_queue; - -// first_tree_queue.push(first_tree_node); -// second_tree_queue.push(second_tree_node); - -// rax* tree = raxNew(); - -// int nodeCount = 0; - -// while((!first_tree_queue.empty()) && (!second_tree_queue.empty())) { -// int first_tree_rax_node_list_size = first_tree_queue.size(); -// int second_tree_rax_node_list_size = second_tree_queue.size(); -// raxNode** first_tree_rax_node_list = (raxNode**)malloc(sizeof(raxNode*) * first_tree_rax_node_list_size); -// raxNode** second_tree_rax_node_list = (raxNode**)malloc(sizeof(raxNode*) * second_tree_rax_node_list_size); - -// for (int i = 0; i < first_tree_queue.size(); i++) { -// first_tree_rax_node_list[i] = first_tree_queue.front(); -// first_tree_queue.pop(); -// } - -// for (int i = 0; i < second_tree_queue.size(); i++) { -// second_tree_rax_node_list[i] = second_tree_queue.front(); -// second_tree_queue.pop(); -// } - -// sortNode(first_tree_rax_node_list, first_tree_queue.size()); -// sortNode(second_tree_rax_node_list, second_tree_queue.size()); - -// int first_tree_index = 0; -// int second_tree_index = 0; - -// while(first_tree_index < first_tree_rax_node_list_size && second_tree_index < second_tree_rax_node_list_size && nodeCount < max_node) { -// if (first_tree_rax_node_list[first_tree_index]->timestamp > second_tree_rax_node_list[second_tree_index]->timestamp) { -// // choose first_tree_rax_node_list[first_tree_index] -// if (raxFind(tree, first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->size) == NULL) { -// raxInsert(tree, first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->size, first_tree_rax_node_list[first_tree_index]->data, NULL); -// nodeCount++; -// } else { -// std::vector token = std::vector(first_tree_rax_node_list[first_tree_index]->data, first_tree_rax_node_list[first_tree_index]->data + first_tree_rax_node_list[first_tree_index]->size); -// insert_tokens.erase(token); -// } -// first_tree_index++; -// } -// } -// } -// } - bool compare(raxIterator a, raxIterator b) { if (a.key_len == b.key_len) { return a.node->timestamp > b.node->timestamp; diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index 47638e55..7fba29e2 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -18,21 +18,30 @@ limitations under the License. #include "client/client.h" #include "common/util/logging.h" #include "kv-state-cache/ds/kv_state_cache.h" +#include "kv_state_cache_utils.h" using namespace vineyard; static Client client; -static std::shared_ptr kv_state_cache_builder = nullptr; -static std::string llm_cache_sync_lock = "llm_cache_sync_lock"; -static std::string llm_cache_object_name = "llm_cache_object"; -static std::thread* sync_thread; -static bool exit_flag = false; -static pthread_mutex_t sync_mutex; +static std::shared_ptr kvStateCacheBuilder = nullptr; +static std::string llmCacheSyncLock = "llmCacheSyncLock"; +static std::string llmCacheObjectName = "llm_cache_object"; +static std::thread* syncThread; +static bool exitFlag = false; +static pthread_mutex_t syncMutex; #ifndef SYNC_INTERVAL #define SYNC_INTERVAL 3 #endif +// for test +void Delete(std::vector token) { + std::shared_ptr evictedNode; + kvStateCacheBuilder->GetRootTree()->Delete(token, evictedNode); + kvStateCacheBuilder->Delete(evictedNode); + raxShow(kvStateCacheBuilder->GetRootTree()->tree); +} + void threadFunc(); void signalHandler(int signum) { @@ -42,25 +51,25 @@ void signalHandler(int signum) { * Use lease to prevent dead lock in the future. */ std::cout << "Interrupt signal (" << signum << ") received.\n"; - exit_flag = true; - sync_thread->join(); + exitFlag = true; + syncThread->join(); exit(signum); } -void InitKVStateCache(int dimension = 10, int cacheCapacity = 10) { - if (kv_state_cache_builder == nullptr) { +void InitKVStateCache(int dimension, int cacheCapacity) { + if (kvStateCacheBuilder == nullptr) { std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); LOG(INFO) << "socket:" << socket; client.Connect(socket); LOG(INFO) << "conneted"; - pthread_mutex_init(&sync_mutex, NULL); + pthread_mutex_init(&syncMutex, NULL); // TBD // try to get cache object - std::string actural_key; + std::string acturalKey; bool result; while (1) { - client.TryAcquireLock(llm_cache_sync_lock, result, actural_key); + client.TryAcquireLock(llmCacheSyncLock, result, acturalKey); if (!result) { LOG(INFO) << "failed to gain the lock, wait for next time"; sleep(1); @@ -71,29 +80,28 @@ void InitKVStateCache(int dimension = 10, int cacheCapacity = 10) { } // // sync global cache object with vineyard - ObjectID global_kv_state_cache_id; - Status status = - client.GetName(llm_cache_object_name, global_kv_state_cache_id); + ObjectID globalKVStateCacheID; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); if (status.ok()) { // if success, pull the cache object - std::shared_ptr global_kv_state_cache = + std::shared_ptr globalKVStateCache = std::dynamic_pointer_cast( - client.GetObject(global_kv_state_cache_id)); + client.GetObject(globalKVStateCacheID)); // TBD cache stragety - kv_state_cache_builder = - std::make_shared(client, global_kv_state_cache); + kvStateCacheBuilder = + std::make_shared(client, globalKVStateCache); } else { // if failed, create a new cache object LOG(INFO) << "failed to get the cache object, create a new one"; - kv_state_cache_builder = std::make_shared( + kvStateCacheBuilder = std::make_shared( client, dimension, cacheCapacity); } // // release the lock - client.TryReleaseLock(actural_key, result); + client.TryReleaseLock(acturalKey, result); VINEYARD_ASSERT(result == true); - sync_thread = new std::thread(threadFunc); + syncThread = new std::thread(threadFunc); signal(SIGINT, signalHandler); // TBD @@ -103,105 +111,113 @@ void InitKVStateCache(int dimension = 10, int cacheCapacity = 10) { void updateInternal(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState) { - kv_state_cache_builder->Update(client, tokenList, nextToken, kvState); + kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); } void Update(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState) { LOG(INFO) << "Update"; - if (pthread_mutex_trylock(&sync_mutex)) { + if (pthread_mutex_trylock(&syncMutex)) { return; } updateInternal(tokenList, nextToken, kvState); - pthread_mutex_unlock(&sync_mutex); + pthread_mutex_unlock(&syncMutex); } void Update(const std::vector& tokenList, const LIST_KV_STATE_WITH_LAYER& kvState) { - if (pthread_mutex_trylock(&sync_mutex)) { + if (pthread_mutex_trylock(&syncMutex)) { return; } - std::vector token_list_copy; + std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - updateInternal(token_list_copy, tokenList[i], kvState[i]); - token_list_copy.push_back(tokenList[i]); + updateInternal(tokenListCopy, tokenList[i], kvState[i]); + tokenListCopy.push_back(tokenList[i]); } - pthread_mutex_unlock(&sync_mutex); + pthread_mutex_unlock(&syncMutex); } KV_STATE_WITH_LAYER queryInternal(const std::vector& tokenList, int token) { - return kv_state_cache_builder->Query(client, tokenList, token); + return kvStateCacheBuilder->Query(client, tokenList, token); } KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token) { LOG(INFO) << "Query"; KV_STATE_WITH_LAYER result; - if (pthread_mutex_trylock(&sync_mutex)) { + if (pthread_mutex_trylock(&syncMutex)) { return result; } result = queryInternal(tokenList, token); - pthread_mutex_unlock(&sync_mutex); + pthread_mutex_unlock(&syncMutex); return result; } LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList) { - LIST_KV_STATE_WITH_LAYER list_kv_state; - if (pthread_mutex_trylock(&sync_mutex)) { - return list_kv_state; + LIST_KV_STATE_WITH_LAYER listKVState; + if (pthread_mutex_trylock(&syncMutex)) { + return listKVState; } - std::vector token_list_copy; + std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - KV_STATE_WITH_LAYER kvState = queryInternal(token_list_copy, tokenList[i]); - list_kv_state.push_back(kvState); - token_list_copy.push_back(tokenList[i]); + KV_STATE_WITH_LAYER kvState = queryInternal(tokenListCopy, tokenList[i]); + listKVState.push_back(kvState); + tokenListCopy.push_back(tokenList[i]); } - pthread_mutex_unlock(&sync_mutex); - return list_kv_state; + pthread_mutex_unlock(&syncMutex); + return listKVState; } void sync() { LOG(INFO) << "sync"; // 1. gain the lock - std::string actural_key; + std::string acturalKey; bool result; - client.TryAcquireLock(llm_cache_sync_lock, result, actural_key); + client.TryAcquireLock(llmCacheSyncLock, result, acturalKey); if (!result) { LOG(INFO) << "failed to gain the lock, wait for next time"; return; } // 2. pull the cache object - ObjectID global_kv_state_cache_id; - std::vector delete_list; + ObjectID globalKVStateCacheID; + std::vector deleteList; - std::shared_ptr global_kv_state_cache = nullptr; - Status status = - client.GetName(llm_cache_object_name, global_kv_state_cache_id); + std::shared_ptr globalKVStateCache = nullptr; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); if (status.ok()) { - delete_list.push_back(global_kv_state_cache_id); - global_kv_state_cache = std::dynamic_pointer_cast( - client.GetObject(global_kv_state_cache_id)); + deleteList.push_back(globalKVStateCacheID); + globalKVStateCache = std::dynamic_pointer_cast( + client.GetObject(globalKVStateCacheID)); } // 3. merge the cache object - kv_state_cache_builder->Merge(client, global_kv_state_cache); + // only the global cache object with higher version will be merged + LOG(INFO) << "Current builder version:" << kvStateCacheBuilder->GetVersion() + << " global version:" + << (globalKVStateCache == nullptr + ? "null" + : std::to_string(globalKVStateCache->GetVersion())); + if (globalKVStateCache != nullptr && + kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) { + kvStateCacheBuilder->Merge(client, globalKVStateCache); + } + kvStateCacheBuilder->UpdateVersion(); // 4. push the cache object - std::shared_ptr kv_state_cache = - kv_state_cache_builder->_Seal(client); - client.Persist(kv_state_cache->id()); + std::shared_ptr kvStateCache = kvStateCacheBuilder->_Seal(client); + client.Persist(kvStateCache->id()); // 5. put the name of the new cache object to the meta server LOG(INFO) << "stage 5"; - client.DropName(llm_cache_object_name); - status = client.PutName(kv_state_cache->id(), llm_cache_object_name); + client.DropName(llmCacheObjectName); + status = client.PutName(kvStateCache->id(), llmCacheObjectName); if (status.ok()) { LOG(INFO) << "put name success"; } else { @@ -210,18 +226,17 @@ void sync() { LOG(INFO) << "stage 6"; // 6. delete old cache object - client.DelData(delete_list); + client.DelData(deleteList); LOG(INFO) << "stage 7"; // 7. create a global cache object replica - // TBD cache stragety - std::dynamic_pointer_cast(kv_state_cache)->Resolve(); - kv_state_cache_builder = std::make_shared( - client, std::dynamic_pointer_cast(kv_state_cache)); + std::dynamic_pointer_cast(kvStateCache)->Resolve(); + kvStateCacheBuilder = std::make_shared( + client, std::dynamic_pointer_cast(kvStateCache)); LOG(INFO) << "stage 8"; // 8. release the lock - client.TryReleaseLock(actural_key, result); + client.TryReleaseLock(acturalKey, result); VINEYARD_ASSERT(result == true); // TBD @@ -231,28 +246,12 @@ void sync() { void threadFunc() { while (1) { sleep(SYNC_INTERVAL); - if (exit_flag) { + if (exitFlag) { break; } LOG(INFO) << "Try sync"; - pthread_mutex_lock(&sync_mutex); + pthread_mutex_lock(&syncMutex); sync(); - pthread_mutex_unlock(&sync_mutex); + pthread_mutex_unlock(&syncMutex); } } - -/* - a. vineyardd with global cache object | sealed - b. client get the object replica - c. client Update replica - d. client seal the local object and try to push object to server (modified - sealed object and global cache version) â…°. if success - 1. vineyardd modify global object meta - 2. client reconstruct the local object replica - 3. goto c - â…±. if failed - 1. client pull the global object - 2. merge the object with local cache (e.g. create a new child_cache_object - and merge) - 3. goto d -*/ diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h index fe930cb1..fa498c1e 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -30,4 +30,6 @@ KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token); LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); +void Delete(std::vector token); + #endif \ No newline at end of file diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index 988f6417..5782589f 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -25,7 +25,7 @@ limitations under the License. using namespace vineyard; #define DEMENSION 10 -#define CAPACITY 100 +#define CAPACITY 20 void init() { InitKVStateCache(DEMENSION, CAPACITY); } @@ -102,6 +102,9 @@ int main() { std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14}; std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; + // total 24 node + // tree 1 : 18 node + // tree 2 : 16 node // std::vector round_3_tokens = {1, 2, 3, 4, 5, 6, 7}; // std::vector round_1_tokens = {1, 2}; // std::vector round_2_tokens = {1, 3}; @@ -109,13 +112,12 @@ int main() { // std::vector round_4_tokens = {1, 3, 5}; // std::vector round_5_tokens = {1, 1}; inference(round_1_tokens); - // inference(round_1_tokens); inference(round_2_tokens); sleep(5); - inference(round_3_tokens); inference(round_1_tokens); inference(round_2_tokens); + // sleep(5); inference(round_3_tokens); // inference(round_3_tokens); // inference(round_3_tokens); diff --git a/test/kv_state_cache_test_2.cc b/test/kv_state_cache_test_2.cc new file mode 100644 index 00000000..28f8bdda --- /dev/null +++ b/test/kv_state_cache_test_2.cc @@ -0,0 +1,129 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include "kv-state-cache/radix-tree/radix.h" + +#include "common/util/logging.h" +#include "kv-state-cache/utils/kv_state_cache_utils.h" + +using namespace vineyard; + +#define DEMENSION 10 +#define CAPACITY 20 + +void init() { InitKVStateCache(DEMENSION, CAPACITY); } + +void print_current_tokens(const std::vector& prefix, int next_token) { + std::string tokens_str = ""; + for (size_t i = 0; i < prefix.size(); ++i) { + tokens_str += std::to_string(prefix[i]); + } + tokens_str += std::to_string(next_token); + LOG(INFO) << "Current tokens: " + tokens_str; +} + +void print_kv_state( + const std::map, std::vector>>& + kv_state) { + LOG(INFO) << "kv_state: "; + for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { + std::string key_state_str = ""; + std::string value_state_str = ""; + for (int i = 0; i < DEMENSION; ++i) { + key_state_str += std::to_string(iter->second.first[i]) + " "; + value_state_str += std::to_string(iter->second.second[i]) + " "; + } + LOG(INFO) << "key_state: " << key_state_str; + LOG(INFO) << "value_state: " << value_state_str; + } +} + +// we do not consider the layer. +std::map, std::vector>> +generate_kv_state(int token) { + std::vector key_state; + std::vector value_state; + for (int i = 0; i < DEMENSION; ++i) { + key_state.push_back(((double) token) / DEMENSION * (i + 1)); + value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2); + } + + std::map, std::vector>> kv_state; + kv_state.insert(std::make_pair(1, std::make_pair(key_state, value_state))); + return kv_state; +} + +void inference(std::vector tokens, bool block = false) { + LOG(INFO) << "inference"; + std::vector inference_tokens; + std::map, std::vector>> kv_state; + + for (size_t i = 0; i < tokens.size(); ++i) { + kv_state = Query(inference_tokens, tokens[i]); + if (kv_state.size() == 0) { + LOG(INFO) << "======================================"; + LOG(INFO) << "Can not find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + LOG(INFO) << "Generate the kv_state and update the cache."; + kv_state = generate_kv_state(tokens[i]); + print_kv_state(kv_state); + Update(inference_tokens, tokens[i], kv_state); + LOG(INFO) << "======================================"; + } else { + LOG(INFO) << "--------------------------------------"; + LOG(INFO) << "Find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + print_kv_state(kv_state); + LOG(INFO) << "--------------------------------------"; + } + inference_tokens.push_back(tokens[i]); + } +} + +int main() { + init(); + std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, + 9, 10, 11, 12, 13, 14}; + std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; + // std::vector round_3_tokens = {1, 2, 3, 4, 5, 6, 7}; + // std::vector round_1_tokens = {1, 2}; + // std::vector round_2_tokens = {1, 3}; + // std::vector round_3_tokens = {1, 3, 4}; + // std::vector round_4_tokens = {1, 3, 5}; + // std::vector round_5_tokens = {1, 1}; + inference(round_1_tokens); + // inference(round_1_tokens); + inference(round_3_tokens); + sleep(5); + inference(round_1_tokens); + inference(round_3_tokens); + // inference(round_2_tokens); + + // inference(round_3_tokens); + // inference(round_3_tokens); + // inference(round_4_tokens); + // inference(round_5_tokens); + // sleep(5); + // inference(round_2_tokens); + // inference(round_1_tokens, true); + while (1) + ; + return 0; +} \ No newline at end of file From 024dc1e998c9bc11af346fc9b265940cb4ce0e98 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Sun, 4 Feb 2024 11:36:19 +0800 Subject: [PATCH 06/20] Replace lz4 with zstd (#1744) What do these changes do? ------------------------- As titled. Related issue number -------------------- Fixes #1743 Signed-off-by: Ye Cao --- .gitmodules | 3 - modules/kv-state-cache/CMakeLists.txt | 11 +--- modules/kv-state-cache/README.rst | 1 - modules/kv-state-cache/lz4 | 1 - .../kv-state-cache/radix-tree/radix-tree.cc | 63 ++++++------------- .../kv-state-cache/radix-tree/radix-tree.h | 1 - 6 files changed, 19 insertions(+), 61 deletions(-) delete mode 160000 modules/kv-state-cache/lz4 diff --git a/.gitmodules b/.gitmodules index 0c08193f..f326d1eb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -49,6 +49,3 @@ [submodule "modules/graph/thirdparty/powturbo"] path = modules/graph/thirdparty/powturbo url = https://github.com/powturbo/TurboPFor-Integer-Compression.git -[submodule "modules/kv-state-cache/lz4"] - path = modules/kv-state-cache/lz4 - url = https://github.com/lz4/lz4.git diff --git a/modules/kv-state-cache/CMakeLists.txt b/modules/kv-state-cache/CMakeLists.txt index ca96b72e..82de5f67 100644 --- a/modules/kv-state-cache/CMakeLists.txt +++ b/modules/kv-state-cache/CMakeLists.txt @@ -1,11 +1,3 @@ -set(LZ4_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/lz4") - -add_custom_target(build_lz4 - COMMAND make -C ${LZ4_SOURCE_DIR} - WORKING_DIRECTORY ${LZ4_SOURCE_DIR}) - -file(GLOB LZ4_LIBRARIES "${LZ4_SOURCE_DIR}/lib/*.so") - file(GLOB VINEYARD_KV_STATE_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" "ds/*.cc" "ds/*.h" @@ -15,11 +7,10 @@ file(GLOB VINEYARD_KV_STATE_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" "utils/*.h" "strategy/*.cc" "strategy/*.h" - "lz4/lib/*.h" ) add_library(vineyard_kv_state_cache ${VINEYARD_KV_STATE_CACHE_SRCS}) -target_link_libraries(vineyard_kv_state_cache PUBLIC vineyard_client vineyard_basic ${LZ4_LIBRARIES}) +target_link_libraries(vineyard_kv_state_cache PUBLIC vineyard_client vineyard_basic) install_export_vineyard_target(vineyard_kv_state_cache) install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") diff --git a/modules/kv-state-cache/README.rst b/modules/kv-state-cache/README.rst index 4acc1a20..c5e48f97 100644 --- a/modules/kv-state-cache/README.rst +++ b/modules/kv-state-cache/README.rst @@ -10,7 +10,6 @@ Build vineyard and vineyard test mkdir build cd build cmake .. -DBUILD_VINEYARD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug - make build_lz4 make -j$(nproc) make vineyard_tests -j$(nproc) diff --git a/modules/kv-state-cache/lz4 b/modules/kv-state-cache/lz4 deleted file mode 160000 index 4cf83dd1..00000000 --- a/modules/kv-state-cache/lz4 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4cf83dd1952898e2a8c5fcd689ce459c53f22ff0 diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index 215b9fcf..c90063c5 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -14,10 +14,13 @@ limitations under the License. */ #include "radix-tree.h" + #include "common/util/base64.h" #include "common/util/logging.h" #include "common/util/status.h" +#include "zstd/lib/zstd.h" + using namespace vineyard; RadixTree::RadixTree(int cacheCapacity) { @@ -261,37 +264,20 @@ std::string RadixTree::Serialize() { } LOG(INFO) << "serializedStr:" << serializedStr; - // use LZ4 to compress the serialized string - const char* const src = serializedStr.c_str(); - const int srcSize = serializedStr.size(); - const int maxDstSize = LZ4_compressBound(srcSize); - char* compressedData = new char[maxDstSize]; - if (compressedData == NULL) { - LOG(INFO) << "Failed to allocate memory for *compressedData."; - } - - const int compressedDataSize = - LZ4_compress_default(src, compressedData, srcSize, maxDstSize); - if (compressedDataSize <= 0) { - LOG(INFO) << "A 0 or negative result from LZ4_compress_default() " - "indicates a failure trying to compress the data. "; + // use ZSTD to compress the serialized string + size_t srcSize = serializedStr.size(); + std::string compressedStr(srcSize, '\0'); + int compressedSize = ZSTD_compress((void *)(compressedStr.c_str()), compressedStr.length(), + serializedStr.c_str(), srcSize, 3); + if (ZSTD_isError(compressedSize)) { + LOG(ERROR) << "ZSTD compression failed: " << ZSTD_getErrorName(compressedSize); } - - if (compressedDataSize > 0) { - LOG(INFO) << "We successfully compressed some data! Ratio: " - << ((float) compressedDataSize / srcSize); - } - - if (compressedData == NULL) { - LOG(INFO) << "Failed to re-alloc memory for compressedData. Sad :("; - } - - std::string compressedStr = std::string(compressedData, compressedDataSize); int cacheCapacity = this->cacheCapacity - 1; + std::string result = std::string((char*) &srcSize, sizeof(int)) + std::string((char*) &cacheCapacity, sizeof(int)) + compressedStr; - delete[] compressedData; + return result; } @@ -302,26 +288,13 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { data.erase(0, sizeof(int)); int cacheCapacity = *(int*) data.c_str(); data.erase(0, sizeof(int)); - char* const decompressBuffer = new char[srcSize]; - if (decompressBuffer == NULL) { - LOG(INFO) << "Failed to allocate memory for *decompressBuffer."; - } - - const int decompressedSize = - LZ4_decompress_safe(data.c_str(), decompressBuffer, data.size(), srcSize); - if (decompressedSize < 0) { - LOG(INFO) << "A negative result from LZ4_decompress_safe indicates a " - "failure trying to decompress the data. See exit code " - "(echo $?) for value returned."; - } - if (decompressedSize >= 0) { - LOG(INFO) << "We successfully decompressed some data!"; + std::string decompressedStr(srcSize, '\0'); + int decompressedSize = ZSTD_decompress((void *)(decompressedStr.c_str()), decompressedStr.size(), + data.c_str(), srcSize); + if (ZSTD_isError(decompressedSize)) { + LOG(ERROR) << "ZSTD decompression failed: " << ZSTD_getErrorName(decompressedSize); } - // if (decompressedSize != data.size()) { - // LOG(INFO) << "Decompressed data is different from original! \n"; - // } - data = std::string(decompressBuffer, decompressedSize); - delete[] decompressBuffer; + data = decompressedStr.substr(0, decompressedSize); std::vector> tokenList; std::vector dataList; diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index 78a7d562..ce2d441c 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -20,7 +20,6 @@ limitations under the License. #include "common/util/base64.h" #include "common/util/logging.h" -#include "lz4.h" #include #include From 03d6325f48826ee40072241903dbcae0afb01be0 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:15:47 +0800 Subject: [PATCH 07/20] Recycle resources of kv state cache object. (#1748) Fixes #1732 Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 57 ++++++-- .../kv-state-cache/radix-tree/radix-tree.cc | 135 +++++++++++++----- .../kv-state-cache/radix-tree/radix-tree.h | 14 +- modules/kv-state-cache/radix-tree/radix.cc | 48 +++++-- modules/kv-state-cache/radix-tree/radix.h | 6 +- test/kv_state_cache_test.cc | 23 ++- test/rax_diff_test.cc | 2 +- 7 files changed, 201 insertions(+), 84 deletions(-) diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index 3437b223..25ebbb19 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -83,7 +83,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, std::shared_ptr rootTreeHeader = this->rootTree->GetRootNode(); rootTreeHeader->treeData->data = treeData; rootTreeHeader->treeData->dataLength = sizeof(TreeData); - this->rootTree->SetSubtreeData(treeData, sizeof(TreeData)); + this->rootTree->SetSubtreeData(treeData); LOG(INFO) << "set builder:" << builder << " to tree:" << this->rootTree->GetRootTree()->head; LOG(INFO) << "data:" << treeData @@ -102,7 +102,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, std::set subTreeData = cache->rootTree->GetSubTreeDataSet(); for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) { - TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data; + TreeData* treeData = (TreeData*) (*iter); LOG(INFO) << "tree data:" << treeData; VINEYARD_ASSERT(treeData->isPtr == false); LOG(INFO) << "id:" << treeData->builderObjectID; @@ -135,13 +135,12 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( kvStateCacheBlockBuilder->GetKeyStateBuilder(); const std::shared_ptr> valueStateTensorBuilder = kvStateCacheBlockBuilder->GetValueStateBuilder(); - OffsetData* new_offset_data = new OffsetData(); + OffsetData new_offset_data; childKVStateCacheBlockBuilder->Update( keyStateTensorBuilder->data() + index * this->dimension, valueStateTensorBuilder->data() + index * this->dimension, - this->dimension, new_offset_data); - nodeDataList[i]->nodeData->data = new_offset_data; - nodeDataList[i]->nodeData->dataLength = sizeof(OffsetData); + this->dimension, &new_offset_data); + data->offset = new_offset_data.offset; // Clear the bitmap. kvStateCacheBlockBuilder->DeleteKVCache(index); } @@ -177,6 +176,13 @@ void KVStateCacheBuilder::Update(Client& client, Delete(evictedNodeData); } + // if (evictedNodeData->treeData != nullptr && evictedNodeData->nodeData != + // nullptr) { + // if (evictedNodeData->nodeData->data != nullptr) { + // delete (TreeData*) evictedNodeData->nodeData->data; + // } + // } + // TBD // Use lock to protect the kv_state_cache_builder LOG(INFO) << "data:" << nodeData->treeData->data @@ -204,7 +210,7 @@ void KVStateCacheBuilder::Update(Client& client, subTreeHeader->treeData->data = newTreeData; subTreeHeader->treeData->dataLength = sizeof(TreeData); - rootTree->SetSubtreeData(newTreeData, sizeof(TreeData)); + rootTree->SetSubtreeData(newTreeData); LOG(INFO) << "block split success"; // kv_state_cache_builder->UnLock(); @@ -257,6 +263,15 @@ void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { kvStateCacheBlockBuilder->DeleteKVCache(data->offset); LOG(INFO) << "stage4"; delete data; + // TBD + // Refactor this code. The data should be deleted by the RadixTree + // delete (DataWrapper*) evictedNodeData->nodeData; + LOG(INFO) << "tree data:" << evictedNodeData->treeData->data; + if (evictedNodeData->cleanTreeData) { + LOG(INFO) << "erase"; + this->rootTree->GetSubTreeDataSet().erase(evictedNodeData->treeData->data); + } + evictedNodeData->RecycleSource(); } void KVStateCacheBuilder::Merge(Client& client, @@ -320,11 +335,10 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { // change the tree data from pointer to object id int count = 0; - LOG(INFO) << "count:" << count; std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); ++iter) { - TreeData* treeData = (TreeData*) ((DataWrapper*) *iter)->data; + TreeData* treeData = (TreeData*) (*iter); VINEYARD_ASSERT(treeData != nullptr); VINEYARD_ASSERT(treeData->isPtr == true); @@ -357,12 +371,25 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { } KVStateCacheBuilder::~KVStateCacheBuilder() { - // TBD - // std::vector> nodeDataList = - // RadixTree::TraverseTreeWithoutSubTree(this->rootTree); - // for (size_t i = 0; i < nodeDataList.size(); i++) { - // delete (OffsetData*) nodeDataList[i]->get_node()->get_data(); - // } + LOG(INFO) << "KVStateCacheBuilder::~KVStateCacheBuilder"; + // get all subtree data and node data + std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); + std::set nodeDataSet = rootTree->GetAllNodeData(); + // 2. delete all subtree data and node data + for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); + ++iter) { + TreeData* treeData = (TreeData*) (*iter); + if (treeData->isPtr == true) { + delete (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder; + delete treeData; + } + } + for (auto iter = nodeDataSet.begin(); iter != nodeDataSet.end(); ++iter) { + OffsetData* data = (OffsetData*) (*iter); + if (data != nullptr) { + delete data; + } + } } } // namespace vineyard diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index c90063c5..1c2dda09 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -40,22 +40,23 @@ RadixTree::RadixTree(int cacheCapacity) { data->data = nullptr; data->dataLength = 0; dataNode->custom_data = data; + LOG(INFO) << "root data wrapper:" << data; dataNode->issubtree = true; this->rootToken = rootToken; } RadixTree::~RadixTree() { - // TBD - // raxFreeWithCallback(this->tree, [](raxNode *n) { - // if (n->iskey && !n->isnull) { - // nodeData* nodedata = (nodeData*) raxGetData(n); - // delete nodedata; - // } - // if (n->issubtree && n->iscustomallocated && !n->iscustomnull) { - // customData* customdata = (customData*) raxGetCustomData(n); - // delete customdata; - // } - // }); + LOG(INFO) << "~RadixTree"; + raxShow(this->tree); + + raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL, false); + if (dataNode != nullptr) { + delete (DataWrapper*) dataNode->custom_data; + delete (DataWrapper*) raxGetData(dataNode); + } + + raxFree(this->tree); } std::shared_ptr RadixTree::Insert( @@ -154,15 +155,22 @@ void RadixTree::DeleteInternal(std::vector tokens, DataWrapper* oldData; raxNode* subTreeNode; std::vector pre; - // raxFindAndReturnDataNode(this->tree, deleteTokensArray, - // deleteTokensArrayLen, - // &subTreeNode, false); + raxNode* dataNode = raxFindAndReturnDataNode( + this->tree, deleteTokensArray, deleteTokensArrayLen, &subTreeNode, false); + bool nodeIsSubTree = false; + if (dataNode != nullptr && dataNode->issubtree) { + nodeIsSubTree = true; + } int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, - (void**) &oldData, &subTreeNode); + (void**) &oldData); if (retval == 1) { evictedNode = std::make_shared( oldData, (DataWrapper*) subTreeNode->custom_data); nodeCount--; + if (nodeIsSubTree) { + // subTreeDataSet.erase(subTreeNode->custom_data); + evictedNode->cleanTreeData = true; + } } else { LOG(INFO) << "remove failed"; } @@ -224,6 +232,15 @@ std::string RadixTree::Serialize() { serializedStr += timestampOSS.str() + "|"; + raxNode* node = + raxFindAndReturnDataNode(this->tree, tokenList[index].data(), + tokenList[index].size(), NULL, false); + uint32_t numNodes = node->numnodes; + std::ostringstream subTreeSizeOSS; + subTreeSizeOSS << std::hex << numNodes; + + serializedStr += subTreeSizeOSS.str() + "|"; + // convert data to hex string char* bytes = (char*) ((DataWrapper*) dataList[index])->data; std::ostringstream dataOSS; @@ -251,7 +268,7 @@ std::string RadixTree::Serialize() { char* bytes = (char*) ((DataWrapper*) subTreeDataList[index])->data; std::ostringstream dataOSS; - LOG(INFO) << "data lengtÏ€h:" + LOG(INFO) << "data length:" << ((DataWrapper*) subTreeDataList[index])->dataLength; for (int i = 0; i < ((DataWrapper*) subTreeDataList[index])->dataLength; ++i) { @@ -267,17 +284,21 @@ std::string RadixTree::Serialize() { // use ZSTD to compress the serialized string size_t srcSize = serializedStr.size(); std::string compressedStr(srcSize, '\0'); - int compressedSize = ZSTD_compress((void *)(compressedStr.c_str()), compressedStr.length(), - serializedStr.c_str(), srcSize, 3); + int compressedSize = + ZSTD_compress((void*) (compressedStr.c_str()), compressedStr.length(), + serializedStr.c_str(), srcSize, 3); if (ZSTD_isError(compressedSize)) { - LOG(ERROR) << "ZSTD compression failed: " << ZSTD_getErrorName(compressedSize); + LOG(ERROR) << "ZSTD compression failed: " + << ZSTD_getErrorName(compressedSize); } int cacheCapacity = this->cacheCapacity - 1; - std::string result = std::string((char*) &srcSize, sizeof(int)) + - std::string((char*) &cacheCapacity, sizeof(int)) + - compressedStr; - + std::string result = + std::string((char*) &srcSize, sizeof(int)) + + std::string((char*) &cacheCapacity, sizeof(int)) + + std::string((char*) &(this->tree->head->numnodes), sizeof(uint32_t)) + + compressedStr; + return result; } @@ -288,11 +309,15 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { data.erase(0, sizeof(int)); int cacheCapacity = *(int*) data.c_str(); data.erase(0, sizeof(int)); + int rootNumNodes = *(uint32_t*) data.c_str(); + data.erase(0, sizeof(uint32_t)); std::string decompressedStr(srcSize, '\0'); - int decompressedSize = ZSTD_decompress((void *)(decompressedStr.c_str()), decompressedStr.size(), - data.c_str(), srcSize); + int decompressedSize = + ZSTD_decompress((void*) (decompressedStr.c_str()), decompressedStr.size(), + data.c_str(), srcSize); if (ZSTD_isError(decompressedSize)) { - LOG(ERROR) << "ZSTD decompression failed: " << ZSTD_getErrorName(decompressedSize); + LOG(ERROR) << "ZSTD decompression failed: " + << ZSTD_getErrorName(decompressedSize); } data = decompressedStr.substr(0, decompressedSize); @@ -303,6 +328,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { std::vector> subTreeTokenList; std::vector subTreeDataList; std::vector subTreeDataSizeList; + std::vector subTreeSizeList; std::istringstream iss(data); std::string line; bool isMainTree = true; @@ -315,7 +341,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { } LOG(INFO) << "data line:" << line << std::endl; std::istringstream lineStream(line); - std::string tokenListPart, timestampPart, dataPart; + std::string tokenListPart, timestampPart, dataPart, subTreeSizePart; if (!std::getline(lineStream, tokenListPart, '|')) { throw std::runtime_error( @@ -326,6 +352,10 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { throw std::runtime_error( "Invalid serialized string format in timestamp part."); } + if (!std::getline(lineStream, subTreeSizePart, '|')) { + throw std::runtime_error( + "Invalid serialized string format in sub tree size part."); + } } if (!std::getline(lineStream, dataPart)) { LOG(INFO) << "data length is 0"; @@ -345,6 +375,15 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { LOG(INFO) << "Invalid timestamp format."; throw std::runtime_error("Invalid timestamp format."); } + + std::istringstream subTreeSizeStream(subTreeSizePart); + uint32_t subTreeSize; + if (!(subTreeSizeStream >> std::hex >> subTreeSize)) { + LOG(INFO) << "Invalid sub tree size format."; + throw std::runtime_error("Invalid sub tree size format."); + } + LOG(INFO) << "Deserialize sub tree size:" << subTreeSize; + subTreeSizeList.push_back(subTreeSize); } size_t dataSize = dataPart.length() / @@ -425,6 +464,16 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { } dataNode->timestamp = timestampList[i]; } + + for (size_t i = 0; i < tokenList.size(); i++) { + raxNode* node = raxFindAndReturnDataNode( + radixTree->tree, tokenList[i].data(), tokenList[i].size(), NULL, false); + LOG(INFO) << "node:" << node << " sub tree node num:" << subTreeSizeList[i]; + node->numnodes = subTreeSizeList[i]; + } + radixTree->tree->head->numnodes = rootNumNodes; + raxShow(radixTree->tree); + LOG(INFO) << "start to insert sub tree token list" << std::endl; for (size_t i = 0; i < subTreeTokenList.size(); i++) { for (size_t j = 0; j < subTreeTokenList[i].size(); j++) { @@ -449,20 +498,18 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { node->issubtree = true; raxSetCustomData(node, data); - // TBD - // refactor this code. - radixTree->subTreeDataSet.insert(data); + radixTree->subTreeDataSet.insert(subTreeDataList[i]); } LOG(INFO) << "Deserialize success"; + raxShow(radixTree->tree); return radixTree; } std::vector> RadixTree::SplitInternal( std::vector tokens, std::shared_ptr& header) { std::vector rootToken; - DataWrapper* dummyData = new DataWrapper(); raxNode* subTreeRootNode = - raxSplit(this->tree, tokens.data(), tokens.size(), dummyData, rootToken); + raxSplit(this->tree, tokens.data(), tokens.size(), rootToken); raxShow(this->tree); subTreeRootNode->issubtree = true; @@ -496,12 +543,9 @@ std::vector> RadixTree::TraverseTreeWithoutSubTree( return nodes; } -void RadixTree::SetSubtreeData(void* data, int dataLength) { - LOG(INFO) << "set subtree data"; - DataWrapper* dataWrapper = new DataWrapper(); - dataWrapper->data = data; - dataWrapper->dataLength = dataLength; - subTreeDataSet.insert(dataWrapper); +void RadixTree::SetSubtreeData(void* data) { + LOG(INFO) << "set subtree data:" << data; + subTreeDataSet.insert(data); } std::shared_ptr RadixTree::GetRootNode() { @@ -531,4 +575,19 @@ void RadixTree::MergeTree(std::shared_ptr tree_1, std::vector tmp(vec.begin() + 1, vec.end()); insert_tokens.insert(tmp); } +} + +std::set RadixTree::GetAllNodeData() { + raxIterator iter; + raxStart(&iter, this->tree); + raxSeek(&iter, "^", NULL, 0); + std::set nodeDataSet; + while (raxNext(&iter)) { + raxNode* node = iter.node; + if (node->isnull) { + continue; + } + nodeDataSet.insert(((DataWrapper*) raxGetData(node))->data); + } + return nodeDataSet; } \ No newline at end of file diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index ce2d441c..139f64e3 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -37,11 +37,21 @@ struct DataWrapper { struct NodeData { DataWrapper* nodeData; DataWrapper* treeData; + bool cleanTreeData = false; NodeData(DataWrapper* nodeData, DataWrapper* treeData) { this->nodeData = nodeData; this->treeData = treeData; } + + void RecycleSource() { + if (this->nodeData != nullptr) { + delete this->nodeData; + } + if (cleanTreeData && this->treeData != nullptr) { + delete this->treeData; + } + } }; class RadixTree : public std::enable_shared_from_this { @@ -87,7 +97,7 @@ class RadixTree : public std::enable_shared_from_this { static std::vector> TraverseTreeWithoutSubTree( raxNode* headNode); - void SetSubtreeData(void* data, int dataLength); + void SetSubtreeData(void* data); rax* GetRootTree() { return this->tree; } @@ -101,6 +111,8 @@ class RadixTree : public std::enable_shared_from_this { std::shared_ptr tree_2, std::vector>& evicted_tokens, std::set>& insert_tokens); + + std::set GetAllNodeData(); }; #endif diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 590d7513..1373fb43 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -155,10 +155,14 @@ static inline void raxStackFree(raxStack *ts) { /* Add the number of nodes in the stack to each node. */ void raxStackAddNumNodes(raxStack *stack, int num) { for (size_t i=0; iitems; i++) { - raxNode *node = (raxNode *)stack->stack[i]; + raxNode *node = (raxNode *)stack->stack[stack->items - i - 1]; node->numnodes+=(num); + if (node->issubtree) { + break; + } } } + /* ---------------------------------------------------------------------------- * Radix tree implementation * --------------------------------------------------------------------------*/ @@ -1097,6 +1101,25 @@ raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len, raxNode** sub_tr return h; } +// raxNode *raxSetSubTreeAndReturnDataNode(rax *rax, int *s, size_t len) { +// raxNode *h; + +// raxStack ts; +// raxStackInit(&ts); +// //debugf("### Lookup: %.*s\n", (int)len, s); +// int splitpos = 0; +// size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts,false); +// if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) +// return NULL; + +// if (h!= nullptr) { +// h->issubtree = true; +// raxStackAddNumNodes(&ts, -h->numnodes); +// } + +// return h; +// } + int raxFindNode(rax *rax, int *s, size_t len, void **node) { raxNode *h; raxStack ts; @@ -1227,7 +1250,7 @@ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { /* Remove the specified item. Returns 1 if the item was found and * deleted, 0 otherwise. */ -int raxRemove(rax *rax, int *s, size_t len, void **old, raxNode** sub_tree_node, bool set_timestamp) { +int raxRemove(rax *rax, int *s, size_t len, void **old, bool set_timestamp) { raxNode *h; raxStack ts; @@ -1240,14 +1263,6 @@ int raxRemove(rax *rax, int *s, size_t len, void **old, raxNode** sub_tree_node, raxStackFree(&ts); return 0; } - if (sub_tree_node != NULL) { - for (int i = ts.items - 1; i >= 0; i--) { - if (((raxNode *)ts.stack[i])->issubtree == true) { - *sub_tree_node = (raxNode *)ts.stack[i]; - break; - } - } - } if (old) *old = raxGetData(h); @@ -2137,6 +2152,11 @@ uint64_t raxSize(rax *rax) { * [1,2] -> [1,2,3,4] -> [] */ +struct datawrapper { + void *data; + int length; +}; + /* The actual implementation of raxShow(). */ void raxRecursiveShow(int level, int lpad, raxNode *n) { char s = n->iscompr ? '"' : '['; @@ -2156,6 +2176,9 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { numchars += printf("=%p",raxGetData(n)); } numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d", n, n->timestamp, n->custom_data, n->issubtree); + if (n->issubtree && n->custom_data != NULL) { + numchars += printf(" cus data:%p" , ((datawrapper *)(n->custom_data))->data); + } int numchildren = n->iscompr ? 1 : n->size; /* Note that 7 and 4 magic constants are the string length @@ -2184,6 +2207,7 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { /* Show a tree, as outlined in the comment above. */ void raxShow(rax *rax) { + printf("rax numnode:%lu\n", rax->numele); raxRecursiveShow(0,0,rax->head); putchar('\n'); } @@ -2295,7 +2319,7 @@ bool raxIsSubtree(raxNode *node) { * tree from the root node. * */ -raxNode *raxSplit(rax *rax, int *s, size_t len, void *data, std::vector& token) { +raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { raxNode *childNode = NULL; raxNode *splitNode = NULL; raxStack stack = raxFindWithStack(rax, s, len); @@ -2349,7 +2373,7 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, void *data, std::vector& to raxSetSubtree(splitNode); - raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); + raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); raxStackFree(&stack); return splitNode; diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h index 25321aa7..57da727c 100644 --- a/modules/kv-state-cache/radix-tree/radix.h +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -111,6 +111,7 @@ typedef struct raxNode { uint32_t numnodes; /* Number of the child nodes */ uint32_t numele; /* Number of elements inside this node. */ uint64_t timestamp; /* Timestamps of the node */ + uint32_t sub_tree_size; /* Number of nodes in the sub tree */ void *custom_data; /* Data layout is as follows: * @@ -212,7 +213,7 @@ int raxInsert(rax *rax, int *s, size_t len, void *data, void **old, bool set_tim int raxTryInsert(rax* rax, int* s, size_t len, void* data, void** old); int raxInsertAndReturnDataNode(rax* rax, int* s, size_t len, void* data, void** node, void** old); -int raxRemove(rax* rax, int* s, size_t len, void** old, raxNode** sub_tree_node = NULL, bool set_timestamp = true); +int raxRemove(rax* rax, int* s, size_t len, void** old, bool set_timestamp = true); void* raxFind(rax* rax, int* s, size_t len); raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len, raxNode** sub_tree_node = NULL, bool set_timestamp = true); void raxSetSubtree(raxNode *n); @@ -238,7 +239,7 @@ void raxSetDebugMsg(int onoff); void raxTraverse(raxNode* rax, std::vector>& dataNodeList); void raxTraverseSubTree(raxNode* n, std::vector &dataNodeList); -raxNode *raxSplit(rax *rax, int *s, size_t len, void *data, std::vector& key); +raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& key); void raxSerialize(rax* root, std::vector>& tokenList, std::vector& dataList, std::vector ×tampsList, std::vector> *subtreeList, std::vector *subtreeNodeList); @@ -251,4 +252,5 @@ void raxFindLastRecentNode(raxNode *node, std::vector& key); void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::set>& insert_tokens, int max_node); void testIteRax(rax *tree); raxNode* raxGetFirstChildPtr(raxNode* node); +// raxNode *raxSetSubTreeAndReturnDataNode(rax *rax, int *s, size_t len); #endif diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index 5782589f..a1f4ed21 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -102,30 +102,23 @@ int main() { std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14}; std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; - // total 24 node - // tree 1 : 18 node - // tree 2 : 16 node - // std::vector round_3_tokens = {1, 2, 3, 4, 5, 6, 7}; - // std::vector round_1_tokens = {1, 2}; - // std::vector round_2_tokens = {1, 3}; - // std::vector round_3_tokens = {1, 3, 4}; - // std::vector round_4_tokens = {1, 3, 5}; - // std::vector round_5_tokens = {1, 1}; + std::vector round_4_tokens = {1, 2, 3, 4, 5, 6}; + inference(round_1_tokens); inference(round_2_tokens); sleep(5); inference(round_1_tokens); inference(round_2_tokens); - // sleep(5); inference(round_3_tokens); - // inference(round_3_tokens); - // inference(round_3_tokens); // inference(round_4_tokens); - // inference(round_5_tokens); // sleep(5); - // inference(round_2_tokens); - // inference(round_1_tokens, true); + // Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + + // 6)); Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + // + 5)); Delete(std::vector(round_4_tokens.begin(), + // round_4_tokens.begin() + 4)); + // Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + + // 3)); while (1) ; return 0; diff --git a/test/rax_diff_test.cc b/test/rax_diff_test.cc index e86731d1..ce1dcf1b 100644 --- a/test/rax_diff_test.cc +++ b/test/rax_diff_test.cc @@ -79,7 +79,7 @@ int main(int argc, char** argv) { for (size_t i = 0; i < evicted_tokens.size(); i++) { // void* tree_data; raxRemove(rt_1, evicted_tokens[i].data(), evicted_tokens[i].size(), NULL, - NULL, false); + false); } for (auto it = insert_tokens.begin(); it != insert_tokens.end(); it++) { From 5065f3b47b8f77600a7a816fb79071b26917364c Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Mon, 5 Feb 2024 19:14:40 +0800 Subject: [PATCH 08/20] Add the unit test for the radix tree. (#1749) Fixes #1747 Signed-off-by: Ye Cao --- modules/kv-state-cache/radix-tree/radix.cc | 2 +- test/kv_state_cache_radix_tree_test.cc | 189 +++++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 test/kv_state_cache_radix_tree_test.cc diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 1373fb43..8e5df06c 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -2336,8 +2336,8 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { index--; } - // find the node that has N/2 children + childNode = (raxNode *)stack.stack[stack.items - 1]; while (items > 0) { raxNode *node = (raxNode *)raxStackPop(&stack); if (node->numnodes > (uint32_t)subtreeNumNodes/2 || node->issubtree) { diff --git a/test/kv_state_cache_radix_tree_test.cc b/test/kv_state_cache_radix_tree_test.cc new file mode 100644 index 00000000..c6faa61a --- /dev/null +++ b/test/kv_state_cache_radix_tree_test.cc @@ -0,0 +1,189 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include "kv-state-cache/radix-tree/radix.h" + +#include "common/util/logging.h" +#include "kv-state-cache/utils/kv_state_cache_utils.h" + +using namespace vineyard; + +void print_tokens(const std::vector& tokens) { + std::string tokens_str = ""; + for (size_t i = 0; i < tokens.size(); ++i) { + tokens_str += std::to_string(tokens[i]); + } + LOG(INFO) << "Current tokens: " + tokens_str; +} + +void radix_tree_insert_test() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* insert new token and check whether the old token is evicted */ + tokens.clear(); + for (int i = 1; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data != NULL); + } + + /* insert a token that prefix is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) == NULL); +} + +void radix_tree_delete_test(){ + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* delete a token list*/ + + tokens.clear(); + node_data = NULL; + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data != NULL); + + /* delete a token list that is not in the radix tree */ + tokens.clear(); + node_data = NULL; + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data == NULL); +} + +void radix_tree_query_test() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* query a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) != NULL); + + /* query a token list that is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) == NULL); +} + +void radix_tree_serialize_and_deserailize() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* serialize radix tree */ + std::string serialized_radix_tree = radix_tree->Serialize(); + + /* deserialize radix tree */ + std::shared_ptr deserialized_radix_tree = radix_tree->Deserialize(serialized_radix_tree); + + /* query to check whether all token list exist */ + tokens.clear(); + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + print_tokens(tokens); + VINEYARD_ASSERT(deserialized_radix_tree->Query(tokens) != NULL); + } +} + +void radix_tree_split() { + std::shared_ptr radix_tree = std::make_shared(20); + + raxShow(radix_tree->tree); + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* split a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + print_tokens(tokens); + std::shared_ptr subTreeHeader; + std::vector> node_data_list = radix_tree->Split(tokens, subTreeHeader); + VINEYARD_ASSERT(node_data_list.size() == 7); +} + +int main() { + LOG(INFO) << "Start to test radix tree insert..."; + radix_tree_insert_test(); + LOG(INFO) << "Finish radix tree insert test!"; + LOG(INFO) << "Start to test radix tree delete..."; + radix_tree_delete_test(); + LOG(INFO) << "Finish radix tree delete test!"; + LOG(INFO) << "Start to test radix tree query..."; + radix_tree_query_test(); + LOG(INFO) << "Finish radix tree query test!"; + LOG(INFO) << "Start to test radix tree serialize and deserialize..."; + radix_tree_serialize_and_deserailize(); + LOG(INFO) << "Finish radix tree serialize and deserialize test!"; + LOG(INFO) << "Start to test radix tree split..."; + radix_tree_split(); + LOG(INFO) << "Finish radix tree split test!"; +} \ No newline at end of file From e7fafc8dd50babc352b4a6ed97261c775c9cc259 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Mon, 5 Feb 2024 23:46:24 +0800 Subject: [PATCH 09/20] Fix the bug in the radix split logic. (#1751) Signed-off-by: Ye Cao --- modules/kv-state-cache/radix-tree/radix.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 8e5df06c..48a115bc 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -2337,7 +2337,6 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { } // find the node that has N/2 children - childNode = (raxNode *)stack.stack[stack.items - 1]; while (items > 0) { raxNode *node = (raxNode *)raxStackPop(&stack); if (node->numnodes > (uint32_t)subtreeNumNodes/2 || node->issubtree) { @@ -2368,7 +2367,12 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { // if the splitNode is NULL, it means that the tree only has one node if (splitNode == NULL) { - return rax->head; + // if the stack is not empty, it means that top of the stack is the target node + if (stack.items > 0) { + return (raxNode *)raxStackPop(&stack); + } else { + return rax->head; + } } raxSetSubtree(splitNode); From 17be7a94b5271b7eb469058d8241f68d4dac7954 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Thu, 22 Feb 2024 19:37:28 +0800 Subject: [PATCH 10/20] Support layer for kv state cache (#1766) Fixes #1742 Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 24 +- modules/kv-state-cache/ds/kv_state_cache.h | 9 +- .../kv-state-cache/ds/kv_state_cache_block.cc | 173 ++++++----- .../kv-state-cache/ds/kv_state_cache_block.h | 53 +++- .../utils/kv_state_cache_utils.cc | 16 +- .../utils/kv_state_cache_utils.h | 3 +- test/kv_state_cache_radix_tree_test.cc | 276 +++++++++--------- test/kv_state_cache_test.cc | 26 +- test/kv_state_cache_test_2.cc | 26 +- 9 files changed, 347 insertions(+), 259 deletions(-) diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index 25ebbb19..bccf17cc 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -60,6 +60,7 @@ void KVStateCache::Resolve() { // 3. construct the member field this->dimension = this->meta_.GetKeyValue("dimension"); this->version = this->meta_.GetKeyValue("version"); + this->layer = this->meta_.GetKeyValue("layer"); LOG(INFO) << "construct the member field success" << std::endl; } @@ -68,11 +69,12 @@ KVStateCache::~KVStateCache() { } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, - int cacheCapacity) { + int cacheCapacity, int layer) { this->dimension = dimension; this->version = 0; + this->layer = layer; KVStateCacheBlockBuilder* builder = - new KVStateCacheBlockBuilder(client, this->dimension); + new KVStateCacheBlockBuilder(client, this->dimension, layer); this->rootTree = std::make_shared(cacheCapacity); @@ -95,6 +97,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, // TBD this->dimension = cache->GetDemension(); this->version = cache->GetVersion(); + this->layer = cache->GetLayer(); // 1. create block builder from block std::map> kvStateCacheBlockMap = cache->kvStateCacheBlockMap; @@ -123,7 +126,7 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( // Split the tree if the list of kvState is full. VINEYARD_ASSERT(nodeDataList.size() > 0); KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder = - new KVStateCacheBlockBuilder(client, this->dimension); + new KVStateCacheBlockBuilder(client, this->dimension, this->layer); for (size_t i = 0; i < nodeDataList.size(); i++) { OffsetData* data = (OffsetData*) nodeDataList[i]->nodeData->data; if (data == nullptr) @@ -131,18 +134,8 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( int index = data->offset; // Transfer the data from this builder to the child builder. - const std::shared_ptr> keyStateTensorBuilder = - kvStateCacheBlockBuilder->GetKeyStateBuilder(); - const std::shared_ptr> valueStateTensorBuilder = - kvStateCacheBlockBuilder->GetValueStateBuilder(); - OffsetData new_offset_data; - childKVStateCacheBlockBuilder->Update( - keyStateTensorBuilder->data() + index * this->dimension, - valueStateTensorBuilder->data() + index * this->dimension, - this->dimension, &new_offset_data); - data->offset = new_offset_data.offset; - // Clear the bitmap. - kvStateCacheBlockBuilder->DeleteKVCache(index); + data->offset = + kvStateCacheBlockBuilder->Split(childKVStateCacheBlockBuilder, index); } LOG(INFO) << "builder:" << kvStateCacheBlockBuilder << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); @@ -330,6 +323,7 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { // 1. store the member variables to cache object meta kvStateCache->meta_.AddKeyValue("dimension", this->dimension); kvStateCache->meta_.AddKeyValue("version", this->version); + kvStateCache->meta_.AddKeyValue("layer", this->layer); // 2. seal all the block and put object id to cache object and // change the tree data from pointer to object id diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h index ca4f4b70..fbdffb7b 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -43,6 +43,7 @@ class KVStateCache : public vineyard::Registered { std::shared_ptr rootTree; int dimension; int cacheCapacity; + int layer; uint64_t version; public: @@ -68,6 +69,8 @@ class KVStateCache : public vineyard::Registered { std::shared_ptr GetRootTree() { return this->rootTree; } + int GetLayer() { return this->layer; } + ~KVStateCache(); friend class KVStateCacheBuilder; @@ -76,10 +79,12 @@ class KVStateCache : public vineyard::Registered { class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr rootTree; int dimension; + int layer = 1; uint64_t version; public: - KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity); + KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity, + int layer); KVStateCacheBuilder(Client& client, std::shared_ptr cache); @@ -109,6 +114,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr GetRootTree() { return this->rootTree; } + int GetLayer() { return this->layer; } + ~KVStateCacheBuilder(); }; diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/kv-state-cache/ds/kv_state_cache_block.cc index 2b49af15..097fbc8d 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.cc +++ b/modules/kv-state-cache/ds/kv_state_cache_block.cc @@ -49,63 +49,79 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) { // TBD // 1. construct the keyStateTensorBuilder and valueStateTensorBuilder - this->keyStateTensor = std::dynamic_pointer_cast>( - this->meta_.GetMember("keyStateTensorBuilder")); - this->valueStateTensor = std::dynamic_pointer_cast>( - this->meta_.GetMember("valueStateTensorBuilder")); + this->layer = this->meta_.GetKeyValue("layer"); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + this->keyStateTensorList.push_back( + std::dynamic_pointer_cast>(this->meta_.GetMember( + "keyStateTensorBuilder_" + std::to_string(currentLayer)))); + this->valueStateTensorList.push_back( + std::dynamic_pointer_cast>(this->meta_.GetMember( + "valueStateTensorBuilder_" + std::to_string(currentLayer)))); + } // 2. construct the member field this->bitmap = this->meta_.GetKeyValue("bitmap"); this->dimension = this->meta_.GetKeyValue("dimension"); } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, - int dimension) { + int dimension, int layer) { this->bitmap = UINT64_MAX; std::vector shape = {LIST_SIZE, dimension}; - this->keyStateTensorBuilder = - std::make_shared>(client, shape); - this->valueStateTensorBuilder = - std::make_shared>(client, shape); + for (int i = 0; i < layer; i++) { + this->keyStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + this->valueStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + } this->dimension = dimension; + this->layer = layer; } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( Client& client, std::shared_ptr kvStateCacheBlock) { this->bitmap = kvStateCacheBlock->bitmap; this->dimension = kvStateCacheBlock->dimension; + this->layer = kvStateCacheBlock->layer; std::vector shape = {LIST_SIZE, dimension}; - this->keyStateTensorBuilder = - std::make_shared>(client, shape); - this->valueStateTensorBuilder = - std::make_shared>(client, shape); - - // transfer the data from kv_state_cache to this builder - memcpy(this->keyStateTensorBuilder->data(), - kvStateCacheBlock->keyStateTensor->data(), - LIST_SIZE * this->dimension * sizeof(double)); - memcpy(this->valueStateTensorBuilder->data(), - kvStateCacheBlock->valueStateTensor->data(), - LIST_SIZE * this->dimension * sizeof(double)); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + this->keyStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + this->valueStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + } + + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + memcpy(this->keyStateTensorBuilderList[currentLayer]->data(), + kvStateCacheBlock->keyStateTensorList[currentLayer]->data(), + LIST_SIZE * this->dimension * sizeof(double)); + memcpy(this->valueStateTensorBuilderList[currentLayer]->data(), + kvStateCacheBlock->valueStateTensorList[currentLayer]->data(), + LIST_SIZE * this->dimension * sizeof(double)); + } } // current we do not consider the layer. Status KVStateCacheBlockBuilder::Query(Client& client, int index, KV_STATE_WITH_LAYER& kvState) { - std::vector keyStateVector; - std::vector valueStateVector; - - for (int i = 0; i < this->dimension; ++i) { - keyStateVector.push_back( - ((double*) keyStateTensorBuilder->data())[index * dimension + i]); - } - - for (int i = 0; i < this->dimension; ++i) { - valueStateVector.push_back( - ((double*) valueStateTensorBuilder->data())[index * dimension + i]); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + std::vector keyStateVector; + std::vector valueStateVector; + + for (int i = 0; i < this->dimension; ++i) { + keyStateVector.push_back( + ((double*) keyStateTensorBuilderList[currentLayer] + ->data())[index * dimension + i]); + } + + for (int i = 0; i < this->dimension; ++i) { + valueStateVector.push_back( + ((double*) valueStateTensorBuilderList[currentLayer] + ->data())[index * dimension + i]); + } + + kvState.insert(std::make_pair( + currentLayer, std::make_pair(keyStateVector, valueStateVector))); } - - kvState.insert( - std::make_pair(1, std::make_pair(keyStateVector, valueStateVector))); return Status::OK(); } @@ -124,40 +140,60 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, OffsetData* data) { int index = this->FindEmptySlot(); LOG(INFO) << "index:" << index; - std::vector keyStateVector = (kvState.find(1)->second).first; - std::vector valueStateVector = (kvState.find(1)->second).second; - VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); - VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); - - double* keyData = (double*) keyStateTensorBuilder->data(); - double* valueData = (double*) valueStateTensorBuilder->data(); - for (int i = 0; i < this->dimension; ++i) { - keyData[index * this->dimension + i] = keyStateVector[i]; - } - for (int i = 0; i < this->dimension; ++i) { - valueData[index * this->dimension + i] = valueStateVector[i]; + LOG(INFO) << "layer:" << layer; + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + std::vector keyStateVector = + (kvState.find(currentLayer)->second).first; + std::vector valueStateVector = + (kvState.find(currentLayer)->second).second; + LOG(INFO) << "vector size:" << keyStateVector.size() << " " + << valueStateVector.size() << " demension" << this->dimension; + VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); + VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); + + double* keyData = (double*) keyStateTensorBuilderList[currentLayer]->data(); + double* valueData = + (double*) valueStateTensorBuilderList[currentLayer]->data(); + memcpy(keyData + index * this->dimension, keyStateVector.data(), + this->dimension * sizeof(double)); + memcpy(valueData + index * this->dimension, valueStateVector.data(), + this->dimension * sizeof(double)); } data->offset = index; ACQUIRE_BIT_RESOURCE(this->bitmap, index); } -void KVStateCacheBlockBuilder::Update(double* keyState, double* valueState, - unsigned long dataLength, - OffsetData* data) { - int index = FindEmptySlot(); - double* keyData = (double*) keyStateTensorBuilder->data(); - double* valueData = (double*) valueStateTensorBuilder->data(); - VINEYARD_ASSERT((unsigned long) this->dimension == dataLength); - for (unsigned long i = 0; i < dataLength; ++i) { - keyData[index * this->dimension + i] = keyState[i]; - } - for (unsigned long i = 0; i < dataLength; ++i) { - valueData[index * this->dimension + i] = valueState[i]; +short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, + int index) { + // TBD + VINEYARD_ASSERT(this->layer == child->layer); + int childIndex = child->FindEmptySlot(); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + std::shared_ptr> keyStateTensorBuilder = + keyStateTensorBuilderList[currentLayer]; + std::shared_ptr> valueStateTensorBuilder = + valueStateTensorBuilderList[currentLayer]; + std::shared_ptr> childKeyStateTensorBuilder = + child->keyStateTensorBuilderList[currentLayer]; + std::shared_ptr> childValueStateTensorBuilder = + child->valueStateTensorBuilderList[currentLayer]; + + double* keyState = + (double*) keyStateTensorBuilder->data() + index * this->dimension; + double* valueState = + (double*) valueStateTensorBuilder->data() + index * this->dimension; + double* childKeyState = (double*) childKeyStateTensorBuilder->data() + + childIndex * this->dimension; + double* childValueState = (double*) childValueStateTensorBuilder->data() + + childIndex * this->dimension; + + memcpy(childKeyState, keyState, this->dimension * sizeof(double)); + memcpy(childValueState, valueState, this->dimension * sizeof(double)); } - data->offset = index; - - ACQUIRE_BIT_RESOURCE(this->bitmap, index); + ACQUIRE_BIT_RESOURCE(child->bitmap, childIndex); + FREE_BIT_RESOURCE(this->bitmap, index); + return childIndex; } Status KVStateCacheBlockBuilder::Build(Client& client) { return Status::OK(); } @@ -170,14 +206,19 @@ std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { std::make_shared(); // 1. seal keyStateTensorBuilder and valueStateTensorBuilder - kvStateCacheBlock->meta_.AddMember("keyStateTensorBuilder", - keyStateTensorBuilder->Seal(client)); - kvStateCacheBlock->meta_.AddMember("valueStateTensorBuilder", - valueStateTensorBuilder->Seal(client)); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + kvStateCacheBlock->meta_.AddMember( + "keyStateTensorBuilder_" + std::to_string(currentLayer), + keyStateTensorBuilderList[currentLayer]->Seal(client)); + kvStateCacheBlock->meta_.AddMember( + "valueStateTensorBuilder_" + std::to_string(currentLayer), + valueStateTensorBuilderList[currentLayer]->Seal(client)); + } // 2. store the member field to meta kvStateCacheBlock->meta_.AddKeyValue("bitmap", this->bitmap); kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension); + kvStateCacheBlock->meta_.AddKeyValue("layer", this->layer); // 3. set the object type to meta kvStateCacheBlock->meta_.SetTypeName(type_name()); diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index 667cd467..8602dd40 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -31,6 +31,10 @@ typedef std::map, std::vector>> typedef std::vector< std::map, std::vector>>> LIST_KV_STATE_WITH_LAYER; +typedef std::vector, std::vector>> + KV_STATE; +typedef std::vector, std::vector>> + LIST_KV_STATE; // Set the bit to 1, which means the resource is not being used #define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit))) @@ -61,10 +65,11 @@ namespace vineyard { class KVStateCacheBlock : public vineyard::Registered { private: - std::shared_ptr> keyStateTensor; - std::shared_ptr> valueStateTensor; + std::vector>> keyStateTensorList; + std::vector>> valueStateTensorList; uint64_t bitmap; ObjectID id; + int layer; int dimension; public: @@ -81,12 +86,20 @@ class KVStateCacheBlock : public vineyard::Registered { uint64_t GetBitmap() { return this->bitmap; } - std::shared_ptr> GetKeyTensor() { - return this->keyStateTensor; + std::shared_ptr> GetKeyTensor(int layer) { + return this->keyStateTensorList[layer]; } - std::shared_ptr> GetValueTensor() { - return this->valueStateTensor; + std::shared_ptr> GetValueTensor(int layer) { + return this->valueStateTensorList[layer]; + } + + std::vector>> GetKeyTensorList() { + return this->keyStateTensorList; + } + + std::vector>> GetValueTensorList() { + return this->valueStateTensorList; } friend class KVStateCacheBlockBuilder; @@ -94,17 +107,19 @@ class KVStateCacheBlock : public vineyard::Registered { class KVStateCacheBlockBuilder : public ObjectBuilder { private: - std::shared_ptr> keyStateTensorBuilder; - std::shared_ptr> valueStateTensorBuilder; + std::vector>> keyStateTensorBuilderList; + std::vector>> + valueStateTensorBuilderList; // TBD // support more than 64 kv-state cache slots uint64_t bitmap; int dimension; + int layer; int FindEmptySlot(); public: - KVStateCacheBlockBuilder(Client& client, int dimension); + KVStateCacheBlockBuilder(Client& client, int dimension, int layer); KVStateCacheBlockBuilder( Client& client, std::shared_ptr kv_state_cache_block); @@ -137,12 +152,24 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { std::shared_ptr _Seal(Client& client) override; - const std::shared_ptr> GetKeyStateBuilder() { - return keyStateTensorBuilder; + short Split(KVStateCacheBlockBuilder* child, int index); + + const std::shared_ptr> GetKeyStateBuilder(int layer) { + return keyStateTensorBuilderList[layer]; + } + + const std::shared_ptr> GetValueStateBuilder(int layer) { + return valueStateTensorBuilderList[layer]; + } + + const std::vector>> + GetKeyStateBuilderList() { + return keyStateTensorBuilderList; } - const std::shared_ptr> GetValueStateBuilder() { - return valueStateTensorBuilder; + const std::vector>> + GetValueStateBuilderList() { + return valueStateTensorBuilderList; } void DeleteKVCache(int bit) { FREE_BIT_RESOURCE(this->bitmap, bit); } diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index 7fba29e2..06b6d96e 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -56,7 +56,7 @@ void signalHandler(int signum) { exit(signum); } -void InitKVStateCache(int dimension, int cacheCapacity) { +void InitKVStateCache(int dimension, int cacheCapacity, int layer) { if (kvStateCacheBuilder == nullptr) { std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); LOG(INFO) << "socket:" << socket; @@ -94,7 +94,7 @@ void InitKVStateCache(int dimension, int cacheCapacity) { // if failed, create a new cache object LOG(INFO) << "failed to get the cache object, create a new one"; kvStateCacheBuilder = std::make_shared( - client, dimension, cacheCapacity); + client, dimension, cacheCapacity, layer); } // // release the lock @@ -109,7 +109,7 @@ void InitKVStateCache(int dimension, int cacheCapacity) { } } -void updateInternal(const std::vector& tokenList, int nextToken, +void UpdateInternal(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState) { kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); } @@ -121,7 +121,7 @@ void Update(const std::vector& tokenList, int nextToken, return; } - updateInternal(tokenList, nextToken, kvState); + UpdateInternal(tokenList, nextToken, kvState); pthread_mutex_unlock(&syncMutex); } @@ -133,13 +133,13 @@ void Update(const std::vector& tokenList, } std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - updateInternal(tokenListCopy, tokenList[i], kvState[i]); + UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); tokenListCopy.push_back(tokenList[i]); } pthread_mutex_unlock(&syncMutex); } -KV_STATE_WITH_LAYER queryInternal(const std::vector& tokenList, +KV_STATE_WITH_LAYER QueryInternal(const std::vector& tokenList, int token) { return kvStateCacheBuilder->Query(client, tokenList, token); } @@ -151,7 +151,7 @@ KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token) { return result; } - result = queryInternal(tokenList, token); + result = QueryInternal(tokenList, token); pthread_mutex_unlock(&syncMutex); return result; @@ -165,7 +165,7 @@ LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList) { std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - KV_STATE_WITH_LAYER kvState = queryInternal(tokenListCopy, tokenList[i]); + KV_STATE_WITH_LAYER kvState = QueryInternal(tokenListCopy, tokenList[i]); listKVState.push_back(kvState); tokenListCopy.push_back(tokenList[i]); } diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h index fa498c1e..4110ae96 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -18,7 +18,8 @@ limitations under the License. #ifndef MODULES_KV_STATE_CACHE_UTILS_H_ #define MODULES_KV_STATE_CACHE_UTILS_H_ -void InitKVStateCache(int dimension = 10, int cacheCapacity = 10); +void InitKVStateCache(int dimension = 10, int cacheCapacity = 10, + int layer = 1); void Update(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState); diff --git a/test/kv_state_cache_radix_tree_test.cc b/test/kv_state_cache_radix_tree_test.cc index c6faa61a..2c06f8c5 100644 --- a/test/kv_state_cache_radix_tree_test.cc +++ b/test/kv_state_cache_radix_tree_test.cc @@ -33,157 +33,159 @@ void print_tokens(const std::vector& tokens) { } void radix_tree_insert_test() { - std::shared_ptr radix_tree = std::make_shared(10); - - /* insert a token list*/ - std::vector tokens; - std::shared_ptr node_data; - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data == NULL); - } - - /* insert new token and check whether the old token is evicted */ - tokens.clear(); - for (int i = 1; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data != NULL); - } - - /* insert a token that prefix is not in the radix tree */ - tokens.clear(); - for (int i = 10; i > 0; i--) { - tokens.push_back(i); - } - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) == NULL); -} + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } -void radix_tree_delete_test(){ - std::shared_ptr radix_tree = std::make_shared(10); - - /* insert a token list*/ - std::vector tokens; - std::shared_ptr node_data; - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data == NULL); - } - - /* delete a token list*/ - - tokens.clear(); - node_data = NULL; - for (int i = 0; i < 5; i++) { - tokens.push_back(i); - } - radix_tree->Delete(tokens, node_data); + /* insert new token and check whether the old token is evicted */ + tokens.clear(); + for (int i = 1; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); VINEYARD_ASSERT(node_data != NULL); + } + + /* insert a token that prefix is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) == NULL); +} + +void radix_tree_delete_test() { + std::shared_ptr radix_tree = std::make_shared(10); - /* delete a token list that is not in the radix tree */ - tokens.clear(); - node_data = NULL; - for (int i = 10; i > 0; i--) { - tokens.push_back(i); - } - radix_tree->Delete(tokens, node_data); + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); VINEYARD_ASSERT(node_data == NULL); + } + + /* delete a token list*/ + + tokens.clear(); + node_data = NULL; + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data != NULL); + + /* delete a token list that is not in the radix tree */ + tokens.clear(); + node_data = NULL; + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data == NULL); } void radix_tree_query_test() { - std::shared_ptr radix_tree = std::make_shared(10); - - /* insert a token list*/ - std::vector tokens; - std::shared_ptr node_data; - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data == NULL); - } - - /* query a token list*/ - tokens.clear(); - for (int i = 0; i < 5; i++) { - tokens.push_back(i); - } - VINEYARD_ASSERT(radix_tree->Query(tokens) != NULL); - - /* query a token list that is not in the radix tree */ - tokens.clear(); - for (int i = 10; i > 0; i--) { - tokens.push_back(i); - } - VINEYARD_ASSERT(radix_tree->Query(tokens) == NULL); + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* query a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) != NULL); + + /* query a token list that is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) == NULL); } void radix_tree_serialize_and_deserailize() { - std::shared_ptr radix_tree = std::make_shared(10); - - /* insert a token list*/ - std::vector tokens; - std::shared_ptr node_data; - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data == NULL); - } - - /* serialize radix tree */ - std::string serialized_radix_tree = radix_tree->Serialize(); - - /* deserialize radix tree */ - std::shared_ptr deserialized_radix_tree = radix_tree->Deserialize(serialized_radix_tree); - - /* query to check whether all token list exist */ - tokens.clear(); - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - print_tokens(tokens); - VINEYARD_ASSERT(deserialized_radix_tree->Query(tokens) != NULL); - } + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* serialize radix tree */ + std::string serialized_radix_tree = radix_tree->Serialize(); + + /* deserialize radix tree */ + std::shared_ptr deserialized_radix_tree = + radix_tree->Deserialize(serialized_radix_tree); + + /* query to check whether all token list exist */ + tokens.clear(); + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + print_tokens(tokens); + VINEYARD_ASSERT(deserialized_radix_tree->Query(tokens) != NULL); + } } void radix_tree_split() { - std::shared_ptr radix_tree = std::make_shared(20); - - raxShow(radix_tree->tree); - /* insert a token list*/ - std::vector tokens; - std::shared_ptr node_data; - for (int i = 0; i < 10; i++) { - tokens.push_back(i); - VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); - VINEYARD_ASSERT(node_data == NULL); - } - - /* split a token list*/ - tokens.clear(); - for (int i = 0; i < 5; i++) { - tokens.push_back(i); - } - print_tokens(tokens); - std::shared_ptr subTreeHeader; - std::vector> node_data_list = radix_tree->Split(tokens, subTreeHeader); - VINEYARD_ASSERT(node_data_list.size() == 7); + std::shared_ptr radix_tree = std::make_shared(20); + + raxShow(radix_tree->tree); + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* split a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + print_tokens(tokens); + std::shared_ptr subTreeHeader; + std::vector> node_data_list = + radix_tree->Split(tokens, subTreeHeader); + VINEYARD_ASSERT(node_data_list.size() == 7); } int main() { - LOG(INFO) << "Start to test radix tree insert..."; - radix_tree_insert_test(); - LOG(INFO) << "Finish radix tree insert test!"; - LOG(INFO) << "Start to test radix tree delete..."; - radix_tree_delete_test(); - LOG(INFO) << "Finish radix tree delete test!"; - LOG(INFO) << "Start to test radix tree query..."; - radix_tree_query_test(); - LOG(INFO) << "Finish radix tree query test!"; - LOG(INFO) << "Start to test radix tree serialize and deserialize..."; - radix_tree_serialize_and_deserailize(); - LOG(INFO) << "Finish radix tree serialize and deserialize test!"; - LOG(INFO) << "Start to test radix tree split..."; - radix_tree_split(); - LOG(INFO) << "Finish radix tree split test!"; + LOG(INFO) << "Start to test radix tree insert..."; + radix_tree_insert_test(); + LOG(INFO) << "Finish radix tree insert test!"; + LOG(INFO) << "Start to test radix tree delete..."; + radix_tree_delete_test(); + LOG(INFO) << "Finish radix tree delete test!"; + LOG(INFO) << "Start to test radix tree query..."; + radix_tree_query_test(); + LOG(INFO) << "Finish radix tree query test!"; + LOG(INFO) << "Start to test radix tree serialize and deserialize..."; + radix_tree_serialize_and_deserailize(); + LOG(INFO) << "Finish radix tree serialize and deserialize test!"; + LOG(INFO) << "Start to test radix tree split..."; + radix_tree_split(); + LOG(INFO) << "Finish radix tree split test!"; } \ No newline at end of file diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index a1f4ed21..81332148 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -26,8 +26,9 @@ using namespace vineyard; #define DEMENSION 10 #define CAPACITY 20 +#define LAYER 3 -void init() { InitKVStateCache(DEMENSION, CAPACITY); } +void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER); } void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; @@ -49,23 +50,30 @@ void print_kv_state( key_state_str += std::to_string(iter->second.first[i]) + " "; value_state_str += std::to_string(iter->second.second[i]) + " "; } + LOG(INFO) << "layer " << iter->first << ":"; LOG(INFO) << "key_state: " << key_state_str; LOG(INFO) << "value_state: " << value_state_str; + LOG(INFO) << "---------------------"; } } // we do not consider the layer. std::map, std::vector>> generate_kv_state(int token) { - std::vector key_state; - std::vector value_state; - for (int i = 0; i < DEMENSION; ++i) { - key_state.push_back(((double) token) / DEMENSION * (i + 1)); - value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2); - } - std::map, std::vector>> kv_state; - kv_state.insert(std::make_pair(1, std::make_pair(key_state, value_state))); + for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { + std::vector key_state; + std::vector value_state; + for (int i = 0; i < DEMENSION; ++i) { + key_state.push_back(((double) token) / DEMENSION * (i + 1) + + currentLayer * 10); + value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2 + + currentLayer * 10); + } + + kv_state.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } return kv_state; } diff --git a/test/kv_state_cache_test_2.cc b/test/kv_state_cache_test_2.cc index 28f8bdda..4bc29157 100644 --- a/test/kv_state_cache_test_2.cc +++ b/test/kv_state_cache_test_2.cc @@ -26,8 +26,9 @@ using namespace vineyard; #define DEMENSION 10 #define CAPACITY 20 +#define LAYER 3 -void init() { InitKVStateCache(DEMENSION, CAPACITY); } +void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER); } void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; @@ -49,23 +50,30 @@ void print_kv_state( key_state_str += std::to_string(iter->second.first[i]) + " "; value_state_str += std::to_string(iter->second.second[i]) + " "; } + LOG(INFO) << "layer " << iter->first << ":"; LOG(INFO) << "key_state: " << key_state_str; LOG(INFO) << "value_state: " << value_state_str; + LOG(INFO) << "---------------------"; } } // we do not consider the layer. std::map, std::vector>> generate_kv_state(int token) { - std::vector key_state; - std::vector value_state; - for (int i = 0; i < DEMENSION; ++i) { - key_state.push_back(((double) token) / DEMENSION * (i + 1)); - value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2); - } - std::map, std::vector>> kv_state; - kv_state.insert(std::make_pair(1, std::make_pair(key_state, value_state))); + for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { + std::vector key_state; + std::vector value_state; + for (int i = 0; i < DEMENSION; ++i) { + key_state.push_back(((double) token) / DEMENSION * (i + 1) + + currentLayer * 10); + value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2 + + currentLayer * 10); + } + + kv_state.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } return kv_state; } From 26cef28d6013afdb26bdcae6aed3c15b46070594 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:36:59 +0800 Subject: [PATCH 11/20] Support to store more than 64 entries for kv cache block (#1769) Signed-off-by: vegetableysm --- modules/kv-state-cache/ds/kv_state_cache.cc | 14 +-- modules/kv-state-cache/ds/kv_state_cache.h | 8 +- .../kv-state-cache/ds/kv_state_cache_block.cc | 87 ++++++++++++++----- .../kv-state-cache/ds/kv_state_cache_block.h | 31 +++++-- .../utils/kv_state_cache_utils.cc | 6 +- .../utils/kv_state_cache_utils.h | 4 +- test/kv_state_cache_test.cc | 21 +++-- test/kv_state_cache_test_2.cc | 12 +-- 8 files changed, 125 insertions(+), 58 deletions(-) diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index bccf17cc..dbbfd4c3 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -69,12 +69,13 @@ KVStateCache::~KVStateCache() { } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, - int cacheCapacity, int layer) { + int cacheCapacity, int layer, + int blockSize) { this->dimension = dimension; this->version = 0; this->layer = layer; KVStateCacheBlockBuilder* builder = - new KVStateCacheBlockBuilder(client, this->dimension, layer); + new KVStateCacheBlockBuilder(client, this->dimension, layer, blockSize); this->rootTree = std::make_shared(cacheCapacity); @@ -95,7 +96,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, KVStateCacheBuilder::KVStateCacheBuilder(Client& client, std::shared_ptr cache) { // TBD - this->dimension = cache->GetDemension(); + this->dimension = cache->GetDimension(); this->version = cache->GetVersion(); this->layer = cache->GetLayer(); // 1. create block builder from block @@ -126,7 +127,8 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( // Split the tree if the list of kvState is full. VINEYARD_ASSERT(nodeDataList.size() > 0); KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder = - new KVStateCacheBlockBuilder(client, this->dimension, this->layer); + new KVStateCacheBlockBuilder(client, this->dimension, this->layer, + kvStateCacheBlockBuilder->GetBlockSize()); for (size_t i = 0; i < nodeDataList.size(); i++) { OffsetData* data = (OffsetData*) nodeDataList[i]->nodeData->data; if (data == nullptr) @@ -183,11 +185,11 @@ void KVStateCacheBuilder::Update(Client& client, LOG(INFO) << "kvStateCacheBlockBuilder:" << kvStateCacheBlockBuilder; if (kvStateCacheBlockBuilder->IsFull()) { /** - * If the kv-state cache of the tree is full, triggle split. Delete the + * If the kv-state cache of the tree is full, trigger split. Delete the * empty node from the radix tree and split the tree. Then, kv-state cache * split according to the new tree. */ - LOG(INFO) << "triggle splits"; + LOG(INFO) << "trigger splits"; std::shared_ptr evictedNodeData = nullptr; this->rootTree->Delete(tokenListCopy, evictedNodeData); diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h index fbdffb7b..f378aef3 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -61,7 +61,7 @@ class KVStateCache : public vineyard::Registered { return this->kvStateCacheBlockList; } - int GetDemension() { return this->dimension; } + int GetDimension() { return this->dimension; } int GetCacheCapacity() { return this->cacheCapacity; } @@ -79,12 +79,12 @@ class KVStateCache : public vineyard::Registered { class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr rootTree; int dimension; - int layer = 1; + int layer; uint64_t version; public: KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity, - int layer); + int layer, int blockSize = DEFAULT_BLOCK_SIZE); KVStateCacheBuilder(Client& client, std::shared_ptr cache); @@ -110,7 +110,7 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { std::shared_ptr _Seal(Client& client) override; - uint64_t GetDemension() { return this->dimension; } + uint64_t GetDimension() { return this->dimension; } std::shared_ptr GetRootTree() { return this->rootTree; } diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/kv-state-cache/ds/kv_state_cache_block.cc index 097fbc8d..397e3b04 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.cc +++ b/modules/kv-state-cache/ds/kv_state_cache_block.cc @@ -23,8 +23,10 @@ namespace vineyard { std::string KVStateCacheBlock::GetBitmapStr() { std::string result; const int bits = 8 * sizeof(unsigned long long); - for (int i = bits - 1; i >= 0; --i) { - result += ((this->bitmap >> i) & 1) ? '1' : '0'; + for (int i = 0; i < this->bitmapSize; i++) { + for (int j = bits - 1; j >= 0; --j) { + result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; + } } return result; } @@ -32,8 +34,10 @@ std::string KVStateCacheBlock::GetBitmapStr() { std::string KVStateCacheBlockBuilder::GetBitmapStr() { std::string result; const int bits = 8 * sizeof(unsigned long long); - for (int i = bits - 1; i >= 0; --i) { - result += ((this->bitmap >> i) & 1) ? '1' : '0'; + for (int i = 0; i < this->bitmapSize; i++) { + for (int j = bits - 1; j >= 0; --j) { + result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; + } } return result; } @@ -59,14 +63,27 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) { "valueStateTensorBuilder_" + std::to_string(currentLayer)))); } // 2. construct the member field - this->bitmap = this->meta_.GetKeyValue("bitmap"); + this->bitmapSize = this->meta_.GetKeyValue("bitmap_size"); + LOG(INFO) << "construct bitmap size:" << this->bitmapSize; + this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); + for (int i = 0; i < this->bitmapSize; i++) { + this->bitmap[i] = + this->meta_.GetKeyValue("bitmap_" + std::to_string(i)); + } this->dimension = this->meta_.GetKeyValue("dimension"); + this->blockSize = this->meta_.GetKeyValue("block_size"); } +KVStateCacheBlock::~KVStateCacheBlock() { free(this->bitmap); } + KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, - int dimension, int layer) { - this->bitmap = UINT64_MAX; - std::vector shape = {LIST_SIZE, dimension}; + int dimension, int layer, + int blockSize) { + this->blockSize = blockSize; + this->bitmapSize = (blockSize + 63) / 64; + this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); + memset((void*) this->bitmap, UINT8_MAX, this->bitmapSize * sizeof(uint64_t)); + std::vector shape = {(int64_t)(blockSize), dimension}; for (int i = 0; i < layer; i++) { this->keyStateTensorBuilderList.push_back( std::make_shared>(client, shape)); @@ -79,10 +96,17 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( Client& client, std::shared_ptr kvStateCacheBlock) { - this->bitmap = kvStateCacheBlock->bitmap; + this->bitmapSize = kvStateCacheBlock->bitmapSize; + this->blockSize = kvStateCacheBlock->blockSize; + LOG(INFO) << "create builder from block object, bitmap size:" + << this->bitmapSize << " block size:" << blockSize; + this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); + for (int i = 0; i < this->bitmapSize; i++) { + this->bitmap[i] = kvStateCacheBlock->bitmap[i]; + } this->dimension = kvStateCacheBlock->dimension; this->layer = kvStateCacheBlock->layer; - std::vector shape = {LIST_SIZE, dimension}; + std::vector shape = {(int64_t)(blockSize), dimension}; for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { this->keyStateTensorBuilderList.push_back( std::make_shared>(client, shape)); @@ -93,10 +117,10 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { memcpy(this->keyStateTensorBuilderList[currentLayer]->data(), kvStateCacheBlock->keyStateTensorList[currentLayer]->data(), - LIST_SIZE * this->dimension * sizeof(double)); + (int64_t)(blockSize) * this->dimension * sizeof(double)); memcpy(this->valueStateTensorBuilderList[currentLayer]->data(), kvStateCacheBlock->valueStateTensorList[currentLayer]->data(), - LIST_SIZE * this->dimension * sizeof(double)); + (int64_t)(blockSize) * this->dimension * sizeof(double)); } } @@ -126,14 +150,24 @@ Status KVStateCacheBlockBuilder::Query(Client& client, int index, } int KVStateCacheBlockBuilder::FindEmptySlot() { - int index = ffsll(this->bitmap) - 1; - VINEYARD_ASSERT(index >= 0 && index < LIST_SIZE); - return index; + for (int i = 0; i < this->bitmapSize; i++) { + if (this->bitmap[i] != 0) { + int index = ffsll(this->bitmap[i]) - 1; + return index + i * 64; + } + } + return -1; } bool KVStateCacheBlockBuilder::IsFull() { - int index = ffsll(this->bitmap) - 1; - return index < 0 || index >= LIST_SIZE; + int left = this->blockSize; + for (int i = 0; i < this->bitmapSize; i++) { + if (this->bitmap[i] != 0 && ffsll(this->bitmap[i]) - 1 < left) { + return false; + } + left -= sizeof(unsigned long long) * 8; + } + return true; } void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, @@ -147,7 +181,7 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, std::vector valueStateVector = (kvState.find(currentLayer)->second).second; LOG(INFO) << "vector size:" << keyStateVector.size() << " " - << valueStateVector.size() << " demension" << this->dimension; + << valueStateVector.size() << " dimension" << this->dimension; VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); @@ -161,7 +195,7 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, } data->offset = index; - ACQUIRE_BIT_RESOURCE(this->bitmap, index); + ACQUIRE_BIT_RESOURCE(this->bitmap[index / 64], index % 64); } short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, @@ -191,8 +225,8 @@ short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, memcpy(childKeyState, keyState, this->dimension * sizeof(double)); memcpy(childValueState, valueState, this->dimension * sizeof(double)); } - ACQUIRE_BIT_RESOURCE(child->bitmap, childIndex); - FREE_BIT_RESOURCE(this->bitmap, index); + ACQUIRE_BIT_RESOURCE(child->bitmap[childIndex / 64], childIndex % 64); + FREE_BIT_RESOURCE(this->bitmap[index / 64], index % 64); return childIndex; } @@ -216,7 +250,14 @@ std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { } // 2. store the member field to meta - kvStateCacheBlock->meta_.AddKeyValue("bitmap", this->bitmap); + kvStateCacheBlock->meta_.AddKeyValue("bitmap_size", this->bitmapSize); + for (int i = 0; i < this->bitmapSize; i++) { + kvStateCacheBlock->meta_.AddKeyValue("bitmap_" + std::to_string(i), + this->bitmap[i]); + } + LOG(INFO) << "seal bitmap:" << this->GetBitmapStr(); + + kvStateCacheBlock->meta_.AddKeyValue("block_size", this->blockSize); kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension); kvStateCacheBlock->meta_.AddKeyValue("layer", this->layer); // 3. set the object type to meta @@ -227,4 +268,6 @@ std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { return kvStateCacheBlock; } +KVStateCacheBlockBuilder::~KVStateCacheBlockBuilder() { free(this->bitmap); } + } // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index 8602dd40..3bb20d3a 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -48,12 +48,12 @@ struct OffsetData { }; namespace vineyard { -#define LIST_SIZE 5 +#define DEFAULT_BLOCK_SIZE 64 /** * @brief KVStateCacheBlock is a cache for kv-cache of LLM. When a new prompt * comes, LLM can query KVStateCacheBlock to get the state of the kv-cache to - * avoid caclulating the kv-cache again if the new prompt is similar to the + * avoid calculate the kv-cache again if the new prompt is similar to the * previous one. * * KVStateCacheBlock is stored in vineyard as a vineyard object which contains a @@ -67,7 +67,9 @@ class KVStateCacheBlock : public vineyard::Registered { private: std::vector>> keyStateTensorList; std::vector>> valueStateTensorList; - uint64_t bitmap; + uint64_t* bitmap; + int blockSize; + int bitmapSize; ObjectID id; int layer; int dimension; @@ -84,7 +86,9 @@ class KVStateCacheBlock : public vineyard::Registered { uint64_t GetDimension() { return this->dimension; } - uint64_t GetBitmap() { return this->bitmap; } + uint64_t* GetBitmap() { return this->bitmap; } + + int GetBlockSize() { return this->blockSize; } std::shared_ptr> GetKeyTensor(int layer) { return this->keyStateTensorList[layer]; @@ -102,6 +106,8 @@ class KVStateCacheBlock : public vineyard::Registered { return this->valueStateTensorList; } + ~KVStateCacheBlock(); + friend class KVStateCacheBlockBuilder; }; @@ -112,14 +118,17 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { valueStateTensorBuilderList; // TBD // support more than 64 kv-state cache slots - uint64_t bitmap; + uint64_t* bitmap; + int blockSize; + int bitmapSize; int dimension; int layer; int FindEmptySlot(); public: - KVStateCacheBlockBuilder(Client& client, int dimension, int layer); + KVStateCacheBlockBuilder(Client& client, int dimension, int layer, + int blockSize); KVStateCacheBlockBuilder( Client& client, std::shared_ptr kv_state_cache_block); @@ -172,13 +181,19 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { return valueStateTensorBuilderList; } - void DeleteKVCache(int bit) { FREE_BIT_RESOURCE(this->bitmap, bit); } + void DeleteKVCache(int bit) { + FREE_BIT_RESOURCE(this->bitmap[bit / 64], bit % 64); + } std::string GetBitmapStr(); - uint64_t GetBitmap() { return this->bitmap; } + uint64_t* GetBitmap() { return this->bitmap; } uint64_t GetDimension() { return this->dimension; } + + int GetBlockSize() { return this->blockSize; } + + ~KVStateCacheBlockBuilder(); }; } // namespace vineyard diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index 06b6d96e..ffaa9ade 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -56,7 +56,8 @@ void signalHandler(int signum) { exit(signum); } -void InitKVStateCache(int dimension, int cacheCapacity, int layer) { +void InitKVStateCache(int dimension, int cacheCapacity, int layer, + int blockSize) { if (kvStateCacheBuilder == nullptr) { std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); LOG(INFO) << "socket:" << socket; @@ -87,14 +88,13 @@ void InitKVStateCache(int dimension, int cacheCapacity, int layer) { std::shared_ptr globalKVStateCache = std::dynamic_pointer_cast( client.GetObject(globalKVStateCacheID)); - // TBD cache stragety kvStateCacheBuilder = std::make_shared(client, globalKVStateCache); } else { // if failed, create a new cache object LOG(INFO) << "failed to get the cache object, create a new one"; kvStateCacheBuilder = std::make_shared( - client, dimension, cacheCapacity, layer); + client, dimension, cacheCapacity, layer, blockSize); } // // release the lock diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h index 4110ae96..4df60be8 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -18,8 +18,8 @@ limitations under the License. #ifndef MODULES_KV_STATE_CACHE_UTILS_H_ #define MODULES_KV_STATE_CACHE_UTILS_H_ -void InitKVStateCache(int dimension = 10, int cacheCapacity = 10, - int layer = 1); +void InitKVStateCache(int dimension = 10, int cacheCapacity = 10, int layer = 1, + int blockSize = 5); void Update(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState); diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index 81332148..fa59f48b 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -24,11 +24,12 @@ limitations under the License. using namespace vineyard; -#define DEMENSION 10 +#define DIMENSION 10 #define CAPACITY 20 #define LAYER 3 +#define BLOCK_SIZE 5 -void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER); } +void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER, BLOCK_SIZE); } void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; @@ -46,7 +47,7 @@ void print_kv_state( for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { std::string key_state_str = ""; std::string value_state_str = ""; - for (int i = 0; i < DEMENSION; ++i) { + for (int i = 0; i < DIMENSION; ++i) { key_state_str += std::to_string(iter->second.first[i]) + " "; value_state_str += std::to_string(iter->second.second[i]) + " "; } @@ -64,10 +65,10 @@ generate_kv_state(int token) { for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { std::vector key_state; std::vector value_state; - for (int i = 0; i < DEMENSION; ++i) { - key_state.push_back(((double) token) / DEMENSION * (i + 1) + + for (int i = 0; i < DIMENSION; ++i) { + key_state.push_back(((double) token) / DIMENSION * (i + 1) + currentLayer * 10); - value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2 + + value_state.push_back(((double) token) / DIMENSION * (i + 1) * 2 + currentLayer * 10); } @@ -106,7 +107,11 @@ void inference(std::vector tokens, bool block = false) { int main() { init(); - std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + std::vector round_1_tokens = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69}; std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14}; std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; @@ -119,6 +124,8 @@ int main() { inference(round_1_tokens); inference(round_2_tokens); inference(round_3_tokens); + sleep(5); + inference(round_3_tokens); // inference(round_4_tokens); // sleep(5); // Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + diff --git a/test/kv_state_cache_test_2.cc b/test/kv_state_cache_test_2.cc index 4bc29157..db7ab2ed 100644 --- a/test/kv_state_cache_test_2.cc +++ b/test/kv_state_cache_test_2.cc @@ -24,11 +24,11 @@ limitations under the License. using namespace vineyard; -#define DEMENSION 10 +#define DIMENSION 10 #define CAPACITY 20 #define LAYER 3 -void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER); } +void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER); } void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; @@ -46,7 +46,7 @@ void print_kv_state( for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { std::string key_state_str = ""; std::string value_state_str = ""; - for (int i = 0; i < DEMENSION; ++i) { + for (int i = 0; i < DIMENSION; ++i) { key_state_str += std::to_string(iter->second.first[i]) + " "; value_state_str += std::to_string(iter->second.second[i]) + " "; } @@ -64,10 +64,10 @@ generate_kv_state(int token) { for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { std::vector key_state; std::vector value_state; - for (int i = 0; i < DEMENSION; ++i) { - key_state.push_back(((double) token) / DEMENSION * (i + 1) + + for (int i = 0; i < DIMENSION; ++i) { + key_state.push_back(((double) token) / DIMENSION * (i + 1) + currentLayer * 10); - value_state.push_back(((double) token) / DEMENSION * (i + 1) * 2 + + value_state.push_back(((double) token) / DIMENSION * (i + 1) * 2 + currentLayer * 10); } From 6ecf15c6b18d507b673a308b7fe9d72d1a583196 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Tue, 27 Feb 2024 19:57:03 +0800 Subject: [PATCH 12/20] Add ci test for kv state cache and fix bug (#1772) Fixes #1773 Signed-off-by: vegetableysm --- .github/workflows/build-test.yml | 10 + CMakeLists.txt | 2 +- modules/kv-state-cache/ds/kv_state_cache.cc | 111 ++++------- modules/kv-state-cache/ds/kv_state_cache.h | 11 +- .../kv-state-cache/ds/kv_state_cache_block.cc | 94 +++++---- .../kv-state-cache/ds/kv_state_cache_block.h | 17 +- .../kv-state-cache/radix-tree/radix-tree.cc | 10 +- .../kv-state-cache/radix-tree/radix-tree.h | 2 + modules/kv-state-cache/radix-tree/radix.cc | 97 ++++++---- .../kv-state-cache/strategy/LRU_strategy.cc | 124 ------------ .../kv-state-cache/strategy/LRU_strategy.h | 67 ------- .../kv-state-cache/strategy/cache_strategy.h | 31 --- .../utils/kv_state_cache_utils.cc | 70 +++---- .../utils/kv_state_cache_utils.h | 21 +- src/common/util/protocols.cc | 5 +- src/server/server/vineyard_server.cc | 6 +- src/server/services/etcd_meta_service.cc | 2 +- test/distributed_lock_test.cc | 66 ------- test/kv_state_cache_multi_test.cc | 94 +++++++++ test/kv_state_cache_object_test.cc | 180 ------------------ test/kv_state_cache_radix_tree_test.cc | 4 +- test/kv_state_cache_test.cc | 168 +++++++++++----- test/kv_state_cache_test_2.cc | 137 ------------- test/rax_diff_test.cc | 101 ---------- test/runner.py | 45 +++++ 25 files changed, 514 insertions(+), 961 deletions(-) delete mode 100644 modules/kv-state-cache/strategy/LRU_strategy.cc delete mode 100644 modules/kv-state-cache/strategy/LRU_strategy.h delete mode 100644 modules/kv-state-cache/strategy/cache_strategy.h delete mode 100644 test/distributed_lock_test.cc create mode 100644 test/kv_state_cache_multi_test.cc delete mode 100644 test/kv_state_cache_object_test.cc delete mode 100644 test/kv_state_cache_test_2.cc delete mode 100644 test/rax_diff_test.cc diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 4c1427b4..29ab598e 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -329,6 +329,16 @@ jobs: if: false uses: mxschmitt/action-tmate@v3 + - name: Run llm tests + run: | + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/lib64:/usr/local/lib/x86_64-linux-gnu + export VINEYARD_DATA_DIR=`pwd`/gstest + export TMPDIR="${TMPDIR:-$(dirname $(mktemp))}" + + rm -rf default.etcd + rm -rf /dev/shm/etcd* + python3 test/runner.py $RUNNER_ARGS --with-llm + - name: Run cpp tests run: | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/lib64:/usr/local/lib/x86_64-linux-gnu diff --git a/CMakeLists.txt b/CMakeLists.txt index 269d65a0..951376cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -981,7 +981,7 @@ endfunction() file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark" PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$" - EXCLUDE_PATTERNS ".*\\.vineyard.h$" + EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(modules/kv-state-cache/radix-tree/ra.*)" ) # the `memcpy.h` is borrowed from external project diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index dbbfd4c3..98792c03 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -14,14 +14,17 @@ limitations under the License. */ #include +#include +#include +#include #include "client/client.h" #include "common/util/base64.h" #include "common/util/logging.h" #include "common/util/status.h" +#include "kv-state-cache/ds/kv_state_cache.h" #include "kv-state-cache/radix-tree/radix-tree.h" #include "kv-state-cache/radix-tree/radix.h" -#include "kv_state_cache.h" namespace vineyard { @@ -31,7 +34,6 @@ void KVStateCache::Construct(const ObjectMeta& meta) { } void KVStateCache::Resolve() { - LOG(INFO) << "Resolve"; std::string typeName = type_name(); VINEYARD_ASSERT(this->meta_.GetTypeName() == typeName, @@ -41,27 +43,24 @@ void KVStateCache::Resolve() { // 1. construct the radix tree this->rootTree = RadixTree::Deserialize( base64_decode(this->meta_.GetKeyValue("radix_tree"))); - LOG(INFO) << "Resolve RadixTree success" << std::endl; raxShow(this->rootTree->GetRootTree()); // 2. construct the kvStateCacheBlockBuilder list size_t numBlocks = this->meta_.GetKeyValue("numBlocks"); - LOG(INFO) << "num blocks:" << numBlocks; for (size_t i = 0; i < numBlocks; i++) { std::shared_ptr kvStateCacheBlockObject = this->meta_.GetMember( "kv_state_cache_block_builder_" + std::to_string(i)); this->kvStateCacheBlockList.push_back( std::dynamic_pointer_cast(kvStateCacheBlockObject)); - this->kvStateCacheBlockMap[kvStateCacheBlockObject->id()] = - std::dynamic_pointer_cast(kvStateCacheBlockObject); - LOG(INFO) << "kvStateCacheBlockObject:" << kvStateCacheBlockObject->id(); } // 3. construct the member field this->dimension = this->meta_.GetKeyValue("dimension"); this->version = this->meta_.GetKeyValue("version"); this->layer = this->meta_.GetKeyValue("layer"); - LOG(INFO) << "construct the member field success" << std::endl; + VLOG(100) << "construct the member field success, with dimension:" + << this->dimension << " version:" << this->version + << " layer:" << this->layer; } KVStateCache::~KVStateCache() { @@ -87,31 +86,24 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, rootTreeHeader->treeData->data = treeData; rootTreeHeader->treeData->dataLength = sizeof(TreeData); this->rootTree->SetSubtreeData(treeData); - LOG(INFO) << "set builder:" << builder - << " to tree:" << this->rootTree->GetRootTree()->head; - LOG(INFO) << "data:" << treeData - << " custom data:" << rootTreeHeader->treeData; } KVStateCacheBuilder::KVStateCacheBuilder(Client& client, std::shared_ptr cache) { - // TBD this->dimension = cache->GetDimension(); this->version = cache->GetVersion(); this->layer = cache->GetLayer(); // 1. create block builder from block - std::map> kvStateCacheBlockMap = - cache->kvStateCacheBlockMap; + std::vector> kvStateCacheBlockList = + cache->GetKVStateCacheBlockList(); this->rootTree = cache->GetRootTree(); std::set subTreeData = cache->rootTree->GetSubTreeDataSet(); for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) { - TreeData* treeData = (TreeData*) (*iter); - LOG(INFO) << "tree data:" << treeData; + TreeData* treeData = reinterpret_cast(*iter); VINEYARD_ASSERT(treeData->isPtr == false); - LOG(INFO) << "id:" << treeData->builderObjectID; std::shared_ptr kvStateCacheBlock = - kvStateCacheBlockMap[treeData->builderObjectID]; + kvStateCacheBlockList[treeData->builderObjectID]; KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = new KVStateCacheBlockBuilder(client, kvStateCacheBlock); @@ -123,14 +115,14 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, std::vector> nodeDataList) { - LOG(INFO) << "split"; // Split the tree if the list of kvState is full. VINEYARD_ASSERT(nodeDataList.size() > 0); KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder = new KVStateCacheBlockBuilder(client, this->dimension, this->layer, kvStateCacheBlockBuilder->GetBlockSize()); for (size_t i = 0; i < nodeDataList.size(); i++) { - OffsetData* data = (OffsetData*) nodeDataList[i]->nodeData->data; + OffsetData* data = + reinterpret_cast(nodeDataList[i]->nodeData->data); if (data == nullptr) continue; int index = data->offset; @@ -139,9 +131,9 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( data->offset = kvStateCacheBlockBuilder->Split(childKVStateCacheBlockBuilder, index); } - LOG(INFO) << "builder:" << kvStateCacheBlockBuilder + VLOG(100) << "builder:" << kvStateCacheBlockBuilder << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); - LOG(INFO) << "child_builder:" << childKVStateCacheBlockBuilder + VLOG(100) << "child_builder:" << childKVStateCacheBlockBuilder << " bitmap:" << childKVStateCacheBlockBuilder->GetBitmapStr(); return childKVStateCacheBlockBuilder; } @@ -150,7 +142,6 @@ void KVStateCacheBuilder::Update(Client& client, const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState) { - LOG(INFO) << "update"; std::vector tokenListCopy = tokenList; tokenListCopy.push_back(nextToken); @@ -162,34 +153,21 @@ void KVStateCacheBuilder::Update(Client& client, LOG(INFO) << "insert failed"; return; } - LOG(INFO) << "insert end"; KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = - (KVStateCacheBlockBuilder*) ((TreeData*) nodeData->treeData->data) - ->kvStateCacheBlockBuilder; - LOG(INFO) << "try to delete"; + reinterpret_cast( + (reinterpret_cast(nodeData->treeData->data)) + ->kvStateCacheBlockBuilder); if (evictedNodeData != nullptr) { Delete(evictedNodeData); } - // if (evictedNodeData->treeData != nullptr && evictedNodeData->nodeData != - // nullptr) { - // if (evictedNodeData->nodeData->data != nullptr) { - // delete (TreeData*) evictedNodeData->nodeData->data; - // } - // } - - // TBD - // Use lock to protect the kv_state_cache_builder - LOG(INFO) << "data:" << nodeData->treeData->data - << " custom data:" << nodeData->treeData; - LOG(INFO) << "kvStateCacheBlockBuilder:" << kvStateCacheBlockBuilder; if (kvStateCacheBlockBuilder->IsFull()) { /** * If the kv-state cache of the tree is full, trigger split. Delete the * empty node from the radix tree and split the tree. Then, kv-state cache * split according to the new tree. */ - LOG(INFO) << "trigger splits"; + VLOG(100) << "trigger splits"; std::shared_ptr evictedNodeData = nullptr; this->rootTree->Delete(tokenListCopy, evictedNodeData); @@ -206,7 +184,7 @@ void KVStateCacheBuilder::Update(Client& client, subTreeHeader->treeData->data = newTreeData; subTreeHeader->treeData->dataLength = sizeof(TreeData); rootTree->SetSubtreeData(newTreeData); - LOG(INFO) << "block split success"; + VLOG(100) << "block split success"; // kv_state_cache_builder->UnLock(); Update(client, tokenList, nextToken, kvState); @@ -218,7 +196,7 @@ void KVStateCacheBuilder::Update(Client& client, nodeData->nodeData->dataLength = sizeof(OffsetData); } - LOG(INFO) << "builder:" << kvStateCacheBlockBuilder + VLOG(100) << "builder:" << kvStateCacheBlockBuilder << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); } @@ -233,38 +211,34 @@ KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( std::shared_ptr nodeData = this->rootTree->Query(tokenListCopy); if (nodeData != nullptr) { - OffsetData* data = (OffsetData*) nodeData->nodeData->data; + OffsetData* data = reinterpret_cast(nodeData->nodeData->data); int offset = data->offset; KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = - (KVStateCacheBlockBuilder*) ((TreeData*) nodeData->treeData->data) - ->kvStateCacheBlockBuilder; + reinterpret_cast( + (reinterpret_cast(nodeData->treeData->data)) + ->kvStateCacheBlockBuilder); - LOG(INFO) << "offset:" << offset; - LOG(INFO) << "kvStateCacheBlockBuilder:" << kvStateCacheBlockBuilder; kvStateCacheBlockBuilder->Query(client, offset, kvState); } return kvState; } void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { - LOG(INFO) << "stage1"; + TreeData* treeData = + reinterpret_cast(evictedNodeData->treeData->data); KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = - (KVStateCacheBlockBuilder*) ((TreeData*) evictedNodeData->treeData->data) - ->kvStateCacheBlockBuilder; - LOG(INFO) << "stage2, builder:" << kvStateCacheBlockBuilder; - OffsetData* data = (OffsetData*) evictedNodeData->nodeData->data; - LOG(INFO) << "stage3"; + reinterpret_cast( + treeData->kvStateCacheBlockBuilder); + OffsetData* data = + reinterpret_cast(evictedNodeData->nodeData->data); kvStateCacheBlockBuilder->DeleteKVCache(data->offset); - LOG(INFO) << "stage4"; delete data; // TBD // Refactor this code. The data should be deleted by the RadixTree // delete (DataWrapper*) evictedNodeData->nodeData; - LOG(INFO) << "tree data:" << evictedNodeData->treeData->data; if (evictedNodeData->cleanTreeData) { - LOG(INFO) << "erase"; - this->rootTree->GetSubTreeDataSet().erase(evictedNodeData->treeData->data); + this->rootTree->ClearSubtreeData(treeData); } evictedNodeData->RecycleSource(); } @@ -284,7 +258,7 @@ void KVStateCacheBuilder::Merge(Client& client, RadixTree::MergeTree(this->rootTree, globalCacheTree, evicted_token_list, insertTokenList); - LOG(INFO) << "insert token list size:" << insertTokenList.size() + VLOG(100) << "insert token list size:" << insertTokenList.size() << " evicted token list size:" << evicted_token_list.size(); for (size_t i = 0; i < evicted_token_list.size(); i++) { std::vector tokenList = @@ -317,7 +291,6 @@ Status KVStateCacheBuilder::Build(Client& client) { } std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { - LOG(INFO) << "cache seal"; this->Build(client); std::shared_ptr kvStateCache = std::make_shared(); @@ -334,19 +307,19 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); ++iter) { - TreeData* treeData = (TreeData*) (*iter); + TreeData* treeData = reinterpret_cast(*iter); VINEYARD_ASSERT(treeData != nullptr); VINEYARD_ASSERT(treeData->isPtr == true); KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = - (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder; - LOG(INFO) << "builder:" << kvStateCacheBlockBuilder; + reinterpret_cast( + treeData->kvStateCacheBlockBuilder); std::shared_ptr kvStateCacheBlock = kvStateCacheBlockBuilder->_Seal(client); kvStateCache->meta_.AddMember( "kv_state_cache_block_builder_" + std::to_string(count), kvStateCacheBlock); - treeData->builderObjectID = kvStateCacheBlock->id(); + treeData->builderObjectID = count; treeData->isPtr = false; count++; } @@ -362,26 +335,26 @@ std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { VINEYARD_CHECK_OK( client.CreateMetaData(kvStateCache->meta_, kvStateCache->id_)); - LOG(INFO) << "KVStateCacheBuilder::_Seal: " << kvStateCache->id_; + VLOG(100) << "KVStateCacheBuilder::_Seal: " << kvStateCache->id_; return kvStateCache; } KVStateCacheBuilder::~KVStateCacheBuilder() { - LOG(INFO) << "KVStateCacheBuilder::~KVStateCacheBuilder"; // get all subtree data and node data std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); std::set nodeDataSet = rootTree->GetAllNodeData(); // 2. delete all subtree data and node data for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); ++iter) { - TreeData* treeData = (TreeData*) (*iter); + TreeData* treeData = reinterpret_cast(*iter); if (treeData->isPtr == true) { - delete (KVStateCacheBlockBuilder*) treeData->kvStateCacheBlockBuilder; + delete reinterpret_cast( + treeData->kvStateCacheBlockBuilder); delete treeData; } } for (auto iter = nodeDataSet.begin(); iter != nodeDataSet.end(); ++iter) { - OffsetData* data = (OffsetData*) (*iter); + OffsetData* data = reinterpret_cast(*iter); if (data != nullptr) { delete data; } diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/kv-state-cache/ds/kv_state_cache.h index f378aef3..7cb80f20 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/kv-state-cache/ds/kv_state_cache.h @@ -14,17 +14,17 @@ limitations under the License. */ #include +#include #include #include "client/client.h" #include "common/util/logging.h" #include "common/util/status.h" +#include "kv-state-cache/ds/kv_state_cache_block.h" #include "kv-state-cache/radix-tree/radix-tree.h" -#include "kv-state-cache/strategy/LRU_strategy.h" -#include "kv_state_cache_block.h" -#ifndef MODULES_KV_STATE_CACHE_H_ -#define MODULES_KV_STATE_CACHE_H_ +#ifndef MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ +#define MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ namespace vineyard { @@ -39,7 +39,6 @@ struct TreeData { class KVStateCache : public vineyard::Registered { private: std::vector> kvStateCacheBlockList; - std::map> kvStateCacheBlockMap; std::shared_ptr rootTree; int dimension; int cacheCapacity; @@ -121,4 +120,4 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { } // namespace vineyard -#endif \ No newline at end of file +#endif // MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/kv-state-cache/ds/kv_state_cache_block.cc index 397e3b04..1441d196 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.cc +++ b/modules/kv-state-cache/ds/kv_state_cache_block.cc @@ -13,16 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "kv_state_cache_block.h" +#include +#include +#include + #include "client/client.h" #include "common/util/logging.h" +#include "kv-state-cache/ds/kv_state_cache_block.h" namespace vineyard { // this function will be removed in the future std::string KVStateCacheBlock::GetBitmapStr() { std::string result; - const int bits = 8 * sizeof(unsigned long long); + const int bits = 8 * sizeof(uint64_t); for (int i = 0; i < this->bitmapSize; i++) { for (int j = bits - 1; j >= 0; --j) { result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; @@ -33,7 +37,7 @@ std::string KVStateCacheBlock::GetBitmapStr() { std::string KVStateCacheBlockBuilder::GetBitmapStr() { std::string result; - const int bits = 8 * sizeof(unsigned long long); + const int bits = 8 * sizeof(uint64_t); for (int i = 0; i < this->bitmapSize; i++) { for (int j = bits - 1; j >= 0; --j) { result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; @@ -64,8 +68,8 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) { } // 2. construct the member field this->bitmapSize = this->meta_.GetKeyValue("bitmap_size"); - LOG(INFO) << "construct bitmap size:" << this->bitmapSize; - this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); + VLOG(100) << "construct bitmap size:" << this->bitmapSize; + this->bitmap = new uint64_t[this->bitmapSize]; for (int i = 0; i < this->bitmapSize; i++) { this->bitmap[i] = this->meta_.GetKeyValue("bitmap_" + std::to_string(i)); @@ -74,15 +78,15 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) { this->blockSize = this->meta_.GetKeyValue("block_size"); } -KVStateCacheBlock::~KVStateCacheBlock() { free(this->bitmap); } +KVStateCacheBlock::~KVStateCacheBlock() { delete this->bitmap; } KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, int dimension, int layer, int blockSize) { this->blockSize = blockSize; this->bitmapSize = (blockSize + 63) / 64; - this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); - memset((void*) this->bitmap, UINT8_MAX, this->bitmapSize * sizeof(uint64_t)); + this->bitmap = new uint64_t[this->bitmapSize]; + memset(this->bitmap, UINT8_MAX, this->bitmapSize * sizeof(uint64_t)); std::vector shape = {(int64_t)(blockSize), dimension}; for (int i = 0; i < layer; i++) { this->keyStateTensorBuilderList.push_back( @@ -98,9 +102,9 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( Client& client, std::shared_ptr kvStateCacheBlock) { this->bitmapSize = kvStateCacheBlock->bitmapSize; this->blockSize = kvStateCacheBlock->blockSize; - LOG(INFO) << "create builder from block object, bitmap size:" + VLOG(100) << "create builder from block object, bitmap size:" << this->bitmapSize << " block size:" << blockSize; - this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t)); + this->bitmap = new uint64_t[this->bitmapSize]; for (int i = 0; i < this->bitmapSize; i++) { this->bitmap[i] = kvStateCacheBlock->bitmap[i]; } @@ -132,15 +136,13 @@ Status KVStateCacheBlockBuilder::Query(Client& client, int index, std::vector valueStateVector; for (int i = 0; i < this->dimension; ++i) { - keyStateVector.push_back( - ((double*) keyStateTensorBuilderList[currentLayer] - ->data())[index * dimension + i]); + keyStateVector.push_back((keyStateTensorBuilderList[currentLayer] + ->data())[index * dimension + i]); } for (int i = 0; i < this->dimension; ++i) { - valueStateVector.push_back( - ((double*) valueStateTensorBuilderList[currentLayer] - ->data())[index * dimension + i]); + valueStateVector.push_back((valueStateTensorBuilderList[currentLayer] + ->data())[index * dimension + i]); } kvState.insert(std::make_pair( @@ -165,7 +167,7 @@ bool KVStateCacheBlockBuilder::IsFull() { if (this->bitmap[i] != 0 && ffsll(this->bitmap[i]) - 1 < left) { return false; } - left -= sizeof(unsigned long long) * 8; + left -= sizeof(uint64_t) * 8; } return true; } @@ -173,21 +175,16 @@ bool KVStateCacheBlockBuilder::IsFull() { void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, OffsetData* data) { int index = this->FindEmptySlot(); - LOG(INFO) << "index:" << index; - LOG(INFO) << "layer:" << layer; for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { std::vector keyStateVector = (kvState.find(currentLayer)->second).first; std::vector valueStateVector = (kvState.find(currentLayer)->second).second; - LOG(INFO) << "vector size:" << keyStateVector.size() << " " - << valueStateVector.size() << " dimension" << this->dimension; VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); - double* keyData = (double*) keyStateTensorBuilderList[currentLayer]->data(); - double* valueData = - (double*) valueStateTensorBuilderList[currentLayer]->data(); + double* keyData = keyStateTensorBuilderList[currentLayer]->data(); + double* valueData = valueStateTensorBuilderList[currentLayer]->data(); memcpy(keyData + index * this->dimension, keyStateVector.data(), this->dimension * sizeof(double)); memcpy(valueData + index * this->dimension, valueStateVector.data(), @@ -198,8 +195,8 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, ACQUIRE_BIT_RESOURCE(this->bitmap[index / 64], index % 64); } -short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, - int index) { +int16_t KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, + int index) { // TBD VINEYARD_ASSERT(this->layer == child->layer); int childIndex = child->FindEmptySlot(); @@ -213,14 +210,13 @@ short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, std::shared_ptr> childValueStateTensorBuilder = child->valueStateTensorBuilderList[currentLayer]; - double* keyState = - (double*) keyStateTensorBuilder->data() + index * this->dimension; + double* keyState = keyStateTensorBuilder->data() + index * this->dimension; double* valueState = - (double*) valueStateTensorBuilder->data() + index * this->dimension; - double* childKeyState = (double*) childKeyStateTensorBuilder->data() + - childIndex * this->dimension; - double* childValueState = (double*) childValueStateTensorBuilder->data() + - childIndex * this->dimension; + valueStateTensorBuilder->data() + index * this->dimension; + double* childKeyState = + childKeyStateTensorBuilder->data() + childIndex * this->dimension; + double* childValueState = + childValueStateTensorBuilder->data() + childIndex * this->dimension; memcpy(childKeyState, keyState, this->dimension * sizeof(double)); memcpy(childValueState, valueState, this->dimension * sizeof(double)); @@ -233,7 +229,6 @@ short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, Status KVStateCacheBlockBuilder::Build(Client& client) { return Status::OK(); } std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { - LOG(INFO) << "block seal:" << this; this->Build(client); std::shared_ptr kvStateCacheBlock = @@ -255,7 +250,6 @@ std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { kvStateCacheBlock->meta_.AddKeyValue("bitmap_" + std::to_string(i), this->bitmap[i]); } - LOG(INFO) << "seal bitmap:" << this->GetBitmapStr(); kvStateCacheBlock->meta_.AddKeyValue("block_size", this->blockSize); kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension); @@ -268,6 +262,34 @@ std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { return kvStateCacheBlock; } -KVStateCacheBlockBuilder::~KVStateCacheBlockBuilder() { free(this->bitmap); } +void KVStateCacheBlockBuilder::PrintKVStateCacheBlock() { + LOG(INFO) << "builder:" << this; + for (int i = 0; i < this->blockSize; i++) { + LOG(INFO) << "index:" << i << " bitmap:" << this->GetBitmapStr(); + } + + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + LOG(INFO) << "layer:" << currentLayer; + for (int i = 0; i < this->blockSize; i++) { + LOG(INFO) << "index:" << i; + std::string keyState = ""; + std::string valueState = ""; + for (int j = 0; j < this->dimension; j++) { + keyState += std::to_string((keyStateTensorBuilderList[currentLayer] + ->data())[i * dimension + j]) + + " "; + valueState += std::to_string((valueStateTensorBuilderList[currentLayer] + ->data())[i * dimension + j]) + + " "; + } + LOG(INFO) << "keyState:" << keyState; + LOG(INFO) << "valueState:" << valueState; + } + } + + LOG(INFO) << "=========================="; +} + +KVStateCacheBlockBuilder::~KVStateCacheBlockBuilder() { delete this->bitmap; } } // namespace vineyard diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/kv-state-cache/ds/kv_state_cache_block.h index 3bb20d3a..4870201f 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/kv-state-cache/ds/kv_state_cache_block.h @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef MODULES_KV_STATE_CACHE_BLOCK_H_ -#define MODULES_KV_STATE_CACHE_BLOCK_H_ +#ifndef MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ +#define MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ #include #include #include +#include +#include +#include #include #include "basic/ds/tensor.h" @@ -44,7 +47,7 @@ typedef std::vector, std::vector>> ((value) &= (~(((uint64_t) 1) << (bit)))) struct OffsetData { - short offset; + int16_t offset; }; namespace vineyard { @@ -142,7 +145,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { */ void Update(const KV_STATE_WITH_LAYER& kv_state, OffsetData* data); - void Update(double* keyState, double* valueState, unsigned long dataLength, + void Update(double* keyState, double* valueState, uint64_t dataLength, OffsetData* data); /** @@ -161,7 +164,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { std::shared_ptr _Seal(Client& client) override; - short Split(KVStateCacheBlockBuilder* child, int index); + int16_t Split(KVStateCacheBlockBuilder* child, int index); const std::shared_ptr> GetKeyStateBuilder(int layer) { return keyStateTensorBuilderList[layer]; @@ -193,9 +196,11 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { int GetBlockSize() { return this->blockSize; } + void PrintKVStateCacheBlock(); + ~KVStateCacheBlockBuilder(); }; } // namespace vineyard -#endif +#endif // MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index 1c2dda09..c528f44c 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -168,7 +168,6 @@ void RadixTree::DeleteInternal(std::vector tokens, oldData, (DataWrapper*) subTreeNode->custom_data); nodeCount--; if (nodeIsSubTree) { - // subTreeDataSet.erase(subTreeNode->custom_data); evictedNode->cleanTreeData = true; } } else { @@ -498,7 +497,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { node->issubtree = true; raxSetCustomData(node, data); - radixTree->subTreeDataSet.insert(subTreeDataList[i]); + radixTree->SetSubtreeData(subTreeDataList[i]); } LOG(INFO) << "Deserialize success"; raxShow(radixTree->tree); @@ -544,10 +543,15 @@ std::vector> RadixTree::TraverseTreeWithoutSubTree( } void RadixTree::SetSubtreeData(void* data) { - LOG(INFO) << "set subtree data:" << data; + VLOG(100) << "set subtree data:" << data; subTreeDataSet.insert(data); } +void RadixTree::ClearSubtreeData(void* data) { + VLOG(100) << "clear subtree data:" << data; + subTreeDataSet.erase(data); +} + std::shared_ptr RadixTree::GetRootNode() { raxNode* node = raxFindAndReturnDataNode(this->tree, rootToken.data(), rootToken.size(), NULL); diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index 139f64e3..8a2d5a95 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -99,6 +99,8 @@ class RadixTree : public std::enable_shared_from_this { void SetSubtreeData(void* data); + void ClearSubtreeData(void* data); + rax* GetRootTree() { return this->tree; } int GetCacheCapacity() { return cacheCapacity - 1; } diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 48a115bc..5d9172e0 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -1589,23 +1589,25 @@ int raxIteratorNextStep(raxIterator *it, int noup) { raxNode **cp = raxNodeFirstChildPtr(it->node); if (!raxIteratorAddToken(it,it->node->data, it->node->iscompr ? it->node->size : 1)) return 0; - if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && - it->subtree_data_list != NULL) { - std::cout << "first find subtree list is:" << std::endl; - std::vector token; - std::string token_str; - for (size_t i = 0; i < it->key_len - 1; i++) { - token.push_back(it->key[i]); - token_str += std::to_string(it->key[i]) + " "; - } - LOG(INFO) << "list is:" << token_str; - (*it->subtree_list).push_back(token); - void *data = raxGetCustomData(it->node); - if (data == NULL) { - throw std::runtime_error("custom data is null"); - } - (*it->subtree_data_list).push_back(data); - } + // TBD + // refactor this code with raxSerial() + // if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // std::cout << "first find subtree list is:" << std::endl; + // std::vector token; + // std::string token_str; + // for (size_t i = 0; i < it->key_len - 1; i++) { + // token.push_back(it->key[i]); + // token_str += std::to_string(it->key[i]) + " "; + // } + // LOG(INFO) << "list is:" << token_str; + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } memcpy(&it->node,cp,sizeof(it->node)); /* Call the node callback if any, and replace the node pointer * if the callback returns true. */ @@ -1634,23 +1636,26 @@ int raxIteratorNextStep(raxIterator *it, int noup) { it->node = orig_node; return 1; } - if (it->node->iskey && it->node->size == 0 && it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && - it->subtree_data_list != NULL) { - // data node is sub tree - std::vector token; - std::string token_str; - for (size_t i = 0; i < it->key_len - 1; i++) { - token.push_back(it->key[i]); - token_str += std::to_string(it->key[i]) + " "; - } - LOG(INFO) << "sub tree is:" << token_str; - (*it->subtree_list).push_back(token); - void *data = raxGetCustomData(it->node); - if (data == NULL) { - throw std::runtime_error("custom data is null"); - } - (*it->subtree_data_list).push_back(data); - } + // TBD + // refactor this code with raxSerial() + // if (it->node->iskey && it->node->size == 0 && it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // // data node is sub tree + // std::vector token; + // std::string token_str; + // for (size_t i = 0; i < it->key_len - 1; i++) { + // token.push_back(it->key[i]); + // token_str += std::to_string(it->key[i]) + " "; + // } + // LOG(INFO) << "sub tree is:" << token_str; + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } + /* If there are no children at the current node, try parent's * next child. */ int prevchild = it->key[it->key_len-1]; @@ -2152,11 +2157,18 @@ uint64_t raxSize(rax *rax) { * [1,2] -> [1,2,3,4] -> [] */ -struct datawrapper { +struct DebugDatawrapper { void *data; int length; }; +struct DebugTreeData { + union { + void* kvStateCacheBlockBuilder; + uint64_t builderObjectID; + }; + bool isPtr = true; +}; /* The actual implementation of raxShow(). */ void raxRecursiveShow(int level, int lpad, raxNode *n) { char s = n->iscompr ? '"' : '['; @@ -2177,7 +2189,15 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { } numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d", n, n->timestamp, n->custom_data, n->issubtree); if (n->issubtree && n->custom_data != NULL) { - numchars += printf(" cus data:%p" , ((datawrapper *)(n->custom_data))->data); + numchars += printf(" cus data:%p" , ((DebugDatawrapper *)(n->custom_data))->data); + DebugTreeData *data = (DebugTreeData *)((DebugDatawrapper *)(n->custom_data))->data; + if (data) { + if (data->isPtr) { + numchars += printf(" builder ptr:%p", data->kvStateCacheBlockBuilder); + } else { + numchars += printf(" builder id:%lu", data->builderObjectID); + } + } } int numchildren = n->iscompr ? 1 : n->size; @@ -2421,6 +2441,11 @@ void raxSerialize(rax *root, std::vector> &tokenList, std::vect tokenList.push_back(token); dataList.push_back(iter.data); timestampList.push_back(iter.node->timestamp); + raxNode *data = raxFindAndReturnDataNode(root, iter.key, iter.key_len, nullptr, false); + if (data->issubtree && subtreeList != nullptr) { + subtreeList->push_back(token); + subtreeDataList->push_back(data->custom_data); + } } raxStop(&iter); } diff --git a/modules/kv-state-cache/strategy/LRU_strategy.cc b/modules/kv-state-cache/strategy/LRU_strategy.cc deleted file mode 100644 index 87d4cdee..00000000 --- a/modules/kv-state-cache/strategy/LRU_strategy.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include "LRU_strategy.h" -#include "common/util/logging.h" - -namespace vineyard { - -void PrintTokenList(std::vector& vector) { - std::string tokens_str = ""; - for (size_t i = 0; i < vector.size(); ++i) { - tokens_str += std::to_string(vector[i]); - } - LOG(INFO) << tokens_str; -} - -void LRUStrategy::PrintLRUList() { - LOG(INFO) << "List:"; - std::shared_ptr node = header; - while (node != nullptr) { - PrintTokenList(node->tokens); - LOG(INFO) << "->"; - node = node->next; - } -} - -LRUStrategy::LRUStrategy(int capacity) { - this->capacity = capacity; - this->header = this->tail = nullptr; - this->current_size = 0; -} - -LRUStrategy::LRUStrategy(const std::vector>& cache_list, - int capacity) { - // TBD -} - -std::shared_ptr LRUStrategy::InsertToHeader( - const std::vector& tokens, std::vector& evicted_tokens) { - if (current_size == capacity) { - std::shared_ptr remove_node = tail; // Remove(); - evicted_tokens = remove_node->tokens; - } - - std::shared_ptr cache_node = std::make_shared(); - cache_node->tokens = tokens; - - if (header == nullptr) { - header = cache_node; - tail = cache_node; - } else { - cache_node->next = header; - header->prev = cache_node; - header = cache_node; - } - - current_size++; - return cache_node; -} - -void LRUStrategy::MoveToHead(std::shared_ptr cache_node) { - if (cache_node == header) { - return; - } - - if (cache_node == tail) { - tail = tail->prev; - tail->next = nullptr; - } else { - cache_node->prev->next = cache_node->next; - cache_node->next->prev = cache_node->prev; - } - - cache_node->next = header; - header->prev = cache_node; - header = cache_node; - cache_node->prev = nullptr; -} - -std::shared_ptr LRUStrategy::Remove() { - std::shared_ptr cache_node = tail; - if (tail->prev != nullptr) { - tail->prev->next = nullptr; - tail = tail->prev; - } else { - header = nullptr; - tail = nullptr; - } - current_size--; - - LOG(INFO) << "Remove token:"; - PrintTokenList(cache_node->tokens); - return cache_node; -} - -void LRUStrategy::Remove(std::shared_ptr cache_node) { - if (cache_node == header) { - header = header->next; - header->prev = nullptr; - } else if (cache_node == tail) { - tail = tail->prev; - tail->next = nullptr; - } else { - cache_node->prev->next = cache_node->next; - cache_node->next->prev = cache_node->prev; - } - current_size--; -} - -std::shared_ptr LRUStrategy::GetHeader() { return header; } - -} // namespace vineyard diff --git a/modules/kv-state-cache/strategy/LRU_strategy.h b/modules/kv-state-cache/strategy/LRU_strategy.h deleted file mode 100644 index fcfc519e..00000000 --- a/modules/kv-state-cache/strategy/LRU_strategy.h +++ /dev/null @@ -1,67 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include - -#include "cache_strategy.h" - -#ifndef MODULES_LRU_STRATEGY_H_ -#define MODULES_LRU_STRATEGY_H_ - -namespace vineyard { - -struct LRUCacheNode { - std::shared_ptr next; - std::shared_ptr prev; - std::vector tokens; -}; - -class LRUStrategy : public CacheStrategy { - private: - int current_size; - - std::shared_ptr header; - - std::shared_ptr tail; - - LRUStrategy(); - - std::shared_ptr Remove(); - - ~LRUStrategy(); - - public: - LRUStrategy(int capacity); - - LRUStrategy(const std::vector>& cache_list, int capacity); - - void MoveToHead(std::shared_ptr cache_node); - - std::shared_ptr InsertToHeader( - const std::vector& tokens, std::vector& evicted_tokens); - - void Remove(std::shared_ptr cache_node); - - std::shared_ptr GetHeader(); - - int GetCapacity() { return capacity; } - - void PrintLRUList(); -}; - -} // namespace vineyard - -#endif \ No newline at end of file diff --git a/modules/kv-state-cache/strategy/cache_strategy.h b/modules/kv-state-cache/strategy/cache_strategy.h deleted file mode 100644 index 36596cd1..00000000 --- a/modules/kv-state-cache/strategy/cache_strategy.h +++ /dev/null @@ -1,31 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include - -#ifndef MODULES_CACHE_STRATEGY_H_ -#define MODULES_CACHE_STRATEGY_H_ - -namespace vineyard { - -class CacheStrategy { - protected: - int capacity; -}; - -} // namespace vineyard - -#endif \ No newline at end of file diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index ffaa9ade..34a47ef7 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -14,13 +14,16 @@ limitations under the License. */ #include +#include +#include +#include #include "client/client.h" #include "common/util/logging.h" #include "kv-state-cache/ds/kv_state_cache.h" -#include "kv_state_cache_utils.h" +#include "kv-state-cache/utils/kv_state_cache_utils.h" -using namespace vineyard; +namespace vineyard { static Client client; static std::shared_ptr kvStateCacheBuilder = nullptr; @@ -50,29 +53,31 @@ void signalHandler(int signum) { * Avoid dead lock if the client is down when the lock is acquired. * Use lease to prevent dead lock in the future. */ - std::cout << "Interrupt signal (" << signum << ") received.\n"; + LOG(INFO) << "Interrupt signal (" << signum << ") received.\n"; + CloseKVStateCache(); + exit(signum); +} + +void CloseKVStateCache() { exitFlag = true; syncThread->join(); - exit(signum); } void InitKVStateCache(int dimension, int cacheCapacity, int layer, - int blockSize) { + int blockSize, std::string socket) { if (kvStateCacheBuilder == nullptr) { - std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); - LOG(INFO) << "socket:" << socket; + VLOG(100) << "socket:" << socket; client.Connect(socket); - LOG(INFO) << "conneted"; pthread_mutex_init(&syncMutex, NULL); // TBD // try to get cache object - std::string acturalKey; + std::string actualKey; bool result; while (1) { - client.TryAcquireLock(llmCacheSyncLock, result, acturalKey); + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); if (!result) { - LOG(INFO) << "failed to gain the lock, wait for next time"; + VLOG(100) << "failed to gain the lock, wait for next time."; sleep(1); continue; } else { @@ -80,25 +85,25 @@ void InitKVStateCache(int dimension, int cacheCapacity, int layer, } } - // // sync global cache object with vineyard + // sync global cache object with vineyard ObjectID globalKVStateCacheID; Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); if (status.ok()) { // if success, pull the cache object std::shared_ptr globalKVStateCache = std::dynamic_pointer_cast( - client.GetObject(globalKVStateCacheID)); + client.FetchAndGetObject(globalKVStateCacheID)); kvStateCacheBuilder = std::make_shared(client, globalKVStateCache); } else { // if failed, create a new cache object - LOG(INFO) << "failed to get the cache object, create a new one"; + VLOG(100) << "failed to get the cache object, create a new one."; kvStateCacheBuilder = std::make_shared( client, dimension, cacheCapacity, layer, blockSize); } // // release the lock - client.TryReleaseLock(acturalKey, result); + client.TryReleaseLock(actualKey, result); VINEYARD_ASSERT(result == true); syncThread = new std::thread(threadFunc); @@ -116,7 +121,6 @@ void UpdateInternal(const std::vector& tokenList, int nextToken, void Update(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState) { - LOG(INFO) << "Update"; if (pthread_mutex_trylock(&syncMutex)) { return; } @@ -145,7 +149,6 @@ KV_STATE_WITH_LAYER QueryInternal(const std::vector& tokenList, } KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token) { - LOG(INFO) << "Query"; KV_STATE_WITH_LAYER result; if (pthread_mutex_trylock(&syncMutex)) { return result; @@ -175,12 +178,12 @@ LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList) { } void sync() { - LOG(INFO) << "sync"; + LOG(INFO) << "Try sync."; // 1. gain the lock - std::string acturalKey; + std::string actualKey; bool result; - client.TryAcquireLock(llmCacheSyncLock, result, acturalKey); + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); if (!result) { LOG(INFO) << "failed to gain the lock, wait for next time"; return; @@ -194,12 +197,12 @@ void sync() { if (status.ok()) { deleteList.push_back(globalKVStateCacheID); globalKVStateCache = std::dynamic_pointer_cast( - client.GetObject(globalKVStateCacheID)); + client.FetchAndGetObject(globalKVStateCacheID)); } // 3. merge the cache object // only the global cache object with higher version will be merged - LOG(INFO) << "Current builder version:" << kvStateCacheBuilder->GetVersion() + VLOG(100) << "Current builder version:" << kvStateCacheBuilder->GetVersion() << " global version:" << (globalKVStateCache == nullptr ? "null" @@ -215,29 +218,29 @@ void sync() { client.Persist(kvStateCache->id()); // 5. put the name of the new cache object to the meta server - LOG(INFO) << "stage 5"; client.DropName(llmCacheObjectName); status = client.PutName(kvStateCache->id(), llmCacheObjectName); - if (status.ok()) { - LOG(INFO) << "put name success"; - } else { - LOG(INFO) << "put name failed with status:" + status.ToString(); + if (!status.ok()) { + throw std::runtime_error("Put cache object name failed."); } - LOG(INFO) << "stage 6"; // 6. delete old cache object client.DelData(deleteList); - LOG(INFO) << "stage 7"; // 7. create a global cache object replica std::dynamic_pointer_cast(kvStateCache)->Resolve(); kvStateCacheBuilder = std::make_shared( client, std::dynamic_pointer_cast(kvStateCache)); - LOG(INFO) << "stage 8"; // 8. release the lock - client.TryReleaseLock(acturalKey, result); - VINEYARD_ASSERT(result == true); + while (1) { + LOG(INFO) << "stage 7"; + client.TryReleaseLock(actualKey, result); + if (result == true) { + break; + } + sleep(1); + } // TBD // use lease to prevent the deadlock if the client is down @@ -249,9 +252,10 @@ void threadFunc() { if (exitFlag) { break; } - LOG(INFO) << "Try sync"; pthread_mutex_lock(&syncMutex); sync(); pthread_mutex_unlock(&syncMutex); } } + +} // namespace vineyard diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/kv-state-cache/utils/kv_state_cache_utils.h index 4df60be8..a40a12d8 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.h @@ -13,13 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include +#include + #include "kv-state-cache/ds/kv_state_cache.h" -#ifndef MODULES_KV_STATE_CACHE_UTILS_H_ -#define MODULES_KV_STATE_CACHE_UTILS_H_ +#ifndef MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ +#define MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ + +namespace vineyard { -void InitKVStateCache(int dimension = 10, int cacheCapacity = 10, int layer = 1, - int blockSize = 5); +void InitKVStateCache( + int dimension = 10, int cacheCapacity = 10, int layer = 1, + int blockSize = 5, + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET"))); void Update(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState); @@ -33,4 +40,8 @@ LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); void Delete(std::vector token); -#endif \ No newline at end of file +void CloseKVStateCache(); + +} // namespace vineyard + +#endif // MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ diff --git a/src/common/util/protocols.cc b/src/common/util/protocols.cc index 5c692c0d..c59b726d 100644 --- a/src/common/util/protocols.cc +++ b/src/common/util/protocols.cc @@ -2139,8 +2139,7 @@ void WriteInstanceStatusReply(const json& meta, std::string& msg) { Status ReadInstanceStatusReply(const json& root, json& meta) { CHECK_IPC_ERROR(root, command_t::INSTANCE_STATUS_REPLY); - meta = root["meta"].get(); - ; + meta = root["meta"]; return Status::OK(); } @@ -2305,7 +2304,6 @@ void WriteTryReleaseLockRequest(const std::string& key, std::string& msg) { Status ReadTryReleaseLockRequest(const json& root, std::string& key) { CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REQUEST); key = root["key"].get(); - ; return Status::OK(); } @@ -2319,7 +2317,6 @@ void WriteTryReleaseLockReply(const bool result, std::string& msg) { Status ReadTryReleaseLockReply(const json& root, bool& result) { CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REPLY); result = root["result"].get(); - ; return Status::OK(); } diff --git a/src/server/server/vineyard_server.cc b/src/server/server/vineyard_server.cc index 063f4ee2..e49ee899 100644 --- a/src/server/server/vineyard_server.cc +++ b/src/server/server/vineyard_server.cc @@ -1060,12 +1060,12 @@ Status VineyardServer::TryAcquireLock(std::string& key, auto self(shared_from_this()); meta_service_ptr_->TryAcquireLock( key, [self, callback](const Status& status, bool result, - std::string actural_key) { + std::string actual_key) { if (status.ok()) { LOG(INFO) << "No error occurred. Gain lock:" << result; - return callback(status, result, actural_key); + return callback(status, result, actual_key); } else { - return callback(status, result, actural_key); + return callback(status, result, actual_key); } }); diff --git a/src/server/services/etcd_meta_service.cc b/src/server/services/etcd_meta_service.cc index 76e71e20..c2a662ca 100644 --- a/src/server/services/etcd_meta_service.cc +++ b/src/server/services/etcd_meta_service.cc @@ -166,7 +166,7 @@ void EtcdMetaService::TryAcquireLock( boost::bind(callback_after_try_lock, Status::OK(), true, resp.lock_key().substr(self->prefix_.size()))); } else { - LOG(INFO) << "lock falied!"; + LOG(INFO) << "lock failed!"; self->server_ptr_->GetMetaContext().post( boost::bind(callback_after_try_lock, Status::OK(), false, "")); } diff --git a/test/distributed_lock_test.cc b/test/distributed_lock_test.cc deleted file mode 100644 index ef81bf2a..00000000 --- a/test/distributed_lock_test.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include - -#include "client/client.h" -#include "common/util/logging.h" - -using namespace vineyard; - -int numThreads = 5; - -static int count = 0; - -void test(int i) { - std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); - Client client; - client.Connect(socket); - - bool result; - std::string actural_key_of_lock; - - LOG(INFO) << "Thread: " << i << " try to acquire lock: test"; - client.TryAcquireLock("test", result, actural_key_of_lock); - LOG(INFO) << "Thread: " << i - << " acquire Lock: " << (result == true ? "success" : "fail") - << ", key is :" + actural_key_of_lock; - - if (result) { - count++; - LOG(INFO) << "count: " << count; - - sleep(3); - - LOG(INFO) << "Thread: " << i << " try to release lock: test"; - client.TryReleaseLock(actural_key_of_lock, result); - LOG(INFO) << "Thread: " << i - << " release Lock: " << (result == true ? "success" : "fail"); - } -} - -int main() { - std::thread threads[numThreads]; - for (int i = 0; i < numThreads; i++) { - threads[i] = std::thread(test, i); - } - - for (int i = 0; i < numThreads; i++) { - threads[i].join(); - } - - return 0; -} \ No newline at end of file diff --git a/test/kv_state_cache_multi_test.cc b/test/kv_state_cache_multi_test.cc new file mode 100644 index 00000000..52d4ae42 --- /dev/null +++ b/test/kv_state_cache_multi_test.cc @@ -0,0 +1,94 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include + +#include "common/util/logging.h" + +char process_name[] = "kv_state_cache_test"; +char arg_0[] = "-s"; +char token_sequence_1[] = "1"; +char token_sequence_2[] = "2"; +char token_sequence_3[] = "3"; +char token_sequence_4[] = "4"; + +const char* program = "./build/bin/kv_state_cache_test"; + +pid_t create_subprocess(char* argv[]) { + pid_t pid = fork(); + if (pid == -1) { + std::cerr << "Failed to fork()" << std::endl; + exit(1); + } else if (pid > 0) { + return pid; + } else { + execv(program, argv); + std::cerr << "Failed to exec()" << std::endl; + exit(1); + } +} + +int main(int argc, char** argv) { + std::string sockets[2]; + std::string rpc_endpoint; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--vineyard-endpoint") == 0) { + rpc_endpoint = std::string(argv[i + 1]); + } else if (strcmp(argv[i], "--vineyard-ipc-sockets") == 0) { + sockets[0] = std::string(argv[i + 1]); + sockets[1] = std::string(argv[i + 2]); + } + } + + char* socket_str[2]; + socket_str[0] = + (char*) malloc(sockets[0].length() + 1); // NOLINT(readability/casting) + socket_str[1] = + (char*) malloc(sockets[1].length() + 1); // NOLINT(readability/casting) + memset(socket_str[0], 0, sockets[0].length() + 1); + memset(socket_str[1], 0, sockets[1].length() + 1); + strcpy(socket_str[0], sockets[0].c_str()); // NOLINT(runtime/printf) + strcpy(socket_str[1], sockets[1].c_str()); // NOLINT(runtime/printf) + + char* args_1[] = {process_name, socket_str[0], arg_0, + token_sequence_1, token_sequence_2, token_sequence_3, + token_sequence_4, nullptr}; + char* args_2[] = {process_name, socket_str[1], arg_0, + token_sequence_1, token_sequence_2, token_sequence_3, + nullptr}; + + std::vector pids; + pids.push_back(create_subprocess(args_1)); + pids.push_back(create_subprocess(args_2)); + for (size_t i = 0; i < pids.size(); i++) { + int status; + waitpid(pids[i], &status, 0); + if (WIFEXITED(status) && WEXITSTATUS(status) != 0) { + free(socket_str[0]); + free(socket_str[1]); + return 1; + } + } + + free(socket_str[0]); + free(socket_str[1]); + + return 0; +} diff --git a/test/kv_state_cache_object_test.cc b/test/kv_state_cache_object_test.cc deleted file mode 100644 index 433fe5e0..00000000 --- a/test/kv_state_cache_object_test.cc +++ /dev/null @@ -1,180 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include - -#include "basic/ds/tensor.h" -#include "common/util/logging.h" -#include "kv-state-cache/ds/kv_state_cache.h" - -using namespace vineyard; - -// std::vector tokens; -// RadixTree* radix_tree; -// std::vector> k_state_list; -// std::vector> v_state_list; -// std::vector> nodes_with_tree_attri_list; - -// #define DIMENSION 10 -// #define TOKEN_NUM 10 -// #define CACHE_CAPACITY 10 - -// void prepareData(KVStateCacheBuilder* kv_state_cache_builder) { -// radix_tree = new RadixTree(10); -// radix_tree->SetCustomData(kv_state_cache_builder, -// sizeof(KVStateCacheBuilder)); - -// for (int i = 0; i < TOKEN_NUM; i++) { -// tokens.push_back(i); -// } - -// LOG(INFO) << "stage 1"; -// for (int i = 0; i < TOKEN_NUM; i++) { -// std::vector key_state; -// for (int j = 0; j < DIMENSION; ++j) { -// key_state.push_back(((double) (j)) * 0.1 + (double) i); -// } -// k_state_list.push_back(key_state); -// } - -// LOG(INFO) << "stage 2"; -// for (int i = 0; i < TOKEN_NUM; i++) { -// std::vector value_state; -// for (int j = 0; j < DIMENSION; ++j) { -// value_state.push_back(((double) (j)) * 0.1 + (double) i); -// } -// v_state_list.push_back(value_state); -// } -// } - -// void updateTest(Client& client, KVStateCacheBuilder* builder) { -// std::vector prefix; - -// for (size_t i = 0; i < tokens.size(); ++i) { -// KV_STATE_WITH_LAYER kv_state; -// kv_state.insert( -// std::make_pair(1, std::make_pair(k_state_list[i], v_state_list[i]))); -// LOG(INFO) << "update test"; -// builder->Update(client, prefix, tokens[i], kv_state); -// prefix.push_back(tokens[i]); -// } -// } - -// void queryTest(Client& client, KVStateCacheBuilder* builder) { -// std::vector prefix; -// KV_STATE_WITH_LAYER kv_state; - -// for (int i = 0; i < TOKEN_NUM; i++) { -// kv_state = builder->Query(client, prefix, tokens[i]); -// std::vector key_state = kv_state[1].first; -// std::vector value_state = kv_state[1].second; - -// VINEYARD_ASSERT( -// key_state.size() == (size_t) DIMENSION, -// "Expected key_state.size() == " + std::to_string(DIMENSION) + -// ", but got + key_state.size() == " + -// std::to_string(key_state.size())); -// VINEYARD_ASSERT( -// value_state.size() == (size_t) DIMENSION, -// "Expected value_state.size() == " + std::to_string(DIMENSION) + -// ", but got + value_state.size() == " + -// std::to_string(value_state.size())); -// for (int j = 0; j < DIMENSION; ++j) { -// VINEYARD_ASSERT(key_state[j] == k_state_list[i][j], -// "Expected key_state[" + std::to_string(j) + -// "] == " + std::to_string(k_state_list[i][j]) + -// ", but got + key_state[" + std::to_string(j) + -// "] == " + std::to_string(key_state[j])); -// VINEYARD_ASSERT(value_state[j] == v_state_list[i][j], -// "Expected value_state[" + std::to_string(j) + -// "] == " + std::to_string(v_state_list[i][j]) + -// ", but got + value_state[" + std::to_string(j) + -// "] == " + std::to_string(value_state[j])); -// } -// prefix.push_back(tokens[i]); -// } -// } - -void sealAndConstructTest(Client& client, KVStateCacheBuilder* builder) { - // ObjectID id = builder->_Seal(client)->id(); - // std::shared_ptr kv_state_cache = - // std::dynamic_pointer_cast(client.GetObject(id)); - // std::vector> kv_state_cache_block_list = - // kv_state_cache->GetKVStateCacheBlockList(); - // std::vector kv_state_cache_block_builder_list = - // builder->GetKVStateCacheBlockBuilderList(); - // for (int i = 0; i < kv_state_cache_block_list.size(); i++) { - // std::shared_ptr kv_state_cache_block = - // kv_state_cache_block_list[i]; - // KVStateCacheBlockBuilder* kv_state_cache_block_builder = - // kv_state_cache_block_builder_list[i]; - - // // compare kv_state_cache_block and kv_state_cache_block_builder - // VINEYARD_ASSERT(kv_state_cache_block->GetDimension() == - // kv_state_cache_block_builder->GetDimension()); - - // VINEYARD_ASSERT(kv_state_cache_block->GetBitmap() == - // kv_state_cache_block_builder->GetBitmap()); - - // LOG(INFO) << "Bitmap:"; - // LOG(INFO) << kv_state_cache_block_builder->GetBitmapStr(); - // LOG(INFO) << kv_state_cache_block->GetBitmapStr(); - - // const std::shared_ptr> k_tensor_builder = - // kv_state_cache_block_builder->getKBuilder(); - // const std::shared_ptr> v_tensor_builder = - // kv_state_cache_block_builder->getVBuilder(); - - // std::shared_ptr> k_tensor = - // kv_state_cache_block->GetKTensor(); - // std::shared_ptr> v_tensor = - // kv_state_cache_block->GetVTensor(); - - // for (int i = 0; i < TOKEN_NUM; i++) { - // for (int j = 0; j < DIMENSION; j++) { - // VINEYARD_ASSERT(k_tensor->data()[i * DIMENSION + j] == - // k_tensor_builder->data()[i * DIMENSION + j]); - // VINEYARD_ASSERT(v_tensor->data()[i * DIMENSION + j] == - // v_tensor_builder->data()[i * DIMENSION + j]); - // } - // } - // } -} - -void splitTest(Client& client, KVStateCacheBuilder* builder) {} - -int main() { - // std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET")); - // Client client; - // client.Connect(socket); - - // LOG(INFO) << "Build kv state cache"; - // KVStateCacheBuilder* kv_state_cache_builder = - // new KVStateCacheBuilder(client, DIMENSION, CACHE_CAPACITY); - - // LOG(INFO) << "Prepare data"; - // prepareData(kv_state_cache_builder); - - // LOG(INFO) << "Test update"; - // updateTest(client, kv_state_cache_builder); - - // LOG(INFO) << "Test query"; - // queryTest(client, kv_state_cache_builder); - - // LOG(INFO) << "Test seal and construct"; - // sealAndConstructTest(client, kv_state_cache_builder); - - return 0; -} \ No newline at end of file diff --git a/test/kv_state_cache_radix_tree_test.cc b/test/kv_state_cache_radix_tree_test.cc index 2c06f8c5..76fd7f30 100644 --- a/test/kv_state_cache_radix_tree_test.cc +++ b/test/kv_state_cache_radix_tree_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "common/util/logging.h" #include "kv-state-cache/utils/kv_state_cache_utils.h" -using namespace vineyard; +using namespace vineyard; // NOLINT(build/namespaces) void print_tokens(const std::vector& tokens) { std::string tokens_str = ""; @@ -188,4 +188,4 @@ int main() { LOG(INFO) << "Start to test radix tree split..."; radix_tree_split(); LOG(INFO) << "Finish radix tree split test!"; -} \ No newline at end of file +} diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index fa59f48b..58c9c721 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -22,19 +22,33 @@ limitations under the License. #include "common/util/logging.h" #include "kv-state-cache/utils/kv_state_cache_utils.h" -using namespace vineyard; - -#define DIMENSION 10 -#define CAPACITY 20 -#define LAYER 3 -#define BLOCK_SIZE 5 - -void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER, BLOCK_SIZE); } +using namespace vineyard; // NOLINT(build/namespaces) + +int dimension = 10; +int capacity = 20; +int layer = 3; +int block_size = 5; + +std::vector round_1_tokens = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69}; +std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14}; +std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; +std::vector round_4_tokens = {1, 2, 3, 4, 5, 6}; + +std::vector> tokens_list; + +void init(int dimension, int capacity, int layer, int block_size, + std::string socket) { + InitKVStateCache(dimension, capacity, layer, block_size, socket); +} void print_current_tokens(const std::vector& prefix, int next_token) { std::string tokens_str = ""; for (size_t i = 0; i < prefix.size(); ++i) { - tokens_str += std::to_string(prefix[i]); + tokens_str += std::to_string(prefix[i]) + " "; } tokens_str += std::to_string(next_token); LOG(INFO) << "Current tokens: " + tokens_str; @@ -47,7 +61,7 @@ void print_kv_state( for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { std::string key_state_str = ""; std::string value_state_str = ""; - for (int i = 0; i < DIMENSION; ++i) { + for (int i = 0; i < dimension; ++i) { key_state_str += std::to_string(iter->second.first[i]) + " "; value_state_str += std::to_string(iter->second.second[i]) + " "; } @@ -62,13 +76,14 @@ void print_kv_state( std::map, std::vector>> generate_kv_state(int token) { std::map, std::vector>> kv_state; - for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { std::vector key_state; std::vector value_state; - for (int i = 0; i < DIMENSION; ++i) { - key_state.push_back(((double) token) / DIMENSION * (i + 1) + + for (int i = 0; i < dimension; ++i) { + key_state.push_back((static_cast(token)) / dimension * (i + 1) + currentLayer * 10); - value_state.push_back(((double) token) / DIMENSION * (i + 1) * 2 + + value_state.push_back((static_cast(token)) / dimension * (i + 1) * + 2 + currentLayer * 10); } @@ -78,63 +93,116 @@ generate_kv_state(int token) { return kv_state; } +void check_kv_state( + const std::map, std::vector>>& + kv_state, + int& token) { + VINEYARD_ASSERT(kv_state.size() == (size_t) layer); + for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { + VINEYARD_ASSERT(iter->second.first.size() == (size_t) dimension); + VINEYARD_ASSERT(iter->second.second.size() == (size_t) dimension); + for (int i = 0; i < dimension; ++i) { + if (iter->second.first[i] != + (static_cast(token)) / dimension * (i + 1) + + iter->first * 10) { + LOG(INFO) << "token:" << token << " dimension" << dimension + << " layer:" << iter->first; + LOG(INFO) << "key_state[" << i << "]: " << iter->second.first[i] + << ". But is should be " + << (static_cast(token)) / dimension * (i + 1) + + iter->first * 10; + throw std::runtime_error("key_state error!"); + } + if (iter->second.second[i] != + (static_cast(token)) / dimension * (i + 1) * 2 + + iter->first * 10) { + LOG(INFO) << "token:" << token << " dimension" << dimension + << " layer:" << iter->first; + LOG(INFO) << "value_state[" << i << "]: " << iter->second.second[i] + << ". But is should be " + << (static_cast(token)) / dimension * (i + 1) * 2 + + iter->first * 10; + throw std::runtime_error("value_state error!"); + } + } + } +} + void inference(std::vector tokens, bool block = false) { - LOG(INFO) << "inference"; std::vector inference_tokens; std::map, std::vector>> kv_state; for (size_t i = 0; i < tokens.size(); ++i) { kv_state = Query(inference_tokens, tokens[i]); if (kv_state.size() == 0) { - LOG(INFO) << "======================================"; LOG(INFO) << "Can not find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); LOG(INFO) << "Generate the kv_state and update the cache."; kv_state = generate_kv_state(tokens[i]); print_kv_state(kv_state); Update(inference_tokens, tokens[i], kv_state); - LOG(INFO) << "======================================"; } else { - LOG(INFO) << "--------------------------------------"; LOG(INFO) << "Find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); - print_kv_state(kv_state); - LOG(INFO) << "--------------------------------------"; + check_kv_state(kv_state, tokens[i]); } + LOG(INFO) << "--------------------------------------"; inference_tokens.push_back(tokens[i]); } } -int main() { - init(); - std::vector round_1_tokens = { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69}; - std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, - 9, 10, 11, 12, 13, 14}; - std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; - std::vector round_4_tokens = {1, 2, 3, 4, 5, 6}; - - inference(round_1_tokens); - inference(round_2_tokens); - sleep(5); +int main(int argc, char** argv) { + if (argc < 2) { + printf("usage ./kv_state_cache_test "); + return 1; + } + std::string ipc_socket = std::string(argv[1]); + + for (int i = 2; i < argc; i++) { + if (strcmp(argv[i], "-d") == 0) { + dimension = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-c") == 0) { + capacity = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-l") == 0) { + layer = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-b") == 0) { + block_size = atoi(argv[i + 1]); + } + if (strcmp(argv[i], "-s") == 0) { + for (int j = i + 1; j < argc; j++) { + if (strcmp(argv[j], "1") == 0) { + tokens_list.push_back(round_1_tokens); + } else if (strcmp(argv[j], "2") == 0) { + tokens_list.push_back(round_2_tokens); + } else if (strcmp(argv[j], "3") == 0) { + tokens_list.push_back(round_3_tokens); + } else if (strcmp(argv[j], "4") == 0) { + tokens_list.push_back(round_4_tokens); + } else { + break; + } + } + } + } + + LOG(INFO) << "Test KVStateCache with dimension: " << dimension + << ", capacity: " << capacity << ", layer: " << layer + << ", block_size: " << block_size << "."; + + init(dimension, capacity, layer, block_size, ipc_socket); + + for (size_t i = 0; i < tokens_list.size(); i++) { + inference(tokens_list[i]); + } - inference(round_1_tokens); - inference(round_2_tokens); - inference(round_3_tokens); sleep(5); - inference(round_3_tokens); - // inference(round_4_tokens); - // sleep(5); - // Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + - // 6)); Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() - // + 5)); Delete(std::vector(round_4_tokens.begin(), - // round_4_tokens.begin() + 4)); - // Delete(std::vector(round_4_tokens.begin(), round_4_tokens.begin() + - // 3)); - while (1) - ; + + for (size_t i = 0; i < tokens_list.size(); i++) { + inference(tokens_list[i]); + } + + LOG(INFO) << "inference end"; + CloseKVStateCache(); + LOG(INFO) << "Passed KVStateCache tests..."; return 0; -} \ No newline at end of file +} diff --git a/test/kv_state_cache_test_2.cc b/test/kv_state_cache_test_2.cc deleted file mode 100644 index db7ab2ed..00000000 --- a/test/kv_state_cache_test_2.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include -#include -#include -#include "kv-state-cache/radix-tree/radix.h" - -#include "common/util/logging.h" -#include "kv-state-cache/utils/kv_state_cache_utils.h" - -using namespace vineyard; - -#define DIMENSION 10 -#define CAPACITY 20 -#define LAYER 3 - -void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER); } - -void print_current_tokens(const std::vector& prefix, int next_token) { - std::string tokens_str = ""; - for (size_t i = 0; i < prefix.size(); ++i) { - tokens_str += std::to_string(prefix[i]); - } - tokens_str += std::to_string(next_token); - LOG(INFO) << "Current tokens: " + tokens_str; -} - -void print_kv_state( - const std::map, std::vector>>& - kv_state) { - LOG(INFO) << "kv_state: "; - for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { - std::string key_state_str = ""; - std::string value_state_str = ""; - for (int i = 0; i < DIMENSION; ++i) { - key_state_str += std::to_string(iter->second.first[i]) + " "; - value_state_str += std::to_string(iter->second.second[i]) + " "; - } - LOG(INFO) << "layer " << iter->first << ":"; - LOG(INFO) << "key_state: " << key_state_str; - LOG(INFO) << "value_state: " << value_state_str; - LOG(INFO) << "---------------------"; - } -} - -// we do not consider the layer. -std::map, std::vector>> -generate_kv_state(int token) { - std::map, std::vector>> kv_state; - for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { - std::vector key_state; - std::vector value_state; - for (int i = 0; i < DIMENSION; ++i) { - key_state.push_back(((double) token) / DIMENSION * (i + 1) + - currentLayer * 10); - value_state.push_back(((double) token) / DIMENSION * (i + 1) * 2 + - currentLayer * 10); - } - - kv_state.insert( - std::make_pair(currentLayer, std::make_pair(key_state, value_state))); - } - return kv_state; -} - -void inference(std::vector tokens, bool block = false) { - LOG(INFO) << "inference"; - std::vector inference_tokens; - std::map, std::vector>> kv_state; - - for (size_t i = 0; i < tokens.size(); ++i) { - kv_state = Query(inference_tokens, tokens[i]); - if (kv_state.size() == 0) { - LOG(INFO) << "======================================"; - LOG(INFO) << "Can not find the kv_state from cache:"; - print_current_tokens(inference_tokens, tokens[i]); - LOG(INFO) << "Generate the kv_state and update the cache."; - kv_state = generate_kv_state(tokens[i]); - print_kv_state(kv_state); - Update(inference_tokens, tokens[i], kv_state); - LOG(INFO) << "======================================"; - } else { - LOG(INFO) << "--------------------------------------"; - LOG(INFO) << "Find the kv_state from cache:"; - print_current_tokens(inference_tokens, tokens[i]); - print_kv_state(kv_state); - LOG(INFO) << "--------------------------------------"; - } - inference_tokens.push_back(tokens[i]); - } -} - -int main() { - init(); - std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, - 9, 10, 11, 12, 13, 14}; - std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; - // std::vector round_3_tokens = {1, 2, 3, 4, 5, 6, 7}; - // std::vector round_1_tokens = {1, 2}; - // std::vector round_2_tokens = {1, 3}; - // std::vector round_3_tokens = {1, 3, 4}; - // std::vector round_4_tokens = {1, 3, 5}; - // std::vector round_5_tokens = {1, 1}; - inference(round_1_tokens); - // inference(round_1_tokens); - inference(round_3_tokens); - sleep(5); - inference(round_1_tokens); - inference(round_3_tokens); - // inference(round_2_tokens); - - // inference(round_3_tokens); - // inference(round_3_tokens); - // inference(round_4_tokens); - // inference(round_5_tokens); - // sleep(5); - // inference(round_2_tokens); - // inference(round_1_tokens, true); - while (1) - ; - return 0; -} \ No newline at end of file diff --git a/test/rax_diff_test.cc b/test/rax_diff_test.cc deleted file mode 100644 index ce1dcf1b..00000000 --- a/test/rax_diff_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include -#include -#include "kv-state-cache/radix-tree/radix.h" - -int key_1[] = {1, 2}; -int key_2[] = {1, 3}; -int key_3[] = {1, 4}; -int key_4[] = {1, 3, 1}; -int key_5[] = {1, 3, 2}; - -void insert(rax* rt, int* key, int len) { - for (int i = 1; i <= len; i++) { - raxInsert(rt, key, i, NULL, NULL); - } -} - -int main(int argc, char** argv) { - rax* rt_1 = raxNew(); - rax* rt_2 = raxNew(); - - int max_node = argc > 1 ? atoi(argv[1]) : 3; - - // raxInsert(rt_2, key_1, 2, NULL, NULL); - // raxInsert(rt_2, key_2, 2, NULL, NULL); - - // raxInsert(rt_1, key_3, 2, NULL, NULL); - // raxInsert(rt_1, key_4, 3, NULL, NULL); - // raxInsert(rt_1, key_5, 3, NULL, NULL); - - insert(rt_1, key_3, 2); - insert(rt_1, key_4, 3); - - sleep(1); - - insert(rt_2, key_1, 2); - insert(rt_2, key_2, 2); - - sleep(1); - - insert(rt_1, key_5, 3); - - raxShow(rt_1); - printf("==============================\n"); - raxShow(rt_2); - printf("==============================\n"); - - testIteRax(rt_1); - printf("==============================\n"); - testIteRax(rt_2); - printf("==============================\n"); - - std::vector> evicted_tokens; - std::set> insert_tokens; - mergeTree(rt_1, rt_2, evicted_tokens, insert_tokens, max_node); - - printf("evicted_tokens:\n"); - for (size_t i = 0; i < evicted_tokens.size(); i++) { - for (size_t j = 0; j < evicted_tokens[i].size(); j++) { - printf("%d ", evicted_tokens[i][j]); - } - printf("\n"); - } - for (size_t i = 0; i < evicted_tokens.size(); i++) { - // void* tree_data; - raxRemove(rt_1, evicted_tokens[i].data(), evicted_tokens[i].size(), NULL, - false); - } - - for (auto it = insert_tokens.begin(); it != insert_tokens.end(); it++) { - raxInsert(rt_1, const_cast(it->data()), it->size(), NULL, NULL, - false); - } - - raxShow(rt_1); - printf("==============================\n"); - raxShow(rt_2); - printf("==============================\n"); - - testIteRax(rt_1); - printf("==============================\n"); - testIteRax(rt_2); - printf("==============================\n"); - - return 0; -} \ No newline at end of file diff --git a/test/runner.py b/test/runner.py index 94c65826..f0a7a7ee 100755 --- a/test/runner.py +++ b/test/runner.py @@ -474,6 +474,7 @@ def run_vineyard_cpp_tests(meta, allocator, endpoints, tests): run_test(tests, 'tensor_test') run_test(tests, 'typename_test') run_test(tests, 'version_test') + run_test(tests, 'kv_state_cache_radix_tree_test') def run_vineyard_spill_tests(meta, allocator, endpoints, tests): @@ -690,6 +691,35 @@ def run_scale_in_out_tests(meta, allocator, endpoints, instance_size=4): time.sleep(5) +def run_llm_tests(meta, allocator, endpoints): + meta_prefix = 'vineyard_test_%s' % time.time() + metadata_settings = make_metadata_settings(meta, endpoints, meta_prefix) + + instance_size = 2 + with start_multiple_vineyardd( + metadata_settings, + ['--allocator', allocator], + default_ipc_socket=VINEYARD_CI_IPC_SOCKET, + instance_size=instance_size, + nowait=False, + ) as instances: # noqa: F841, pylint: disable=unused-variable + vineyard_ipc_socket_1 = '%s.%d' % (VINEYARD_CI_IPC_SOCKET, 0) + vineyard_ipc_socket_2 = '%s.%d' % (VINEYARD_CI_IPC_SOCKET, 1) + + rpc_socket_port = instances[0][1] + subprocess.check_call( + [ + './build/bin/kv_state_cache_multi_test', + '--vineyard-endpoint', + 'localhost:%s' % rpc_socket_port, + '--vineyard-ipc-sockets', + vineyard_ipc_socket_1, + vineyard_ipc_socket_2, + ], + cwd=os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'), + ) + + def run_python_deploy_tests(meta, allocator, endpoints, test_args, with_migration): meta_prefix = 'vineyard_test_%s' % time.time() metadata_settings = make_metadata_settings(meta, endpoints, meta_prefix) @@ -862,6 +892,12 @@ def parse_sys_args(): default=False, help='Whether to run deployment and scaling in/out tests', ) + arg_parser.add_argument( + '--with-llm', + action='store_true', + default=False, + help='Whether to run llm tests', + ) arg_parser.add_argument( '--with-migration', action='store_true', @@ -972,6 +1008,14 @@ def execute_tests(args): with start_metadata_engine(args.meta) as (_, endpoints): run_fuse_test(args.meta, args.allocator, endpoints, python_test_args) + if args.with_llm: + with start_metadata_engine(args.meta) as (_, endpoints): + run_llm_tests( + args.meta, + args.allocator, + endpoints, + ) + def main(): parser, args = parse_sys_args() @@ -987,6 +1031,7 @@ def main(): or args.with_deployment or args.with_io or args.with_fuse + or args.with_llm ): print( 'Error: \n\tat least one of of --with-{cpp,graph,python,io,fuse} needs ' From 44d7c0085d76ee6bd95b5f27dcdd1963e6e98ddb Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Wed, 28 Feb 2024 11:59:23 +0800 Subject: [PATCH 13/20] Disable code formatting of radix tree and fix a bug of serialization and deserialization (#1771) Signed-off-by: Ye Cao --- CMakeLists.txt | 2 +- .../kv-state-cache/radix-tree/radix-tree.cc | 110 ++++++++++-------- .../kv-state-cache/radix-tree/radix-tree.h | 19 +-- modules/kv-state-cache/radix-tree/radix.h | 69 ++++++----- .../kv-state-cache/radix-tree/rax_malloc.h | 7 +- 5 files changed, 118 insertions(+), 89 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 951376cd..10998b31 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -981,7 +981,7 @@ endfunction() file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark" PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$" - EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(modules/kv-state-cache/radix-tree/ra.*)" + EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(.*modules/kv-state-cache/radix-tree/radix\.(cc|h)$)" ) # the `memcpy.h` is borrowed from external project diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index c528f44c..45b8f1e8 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "radix-tree.h" +#include "kv-state-cache/radix-tree/radix-tree.h" #include "common/util/base64.h" #include "common/util/logging.h" @@ -21,7 +21,7 @@ limitations under the License. #include "zstd/lib/zstd.h" -using namespace vineyard; +using namespace vineyard; // NOLINT(build/namespaces) RadixTree::RadixTree(int cacheCapacity) { this->tree = raxNew(); @@ -52,8 +52,8 @@ RadixTree::~RadixTree() { raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), rootToken.size(), NULL, false); if (dataNode != nullptr) { - delete (DataWrapper*) dataNode->custom_data; - delete (DataWrapper*) raxGetData(dataNode); + delete reinterpret_cast(dataNode->custom_data); + delete reinterpret_cast(raxGetData(dataNode)); } raxFree(this->tree); @@ -99,7 +99,7 @@ std::shared_ptr RadixTree::InsertInternal( raxNode* dataNode = NULL; int retval = raxInsertAndReturnDataNode( this->tree, insertTokensArray, insertTokensArrayLen, dummyData, - (void**) &dataNode, (void**) &oldData); + reinterpret_cast(&dataNode), reinterpret_cast(&oldData)); if (dataNode == NULL) { throw std::runtime_error("Insert token list failed"); return NULL; @@ -140,8 +140,8 @@ std::shared_ptr RadixTree::InsertInternal( if (subTreeNode == nullptr) { return std::make_shared(dummyData, nullptr); } - return std::make_shared(dummyData, - (DataWrapper*) subTreeNode->custom_data); + return std::make_shared( + dummyData, reinterpret_cast(subTreeNode->custom_data)); } void RadixTree::DeleteInternal(std::vector tokens, @@ -162,10 +162,10 @@ void RadixTree::DeleteInternal(std::vector tokens, nodeIsSubTree = true; } int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, - (void**) &oldData); + reinterpret_cast(&oldData)); if (retval == 1) { evictedNode = std::make_shared( - oldData, (DataWrapper*) subTreeNode->custom_data); + oldData, reinterpret_cast(subTreeNode->custom_data)); nodeCount--; if (nodeIsSubTree) { evictedNode->cleanTreeData = true; @@ -193,8 +193,9 @@ std::shared_ptr RadixTree::QueryInternal(std::vector key) { return NULL; } - return std::make_shared((DataWrapper*) raxGetData(dataNode), - (DataWrapper*) subTreeNode->custom_data); + return std::make_shared( + reinterpret_cast(raxGetData(dataNode)), + reinterpret_cast(subTreeNode->custom_data)); } std::string RadixTree::Serialize() { @@ -241,10 +242,13 @@ std::string RadixTree::Serialize() { serializedStr += subTreeSizeOSS.str() + "|"; // convert data to hex string - char* bytes = (char*) ((DataWrapper*) dataList[index])->data; + char* bytes = reinterpret_cast( + (reinterpret_cast(dataList[index]))->data); std::ostringstream dataOSS; - for (int i = 0; i < ((DataWrapper*) dataList[index])->dataLength; i++) { + for (int i = 0; + i < (reinterpret_cast(dataList[index]))->dataLength; + i++) { dataOSS << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); } @@ -264,17 +268,22 @@ std::string RadixTree::Serialize() { serializedStr += "|"; // convert custom data to hex string - char* bytes = (char*) ((DataWrapper*) subTreeDataList[index])->data; + char* bytes = reinterpret_cast( + (reinterpret_cast(subTreeDataList[index]))->data); std::ostringstream dataOSS; - LOG(INFO) << "data length:" - << ((DataWrapper*) subTreeDataList[index])->dataLength; - for (int i = 0; i < ((DataWrapper*) subTreeDataList[index])->dataLength; + LOG(INFO) + << "data length:" + << (reinterpret_cast(subTreeDataList[index]))->dataLength; + for (int i = 0; + i < + (reinterpret_cast(subTreeDataList[index]))->dataLength; ++i) { dataOSS << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); } - LOG(INFO) << "data:" << ((DataWrapper*) subTreeDataList[index])->data; + LOG(INFO) << "data:" + << (reinterpret_cast(subTreeDataList[index]))->data; LOG(INFO) << "data oss:" << dataOSS.str(); serializedStr += dataOSS.str() + "\n"; } @@ -282,10 +291,11 @@ std::string RadixTree::Serialize() { // use ZSTD to compress the serialized string size_t srcSize = serializedStr.size(); - std::string compressedStr(srcSize, '\0'); - int compressedSize = - ZSTD_compress((void*) (compressedStr.c_str()), compressedStr.length(), - serializedStr.c_str(), srcSize, 3); + size_t dstSize = ZSTD_compressBound(srcSize); + std::string compressedStr(dstSize + 1, '\0'); + LOG(INFO) << "src size:" << srcSize << " dst size:" << dstSize; + int compressedSize = ZSTD_compress(compressedStr.data(), compressedStr.size(), + serializedStr.c_str(), srcSize, 3); if (ZSTD_isError(compressedSize)) { LOG(ERROR) << "ZSTD compression failed: " << ZSTD_getErrorName(compressedSize); @@ -293,9 +303,10 @@ std::string RadixTree::Serialize() { int cacheCapacity = this->cacheCapacity - 1; std::string result = - std::string((char*) &srcSize, sizeof(int)) + - std::string((char*) &cacheCapacity, sizeof(int)) + - std::string((char*) &(this->tree->head->numnodes), sizeof(uint32_t)) + + std::string(reinterpret_cast(&compressedSize), sizeof(int)) + + std::string(reinterpret_cast(&cacheCapacity), sizeof(int)) + + std::string(reinterpret_cast(&(this->tree->head->numnodes)), + sizeof(uint32_t)) + compressedStr; return result; @@ -304,16 +315,23 @@ std::string RadixTree::Serialize() { std::shared_ptr RadixTree::Deserialize(std::string data) { LOG(INFO) << "Deserialize......"; // use LZ4 to decompress the serialized string - int srcSize = *(int*) data.c_str(); + int compressedSize = *reinterpret_cast(data.data()); data.erase(0, sizeof(int)); - int cacheCapacity = *(int*) data.c_str(); + int cacheCapacity = *reinterpret_cast(data.data()); data.erase(0, sizeof(int)); - int rootNumNodes = *(uint32_t*) data.c_str(); + int rootNumNodes = *reinterpret_cast(data.data()); data.erase(0, sizeof(uint32_t)); - std::string decompressedStr(srcSize, '\0'); + int ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); + if (ds == ZSTD_CONTENTSIZE_ERROR) { + LOG(ERROR) << "Error: not a valid compressed frame"; + } else if (ds == ZSTD_CONTENTSIZE_UNKNOWN) { + LOG(ERROR) + << "Error: original size unknown. Use streaming decompression instead."; + } + + std::string decompressedStr(ds + 1, '\0'); int decompressedSize = - ZSTD_decompress((void*) (decompressedStr.c_str()), decompressedStr.size(), - data.c_str(), srcSize); + ZSTD_decompress(decompressedStr.data(), ds, data.c_str(), compressedSize); if (ZSTD_isError(decompressedSize)) { LOG(ERROR) << "ZSTD decompression failed: " << ZSTD_getErrorName(decompressedSize); @@ -338,7 +356,6 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { line.pop_back(); continue; } - LOG(INFO) << "data line:" << line << std::endl; std::istringstream lineStream(line); std::string tokenListPart, timestampPart, dataPart, subTreeSizePart; @@ -357,7 +374,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { } } if (!std::getline(lineStream, dataPart)) { - LOG(INFO) << "data length is 0"; + LOG(ERROR) << "data length is 0"; } std::istringstream keyStream(tokenListPart); @@ -371,17 +388,14 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { if (isMainTree) { std::istringstream timestampStream(timestampPart); if (!(timestampStream >> std::hex >> timestamp)) { - LOG(INFO) << "Invalid timestamp format."; - throw std::runtime_error("Invalid timestamp format."); + LOG(ERROR) << "Invalid timestamp format."; } std::istringstream subTreeSizeStream(subTreeSizePart); uint32_t subTreeSize; if (!(subTreeSizeStream >> std::hex >> subTreeSize)) { - LOG(INFO) << "Invalid sub tree size format."; - throw std::runtime_error("Invalid sub tree size format."); + LOG(ERROR) << "Invalid sub tree size format."; } - LOG(INFO) << "Deserialize sub tree size:" << subTreeSize; subTreeSizeList.push_back(subTreeSize); } @@ -455,8 +469,8 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { // TBD // check retval raxInsertAndReturnDataNode(radixTree->tree, insertTokensArray, - insertTokensArrayLen, data, (void**) &dataNode, - NULL); + insertTokensArrayLen, data, + reinterpret_cast(&dataNode), NULL); if (dataNode == NULL) { throw std::runtime_error("Insert token list failed"); @@ -517,7 +531,7 @@ std::vector> RadixTree::SplitInternal( treeData->dataLength = 0; subTreeRootNode->custom_data = treeData; header = std::make_shared( - (DataWrapper*) raxGetData(subTreeRootNode), treeData); + reinterpret_cast(raxGetData(subTreeRootNode)), treeData); return TraverseTreeWithoutSubTree(subTreeRootNode); } @@ -536,8 +550,8 @@ std::vector> RadixTree::TraverseTreeWithoutSubTree( LOG(INFO) << "data node list:" << dataNodeList.size(); for (size_t i = 0; i < dataNodeList.size(); i++) { nodes.push_back(std::make_shared( - (DataWrapper*) raxGetData(dataNodeList[i]), - (DataWrapper*) dataNodeList[i]->custom_data)); + reinterpret_cast(raxGetData(dataNodeList[i])), + reinterpret_cast(dataNodeList[i]->custom_data))); } return nodes; } @@ -555,8 +569,9 @@ void RadixTree::ClearSubtreeData(void* data) { std::shared_ptr RadixTree::GetRootNode() { raxNode* node = raxFindAndReturnDataNode(this->tree, rootToken.data(), rootToken.size(), NULL); - return std::make_shared((DataWrapper*) raxGetData(node), - (DataWrapper*) node->custom_data); + return std::make_shared( + reinterpret_cast(raxGetData(node)), + reinterpret_cast(node->custom_data)); } void RadixTree::MergeTree(std::shared_ptr tree_1, @@ -591,7 +606,8 @@ std::set RadixTree::GetAllNodeData() { if (node->isnull) { continue; } - nodeDataSet.insert(((DataWrapper*) raxGetData(node))->data); + nodeDataSet.insert( + (reinterpret_cast(raxGetData(node)))->data); } return nodeDataSet; -} \ No newline at end of file +} diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/kv-state-cache/radix-tree/radix-tree.h index 8a2d5a95..23774ed0 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/kv-state-cache/radix-tree/radix-tree.h @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef RADIX_TREE_H -#define RADIX_TREE_H +#ifndef MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_ +#define MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_ -#include "radix.h" - -#include "common/util/base64.h" -#include "common/util/logging.h" +#include "kv-state-cache/radix-tree/radix.h" #include #include #include #include +#include #include -using namespace vineyard; +#include "common/util/base64.h" +#include "common/util/logging.h" + +using namespace vineyard; // NOLINT(build/namespaces) struct DataWrapper { void* data; @@ -75,7 +76,7 @@ class RadixTree : public std::enable_shared_from_this { std::vector tokens, std::shared_ptr& header); public: - RadixTree(int cacheCapacity); + RadixTree(int cacheCapacity); // NOLINT(runtime/explicit) ~RadixTree(); @@ -117,4 +118,4 @@ class RadixTree : public std::enable_shared_from_this { std::set GetAllNodeData(); }; -#endif +#endif // MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_" diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/kv-state-cache/radix-tree/radix.h index 57da727c..cbfebf2b 100644 --- a/modules/kv-state-cache/radix-tree/radix.h +++ b/modules/kv-state-cache/radix-tree/radix.h @@ -33,13 +33,13 @@ #include #include -#include -#include -#include -#include #include +#include +#include +#include #include #include +#include /* Representation of a radix tree as implemented in this file, that contains * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after @@ -110,9 +110,9 @@ typedef struct raxNode { uint32_t size : 26; /* Number of children, or compressed string len. */ uint32_t numnodes; /* Number of the child nodes */ uint32_t numele; /* Number of elements inside this node. */ - uint64_t timestamp; /* Timestamps of the node */ + uint64_t timestamp; /* Timestamps of the node */ uint32_t sub_tree_size; /* Number of nodes in the sub tree */ - void *custom_data; + void* custom_data; /* Data layout is as follows: * * If node is not compressed we have 'size' bytes, one for each children @@ -196,9 +196,10 @@ typedef struct raxIterator { size_t key_len; /* Current key length. */ size_t key_max; /* Max key len the current key buffer can hold. */ int key_static_tokens[RAX_ITER_STATIC_LEN]; - bool add_to_subtree_list; /* Whether to add the current node to the subtree list. */ - std::vector> *subtree_list; /* List of subtrees. */ - std::vector *subtree_data_list; /* List of subtrees' data. */ + bool add_to_subtree_list; /* Whether to add the current node to the subtree + list. */ + std::vector>* subtree_list; /* List of subtrees. */ + std::vector* subtree_data_list; /* List of subtrees' data. */ raxNode* node; /* Current node. Only for unsafe iteration. */ raxStack stack; /* Stack used for unsafe iteration. */ raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ @@ -209,19 +210,24 @@ extern void* raxNotFound; /* Exported API. */ rax* raxNew(void); -int raxInsert(rax *rax, int *s, size_t len, void *data, void **old, bool set_timestamp = true); +int raxInsert(rax* rax, int* s, size_t len, void* data, void** old, + bool set_timestamp = true); int raxTryInsert(rax* rax, int* s, size_t len, void* data, void** old); int raxInsertAndReturnDataNode(rax* rax, int* s, size_t len, void* data, void** node, void** old); -int raxRemove(rax* rax, int* s, size_t len, void** old, bool set_timestamp = true); +int raxRemove(rax* rax, int* s, size_t len, void** old, + bool set_timestamp = true); void* raxFind(rax* rax, int* s, size_t len); -raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len, raxNode** sub_tree_node = NULL, bool set_timestamp = true); -void raxSetSubtree(raxNode *n); -void raxSetSubtreeAllocated(raxNode *node); -void raxSetSubtreeNotNull(raxNode *node); -int raxFindNodeWithParent(rax *rax, int *s, size_t len, void **node, void **parent); +raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len, + raxNode** sub_tree_node = NULL, + bool set_timestamp = true); +void raxSetSubtree(raxNode* n); +void raxSetSubtreeAllocated(raxNode* node); +void raxSetSubtreeNotNull(raxNode* node); +int raxFindNodeWithParent(rax* rax, int* s, size_t len, void** node, + void** parent); void raxFree(rax* rax); -void raxFreeWithCallback(rax *rax, void (*free_callback)(raxNode *)); +void raxFreeWithCallback(rax* rax, void (*free_callback)(raxNode*)); void raxStart(raxIterator* it, rax* rt); int raxSeek(raxIterator* it, const char* op, int* ele, size_t len); int raxNext(raxIterator* it); @@ -232,25 +238,30 @@ void raxStop(raxIterator* it); int raxEOF(raxIterator* it); void raxShow(rax* rax); uint64_t raxSize(rax* rax); -void raxSetCustomData(raxNode *n, void *data); -void *raxGetCustomData(raxNode *n); +void raxSetCustomData(raxNode* n, void* data); +void* raxGetCustomData(raxNode* n); unsigned long raxTouch(raxNode* n); void raxSetDebugMsg(int onoff); void raxTraverse(raxNode* rax, std::vector>& dataNodeList); -void raxTraverseSubTree(raxNode* n, std::vector &dataNodeList); -raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& key); -void raxSerialize(rax* root, std::vector>& tokenList, std::vector& dataList, std::vector ×tampsList, - std::vector> *subtreeList, std::vector *subtreeNodeList); +void raxTraverseSubTree(raxNode* n, std::vector& dataNodeList); +raxNode* raxSplit(rax* rax, int* s, size_t len, std::vector& key); +void raxSerialize(rax* root, std::vector>& tokenList, + std::vector& dataList, + std::vector& timestampsList, + std::vector>* subtreeList, + std::vector* subtreeNodeList); /* Internal API. May be used by the node callback in order to access rax nodes * in a low level way, so this function is exported as well. */ -void raxSetData(raxNode *n, void *data); -void *raxGetData(raxNode *n); -int raxFindNode(rax *rax, int *s, size_t len, void **node); -void raxFindLastRecentNode(raxNode *node, std::vector& key); -void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::set>& insert_tokens, int max_node); -void testIteRax(rax *tree); +void raxSetData(raxNode* n, void* data); +void* raxGetData(raxNode* n); +int raxFindNode(rax* rax, int* s, size_t len, void** node); +void raxFindLastRecentNode(raxNode* node, std::vector& key); +void mergeTree(rax* first_tree, rax* second_tree, + std::vector>& evicted_tokens, + std::set>& insert_tokens, int max_node); +void testIteRax(rax* tree); raxNode* raxGetFirstChildPtr(raxNode* node); // raxNode *raxSetSubTreeAndReturnDataNode(rax *rax, int *s, size_t len); #endif diff --git a/modules/kv-state-cache/radix-tree/rax_malloc.h b/modules/kv-state-cache/radix-tree/rax_malloc.h index e9d5d5d7..c62af522 100644 --- a/modules/kv-state-cache/radix-tree/rax_malloc.h +++ b/modules/kv-state-cache/radix-tree/rax_malloc.h @@ -35,9 +35,10 @@ * the include of your alternate allocator if needed (not needed in order * to use the default libc allocator). */ -#ifndef RAX_ALLOC_H -#define RAX_ALLOC_H +#ifndef MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#define MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ #define rax_malloc malloc #define rax_realloc realloc #define rax_free free -#endif + +#endif // MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ From 9f6a64ed7c12008114917a1a7ce1f5bbf78dc268 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Wed, 28 Feb 2024 13:56:15 +0800 Subject: [PATCH 14/20] Add a benchmark test for kv state cache (#1774) Fixes #1770 Signed-off-by: Ye Cao --- modules/kv-state-cache/ds/kv_state_cache.cc | 2 +- .../kv-state-cache/radix-tree/radix-tree.cc | 84 +++++---- modules/kv-state-cache/radix-tree/radix.cc | 13 +- .../utils/kv_state_cache_utils.cc | 1 - test/kv_state_cache_benchmark_test.cc | 160 ++++++++++++++++++ 5 files changed, 208 insertions(+), 52 deletions(-) create mode 100644 test/kv_state_cache_benchmark_test.cc diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/kv-state-cache/ds/kv_state_cache.cc index 98792c03..d6f20f66 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/kv-state-cache/ds/kv_state_cache.cc @@ -150,7 +150,7 @@ void KVStateCacheBuilder::Update(Client& client, std::shared_ptr nodeData = this->rootTree->Insert(tokenListCopy, evictedNodeData); if (nodeData == nullptr) { - LOG(INFO) << "insert failed"; + LOG(ERROR) << "insert failed"; return; } KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index 45b8f1e8..da549557 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -33,21 +33,21 @@ RadixTree::RadixTree(int cacheCapacity) { std::vector rootToken = {INT32_MAX}; std::shared_ptr evictedNode; this->InsertInternal(rootToken, evictedNode); - raxShow(this->tree); + // raxShow(this->tree); raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), rootToken.size(), NULL, false); DataWrapper* data = new DataWrapper(); data->data = nullptr; data->dataLength = 0; dataNode->custom_data = data; - LOG(INFO) << "root data wrapper:" << data; + VLOG(100) << "root data wrapper:" << data; dataNode->issubtree = true; this->rootToken = rootToken; } RadixTree::~RadixTree() { - LOG(INFO) << "~RadixTree"; - raxShow(this->tree); + VLOG(100) << "~RadixTree"; + // raxShow(this->tree); raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), rootToken.size(), NULL, false); @@ -105,14 +105,14 @@ std::shared_ptr RadixTree::InsertInternal( return NULL; } if (retval == 1) { - LOG(INFO) << "node count++:" << this->nodeCount; + VLOG(100) << "node count++:" << this->nodeCount; nodeCount++; } - raxShow(this->tree); + // raxShow(this->tree); if (this->nodeCount > this->cacheCapacity) { - LOG(INFO) << "cache capacity is full, evict the last recent node"; - LOG(INFO) << "cache capacity:" << this->cacheCapacity + VLOG(100) << "cache capacity is full, evict the last recent node"; + VLOG(100) << "cache capacity:" << this->cacheCapacity << " node count:" << this->nodeCount; // evict the last recent node (the node with the largest lru index- std::vector evictedTokensVector; @@ -127,13 +127,12 @@ std::shared_ptr RadixTree::InsertInternal( raxNode* subTreeNode = nullptr; dataNode = raxFindAndReturnDataNode( this->tree, insertTokensArray, insertTokensArrayLen, &subTreeNode, false); - LOG(INFO) << "sub tree node:" << subTreeNode << " data node:" << dataNode; + VLOG(100) << "sub tree node:" << subTreeNode << " data node:" << dataNode; /** * if the data node is null, it means the evicted node is the same node as * the inserted node. */ if (dataNode == NULL) { - LOG(INFO) << "get failed"; return NULL; } @@ -171,12 +170,12 @@ void RadixTree::DeleteInternal(std::vector tokens, evictedNode->cleanTreeData = true; } } else { - LOG(INFO) << "remove failed"; + LOG(ERROR) << "remove failed"; } } std::shared_ptr RadixTree::QueryInternal(std::vector key) { - LOG(INFO) << "Query"; + VLOG(100) << "Query"; int* tokens = key.data(); size_t tokensLen = key.size(); @@ -187,9 +186,8 @@ std::shared_ptr RadixTree::QueryInternal(std::vector key) { raxNode* subTreeNode; raxNode* dataNode = raxFindAndReturnDataNode(this->tree, tokens, tokensLen, &subTreeNode); - LOG(INFO) << "query subtree node:" << subTreeNode; + VLOG(100) << "query subtree node:" << subTreeNode; if (dataNode == NULL) { - LOG(INFO) << "get failed"; return NULL; } @@ -199,8 +197,8 @@ std::shared_ptr RadixTree::QueryInternal(std::vector key) { } std::string RadixTree::Serialize() { - LOG(INFO) << "Serialize......"; - raxShow(this->tree); + VLOG(100) << "Serialize......"; + // raxShow(this->tree); std::vector> tokenList; std::vector dataList; std::vector timestampList; @@ -209,7 +207,7 @@ std::string RadixTree::Serialize() { raxSerialize(this->tree, tokenList, dataList, timestampList, &subTreeTokenList, &subTreeDataList); - raxShow(this->tree); + // raxShow(this->tree); std::string serializedStr; if (tokenList.size() != dataList.size()) { @@ -257,7 +255,7 @@ std::string RadixTree::Serialize() { serializedStr += "\t\n"; - LOG(INFO) << "sub tree token list size:" << subTreeTokenList.size(); + VLOG(100) << "sub tree token list size:" << subTreeTokenList.size(); for (size_t index = 0; index < subTreeTokenList.size(); index++) { for (size_t j = 0; j < subTreeTokenList[index].size(); j++) { serializedStr += std::to_string(subTreeTokenList[index][j]); @@ -272,7 +270,7 @@ std::string RadixTree::Serialize() { (reinterpret_cast(subTreeDataList[index]))->data); std::ostringstream dataOSS; - LOG(INFO) + VLOG(100) << "data length:" << (reinterpret_cast(subTreeDataList[index]))->dataLength; for (int i = 0; @@ -282,12 +280,12 @@ std::string RadixTree::Serialize() { dataOSS << std::hex << std::setw(2) << std::setfill('0') << static_cast(static_cast(bytes[i])); } - LOG(INFO) << "data:" + VLOG(100) << "data:" << (reinterpret_cast(subTreeDataList[index]))->data; - LOG(INFO) << "data oss:" << dataOSS.str(); + VLOG(100) << "data oss:" << dataOSS.str(); serializedStr += dataOSS.str() + "\n"; } - LOG(INFO) << "serializedStr:" << serializedStr; + VLOG(100) << "serializedStr:" << serializedStr; // use ZSTD to compress the serialized string size_t srcSize = serializedStr.size(); @@ -313,7 +311,7 @@ std::string RadixTree::Serialize() { } std::shared_ptr RadixTree::Deserialize(std::string data) { - LOG(INFO) << "Deserialize......"; + VLOG(100) << "Deserialize......"; // use LZ4 to decompress the serialized string int compressedSize = *reinterpret_cast(data.data()); data.erase(0, sizeof(int)); @@ -356,6 +354,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { line.pop_back(); continue; } + VLOG(100) << "data line:" << line << std::endl; std::istringstream lineStream(line); std::string tokenListPart, timestampPart, dataPart, subTreeSizePart; @@ -396,6 +395,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { if (!(subTreeSizeStream >> std::hex >> subTreeSize)) { LOG(ERROR) << "Invalid sub tree size format."; } + VLOG(100) << "Deserialize sub tree size:" << subTreeSize; subTreeSizeList.push_back(subTreeSize); } @@ -410,7 +410,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { // is created by upper layer. Here just recover it from serialized // string. char* data = nullptr; - LOG(INFO) << "data size:" << dataSize; + VLOG(100) << "data size:" << dataSize; if (dataSize != 0) { data = new char[dataSize]; std::istringstream dataStream(dataPart); @@ -420,16 +420,14 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { // Read two characters for one byte if (!dataStream.read(hex, 2)) { delete[] data; - LOG(INFO) << "Invalid data format."; - throw std::runtime_error("Invalid data format."); + LOG(ERROR) << "Invalid data format."; } // Convert the two hex characters to one byte unsigned int byte; std::istringstream hexStream(hex); if (!(hexStream >> std::hex >> byte)) { delete[] data; - LOG(INFO) << "Invalid data format."; - throw std::runtime_error("Invalid data format."); + LOG(ERROR) << "Invalid data format."; } reinterpret_cast(data)[i] = static_cast(byte); @@ -452,13 +450,13 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { std::make_shared(cacheCapacity); radixTree->nodeCount = tokenList.size(); - raxShow(radixTree->tree); + // raxShow(radixTree->tree); for (size_t i = 0; i < tokenList.size(); i++) { std::string token_str = ""; for (size_t j = 0; j < tokenList[i].size(); j++) { token_str += std::to_string(tokenList[i][j]); } - LOG(INFO) << "token:" << token_str; + VLOG(100) << "token:" << token_str; int* insertTokensArray = tokenList[i].data(); size_t insertTokensArrayLen = tokenList[i].size(); DataWrapper* data = new DataWrapper(); @@ -481,40 +479,40 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { for (size_t i = 0; i < tokenList.size(); i++) { raxNode* node = raxFindAndReturnDataNode( radixTree->tree, tokenList[i].data(), tokenList[i].size(), NULL, false); - LOG(INFO) << "node:" << node << " sub tree node num:" << subTreeSizeList[i]; + VLOG(100) << "node:" << node << " sub tree node num:" << subTreeSizeList[i]; node->numnodes = subTreeSizeList[i]; } radixTree->tree->head->numnodes = rootNumNodes; - raxShow(radixTree->tree); + // raxShow(radixTree->tree); - LOG(INFO) << "start to insert sub tree token list" << std::endl; + VLOG(100) << "start to insert sub tree token list" << std::endl; for (size_t i = 0; i < subTreeTokenList.size(); i++) { for (size_t j = 0; j < subTreeTokenList[i].size(); j++) { - LOG(INFO) << subTreeTokenList[i][j]; + VLOG(100) << subTreeTokenList[i][j]; } raxNode* node = nullptr; - LOG(INFO) << "stage 1"; + VLOG(100) << "stage 1"; VINEYARD_ASSERT(radixTree->tree != nullptr); // TBD refator this code. node = raxFindAndReturnDataNode(radixTree->tree, subTreeTokenList[i].data(), subTreeTokenList[i].size(), NULL, false); VINEYARD_ASSERT(node != nullptr); - LOG(INFO) << "stage 2"; + VLOG(100) << "stage 2"; DataWrapper* data = new DataWrapper(); data->data = subTreeDataList[i]; - LOG(INFO) << subTreeDataList[i]; + VLOG(100) << subTreeDataList[i]; data->dataLength = subTreeDataSizeList[i]; - LOG(INFO) << "stage 3"; + VLOG(100) << "stage 3"; node->issubtree = true; raxSetCustomData(node, data); radixTree->SetSubtreeData(subTreeDataList[i]); } - LOG(INFO) << "Deserialize success"; - raxShow(radixTree->tree); + VLOG(100) << "Deserialize success"; + // raxShow(radixTree->tree); return radixTree; } @@ -524,7 +522,7 @@ std::vector> RadixTree::SplitInternal( raxNode* subTreeRootNode = raxSplit(this->tree, tokens.data(), tokens.size(), rootToken); - raxShow(this->tree); + // raxShow(this->tree); subTreeRootNode->issubtree = true; DataWrapper* treeData = new DataWrapper(); treeData->data = nullptr; @@ -540,14 +538,14 @@ std::vector> RadixTree::TraverseTreeWithoutSubTree( raxNode* headNode) { std::vector> nodes; if (headNode == NULL) { - LOG(INFO) << "traverse failed"; + VLOG(100) << "traverse failed"; return nodes; } std::vector dataNodeList; std::vector pre_tmp; raxTraverseSubTree(headNode, dataNodeList); - LOG(INFO) << "data node list:" << dataNodeList.size(); + VLOG(100) << "data node list:" << dataNodeList.size(); for (size_t i = 0; i < dataNodeList.size(); i++) { nodes.push_back(std::make_shared( reinterpret_cast(raxGetData(dataNodeList[i])), diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/kv-state-cache/radix-tree/radix.cc index 5d9172e0..975bbeeb 100644 --- a/modules/kv-state-cache/radix-tree/radix.cc +++ b/modules/kv-state-cache/radix-tree/radix.cc @@ -1686,7 +1686,6 @@ int raxIteratorNextStep(raxIterator *it, int noup) { if (!raxStackPush(&it->stack,it->node)) return 0; // if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && // it->subtree_data_list != NULL) { - // std::cout << "second find subtree list is:" << std::endl; // std::vector token; // for (size_t i = 0; i < it->key_len; i++) { // token.push_back(it->key[i]); @@ -2383,7 +2382,7 @@ raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { token_str += std::to_string(token[i]); token_str += " "; } - LOG(INFO) << "split token: " << token_str; + VLOG(100) << "split token: " << token_str; // if the splitNode is NULL, it means that the tree only has one node if (splitNode == NULL) { @@ -2528,10 +2527,10 @@ void mergeTree(rax* first_tree, rax* second_tree, std::vector>& evicted_tokens, std::set>& insert_tokens, int max_node) { printf("merge tree!\n"); - LOG(INFO) << "==============tree 1===================="; - raxShow(first_tree); - LOG(INFO) << "==============tree 2===================="; - raxShow(second_tree); + VLOG(100) << "==============tree 1===================="; + //raxShow(first_tree); + VLOG(100) << "==============tree 2===================="; + //raxShow(second_tree); raxIterator first_tree_iter; raxIterator second_tree_iter; rax* tmp = raxNew(); @@ -2787,7 +2786,7 @@ void mergeTree(rax* first_tree, rax* second_tree, } } else if (second_tree_index >= second_tree_iter_list.size()) { // second tree is empty - raxShow(tmp); + //raxShow(tmp); printf("nodeCount:%d\n", nodeCount); printf("first_tree_index:%ld\n", first_tree_index); while (first_tree_index < first_tree_iter_list.size() && diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/kv-state-cache/utils/kv_state_cache_utils.cc index 34a47ef7..a0d4c861 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/kv-state-cache/utils/kv_state_cache_utils.cc @@ -234,7 +234,6 @@ void sync() { // 8. release the lock while (1) { - LOG(INFO) << "stage 7"; client.TryReleaseLock(actualKey, result); if (result == true) { break; diff --git a/test/kv_state_cache_benchmark_test.cc b/test/kv_state_cache_benchmark_test.cc new file mode 100644 index 00000000..6bb5aea2 --- /dev/null +++ b/test/kv_state_cache_benchmark_test.cc @@ -0,0 +1,160 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include "arrow/api.h" +#include "arrow/io/api.h" + +#include "basic/stream/byte_stream.h" +#include "basic/stream/dataframe_stream.h" +#include "basic/stream/recordbatch_stream.h" +#include "client/client.h" +#include "client/ds/object_meta.h" +#include "common/util/logging.h" + +#include "kv-state-cache/utils/kv_state_cache_utils.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +#define DIMENSION 100 +#define CAPACITY 1000 +#define LAYER 64 +#define BLOCK_SIZE 100 + +void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER, BLOCK_SIZE); } + +std::vector generate_random_tokens(size_t max_length) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist(1, 10000); + + size_t length = dist(gen) % max_length + 1; + std::vector tokens(length); + for (size_t i = 0; i < length; ++i) { + tokens[i] = dist(gen); + } + return tokens; +} + +std::map, std::vector>> +generate_kv_state(int token) { + std::map, std::vector>> kv_state; + for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { + std::vector key_state; + std::vector value_state; + for (int i = 0; i < DIMENSION; ++i) { + key_state.push_back((static_cast(token)) / DIMENSION * (i + 1) + + currentLayer * 10); + value_state.push_back((static_cast(token)) / DIMENSION * (i + 1) * + 2 + + currentLayer * 10); + } + + kv_state.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } + return kv_state; +} + +// test the performance of Query and Update function +void benchmark_inference(std::vector>& tokens) { + LOG(INFO) << "inference for benchmark"; + std::map, std::vector>> kv_state; + + std::chrono::steady_clock::time_point start, end; + double token_list_size = 0; + std::chrono::duration update_duration(0); + std::chrono::duration query_duration(0); + double total_update_duration = 0; + double total_query_duration = 0; + + for (size_t i = 0; i < tokens.size(); ++i) { + std::vector inference_tokens; + for (size_t j = 0; j < tokens[i].size(); ++j) { + start = std::chrono::steady_clock::now(); + kv_state = Query(inference_tokens, tokens[i][j]); + end = std::chrono::steady_clock::now(); + query_duration += end - start; + + if (kv_state.size() == 0) { + kv_state = generate_kv_state(tokens[i][j]); + + start = std::chrono::steady_clock::now(); + Update(inference_tokens, tokens[i][j], kv_state); + end = std::chrono::steady_clock::now(); + update_duration += end - start; + } + inference_tokens.push_back(tokens[i][j]); + token_list_size++; + } + total_update_duration += update_duration.count(); + total_query_duration += query_duration.count(); + } + + LOG(INFO) << "Token list size is " << token_list_size + << "Total Update time is " << total_update_duration << "s " + << "Total Query time is " << total_query_duration << "s " + << "Average update time is " + << token_list_size / total_update_duration << "token/s " + << "Average query time is " + << token_list_size / total_query_duration << "token/s "; +} + +int main(int argc, char** argv) { + if (argc < 2) { + printf("usage ./kv_state_cache_benchmark "); + return 1; + } + std::string ipc_socket = std::string(argv[1]); + + init(); + + std::atomic inference_done(false); + + std::thread memory_monitor([&]() { + Client client; + size_t max_memory_usage = 0; + VINEYARD_CHECK_OK(client.Connect(ipc_socket)); + while (!inference_done) { + sleep(1); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + LOG(INFO) << "status->memory_usage is:" << status->memory_usage; + if (status->memory_usage > max_memory_usage) { + max_memory_usage = status->memory_usage; + } + } + LOG(INFO) << "Max memory usage is " << max_memory_usage; + }); + + std::thread inference([&]() { + const size_t num_lists = 10; + std::vector> all_token_lists; + for (size_t i = 0; i < num_lists; ++i) { + all_token_lists.push_back(generate_random_tokens(2000)); + } + + benchmark_inference(all_token_lists); + inference_done = true; + }); + + memory_monitor.join(); + inference.join(); + return 0; +} From 7bdffc67812084f35947de59139e920b92ba0a39 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Wed, 28 Feb 2024 15:20:55 +0800 Subject: [PATCH 15/20] Add the rax linsence and fix a warning (#1778) Signed-off-by: Ye Cao --- LICENSE | 15 +++++++++++++++ NOTICE.txt | 4 ++++ README.rst | 1 + modules/kv-state-cache/radix-tree/radix-tree.cc | 2 +- 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index ab3ae45d..1836516f 100644 --- a/LICENSE +++ b/LICENSE @@ -1181,3 +1181,18 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +The files modules/kv-state-cache/radix-tree/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, +which has the following license: + +Copyright (c) 2017, Salvatore Sanfilippo +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/NOTICE.txt b/NOTICE.txt index f8186a78..c326975c 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -55,3 +55,7 @@ This product includes software from the ClickHouse project This product includes software from the BBHash project * Copyright (c) 2015 Guillaume Rizk * https://github.com/rizkg/BBHash + +This product includes software from the rax project (BSD, 2-clause) + * Copyright (c) 2017-2019, Salvatore Sanfilippo + * https://github.com/antirez/rax diff --git a/README.rst b/README.rst index 956ace05..a0297818 100644 --- a/README.rst +++ b/README.rst @@ -298,6 +298,7 @@ We thank the following excellent open-source projects: - `skywalking-swck `_ A kubernetes operator for the Apache Skywalking. - `wyhash `_, C++ wrapper around wyhash and wyrand. - `BBHash `_, a fast, minimal-memory perfect hash function. +- `rax `_, an ANSI C radix tree implementation. License ------- diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/kv-state-cache/radix-tree/radix-tree.cc index da549557..a8d0a6dd 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/kv-state-cache/radix-tree/radix-tree.cc @@ -319,7 +319,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { data.erase(0, sizeof(int)); int rootNumNodes = *reinterpret_cast(data.data()); data.erase(0, sizeof(uint32_t)); - int ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); + unsigned long long ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); if (ds == ZSTD_CONTENTSIZE_ERROR) { LOG(ERROR) << "Error: not a valid compressed frame"; } else if (ds == ZSTD_CONTENTSIZE_UNKNOWN) { From 5c12567699cc9d76bf6c0ac93298746d700c2c2a Mon Sep 17 00:00:00 2001 From: vegetableysm Date: Wed, 28 Feb 2024 16:14:17 +0800 Subject: [PATCH 16/20] Clean code. Signed-off-by: vegetableysm --- CMakeLists.txt | 10 +++---- LICENSE | 2 +- modules/kv-state-cache/CMakeLists.txt | 18 ------------- modules/llm-cache/CMakeLists.txt | 16 +++++++++++ .../{kv-state-cache => llm-cache}/README.rst | 0 .../ds/kv_state_cache.cc | 17 ++++-------- .../ds/kv_state_cache.h | 10 +++---- .../ds/kv_state_cache_block.cc | 2 +- .../ds/kv_state_cache_block.h | 27 +++++++++---------- .../radix-tree/radix-tree.cc | 4 +-- .../radix-tree/radix-tree.h | 8 +++--- .../radix-tree/radix.cc | 0 .../radix-tree/radix.h | 0 .../radix-tree/rax_malloc.h | 6 ++--- .../utils/kv_state_cache_utils.cc | 4 +-- .../utils/kv_state_cache_utils.h | 8 +++--- src/server/services/etcd_meta_service.cc | 7 +---- src/server/services/local_meta_service.h | 12 ++++----- src/server/services/redis_meta_service.h | 12 ++++----- test/kv_state_cache_benchmark_test.cc | 2 +- test/kv_state_cache_radix_tree_test.cc | 4 +-- test/kv_state_cache_test.cc | 4 +-- 22 files changed, 79 insertions(+), 94 deletions(-) delete mode 100644 modules/kv-state-cache/CMakeLists.txt create mode 100644 modules/llm-cache/CMakeLists.txt rename modules/{kv-state-cache => llm-cache}/README.rst (100%) rename modules/{kv-state-cache => llm-cache}/ds/kv_state_cache.cc (97%) rename modules/{kv-state-cache => llm-cache}/ds/kv_state_cache.h (92%) rename modules/{kv-state-cache => llm-cache}/ds/kv_state_cache_block.cc (99%) rename modules/{kv-state-cache => llm-cache}/ds/kv_state_cache_block.h (90%) rename modules/{kv-state-cache => llm-cache}/radix-tree/radix-tree.cc (99%) rename modules/{kv-state-cache => llm-cache}/radix-tree/radix-tree.h (93%) rename modules/{kv-state-cache => llm-cache}/radix-tree/radix.cc (100%) rename modules/{kv-state-cache => llm-cache}/radix-tree/radix.h (100%) rename modules/{kv-state-cache => llm-cache}/radix-tree/rax_malloc.h (92%) rename modules/{kv-state-cache => llm-cache}/utils/kv_state_cache_utils.cc (98%) rename modules/{kv-state-cache => llm-cache}/utils/kv_state_cache_utils.h (84%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10998b31..2dc42d64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,7 +62,7 @@ option(BUILD_VINEYARD_MALLOC_OVERRIDE "Using the client-side malloc to override option(BUILD_VINEYARD_FUSE "Enable vineyard's fuse support" OFF) option(BUILD_VINEYARD_FUSE_PARQUET "Enable vineyard's fuse parquet support" OFF) option(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME "Enable hosseinmoein dataframe support" OFF) -option(BUILD_VINEYARD_KV_STATE_CACHE "Enable kv-state cache support" ON) +option(BUILD_VINEYARD_LLM_CACHE "Enable kv-state cache support" ON) option(BUILD_VINEYARD_TESTS "Generate make targets for vineyard tests" ON) option(BUILD_VINEYARD_TESTS_ALL "Include make targets for vineyard tests to ALL" OFF) @@ -933,9 +933,9 @@ if(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME) list(APPEND VINEYARD_INSTALL_LIBS vineyard_hosseinmoein_dataframe) endif() -if(BUILD_VINEYARD_KV_STATE_CACHE) - add_subdirectory(modules/kv-state-cache) - list(APPEND VINEYARD_INSTALL_LIBS vineyard_kv_state_cache) +if(BUILD_VINEYARD_LLM_CACHE) + add_subdirectory(modules/llm-cache) + list(APPEND VINEYARD_INSTALL_LIBS vineyard_llm_cache) endif() if(BUILD_VINEYARD_TESTS) @@ -981,7 +981,7 @@ endfunction() file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark" PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$" - EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(.*modules/kv-state-cache/radix-tree/radix\.(cc|h)$)" + EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(.*modules/llm-cache/radix-tree/radix\.(cc|h)$)" ) # the `memcpy.h` is borrowed from external project diff --git a/LICENSE b/LICENSE index 1836516f..3ff17ec4 100644 --- a/LICENSE +++ b/LICENSE @@ -1184,7 +1184,7 @@ SOFTWARE. ------------------------------------------------------------------------------- -The files modules/kv-state-cache/radix-tree/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, +The files modules/llm-cache/radix-tree/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, which has the following license: Copyright (c) 2017, Salvatore Sanfilippo diff --git a/modules/kv-state-cache/CMakeLists.txt b/modules/kv-state-cache/CMakeLists.txt deleted file mode 100644 index 82de5f67..00000000 --- a/modules/kv-state-cache/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -file(GLOB VINEYARD_KV_STATE_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" - "ds/*.cc" - "ds/*.h" - "radix-tree/*.cc" - "radix-tree/*.h" - "utils/*.cc" - "utils/*.h" - "strategy/*.cc" - "strategy/*.h" -) - -add_library(vineyard_kv_state_cache ${VINEYARD_KV_STATE_CACHE_SRCS}) -target_link_libraries(vineyard_kv_state_cache PUBLIC vineyard_client vineyard_basic) - -install_export_vineyard_target(vineyard_kv_state_cache) -install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") -install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") -install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt new file mode 100644 index 00000000..7c428565 --- /dev/null +++ b/modules/llm-cache/CMakeLists.txt @@ -0,0 +1,16 @@ +file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" + "ds/*.cc" + "ds/*.h" + "radix-tree/*.cc" + "radix-tree/*.h" + "utils/*.cc" + "utils/*.h" +) + +add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS}) +target_link_libraries(vineyard_llm_cache PUBLIC vineyard_client vineyard_basic) + +install_export_vineyard_target(vineyard_llm_cache) +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") diff --git a/modules/kv-state-cache/README.rst b/modules/llm-cache/README.rst similarity index 100% rename from modules/kv-state-cache/README.rst rename to modules/llm-cache/README.rst diff --git a/modules/kv-state-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc similarity index 97% rename from modules/kv-state-cache/ds/kv_state_cache.cc rename to modules/llm-cache/ds/kv_state_cache.cc index d6f20f66..c040dc5e 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.cc +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -22,9 +22,9 @@ limitations under the License. #include "common/util/base64.h" #include "common/util/logging.h" #include "common/util/status.h" -#include "kv-state-cache/ds/kv_state_cache.h" -#include "kv-state-cache/radix-tree/radix-tree.h" -#include "kv-state-cache/radix-tree/radix.h" +#include "llm-cache/ds/kv_state_cache.h" +#include "llm-cache/radix-tree/radix-tree.h" +#include "llm-cache/radix-tree/radix.h" namespace vineyard { @@ -63,9 +63,7 @@ void KVStateCache::Resolve() { << " layer:" << this->layer; } -KVStateCache::~KVStateCache() { - // TBD -} +KVStateCache::~KVStateCache() {} KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity, int layer, @@ -200,8 +198,6 @@ void KVStateCacheBuilder::Update(Client& client, << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); } -static std::shared_ptr node; - KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( Client& client, const std::vector& tokenList, int token) { std::vector tokenListCopy = tokenList; @@ -285,10 +281,7 @@ void KVStateCacheBuilder::Merge(Client& client, return; } -Status KVStateCacheBuilder::Build(Client& client) { - // TBD - return Status::OK(); -} +Status KVStateCacheBuilder::Build(Client& client) { return Status::OK(); } std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { this->Build(client); diff --git a/modules/kv-state-cache/ds/kv_state_cache.h b/modules/llm-cache/ds/kv_state_cache.h similarity index 92% rename from modules/kv-state-cache/ds/kv_state_cache.h rename to modules/llm-cache/ds/kv_state_cache.h index 7cb80f20..075a1f61 100644 --- a/modules/kv-state-cache/ds/kv_state_cache.h +++ b/modules/llm-cache/ds/kv_state_cache.h @@ -20,11 +20,11 @@ limitations under the License. #include "client/client.h" #include "common/util/logging.h" #include "common/util/status.h" -#include "kv-state-cache/ds/kv_state_cache_block.h" -#include "kv-state-cache/radix-tree/radix-tree.h" +#include "llm-cache/ds/kv_state_cache_block.h" +#include "llm-cache/radix-tree/radix-tree.h" -#ifndef MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ -#define MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ namespace vineyard { @@ -120,4 +120,4 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { } // namespace vineyard -#endif // MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_H_ +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.cc b/modules/llm-cache/ds/kv_state_cache_block.cc similarity index 99% rename from modules/kv-state-cache/ds/kv_state_cache_block.cc rename to modules/llm-cache/ds/kv_state_cache_block.cc index 1441d196..0b251df9 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.cc +++ b/modules/llm-cache/ds/kv_state_cache_block.cc @@ -19,7 +19,7 @@ limitations under the License. #include "client/client.h" #include "common/util/logging.h" -#include "kv-state-cache/ds/kv_state_cache_block.h" +#include "llm-cache/ds/kv_state_cache_block.h" namespace vineyard { diff --git a/modules/kv-state-cache/ds/kv_state_cache_block.h b/modules/llm-cache/ds/kv_state_cache_block.h similarity index 90% rename from modules/kv-state-cache/ds/kv_state_cache_block.h rename to modules/llm-cache/ds/kv_state_cache_block.h index 4870201f..81747d73 100644 --- a/modules/kv-state-cache/ds/kv_state_cache_block.h +++ b/modules/llm-cache/ds/kv_state_cache_block.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ -#define MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ #include #include @@ -27,17 +27,16 @@ limitations under the License. #include "basic/ds/tensor.h" #include "client/ds/blob.h" #include "client/ds/i_object.h" -#include "kv-state-cache/radix-tree/radix-tree.h" - -typedef std::map, std::vector>> - KV_STATE_WITH_LAYER; -typedef std::vector< - std::map, std::vector>>> - LIST_KV_STATE_WITH_LAYER; -typedef std::vector, std::vector>> - KV_STATE; -typedef std::vector, std::vector>> - LIST_KV_STATE; +#include "llm-cache/radix-tree/radix-tree.h" + +using KV_STATE_WITH_LAYER = + std::map, std::vector>>; +using LIST_KV_STATE_WITH_LAYER = std::vector< + std::map, std::vector>>>; +using KV_STATE = + std::vector, std::vector>>; +using LIST_KV_STATE = + std::vector, std::vector>>; // Set the bit to 1, which means the resource is not being used #define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit))) @@ -203,4 +202,4 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { } // namespace vineyard -#endif // MODULES_KV_STATE_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ diff --git a/modules/kv-state-cache/radix-tree/radix-tree.cc b/modules/llm-cache/radix-tree/radix-tree.cc similarity index 99% rename from modules/kv-state-cache/radix-tree/radix-tree.cc rename to modules/llm-cache/radix-tree/radix-tree.cc index a8d0a6dd..6fb1e2d7 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.cc +++ b/modules/llm-cache/radix-tree/radix-tree.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "kv-state-cache/radix-tree/radix-tree.h" +#include "llm-cache/radix-tree/radix-tree.h" #include "common/util/base64.h" #include "common/util/logging.h" @@ -319,7 +319,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { data.erase(0, sizeof(int)); int rootNumNodes = *reinterpret_cast(data.data()); data.erase(0, sizeof(uint32_t)); - unsigned long long ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); + uint64_t ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); if (ds == ZSTD_CONTENTSIZE_ERROR) { LOG(ERROR) << "Error: not a valid compressed frame"; } else if (ds == ZSTD_CONTENTSIZE_UNKNOWN) { diff --git a/modules/kv-state-cache/radix-tree/radix-tree.h b/modules/llm-cache/radix-tree/radix-tree.h similarity index 93% rename from modules/kv-state-cache/radix-tree/radix-tree.h rename to modules/llm-cache/radix-tree/radix-tree.h index 23774ed0..28958234 100644 --- a/modules/kv-state-cache/radix-tree/radix-tree.h +++ b/modules/llm-cache/radix-tree/radix-tree.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_ -#define MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_ +#ifndef MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ +#define MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ -#include "kv-state-cache/radix-tree/radix.h" +#include "llm-cache/radix-tree/radix.h" #include #include @@ -118,4 +118,4 @@ class RadixTree : public std::enable_shared_from_this { std::set GetAllNodeData(); }; -#endif // MODULES_KV_STATE_CACHE_RADIX_TREE_RADIX_TREE_H_" +#endif // MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ diff --git a/modules/kv-state-cache/radix-tree/radix.cc b/modules/llm-cache/radix-tree/radix.cc similarity index 100% rename from modules/kv-state-cache/radix-tree/radix.cc rename to modules/llm-cache/radix-tree/radix.cc diff --git a/modules/kv-state-cache/radix-tree/radix.h b/modules/llm-cache/radix-tree/radix.h similarity index 100% rename from modules/kv-state-cache/radix-tree/radix.h rename to modules/llm-cache/radix-tree/radix.h diff --git a/modules/kv-state-cache/radix-tree/rax_malloc.h b/modules/llm-cache/radix-tree/rax_malloc.h similarity index 92% rename from modules/kv-state-cache/radix-tree/rax_malloc.h rename to modules/llm-cache/radix-tree/rax_malloc.h index c62af522..fdd2430b 100644 --- a/modules/kv-state-cache/radix-tree/rax_malloc.h +++ b/modules/llm-cache/radix-tree/rax_malloc.h @@ -35,10 +35,10 @@ * the include of your alternate allocator if needed (not needed in order * to use the default libc allocator). */ -#ifndef MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ -#define MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#ifndef MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#define MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_ #define rax_malloc malloc #define rax_realloc realloc #define rax_free free -#endif // MODULES_KV_STATE_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#endif // MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_ diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.cc b/modules/llm-cache/utils/kv_state_cache_utils.cc similarity index 98% rename from modules/kv-state-cache/utils/kv_state_cache_utils.cc rename to modules/llm-cache/utils/kv_state_cache_utils.cc index a0d4c861..90713b3a 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.cc +++ b/modules/llm-cache/utils/kv_state_cache_utils.cc @@ -20,8 +20,8 @@ limitations under the License. #include "client/client.h" #include "common/util/logging.h" -#include "kv-state-cache/ds/kv_state_cache.h" -#include "kv-state-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/ds/kv_state_cache.h" +#include "llm-cache/utils/kv_state_cache_utils.h" namespace vineyard { diff --git a/modules/kv-state-cache/utils/kv_state_cache_utils.h b/modules/llm-cache/utils/kv_state_cache_utils.h similarity index 84% rename from modules/kv-state-cache/utils/kv_state_cache_utils.h rename to modules/llm-cache/utils/kv_state_cache_utils.h index a40a12d8..21b9d873 100644 --- a/modules/kv-state-cache/utils/kv_state_cache_utils.h +++ b/modules/llm-cache/utils/kv_state_cache_utils.h @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "kv-state-cache/ds/kv_state_cache.h" +#include "llm-cache/ds/kv_state_cache.h" -#ifndef MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ -#define MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ +#ifndef MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ +#define MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ namespace vineyard { @@ -44,4 +44,4 @@ void CloseKVStateCache(); } // namespace vineyard -#endif // MODULES_KV_STATE_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ +#endif // MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ diff --git a/src/server/services/etcd_meta_service.cc b/src/server/services/etcd_meta_service.cc index c2a662ca..c45b5f79 100644 --- a/src/server/services/etcd_meta_service.cc +++ b/src/server/services/etcd_meta_service.cc @@ -153,7 +153,6 @@ void EtcdMetaService::Stop() { void EtcdMetaService::TryAcquireLock( std::string key, callback_t callback_after_try_lock) { - LOG(INFO) << "TryAcquireLock, key:" << key; auto self(shared_from_base()); etcd_->lock(prefix_ + key) @@ -161,12 +160,10 @@ void EtcdMetaService::TryAcquireLock( pplx::task const& resp_task) { auto const& resp = resp_task.get(); if (resp.is_ok()) { - LOG(INFO) << "lock success! key is :" + resp.lock_key(); self->server_ptr_->GetMetaContext().post( boost::bind(callback_after_try_lock, Status::OK(), true, resp.lock_key().substr(self->prefix_.size()))); } else { - LOG(INFO) << "lock failed!"; self->server_ptr_->GetMetaContext().post( boost::bind(callback_after_try_lock, Status::OK(), false, "")); } @@ -182,13 +179,11 @@ void EtcdMetaService::TryReleaseLock( pplx::task const& resp_task) { auto const& resp = resp_task.get(); if (resp.is_ok()) { - LOG(INFO) << "unlock success!"; self->server_ptr_->GetMetaContext().post( boost::bind(callback_after_try_unlock, Status::OK(), true)); } else { - LOG(INFO) << "unlock failed!"; self->server_ptr_->GetMetaContext().post( - boost::bind(callback_after_try_unlock, Status::OK(), true)); + boost::bind(callback_after_try_unlock, Status::OK(), false)); } }); } diff --git a/src/server/services/local_meta_service.h b/src/server/services/local_meta_service.h index 50136a64..6491741d 100644 --- a/src/server/services/local_meta_service.h +++ b/src/server/services/local_meta_service.h @@ -51,15 +51,15 @@ class LocalMetaService : public IMetaService { ~LocalMetaService() override {} void TryAcquireLock(std::string key, - callback_t callback_after_try_locked) { - // TBD - assert(false); + callback_t callback_after_try_lock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_lock, Status::NotImplemented(), false, "")); } void TryReleaseLock(std::string key, - callback_t callback_after_try_unlocked) { - // TBD - assert(false); + callback_t callback_after_try_unlock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_unlock, Status::NotImplemented(), false)); } protected: diff --git a/src/server/services/redis_meta_service.h b/src/server/services/redis_meta_service.h index d1fa6749..91adc5e6 100644 --- a/src/server/services/redis_meta_service.h +++ b/src/server/services/redis_meta_service.h @@ -178,15 +178,15 @@ class RedisMetaService : public IMetaService { ~RedisMetaService() override {} void TryAcquireLock(std::string key, - callback_t callback_after_try_locked) { - // TBD - assert(false); + callback_t callback_after_try_lock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_lock, Status::NotImplemented, false, "")); } void TryReleaseLock(std::string key, - callback_t callback_after_try_unlocked) { - // TBD - assert(false); + callback_t callback_after_try_unlock) { + server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::NotImplemented, false)); } protected: diff --git a/test/kv_state_cache_benchmark_test.cc b/test/kv_state_cache_benchmark_test.cc index 6bb5aea2..6f574054 100644 --- a/test/kv_state_cache_benchmark_test.cc +++ b/test/kv_state_cache_benchmark_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "client/ds/object_meta.h" #include "common/util/logging.h" -#include "kv-state-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/utils/kv_state_cache_utils.h" using namespace vineyard; // NOLINT(build/namespaces) diff --git a/test/kv_state_cache_radix_tree_test.cc b/test/kv_state_cache_radix_tree_test.cc index 76fd7f30..10f1de44 100644 --- a/test/kv_state_cache_radix_tree_test.cc +++ b/test/kv_state_cache_radix_tree_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include #include -#include "kv-state-cache/radix-tree/radix.h" +#include "llm-cache/radix-tree/radix.h" #include "common/util/logging.h" -#include "kv-state-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/utils/kv_state_cache_utils.h" using namespace vineyard; // NOLINT(build/namespaces) diff --git a/test/kv_state_cache_test.cc b/test/kv_state_cache_test.cc index 58c9c721..abc4b83d 100644 --- a/test/kv_state_cache_test.cc +++ b/test/kv_state_cache_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include #include -#include "kv-state-cache/radix-tree/radix.h" +#include "llm-cache/radix-tree/radix.h" #include "common/util/logging.h" -#include "kv-state-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/utils/kv_state_cache_utils.h" using namespace vineyard; // NOLINT(build/namespaces) From 24c241e5f04b5fb704a8e51780d9143e52359388 Mon Sep 17 00:00:00 2001 From: vegetableysm Date: Wed, 28 Feb 2024 16:46:55 +0800 Subject: [PATCH 17/20] Move llm tests to llm-cache directory Signed-off-by: vegetableysm --- modules/llm-cache/CMakeLists.txt | 2 ++ modules/llm-cache/tests/CMakeLists.txt | 26 +++++++++++++++++++ .../tests}/kv_state_cache_benchmark_test.cc | 0 .../tests}/kv_state_cache_multi_test.cc | 0 .../tests}/kv_state_cache_radix_tree_test.cc | 0 .../llm-cache/tests}/kv_state_cache_test.cc | 0 6 files changed, 28 insertions(+) create mode 100644 modules/llm-cache/tests/CMakeLists.txt rename {test => modules/llm-cache/tests}/kv_state_cache_benchmark_test.cc (100%) rename {test => modules/llm-cache/tests}/kv_state_cache_multi_test.cc (100%) rename {test => modules/llm-cache/tests}/kv_state_cache_radix_tree_test.cc (100%) rename {test => modules/llm-cache/tests}/kv_state_cache_test.cc (100%) diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt index 7c428565..911865fe 100644 --- a/modules/llm-cache/CMakeLists.txt +++ b/modules/llm-cache/CMakeLists.txt @@ -14,3 +14,5 @@ install_export_vineyard_target(vineyard_llm_cache) install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") + +add_subdirectory(tests) diff --git a/modules/llm-cache/tests/CMakeLists.txt b/modules/llm-cache/tests/CMakeLists.txt new file mode 100644 index 00000000..8896f24c --- /dev/null +++ b/modules/llm-cache/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +enable_testing() + +macro(add_test_case testname testfile) + add_executable(${testname} ${testfile}) + + target_include_directories(${testname} PRIVATE ${GLOG_INCLUDE_DIRS}) + target_link_libraries(${testname} PRIVATE vineyard_llm_cache) + + add_test(${testname} ${testname}) + add_dependencies(vineyard_tests ${testname}) +endmacro() + +file(GLOB LLM_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "./*.cc") + +foreach(testfile ${LLM_TEST_FILES}) + string(REGEX MATCH "^(.*)\\.[^.]*$" dummy ${testfile}) + set(testname ${CMAKE_MATCH_1}) + + if(${testname} STREQUAL "gpumalloc_test" AND NOT USE_CUDA) + continue() + endif() + + message(STATUS "Found unit_test - " ${testname}) + add_test_case(${testname} ${testfile}) +endforeach() + diff --git a/test/kv_state_cache_benchmark_test.cc b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc similarity index 100% rename from test/kv_state_cache_benchmark_test.cc rename to modules/llm-cache/tests/kv_state_cache_benchmark_test.cc diff --git a/test/kv_state_cache_multi_test.cc b/modules/llm-cache/tests/kv_state_cache_multi_test.cc similarity index 100% rename from test/kv_state_cache_multi_test.cc rename to modules/llm-cache/tests/kv_state_cache_multi_test.cc diff --git a/test/kv_state_cache_radix_tree_test.cc b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc similarity index 100% rename from test/kv_state_cache_radix_tree_test.cc rename to modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc diff --git a/test/kv_state_cache_test.cc b/modules/llm-cache/tests/kv_state_cache_test.cc similarity index 100% rename from test/kv_state_cache_test.cc rename to modules/llm-cache/tests/kv_state_cache_test.cc From 61a77c78b2f69ff0b3628d2d5f533258c308ba13 Mon Sep 17 00:00:00 2001 From: vegetableysm Date: Thu, 29 Feb 2024 12:00:42 +0800 Subject: [PATCH 18/20] Use nano second as timestamp Signed-off-by: vegetableysm --- modules/llm-cache/radix-tree/radix-tree.cc | 2 +- modules/llm-cache/radix-tree/radix.cc | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/modules/llm-cache/radix-tree/radix-tree.cc b/modules/llm-cache/radix-tree/radix-tree.cc index 6fb1e2d7..62e213e4 100644 --- a/modules/llm-cache/radix-tree/radix-tree.cc +++ b/modules/llm-cache/radix-tree/radix-tree.cc @@ -161,7 +161,7 @@ void RadixTree::DeleteInternal(std::vector tokens, nodeIsSubTree = true; } int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, - reinterpret_cast(&oldData)); + reinterpret_cast(&oldData), false); if (retval == 1) { evictedNode = std::make_shared( oldData, reinterpret_cast(subTreeNode->custom_data)); diff --git a/modules/llm-cache/radix-tree/radix.cc b/modules/llm-cache/radix-tree/radix.cc index 975bbeeb..6c5658c2 100644 --- a/modules/llm-cache/radix-tree/radix.cc +++ b/modules/llm-cache/radix-tree/radix.cc @@ -512,8 +512,9 @@ static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **st // int64_t timestamp = ms.count(); auto now = std::chrono::high_resolution_clock::now(); - auto micros = std::chrono::time_point_cast(now).time_since_epoch().count(); - int64_t timestamp = micros; + // auto micros = std::chrono::time_point_cast(now).time_since_epoch().count(); + auto nanos = std::chrono::duration_cast(now.time_since_epoch()).count(); + int64_t timestamp = nanos; while(h->size && i < len) { debugnode("Lookup current node",h); @@ -2186,7 +2187,7 @@ void raxRecursiveShow(int level, int lpad, raxNode *n) { if (n->iskey) { numchars += printf("=%p",raxGetData(n)); } - numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d", n, n->timestamp, n->custom_data, n->issubtree); + numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d is_compr:%d", n, n->timestamp, n->custom_data, n->issubtree, n->iscompr); if (n->issubtree && n->custom_data != NULL) { numchars += printf(" cus data:%p" , ((DebugDatawrapper *)(n->custom_data))->data); DebugTreeData *data = (DebugTreeData *)((DebugDatawrapper *)(n->custom_data))->data; @@ -2463,13 +2464,18 @@ void raxFindLastRecentNode(raxNode *node, std::vector& key) { raxNode *chossenChild = childList[0]; int choosenChildIndex = 0; for (int i = 1; i < numChildren; i++) { - if (childList[i]->timestamp != 0 && childList[i]->timestamp <= chossenChild->timestamp) { - if (childList[i]->timestamp == chossenChild->timestamp && childList[i]->numnodes > chossenChild->numnodes) { + if (childList[i]->timestamp == chossenChild->timestamp) { + if (childList[i]->numnodes > chossenChild->numnodes) { + LOG(INFO) << "childList[i]->numnodes > chossenChild->numnodes"; + LOG(INFO) << "node1:" << childList[i] << " node:2" << chossenChild; chossenChild = childList[i]; choosenChildIndex = i; } // chossenChild = childList[i]; // choosenChildIndex = i; + } else if (childList[i]->timestamp < chossenChild->timestamp) { + chossenChild = childList[i]; + choosenChildIndex = i; } } From 2acff417d2d32a3a0955494a2243504fdc73533b Mon Sep 17 00:00:00 2001 From: vegetableysm Date: Thu, 29 Feb 2024 19:45:37 +0800 Subject: [PATCH 19/20] Refactor kv state cache utils api. Fix bug of persist retry. Signed-off-by: vegetableysm --- modules/llm-cache/CMakeLists.txt | 3 - modules/llm-cache/ds/kv_state_cache.cc | 4 +- .../llm-cache/ds/kv_state_cache_manager.cc | 265 ++++++++++++++++++ modules/llm-cache/ds/kv_state_cache_manager.h | 76 +++++ modules/llm-cache/radix-tree/radix-tree.cc | 4 +- .../tests/kv_state_cache_benchmark_test.cc | 14 +- .../tests/kv_state_cache_radix_tree_test.cc | 2 +- .../llm-cache/tests/kv_state_cache_test.cc | 13 +- .../llm-cache/utils/kv_state_cache_utils.cc | 260 ----------------- .../llm-cache/utils/kv_state_cache_utils.h | 47 ---- src/server/async/socket_server.cc | 37 +-- src/server/server/vineyard_server.cc | 2 - 12 files changed, 383 insertions(+), 344 deletions(-) create mode 100644 modules/llm-cache/ds/kv_state_cache_manager.cc create mode 100644 modules/llm-cache/ds/kv_state_cache_manager.h delete mode 100644 modules/llm-cache/utils/kv_state_cache_utils.cc delete mode 100644 modules/llm-cache/utils/kv_state_cache_utils.h diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt index 911865fe..afbd347a 100644 --- a/modules/llm-cache/CMakeLists.txt +++ b/modules/llm-cache/CMakeLists.txt @@ -3,15 +3,12 @@ file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" "ds/*.h" "radix-tree/*.cc" "radix-tree/*.h" - "utils/*.cc" - "utils/*.h" ) add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS}) target_link_libraries(vineyard_llm_cache PUBLIC vineyard_client vineyard_basic) install_export_vineyard_target(vineyard_llm_cache) -install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/utils/") install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") diff --git a/modules/llm-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc index c040dc5e..dc4df8b8 100644 --- a/modules/llm-cache/ds/kv_state_cache.cc +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -43,7 +43,7 @@ void KVStateCache::Resolve() { // 1. construct the radix tree this->rootTree = RadixTree::Deserialize( base64_decode(this->meta_.GetKeyValue("radix_tree"))); - raxShow(this->rootTree->GetRootTree()); + // raxShow(this->rootTree->GetRootTree()); // 2. construct the kvStateCacheBlockBuilder list size_t numBlocks = this->meta_.GetKeyValue("numBlocks"); @@ -148,7 +148,6 @@ void KVStateCacheBuilder::Update(Client& client, std::shared_ptr nodeData = this->rootTree->Insert(tokenListCopy, evictedNodeData); if (nodeData == nullptr) { - LOG(ERROR) << "insert failed"; return; } KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = @@ -235,6 +234,7 @@ void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { // delete (DataWrapper*) evictedNodeData->nodeData; if (evictedNodeData->cleanTreeData) { this->rootTree->ClearSubtreeData(treeData); + delete kvStateCacheBlockBuilder; } evictedNodeData->RecycleSource(); } diff --git a/modules/llm-cache/ds/kv_state_cache_manager.cc b/modules/llm-cache/ds/kv_state_cache_manager.cc new file mode 100644 index 00000000..551317e7 --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_manager.cc @@ -0,0 +1,265 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "llm-cache/ds/kv_state_cache.h" +#include "llm-cache/ds/kv_state_cache_manager.h" + +namespace vineyard { + +KVStateCacheManager::KVStateCacheManager(int dimension, int cacheCapacity, + int layer, int blockSize, + int syncInterval, std::string socket) { + this->syncInterval = syncInterval; + VLOG(100) << "socket:" << socket; + client.Connect(socket); + + // TBD + // try to get cache object + std::string actualKey; + bool result; + while (1) { + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); + if (!result) { + VLOG(100) << "failed to gain the lock, wait for next time."; + sleep(1); + continue; + } else { + break; + } + } + + // sync global cache object with vineyard + ObjectID globalKVStateCacheID; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); + if (status.ok()) { + // if success, pull the cache object + std::shared_ptr globalKVStateCache = + std::dynamic_pointer_cast( + client.FetchAndGetObject(globalKVStateCacheID)); + kvStateCacheBuilder = + std::make_shared(client, globalKVStateCache); + } else { + // if failed, create a new cache object + VLOG(100) << "failed to get the cache object, create a new one."; + kvStateCacheBuilder = std::make_shared( + client, dimension, cacheCapacity, layer, blockSize); + } + + // release the lock + client.TryReleaseLock(actualKey, result); + VINEYARD_ASSERT(result == true); + + // syncThread = new std::thread(threadFunc); + syncThread = new std::thread(SyncThreadFunc, this); + + // TBD + // use lease to prevent the deadlock if the client is down +} + +void KVStateCacheManager::UpdateInternal(const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); +} + +KV_STATE_WITH_LAYER KVStateCacheManager::QueryInternal( + const std::vector& tokenList, int token) { + return kvStateCacheBuilder->Query(client, tokenList, token); +} + +void KVStateCacheManager::Update(const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + if (!syncMutex.try_lock()) { + return; + } + + UpdateInternal(tokenList, nextToken, kvState); + + syncMutex.unlock(); +} + +void KVStateCacheManager::Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState) { + if (!syncMutex.try_lock()) { + return; + } + + std::vector tokenListCopy; + for (size_t i = 0; i < tokenList.size(); i++) { + UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); + tokenListCopy.push_back(tokenList[i]); + } + + syncMutex.unlock(); +} + +KV_STATE_WITH_LAYER KVStateCacheManager::Query( + const std::vector& tokenList, int token) { + KV_STATE_WITH_LAYER result; + + if (!syncMutex.try_lock()) { + return result; + } + + result = QueryInternal(tokenList, token); + syncMutex.unlock(); + + return result; +} + +LIST_KV_STATE_WITH_LAYER KVStateCacheManager::Query( + const std::vector& tokenList) { + LIST_KV_STATE_WITH_LAYER listKVState; + if (!syncMutex.try_lock()) { + return listKVState; + } + + std::vector tokenListCopy; + for (size_t i = 0; i < tokenList.size(); i++) { + KV_STATE_WITH_LAYER kvState = QueryInternal(tokenListCopy, tokenList[i]); + listKVState.push_back(kvState); + tokenListCopy.push_back(tokenList[i]); + } + + syncMutex.unlock(); + return listKVState; +} + +KVStateCacheManager::~KVStateCacheManager() { + LOG(INFO) << "Wait for sync thread to exit."; + { + std::lock_guard lock(exitMutex); + exitFlag = true; + } + cv.notify_one(); + syncThread->join(); + delete syncThread; + LOG(INFO) << "KVStateCacheManager exit."; +} + +// This function is used for testing +void KVStateCacheManager::Delete(std::vector token) { + std::shared_ptr evictedNode; + kvStateCacheBuilder->GetRootTree()->Delete(token, evictedNode); + kvStateCacheBuilder->Delete(evictedNode); + raxShow(kvStateCacheBuilder->GetRootTree()->tree); +} + +void KVStateCacheManager::Sync() { + // 1. gain the lock + std::string actualKey; + bool result; + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); + if (!result) { + LOG(INFO) << "failed to gain the lock, wait for next time"; + return; + } + // 2. pull the cache object + ObjectID globalKVStateCacheID; + std::vector deleteList; + + std::shared_ptr globalKVStateCache = nullptr; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); + if (status.ok()) { + deleteList.push_back(globalKVStateCacheID); + globalKVStateCache = std::dynamic_pointer_cast( + client.FetchAndGetObject(globalKVStateCacheID)); + } + + // 3. merge the cache object + // only the global cache object with higher version will be merged + VLOG(100) << "Current builder version:" << kvStateCacheBuilder->GetVersion() + << " global version:" + << (globalKVStateCache == nullptr + ? "null" + : std::to_string(globalKVStateCache->GetVersion())); + if (globalKVStateCache != nullptr && + kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) { + kvStateCacheBuilder->Merge(client, globalKVStateCache); + } + kvStateCacheBuilder->UpdateVersion(); + + // 4. push the cache object + std::shared_ptr kvStateCache = kvStateCacheBuilder->_Seal(client); + client.Persist(kvStateCache->id()); + + // 5. put the name of the new cache object to the meta server + client.DropName(llmCacheObjectName); + status = client.PutName(kvStateCache->id(), llmCacheObjectName); + if (!status.ok()) { + throw std::runtime_error("Put cache object name failed."); + } + + // 6. delete old cache object + client.DelData(deleteList, true, true); + + // 7. create a global cache object replica + std::dynamic_pointer_cast(kvStateCache)->Resolve(); + kvStateCacheBuilder = std::make_shared( + client, std::dynamic_pointer_cast(kvStateCache)); + + // 8. release the lock + while (1) { + client.TryReleaseLock(actualKey, result); + if (result == true) { + break; + } + sleep(1); + } + + // TBD + // use lease to prevent the deadlock if the client is down +} + +void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) { + uint64_t last_time = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + while (1) { + std::unique_lock lock(manager->exitMutex); + if (manager->cv.wait_for( + lock, std::chrono::seconds(manager->syncInterval), + [&manager, &last_time] { + uint64_t current_time = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + return manager->exitFlag || + static_cast(current_time - last_time) >= + manager->syncInterval; + })) { + if (manager->exitFlag) { + break; + } + manager->syncMutex.lock(); + manager->Sync(); + manager->syncMutex.unlock(); + last_time = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + } + } + LOG(INFO) << "Sync thread exit."; +} + +} // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache_manager.h b/modules/llm-cache/ds/kv_state_cache_manager.h new file mode 100644 index 00000000..238ae31a --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_manager.h @@ -0,0 +1,76 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include + +#include "llm-cache/ds/kv_state_cache.h" + +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ + +namespace vineyard { + +class KVStateCacheManager { + private: + Client client; + std::shared_ptr kvStateCacheBuilder = nullptr; + std::string llmCacheSyncLock = "llmCacheSyncLock"; + std::string llmCacheObjectName = "llm_cache_object"; + std::thread* syncThread; + std::mutex syncMutex; + int syncInterval; + bool exitFlag = false; + std::condition_variable cv; + std::mutex exitMutex; + + public: + KVStateCacheManager( + int dimension = 10, int cacheCapacity = 10, int layer = 1, + int blockSize = 5, int syncInterval = 3, + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET"))); + + void Update(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState); + + void Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState); + + KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token); + + LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); + + ~KVStateCacheManager(); + + private: + void UpdateInternal(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState); + + KV_STATE_WITH_LAYER QueryInternal(const std::vector& tokenList, + int token); + + void Delete(std::vector token); + + static void SyncThreadFunc(KVStateCacheManager* manager); + + void Sync(); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ diff --git a/modules/llm-cache/radix-tree/radix-tree.cc b/modules/llm-cache/radix-tree/radix-tree.cc index 62e213e4..60f14326 100644 --- a/modules/llm-cache/radix-tree/radix-tree.cc +++ b/modules/llm-cache/radix-tree/radix-tree.cc @@ -291,7 +291,7 @@ std::string RadixTree::Serialize() { size_t srcSize = serializedStr.size(); size_t dstSize = ZSTD_compressBound(srcSize); std::string compressedStr(dstSize + 1, '\0'); - LOG(INFO) << "src size:" << srcSize << " dst size:" << dstSize; + VLOG(100) << "src size:" << srcSize << " dst size:" << dstSize; int compressedSize = ZSTD_compress(compressedStr.data(), compressedStr.size(), serializedStr.c_str(), srcSize, 3); if (ZSTD_isError(compressedSize)) { @@ -373,7 +373,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { } } if (!std::getline(lineStream, dataPart)) { - LOG(ERROR) << "data length is 0"; + VLOG(100) << "data length is 0"; } std::istringstream keyStream(tokenListPart); diff --git a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc index 6f574054..aba51d54 100644 --- a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "client/ds/object_meta.h" #include "common/util/logging.h" -#include "llm-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/ds/kv_state_cache_manager.h" using namespace vineyard; // NOLINT(build/namespaces) @@ -37,7 +37,12 @@ using namespace vineyard; // NOLINT(build/namespaces) #define LAYER 64 #define BLOCK_SIZE 100 -void init() { InitKVStateCache(DIMENSION, CAPACITY, LAYER, BLOCK_SIZE); } +KVStateCacheManager* manager; + +void init() { + manager = + new KVStateCacheManager(DIMENSION, CAPACITY, LAYER, DEFAULT_BLOCK_SIZE); +} std::vector generate_random_tokens(size_t max_length) { std::random_device rd; @@ -88,7 +93,7 @@ void benchmark_inference(std::vector>& tokens) { std::vector inference_tokens; for (size_t j = 0; j < tokens[i].size(); ++j) { start = std::chrono::steady_clock::now(); - kv_state = Query(inference_tokens, tokens[i][j]); + kv_state = manager->Query(inference_tokens, tokens[i][j]); end = std::chrono::steady_clock::now(); query_duration += end - start; @@ -96,7 +101,7 @@ void benchmark_inference(std::vector>& tokens) { kv_state = generate_kv_state(tokens[i][j]); start = std::chrono::steady_clock::now(); - Update(inference_tokens, tokens[i][j], kv_state); + manager->Update(inference_tokens, tokens[i][j], kv_state); end = std::chrono::steady_clock::now(); update_duration += end - start; } @@ -156,5 +161,6 @@ int main(int argc, char** argv) { memory_monitor.join(); inference.join(); + delete manager; return 0; } diff --git a/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc index 10f1de44..6e656b3f 100644 --- a/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llm-cache/radix-tree/radix.h" #include "common/util/logging.h" -#include "llm-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/ds/kv_state_cache_manager.h" using namespace vineyard; // NOLINT(build/namespaces) diff --git a/modules/llm-cache/tests/kv_state_cache_test.cc b/modules/llm-cache/tests/kv_state_cache_test.cc index abc4b83d..9debcd06 100644 --- a/modules/llm-cache/tests/kv_state_cache_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llm-cache/radix-tree/radix.h" #include "common/util/logging.h" -#include "llm-cache/utils/kv_state_cache_utils.h" +#include "llm-cache/ds/kv_state_cache_manager.h" using namespace vineyard; // NOLINT(build/namespaces) @@ -40,9 +40,12 @@ std::vector round_4_tokens = {1, 2, 3, 4, 5, 6}; std::vector> tokens_list; +KVStateCacheManager* kv_state_cache_manager; + void init(int dimension, int capacity, int layer, int block_size, std::string socket) { - InitKVStateCache(dimension, capacity, layer, block_size, socket); + kv_state_cache_manager = new KVStateCacheManager(dimension, capacity, layer, + block_size, 3, socket); } void print_current_tokens(const std::vector& prefix, int next_token) { @@ -133,14 +136,14 @@ void inference(std::vector tokens, bool block = false) { std::map, std::vector>> kv_state; for (size_t i = 0; i < tokens.size(); ++i) { - kv_state = Query(inference_tokens, tokens[i]); + kv_state = kv_state_cache_manager->Query(inference_tokens, tokens[i]); if (kv_state.size() == 0) { LOG(INFO) << "Can not find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); LOG(INFO) << "Generate the kv_state and update the cache."; kv_state = generate_kv_state(tokens[i]); print_kv_state(kv_state); - Update(inference_tokens, tokens[i], kv_state); + kv_state_cache_manager->Update(inference_tokens, tokens[i], kv_state); } else { LOG(INFO) << "Find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); @@ -202,7 +205,7 @@ int main(int argc, char** argv) { } LOG(INFO) << "inference end"; - CloseKVStateCache(); + delete kv_state_cache_manager; LOG(INFO) << "Passed KVStateCache tests..."; return 0; } diff --git a/modules/llm-cache/utils/kv_state_cache_utils.cc b/modules/llm-cache/utils/kv_state_cache_utils.cc deleted file mode 100644 index 90713b3a..00000000 --- a/modules/llm-cache/utils/kv_state_cache_utils.cc +++ /dev/null @@ -1,260 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include -#include -#include - -#include "client/client.h" -#include "common/util/logging.h" -#include "llm-cache/ds/kv_state_cache.h" -#include "llm-cache/utils/kv_state_cache_utils.h" - -namespace vineyard { - -static Client client; -static std::shared_ptr kvStateCacheBuilder = nullptr; -static std::string llmCacheSyncLock = "llmCacheSyncLock"; -static std::string llmCacheObjectName = "llm_cache_object"; -static std::thread* syncThread; -static bool exitFlag = false; -static pthread_mutex_t syncMutex; - -#ifndef SYNC_INTERVAL -#define SYNC_INTERVAL 3 -#endif - -// for test -void Delete(std::vector token) { - std::shared_ptr evictedNode; - kvStateCacheBuilder->GetRootTree()->Delete(token, evictedNode); - kvStateCacheBuilder->Delete(evictedNode); - raxShow(kvStateCacheBuilder->GetRootTree()->tree); -} - -void threadFunc(); - -void signalHandler(int signum) { - /* - * TBD - * Avoid dead lock if the client is down when the lock is acquired. - * Use lease to prevent dead lock in the future. - */ - LOG(INFO) << "Interrupt signal (" << signum << ") received.\n"; - CloseKVStateCache(); - exit(signum); -} - -void CloseKVStateCache() { - exitFlag = true; - syncThread->join(); -} - -void InitKVStateCache(int dimension, int cacheCapacity, int layer, - int blockSize, std::string socket) { - if (kvStateCacheBuilder == nullptr) { - VLOG(100) << "socket:" << socket; - client.Connect(socket); - - pthread_mutex_init(&syncMutex, NULL); - // TBD - // try to get cache object - std::string actualKey; - bool result; - while (1) { - client.TryAcquireLock(llmCacheSyncLock, result, actualKey); - if (!result) { - VLOG(100) << "failed to gain the lock, wait for next time."; - sleep(1); - continue; - } else { - break; - } - } - - // sync global cache object with vineyard - ObjectID globalKVStateCacheID; - Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); - if (status.ok()) { - // if success, pull the cache object - std::shared_ptr globalKVStateCache = - std::dynamic_pointer_cast( - client.FetchAndGetObject(globalKVStateCacheID)); - kvStateCacheBuilder = - std::make_shared(client, globalKVStateCache); - } else { - // if failed, create a new cache object - VLOG(100) << "failed to get the cache object, create a new one."; - kvStateCacheBuilder = std::make_shared( - client, dimension, cacheCapacity, layer, blockSize); - } - - // // release the lock - client.TryReleaseLock(actualKey, result); - VINEYARD_ASSERT(result == true); - - syncThread = new std::thread(threadFunc); - - signal(SIGINT, signalHandler); - // TBD - // use lease to prevent the deadlock if the client is down - } -} - -void UpdateInternal(const std::vector& tokenList, int nextToken, - const KV_STATE_WITH_LAYER& kvState) { - kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); -} - -void Update(const std::vector& tokenList, int nextToken, - const KV_STATE_WITH_LAYER& kvState) { - if (pthread_mutex_trylock(&syncMutex)) { - return; - } - - UpdateInternal(tokenList, nextToken, kvState); - - pthread_mutex_unlock(&syncMutex); -} - -void Update(const std::vector& tokenList, - const LIST_KV_STATE_WITH_LAYER& kvState) { - if (pthread_mutex_trylock(&syncMutex)) { - return; - } - std::vector tokenListCopy; - for (size_t i = 0; i < tokenList.size(); i++) { - UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); - tokenListCopy.push_back(tokenList[i]); - } - pthread_mutex_unlock(&syncMutex); -} - -KV_STATE_WITH_LAYER QueryInternal(const std::vector& tokenList, - int token) { - return kvStateCacheBuilder->Query(client, tokenList, token); -} - -KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token) { - KV_STATE_WITH_LAYER result; - if (pthread_mutex_trylock(&syncMutex)) { - return result; - } - - result = QueryInternal(tokenList, token); - pthread_mutex_unlock(&syncMutex); - - return result; -} - -LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList) { - LIST_KV_STATE_WITH_LAYER listKVState; - if (pthread_mutex_trylock(&syncMutex)) { - return listKVState; - } - - std::vector tokenListCopy; - for (size_t i = 0; i < tokenList.size(); i++) { - KV_STATE_WITH_LAYER kvState = QueryInternal(tokenListCopy, tokenList[i]); - listKVState.push_back(kvState); - tokenListCopy.push_back(tokenList[i]); - } - - pthread_mutex_unlock(&syncMutex); - return listKVState; -} - -void sync() { - LOG(INFO) << "Try sync."; - - // 1. gain the lock - std::string actualKey; - bool result; - client.TryAcquireLock(llmCacheSyncLock, result, actualKey); - if (!result) { - LOG(INFO) << "failed to gain the lock, wait for next time"; - return; - } - // 2. pull the cache object - ObjectID globalKVStateCacheID; - std::vector deleteList; - - std::shared_ptr globalKVStateCache = nullptr; - Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); - if (status.ok()) { - deleteList.push_back(globalKVStateCacheID); - globalKVStateCache = std::dynamic_pointer_cast( - client.FetchAndGetObject(globalKVStateCacheID)); - } - - // 3. merge the cache object - // only the global cache object with higher version will be merged - VLOG(100) << "Current builder version:" << kvStateCacheBuilder->GetVersion() - << " global version:" - << (globalKVStateCache == nullptr - ? "null" - : std::to_string(globalKVStateCache->GetVersion())); - if (globalKVStateCache != nullptr && - kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) { - kvStateCacheBuilder->Merge(client, globalKVStateCache); - } - kvStateCacheBuilder->UpdateVersion(); - - // 4. push the cache object - std::shared_ptr kvStateCache = kvStateCacheBuilder->_Seal(client); - client.Persist(kvStateCache->id()); - - // 5. put the name of the new cache object to the meta server - client.DropName(llmCacheObjectName); - status = client.PutName(kvStateCache->id(), llmCacheObjectName); - if (!status.ok()) { - throw std::runtime_error("Put cache object name failed."); - } - - // 6. delete old cache object - client.DelData(deleteList); - - // 7. create a global cache object replica - std::dynamic_pointer_cast(kvStateCache)->Resolve(); - kvStateCacheBuilder = std::make_shared( - client, std::dynamic_pointer_cast(kvStateCache)); - - // 8. release the lock - while (1) { - client.TryReleaseLock(actualKey, result); - if (result == true) { - break; - } - sleep(1); - } - - // TBD - // use lease to prevent the deadlock if the client is down -} - -void threadFunc() { - while (1) { - sleep(SYNC_INTERVAL); - if (exitFlag) { - break; - } - pthread_mutex_lock(&syncMutex); - sync(); - pthread_mutex_unlock(&syncMutex); - } -} - -} // namespace vineyard diff --git a/modules/llm-cache/utils/kv_state_cache_utils.h b/modules/llm-cache/utils/kv_state_cache_utils.h deleted file mode 100644 index 21b9d873..00000000 --- a/modules/llm-cache/utils/kv_state_cache_utils.h +++ /dev/null @@ -1,47 +0,0 @@ -/** Copyright 2020-2023 Alibaba Group Holding Limited. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include - -#include "llm-cache/ds/kv_state_cache.h" - -#ifndef MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ -#define MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ - -namespace vineyard { - -void InitKVStateCache( - int dimension = 10, int cacheCapacity = 10, int layer = 1, - int blockSize = 5, - std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET"))); - -void Update(const std::vector& tokenList, int nextToken, - const KV_STATE_WITH_LAYER& kvState); - -void Update(const std::vector& tokenList, - const LIST_KV_STATE_WITH_LAYER& kvState); - -KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token); - -LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); - -void Delete(std::vector token); - -void CloseKVStateCache(); - -} // namespace vineyard - -#endif // MODULES_LLM_CACHE_UTILS_KV_STATE_CACHE_UTILS_H_ diff --git a/src/server/async/socket_server.cc b/src/server/async/socket_server.cc index 36b58ac1..46d257b8 100644 --- a/src/server/async/socket_server.cc +++ b/src/server/async/socket_server.cc @@ -1107,24 +1107,25 @@ bool SocketConnection::doPersist(const json& root) { auto self(shared_from_this()); ObjectID id; TRY_READ_REQUEST(ReadPersistRequest, root, id); - RESPONSE_ON_ERROR(server_ptr_->Persist(id, [self, id](const Status& status) { - std::string message_out; - if (status.ok()) { - WritePersistReply(message_out); - self->doWrite(message_out); - } else if (status.IsEtcdError()) { - // retry on etcd error: reprocess the message - VLOG(100) << "Warning: " - << "Retry persist on etcd error: " << status.ToString(); - self->server_ptr_->GetIOContext().post( - [self, id]() { self->doPersist(id); }); - } else { - VLOG(100) << "Error: " << status.ToString(); - WriteErrorReply(status, message_out); - self->doWrite(message_out); - } - return Status::OK(); - })); + RESPONSE_ON_ERROR( + server_ptr_->Persist(id, [self, id, root](const Status& status) { + std::string message_out; + if (status.ok()) { + WritePersistReply(message_out); + self->doWrite(message_out); + } else if (status.IsEtcdError()) { + // retry on etcd error: reprocess the message + VLOG(100) << "Warning: " + << "Retry persist on etcd error: " << status.ToString(); + self->server_ptr_->GetIOContext().post( + [self, id, root]() { self->doPersist(root); }); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + self->doWrite(message_out); + } + return Status::OK(); + })); return false; } diff --git a/src/server/server/vineyard_server.cc b/src/server/server/vineyard_server.cc index e49ee899..ed18935c 100644 --- a/src/server/server/vineyard_server.cc +++ b/src/server/server/vineyard_server.cc @@ -1062,7 +1062,6 @@ Status VineyardServer::TryAcquireLock(std::string& key, key, [self, callback](const Status& status, bool result, std::string actual_key) { if (status.ok()) { - LOG(INFO) << "No error occurred. Gain lock:" << result; return callback(status, result, actual_key); } else { return callback(status, result, actual_key); @@ -1079,7 +1078,6 @@ Status VineyardServer::TryReleaseLock(std::string& key, meta_service_ptr_->TryReleaseLock( key, [self, callback](const Status& status, bool result) { if (status.ok()) { - LOG(INFO) << "No error occurred. Release lock:" << result; return callback(status, result); } else { return status; From 3fb23762e1b1dadf768e18bf28e4f2063e0ebe43 Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Thu, 29 Feb 2024 21:24:19 +0800 Subject: [PATCH 20/20] Move the radix tree to the thirdparty and update the type of kv_state (#1779) Signed-off-by: Ye Cao --- CMakeLists.txt | 2 +- LICENSE | 2 +- modules/llm-cache/CMakeLists.txt | 2 + modules/llm-cache/ds/kv_state_cache.cc | 36 ++++++-- modules/llm-cache/ds/kv_state_cache.h | 4 +- modules/llm-cache/ds/kv_state_cache_block.cc | 44 ++++------ modules/llm-cache/ds/kv_state_cache_block.h | 23 +++-- .../llm-cache/ds/kv_state_cache_manager.cc | 28 +++--- modules/llm-cache/ds/kv_state_cache_manager.h | 10 ++- modules/llm-cache/radix-tree/radix-tree.cc | 11 +-- modules/llm-cache/radix-tree/radix-tree.h | 2 +- .../tests/kv_state_cache_benchmark_test.cc | 32 +++---- .../tests/kv_state_cache_radix_tree_test.cc | 6 +- .../llm-cache/tests/kv_state_cache_test.cc | 85 ++++++++++++------- .../radix-tree => thirdparty/rax}/radix.cc | 70 ++++++++------- .../radix-tree => thirdparty/rax}/radix.h | 2 +- .../rax}/rax_malloc.h | 0 17 files changed, 193 insertions(+), 166 deletions(-) rename {modules/llm-cache/radix-tree => thirdparty/rax}/radix.cc (98%) rename {modules/llm-cache/radix-tree => thirdparty/rax}/radix.h (99%) rename {modules/llm-cache/radix-tree => thirdparty/rax}/rax_malloc.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dc42d64..4a99e8b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -981,7 +981,7 @@ endfunction() file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark" PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$" - EXCLUDE_PATTERNS "(.*\\.vineyard.h$)|(.*modules/llm-cache/radix-tree/radix\.(cc|h)$)" + EXCLUDE_PATTERNS "(.*\\.vineyard.h$)" ) # the `memcpy.h` is borrowed from external project diff --git a/LICENSE b/LICENSE index 3ff17ec4..483e7f5e 100644 --- a/LICENSE +++ b/LICENSE @@ -1184,7 +1184,7 @@ SOFTWARE. ------------------------------------------------------------------------------- -The files modules/llm-cache/radix-tree/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, +The files thirdparty/rax/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, which has the following license: Copyright (c) 2017, Salvatore Sanfilippo diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt index afbd347a..e5b9efa2 100644 --- a/modules/llm-cache/CMakeLists.txt +++ b/modules/llm-cache/CMakeLists.txt @@ -3,6 +3,8 @@ file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" "ds/*.h" "radix-tree/*.cc" "radix-tree/*.h" + "${PROJECT_SOURCE_DIR}/thirdparty/rax/*.cc" + "${PROJECT_SOURCE_DIR}/thirdparty/rax/*.h" ) add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS}) diff --git a/modules/llm-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc index dc4df8b8..4c1615f8 100644 --- a/modules/llm-cache/ds/kv_state_cache.cc +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include "client/client.h" #include "common/util/base64.h" @@ -24,7 +25,8 @@ limitations under the License. #include "common/util/status.h" #include "llm-cache/ds/kv_state_cache.h" #include "llm-cache/radix-tree/radix-tree.h" -#include "llm-cache/radix-tree/radix.h" + +#include "rax/radix.h" namespace vineyard { @@ -197,12 +199,12 @@ void KVStateCacheBuilder::Update(Client& client, << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); } -KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( - Client& client, const std::vector& tokenList, int token) { +int KVStateCacheBuilder::Query(Client& client, + const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState) { std::vector tokenListCopy = tokenList; tokenListCopy.push_back(token); - KV_STATE_WITH_LAYER kvState; std::shared_ptr nodeData = this->rootTree->Query(tokenListCopy); if (nodeData != nullptr) { @@ -214,9 +216,9 @@ KV_STATE_WITH_LAYER KVStateCacheBuilder::Query( (reinterpret_cast(nodeData->treeData->data)) ->kvStateCacheBlockBuilder); - kvStateCacheBlockBuilder->Query(client, offset, kvState); + return kvStateCacheBlockBuilder->Query(client, offset, kvState); } - return kvState; + return -1; } void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { @@ -273,10 +275,28 @@ void KVStateCacheBuilder::Merge(Client& client, for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) { std::vector tokenList = std::vector((*it).begin(), (*it).end() - 1); - KV_STATE_WITH_LAYER kvState = - globalCacheBuilder->Query(client, tokenList, (*it).back()); + KV_STATE_WITH_LAYER kvState; + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(this->dimension * sizeof(double)); + key_state.length = this->dimension * sizeof(double); + value_state.data = malloc(this->dimension * sizeof(double)); + value_state.length = this->dimension * sizeof(double); + + kvState.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } + globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState); this->Update(client, tokenList, (*it).back(), kvState); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + K_STATE key_state = kvState[currentLayer].first; + V_STATE value_state = kvState[currentLayer].second; + free(key_state.data); + free(value_state.data); + } } + this->version = globalCacheBuilder->GetVersion(); return; } diff --git a/modules/llm-cache/ds/kv_state_cache.h b/modules/llm-cache/ds/kv_state_cache.h index 075a1f61..82e6a76c 100644 --- a/modules/llm-cache/ds/kv_state_cache.h +++ b/modules/llm-cache/ds/kv_state_cache.h @@ -94,8 +94,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { void Update(Client& client, const std::vector& token_list, int next_token, const KV_STATE_WITH_LAYER& kv_state); - KV_STATE_WITH_LAYER Query(Client& client, const std::vector& token_list, - int token); + int Query(Client& client, const std::vector& token_list, int token, + KV_STATE_WITH_LAYER& kv_state); void Delete(std::shared_ptr evicted_node); diff --git a/modules/llm-cache/ds/kv_state_cache_block.cc b/modules/llm-cache/ds/kv_state_cache_block.cc index 0b251df9..17477143 100644 --- a/modules/llm-cache/ds/kv_state_cache_block.cc +++ b/modules/llm-cache/ds/kv_state_cache_block.cc @@ -129,26 +129,18 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( } // current we do not consider the layer. -Status KVStateCacheBlockBuilder::Query(Client& client, int index, - KV_STATE_WITH_LAYER& kvState) { +int KVStateCacheBlockBuilder::Query(Client& client, int index, + KV_STATE_WITH_LAYER& kvState) { for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { - std::vector keyStateVector; - std::vector valueStateVector; - - for (int i = 0; i < this->dimension; ++i) { - keyStateVector.push_back((keyStateTensorBuilderList[currentLayer] - ->data())[index * dimension + i]); - } - - for (int i = 0; i < this->dimension; ++i) { - valueStateVector.push_back((valueStateTensorBuilderList[currentLayer] - ->data())[index * dimension + i]); - } - - kvState.insert(std::make_pair( - currentLayer, std::make_pair(keyStateVector, valueStateVector))); + memcpy((kvState.find(currentLayer)->second).first.data, + keyStateTensorBuilderList[currentLayer]->data() + index * dimension, + dimension * sizeof(double)); + memcpy( + (kvState.find(currentLayer)->second).second.data, + valueStateTensorBuilderList[currentLayer]->data() + index * dimension, + dimension * sizeof(double)); } - return Status::OK(); + return 0; } int KVStateCacheBlockBuilder::FindEmptySlot() { @@ -176,18 +168,18 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, OffsetData* data) { int index = this->FindEmptySlot(); for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { - std::vector keyStateVector = - (kvState.find(currentLayer)->second).first; - std::vector valueStateVector = - (kvState.find(currentLayer)->second).second; - VINEYARD_ASSERT(keyStateVector.size() == (size_t) this->dimension); - VINEYARD_ASSERT(valueStateVector.size() == (size_t) this->dimension); + K_STATE keyState = (kvState.find(currentLayer)->second).first; + V_STATE valueState = (kvState.find(currentLayer)->second).second; + VINEYARD_ASSERT(keyState.length == + (size_t) this->dimension * sizeof(double)); + VINEYARD_ASSERT(valueState.length == + (size_t) this->dimension * sizeof(double)); double* keyData = keyStateTensorBuilderList[currentLayer]->data(); double* valueData = valueStateTensorBuilderList[currentLayer]->data(); - memcpy(keyData + index * this->dimension, keyStateVector.data(), + memcpy(keyData + index * this->dimension, keyState.data, this->dimension * sizeof(double)); - memcpy(valueData + index * this->dimension, valueStateVector.data(), + memcpy(valueData + index * this->dimension, valueState.data, this->dimension * sizeof(double)); } data->offset = index; diff --git a/modules/llm-cache/ds/kv_state_cache_block.h b/modules/llm-cache/ds/kv_state_cache_block.h index 81747d73..5e0a7262 100644 --- a/modules/llm-cache/ds/kv_state_cache_block.h +++ b/modules/llm-cache/ds/kv_state_cache_block.h @@ -29,14 +29,19 @@ limitations under the License. #include "client/ds/i_object.h" #include "llm-cache/radix-tree/radix-tree.h" -using KV_STATE_WITH_LAYER = - std::map, std::vector>>; -using LIST_KV_STATE_WITH_LAYER = std::vector< - std::map, std::vector>>>; -using KV_STATE = - std::vector, std::vector>>; -using LIST_KV_STATE = - std::vector, std::vector>>; +struct State { + void* data; + size_t length; +}; + +using K_STATE = State; +using V_STATE = State; + +using KV_STATE_WITH_LAYER = std::map>; +using LIST_KV_STATE_WITH_LAYER = + std::vector>>; +using KV_STATE = std::vector>; +using LIST_KV_STATE = std::vector>; // Set the bit to 1, which means the resource is not being used #define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit))) @@ -155,7 +160,7 @@ class KVStateCacheBlockBuilder : public ObjectBuilder { * @param kv_state The kv-state of the prompt returned by radix-tree. If the * kv-state is not found, the data of kv-state is invalid. */ - Status Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state); + int Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state); bool IsFull(); diff --git a/modules/llm-cache/ds/kv_state_cache_manager.cc b/modules/llm-cache/ds/kv_state_cache_manager.cc index 551317e7..b13770df 100644 --- a/modules/llm-cache/ds/kv_state_cache_manager.cc +++ b/modules/llm-cache/ds/kv_state_cache_manager.cc @@ -81,9 +81,10 @@ void KVStateCacheManager::UpdateInternal(const std::vector& tokenList, kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); } -KV_STATE_WITH_LAYER KVStateCacheManager::QueryInternal( - const std::vector& tokenList, int token) { - return kvStateCacheBuilder->Query(client, tokenList, token); +int KVStateCacheManager::QueryInternal(const std::vector& tokenList, + int token, + KV_STATE_WITH_LAYER& kvState) { + return kvStateCacheBuilder->Query(client, tokenList, token, kvState); } void KVStateCacheManager::Update(const std::vector& tokenList, @@ -113,36 +114,35 @@ void KVStateCacheManager::Update(const std::vector& tokenList, syncMutex.unlock(); } -KV_STATE_WITH_LAYER KVStateCacheManager::Query( - const std::vector& tokenList, int token) { - KV_STATE_WITH_LAYER result; +int KVStateCacheManager::Query(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState) { + int result = -1; if (!syncMutex.try_lock()) { return result; } - result = QueryInternal(tokenList, token); + result = QueryInternal(tokenList, token, kvState); syncMutex.unlock(); return result; } -LIST_KV_STATE_WITH_LAYER KVStateCacheManager::Query( - const std::vector& tokenList) { - LIST_KV_STATE_WITH_LAYER listKVState; +int KVStateCacheManager::Query(const std::vector& tokenList, + LIST_KV_STATE_WITH_LAYER& listKVState) { + int result = -1; if (!syncMutex.try_lock()) { - return listKVState; + return result; } std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - KV_STATE_WITH_LAYER kvState = QueryInternal(tokenListCopy, tokenList[i]); - listKVState.push_back(kvState); + result = QueryInternal(tokenListCopy, tokenList[i], listKVState[i]); tokenListCopy.push_back(tokenList[i]); } syncMutex.unlock(); - return listKVState; + return result; } KVStateCacheManager::~KVStateCacheManager() { diff --git a/modules/llm-cache/ds/kv_state_cache_manager.h b/modules/llm-cache/ds/kv_state_cache_manager.h index 238ae31a..408cac8a 100644 --- a/modules/llm-cache/ds/kv_state_cache_manager.h +++ b/modules/llm-cache/ds/kv_state_cache_manager.h @@ -51,9 +51,11 @@ class KVStateCacheManager { void Update(const std::vector& tokenList, const LIST_KV_STATE_WITH_LAYER& kvState); - KV_STATE_WITH_LAYER Query(const std::vector& tokenList, int token); + int Query(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState); - LIST_KV_STATE_WITH_LAYER Query(const std::vector& tokenList); + int Query(const std::vector& tokenList, + LIST_KV_STATE_WITH_LAYER& listKVState); ~KVStateCacheManager(); @@ -61,8 +63,8 @@ class KVStateCacheManager { void UpdateInternal(const std::vector& tokenList, int nextToken, const KV_STATE_WITH_LAYER& kvState); - KV_STATE_WITH_LAYER QueryInternal(const std::vector& tokenList, - int token); + int QueryInternal(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState); void Delete(std::vector token); diff --git a/modules/llm-cache/radix-tree/radix-tree.cc b/modules/llm-cache/radix-tree/radix-tree.cc index 60f14326..93e37a79 100644 --- a/modules/llm-cache/radix-tree/radix-tree.cc +++ b/modules/llm-cache/radix-tree/radix-tree.cc @@ -359,17 +359,14 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { std::string tokenListPart, timestampPart, dataPart, subTreeSizePart; if (!std::getline(lineStream, tokenListPart, '|')) { - throw std::runtime_error( - "Invalid serialized string format in token list part."); + LOG(ERROR) << "Invalid serialized string format in token list part."; } if (isMainTree) { if (!std::getline(lineStream, timestampPart, '|')) { - throw std::runtime_error( - "Invalid serialized string format in timestamp part."); + LOG(ERROR) << "Invalid serialized string format in timestamp part."; } if (!std::getline(lineStream, subTreeSizePart, '|')) { - throw std::runtime_error( - "Invalid serialized string format in sub tree size part."); + LOG(ERROR) << "Invalid serialized string format in sub tree size part."; } } if (!std::getline(lineStream, dataPart)) { @@ -471,7 +468,7 @@ std::shared_ptr RadixTree::Deserialize(std::string data) { reinterpret_cast(&dataNode), NULL); if (dataNode == NULL) { - throw std::runtime_error("Insert token list failed"); + LOG(ERROR) << "Insert token list failed"; } dataNode->timestamp = timestampList[i]; } diff --git a/modules/llm-cache/radix-tree/radix-tree.h b/modules/llm-cache/radix-tree/radix-tree.h index 28958234..211b48f9 100644 --- a/modules/llm-cache/radix-tree/radix-tree.h +++ b/modules/llm-cache/radix-tree/radix-tree.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ #define MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ -#include "llm-cache/radix-tree/radix.h" +#include "rax/radix.h" #include #include diff --git a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc index aba51d54..feb1166b 100644 --- a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc @@ -18,12 +18,7 @@ limitations under the License. #include #include #include -#include "arrow/api.h" -#include "arrow/io/api.h" -#include "basic/stream/byte_stream.h" -#include "basic/stream/dataframe_stream.h" -#include "basic/stream/recordbatch_stream.h" #include "client/client.h" #include "client/ds/object_meta.h" #include "common/util/logging.h" @@ -57,19 +52,15 @@ std::vector generate_random_tokens(size_t max_length) { return tokens; } -std::map, std::vector>> -generate_kv_state(int token) { - std::map, std::vector>> kv_state; +std::map> generate_kv_state(int token) { + std::map> kv_state; for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { - std::vector key_state; - std::vector value_state; - for (int i = 0; i < DIMENSION; ++i) { - key_state.push_back((static_cast(token)) / DIMENSION * (i + 1) + - currentLayer * 10); - value_state.push_back((static_cast(token)) / DIMENSION * (i + 1) * - 2 + - currentLayer * 10); - } + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(DIMENSION * sizeof(double)); + key_state.length = DIMENSION * sizeof(double); + value_state.data = malloc(DIMENSION * sizeof(double)); + value_state.length = DIMENSION * sizeof(double); kv_state.insert( std::make_pair(currentLayer, std::make_pair(key_state, value_state))); @@ -80,7 +71,7 @@ generate_kv_state(int token) { // test the performance of Query and Update function void benchmark_inference(std::vector>& tokens) { LOG(INFO) << "inference for benchmark"; - std::map, std::vector>> kv_state; + std::map> kv_state; std::chrono::steady_clock::time_point start, end; double token_list_size = 0; @@ -93,13 +84,12 @@ void benchmark_inference(std::vector>& tokens) { std::vector inference_tokens; for (size_t j = 0; j < tokens[i].size(); ++j) { start = std::chrono::steady_clock::now(); - kv_state = manager->Query(inference_tokens, tokens[i][j]); + kv_state = generate_kv_state(tokens[i][j]); + manager->Query(inference_tokens, tokens[i][j], kv_state); end = std::chrono::steady_clock::now(); query_duration += end - start; if (kv_state.size() == 0) { - kv_state = generate_kv_state(tokens[i][j]); - start = std::chrono::steady_clock::now(); manager->Update(inference_tokens, tokens[i][j], kv_state); end = std::chrono::steady_clock::now(); diff --git a/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc index 6e656b3f..d3f4ae78 100644 --- a/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include #include -#include "llm-cache/radix-tree/radix.h" +#include "rax/radix.h" #include "common/util/logging.h" #include "llm-cache/ds/kv_state_cache_manager.h" @@ -119,7 +119,7 @@ void radix_tree_query_test() { VINEYARD_ASSERT(radix_tree->Query(tokens) == NULL); } -void radix_tree_serialize_and_deserailize() { +void radix_tree_serialize_and_deserialize() { std::shared_ptr radix_tree = std::make_shared(10); /* insert a token list*/ @@ -183,7 +183,7 @@ int main() { radix_tree_query_test(); LOG(INFO) << "Finish radix tree query test!"; LOG(INFO) << "Start to test radix tree serialize and deserialize..."; - radix_tree_serialize_and_deserailize(); + radix_tree_serialize_and_deserialize(); LOG(INFO) << "Finish radix tree serialize and deserialize test!"; LOG(INFO) << "Start to test radix tree split..."; radix_tree_split(); diff --git a/modules/llm-cache/tests/kv_state_cache_test.cc b/modules/llm-cache/tests/kv_state_cache_test.cc index 9debcd06..e2d1e98a 100644 --- a/modules/llm-cache/tests/kv_state_cache_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include #include -#include "llm-cache/radix-tree/radix.h" +#include "rax/radix.h" #include "common/util/logging.h" #include "llm-cache/ds/kv_state_cache_manager.h" @@ -58,15 +58,20 @@ void print_current_tokens(const std::vector& prefix, int next_token) { } void print_kv_state( - const std::map, std::vector>>& - kv_state) { + const std::map>& kv_state) { LOG(INFO) << "kv_state: "; for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { std::string key_state_str = ""; std::string value_state_str = ""; for (int i = 0; i < dimension; ++i) { - key_state_str += std::to_string(iter->second.first[i]) + " "; - value_state_str += std::to_string(iter->second.second[i]) + " "; + key_state_str += + std::to_string( + (reinterpret_cast(iter->second.first.data))[i]) + + " "; + value_state_str += + std::to_string( + (reinterpret_cast(iter->second.second.data))[i]) + + " "; } LOG(INFO) << "layer " << iter->first << ":"; LOG(INFO) << "key_state: " << key_state_str; @@ -76,19 +81,16 @@ void print_kv_state( } // we do not consider the layer. -std::map, std::vector>> -generate_kv_state(int token) { - std::map, std::vector>> kv_state; +std::map> generate_kv_state() { + std::map> kv_state; for (int currentLayer = 0; currentLayer < layer; currentLayer++) { - std::vector key_state; - std::vector value_state; - for (int i = 0; i < dimension; ++i) { - key_state.push_back((static_cast(token)) / dimension * (i + 1) + - currentLayer * 10); - value_state.push_back((static_cast(token)) / dimension * (i + 1) * - 2 + - currentLayer * 10); - } + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(dimension * sizeof(double)); + value_state.data = malloc(dimension * sizeof(double)); + + key_state.length = dimension * sizeof(double); + value_state.length = dimension * sizeof(double); kv_state.insert( std::make_pair(currentLayer, std::make_pair(key_state, value_state))); @@ -96,32 +98,50 @@ generate_kv_state(int token) { return kv_state; } -void check_kv_state( - const std::map, std::vector>>& - kv_state, - int& token) { +void update_kv_state(std::map>& kvState, + int token) { + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { + K_STATE key_state = kvState[currentLayer].first; + V_STATE value_state = kvState[currentLayer].second; + for (int i = 0; i < dimension; ++i) { + (reinterpret_cast(key_state.data))[i] = + (static_cast(token)) / dimension * (i + 1) + + currentLayer * 10; + (reinterpret_cast(value_state.data))[i] = + (static_cast(token)) / dimension * (i + 1) * 2 + + currentLayer * 10; + } + } +} + +void check_kv_state(const std::map>& kv_state, + int& token) { VINEYARD_ASSERT(kv_state.size() == (size_t) layer); for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { - VINEYARD_ASSERT(iter->second.first.size() == (size_t) dimension); - VINEYARD_ASSERT(iter->second.second.size() == (size_t) dimension); + VINEYARD_ASSERT(iter->second.first.length == + (size_t) dimension * sizeof(double)); + VINEYARD_ASSERT(iter->second.second.length == + (size_t) dimension * sizeof(double)); for (int i = 0; i < dimension; ++i) { - if (iter->second.first[i] != + if ((reinterpret_cast(iter->second.first.data))[i] != (static_cast(token)) / dimension * (i + 1) + iter->first * 10) { LOG(INFO) << "token:" << token << " dimension" << dimension << " layer:" << iter->first; - LOG(INFO) << "key_state[" << i << "]: " << iter->second.first[i] + LOG(INFO) << "key_state[" << i << "]: " + << (reinterpret_cast(iter->second.first.data))[i] << ". But is should be " << (static_cast(token)) / dimension * (i + 1) + iter->first * 10; throw std::runtime_error("key_state error!"); } - if (iter->second.second[i] != + if ((reinterpret_cast(iter->second.second.data))[i] != (static_cast(token)) / dimension * (i + 1) * 2 + iter->first * 10) { LOG(INFO) << "token:" << token << " dimension" << dimension << " layer:" << iter->first; - LOG(INFO) << "value_state[" << i << "]: " << iter->second.second[i] + LOG(INFO) << "value_state[" << i << "]: " + << (reinterpret_cast(iter->second.second.data))[i] << ". But is should be " << (static_cast(token)) / dimension * (i + 1) * 2 + iter->first * 10; @@ -133,15 +153,16 @@ void check_kv_state( void inference(std::vector tokens, bool block = false) { std::vector inference_tokens; - std::map, std::vector>> kv_state; - + std::map> kv_state; + kv_state = generate_kv_state(); for (size_t i = 0; i < tokens.size(); ++i) { - kv_state = kv_state_cache_manager->Query(inference_tokens, tokens[i]); - if (kv_state.size() == 0) { + int result = + kv_state_cache_manager->Query(inference_tokens, tokens[i], kv_state); + if (result != 0) { LOG(INFO) << "Can not find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); LOG(INFO) << "Generate the kv_state and update the cache."; - kv_state = generate_kv_state(tokens[i]); + update_kv_state(kv_state, tokens[i]); print_kv_state(kv_state); kv_state_cache_manager->Update(inference_tokens, tokens[i], kv_state); } else { diff --git a/modules/llm-cache/radix-tree/radix.cc b/thirdparty/rax/radix.cc similarity index 98% rename from modules/llm-cache/radix-tree/radix.cc rename to thirdparty/rax/radix.cc index 6c5658c2..82739625 100644 --- a/modules/llm-cache/radix-tree/radix.cc +++ b/thirdparty/rax/radix.cc @@ -297,7 +297,7 @@ raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlin size_t curlen = raxNodeCurrentLength(n); n->size++; size_t newlen = raxNodeCurrentLength(n); - n->size--; /* For now restore the orignal size. We'll update it only on + n->size--; /* For now restore the original size. We'll update it only on success at the end. */ /* Alloc the new child we will link to 'n'. */ @@ -645,7 +645,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o * * Splitting a compressed node have a few possible cases. * Imagine that the node 'h' we are currently at is a compressed - * node contaning the token list [1,2,3,4,5,6,7] (it means that it represents + * node containing the token list [1,2,3,4,5,6,7] (it means that it represents * nodes 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 with the only child * pointer of this node pointing at the 7 node, because remember that * we have tokens at the edges of the graph, not inside the nodes @@ -744,7 +744,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o * Let $SPLITPOS be the zero-based index at which, in the * compressed node array of token list, we stopped iterating because * there were no more keys character to match. So in the example of - * the node [1,2,3,4,5,6,7], addig the token list [1,2,3,4], the $SPLITPOS is 4. + * the node [1,2,3,4,5,6,7], adding the token list [1,2,3,4], the $SPLITPOS is 4. * * 1. Save the current compressed node $NEXT pointer (the pointer to the * child element, that is always present in compressed nodes). @@ -757,7 +757,7 @@ int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int o * * 3. Trim the current node to contain the first $SPLITPOS tokens. * As usually if the new node length is just 1, set iscompr to 0. - * Take the iskey / associated value as it was in the orignal node. + * Take the iskey / associated value as it was in the original node. * Fix the parent's reference. * * 4. Set the postfix node as the only child pointer of the trimmed @@ -1341,9 +1341,9 @@ int raxRemove(rax *rax, int *s, size_t len, void **old, bool set_timestamp) { * We try to navigate upward till there are other nodes that can be * compressed, when we reach the upper node which is not a key and has * a single child, we scan the chain of children to collect the - * compressable part of the tree, and replace the current node with the + * compressible part of the tree, and replace the current node with the * new one, fixing the child pointer to reference the first non - * compressable node. + * compressible node. * * Example of case "1". A tree stores the keys [1,2,3] = 1 and * [1,2,3,4,5] = 2: @@ -1505,7 +1505,7 @@ void raxFree(rax *rax) { * to initialize the iterator, and must be followed by a raxSeek() call, * otherwise the raxPrev()/raxNext() functions will just return EOF. */ void raxStart(raxIterator *it, rax *rt) { - it->flags = RAX_ITER_EOF; /* No crash if the iterator is not seeked. */ + it->flags = RAX_ITER_EOF; /* No crash if the iterator is not sought. */ it->rt = rt; it->key_len = 0; it->key = it->key_static_tokens; @@ -1584,7 +1584,7 @@ int raxIteratorNextStep(raxIterator *it, int noup) { if (!noup && children) { debugf("GO DEEPER\n"); /* Seek the lexicographically smaller key in this subtree, which - * is the first one found always going torwards the first child + * is the first one found always going towards the first child * of every successive node. */ if (!raxStackPush(&it->stack,it->node)) return 0; raxNode **cp = raxNodeFirstChildPtr(it->node); @@ -1817,7 +1817,7 @@ int raxIteratorPrevStep(raxIterator *it, int noup) { int raxSeek(raxIterator *it, const char *op, int *ele, size_t len) { int eq = 0, lt = 0, gt = 0, first = 0, last = 0; - it->stack.items = 0; /* Just resetting. Intialized by raxStart(). */ + it->stack.items = 0; /* Just resetting. Initialized by raxStart(). */ it->flags |= RAX_ITER_JUST_SEEKED; it->flags &= ~RAX_ITER_EOF; it->key_len = 0; @@ -2046,9 +2046,9 @@ int raxRandomWalk(raxIterator *it, size_t steps) { } if (steps == 0) { - size_t fle = 1+floor(log(it->rt->numele)); - fle *= 2; - steps = 1 + rand() % fle; + size_t file = 1+floor(log(it->rt->numele)); + file *= 2; + steps = 1 + rand() % file; } raxNode *n = it->node; @@ -2461,21 +2461,19 @@ void raxFindLastRecentNode(raxNode *node, std::vector& key) { return; } - raxNode *chossenChild = childList[0]; - int choosenChildIndex = 0; + raxNode *chosenChild = childList[0]; + int chosenChildIndex = 0; for (int i = 1; i < numChildren; i++) { - if (childList[i]->timestamp == chossenChild->timestamp) { - if (childList[i]->numnodes > chossenChild->numnodes) { - LOG(INFO) << "childList[i]->numnodes > chossenChild->numnodes"; - LOG(INFO) << "node1:" << childList[i] << " node:2" << chossenChild; - chossenChild = childList[i]; - choosenChildIndex = i; + if (childList[i]->timestamp == chosenChild->timestamp) { + if (childList[i]->numnodes > chosenChild->numnodes) { + VLOG(100) << "childList[i]->numnodes > chossenChild->numnodes"; + VLOG(100) << "node1:" << childList[i] << " node:2" << chosenChild; + chosenChild = childList[i]; + chosenChildIndex = i; } - // chossenChild = childList[i]; - // choosenChildIndex = i; - } else if (childList[i]->timestamp < chossenChild->timestamp) { - chossenChild = childList[i]; - choosenChildIndex = i; + } else if (childList[i]->timestamp < chosenChild->timestamp) { + chosenChild = childList[i]; + chosenChildIndex = i; } } @@ -2484,10 +2482,10 @@ void raxFindLastRecentNode(raxNode *node, std::vector& key) { key.push_back(node->data[i]); } } else { - key.push_back(node->data[choosenChildIndex]); + key.push_back(node->data[chosenChildIndex]); } - raxFindLastRecentNode(chossenChild, key); + raxFindLastRecentNode(chosenChild, key); } bool compareKey(int *first_key, int *second_key, int first_key_len, int second_key_len) { @@ -2581,7 +2579,7 @@ void mergeTree(rax* first_tree, rax* second_tree, int nodeCount = 0; /** - * We use two structures to store the nodes choosen from the second tree + * We use two structures to store the nodes chosen from the second tree * and the nodes evicted from the first tree. */ while (nodeCount < max_node) { @@ -2620,13 +2618,13 @@ void mergeTree(rax* first_tree, rax* second_tree, /** * Choose first tree node. * If the key is in the record tree, it means that there exist a same key in - * the second tree and has been choosen in the past. So we just need to remove + * the second tree and has been chosen in the past. So we just need to remove * the key from the insert_tokens and update the timestamp of the key in the * first tree. * If the key is not in the record tree, it means that the key has not been - * choosen in the past. So we need to insert the key into the record tree. + * chosen in the past. So we need to insert the key into the record tree. */ - printf("chosse first key %ld : %ld\n", + printf("choose first key %ld : %ld\n", first_tree_iter_list[first_tree_index].node->timestamp, second_tree_iter_list[second_tree_index].node->timestamp); if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, @@ -2650,12 +2648,12 @@ void mergeTree(rax* first_tree, rax* second_tree, /** * Choose second tree node. * If the key is in the record tree, it means that there exist a same key in - * the first tree and has been choosen in the past. So we need do nothing. + * the first tree and has been chosen in the past. So we need do nothing. * If the key is not in the record tree, it means that the key has not been - * choosen in the past. So we need to insert the key into the record tree. + * chosen in the past. So we need to insert the key into the record tree. * and insert the key into the insert_tokens. */ - printf("chosse second key %ld : %ld\n", + printf("choose second key %ld : %ld\n", first_tree_iter_list[first_tree_index].node->timestamp, second_tree_iter_list[second_tree_index].node->timestamp); // choose second key @@ -2680,7 +2678,7 @@ void mergeTree(rax* first_tree, rax* second_tree, */ if (first_tree_iter_list[first_tree_index].node->numnodes <= second_tree_iter_list[second_tree_index].node->numnodes) { - printf("chosse first key %ld : %ld\n", + printf("choose first key %ld : %ld\n", first_tree_iter_list[first_tree_index].node->timestamp, second_tree_iter_list[second_tree_index].node->timestamp); // choose first key @@ -2704,7 +2702,7 @@ void mergeTree(rax* first_tree, rax* second_tree, } first_tree_index++; } else { - printf("chosse second key %ld : %ld\n", + printf("choose second key %ld : %ld\n", first_tree_iter_list[first_tree_index].node->timestamp, second_tree_iter_list[second_tree_index].node->timestamp); // choose second key diff --git a/modules/llm-cache/radix-tree/radix.h b/thirdparty/rax/radix.h similarity index 99% rename from modules/llm-cache/radix-tree/radix.h rename to thirdparty/rax/radix.h index cbfebf2b..3590fb77 100644 --- a/modules/llm-cache/radix-tree/radix.h +++ b/thirdparty/rax/radix.h @@ -136,7 +136,7 @@ typedef struct raxNode { * nodes). * * If the node has an associated key (iskey=1) and is not NULL - * (isnull=0), then after the raxNode pointers poiting to the + * (isnull=0), then after the raxNode pointers pointing to the * children, an additional value pointer is present (as you can see * in the representation above as "value-ptr" field). */ diff --git a/modules/llm-cache/radix-tree/rax_malloc.h b/thirdparty/rax/rax_malloc.h similarity index 100% rename from modules/llm-cache/radix-tree/rax_malloc.h rename to thirdparty/rax/rax_malloc.h