diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 4c1427b4..29ab598e 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -329,6 +329,16 @@ jobs: if: false uses: mxschmitt/action-tmate@v3 + - name: Run llm tests + run: | + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/lib64:/usr/local/lib/x86_64-linux-gnu + export VINEYARD_DATA_DIR=`pwd`/gstest + export TMPDIR="${TMPDIR:-$(dirname $(mktemp))}" + + rm -rf default.etcd + rm -rf /dev/shm/etcd* + python3 test/runner.py $RUNNER_ARGS --with-llm + - name: Run cpp tests run: | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/lib64:/usr/local/lib/x86_64-linux-gnu diff --git a/CMakeLists.txt b/CMakeLists.txt index b1afa939..4a99e8b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(BUILD_VINEYARD_MALLOC_OVERRIDE "Using the client-side malloc to override option(BUILD_VINEYARD_FUSE "Enable vineyard's fuse support" OFF) option(BUILD_VINEYARD_FUSE_PARQUET "Enable vineyard's fuse parquet support" OFF) option(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME "Enable hosseinmoein dataframe support" OFF) +option(BUILD_VINEYARD_LLM_CACHE "Enable kv-state cache support" ON) option(BUILD_VINEYARD_TESTS "Generate make targets for vineyard tests" ON) option(BUILD_VINEYARD_TESTS_ALL "Include make targets for vineyard tests to ALL" OFF) @@ -932,6 +933,11 @@ if(BUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME) list(APPEND VINEYARD_INSTALL_LIBS vineyard_hosseinmoein_dataframe) endif() +if(BUILD_VINEYARD_LLM_CACHE) + add_subdirectory(modules/llm-cache) + list(APPEND VINEYARD_INSTALL_LIBS vineyard_llm_cache) +endif() + if(BUILD_VINEYARD_TESTS) add_subdirectory(test) endif() @@ -975,7 +981,7 @@ endfunction() file_glob_recurse(FILES_NEED_FORMAT DIRECTORIES "src" "modules" "python" "test" "benchmark" PATTERNS ".*\\.(cc|cpp|h|hpp|vineyard-mod)$" - EXCLUDE_PATTERNS ".*\\.vineyard.h$" + EXCLUDE_PATTERNS "(.*\\.vineyard.h$)" ) # the `memcpy.h` is borrowed from external project diff --git a/LICENSE b/LICENSE index ab3ae45d..483e7f5e 100644 --- a/LICENSE +++ b/LICENSE @@ -1181,3 +1181,18 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +The files thirdparty/rax/{radix.cc, radix.h, rax_malloc} is referred from project antirez/rax, +which has the following license: + +Copyright (c) 2017, Salvatore Sanfilippo +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/NOTICE.txt b/NOTICE.txt index f8186a78..c326975c 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -55,3 +55,7 @@ This product includes software from the ClickHouse project This product includes software from the BBHash project * Copyright (c) 2015 Guillaume Rizk * https://github.com/rizkg/BBHash + +This product includes software from the rax project (BSD, 2-clause) + * Copyright (c) 2017-2019, Salvatore Sanfilippo + * https://github.com/antirez/rax diff --git a/README.rst b/README.rst index 956ace05..a0297818 100644 --- a/README.rst +++ b/README.rst @@ -298,6 +298,7 @@ We thank the following excellent open-source projects: - `skywalking-swck `_ A kubernetes operator for the Apache Skywalking. - `wyhash `_, C++ wrapper around wyhash and wyrand. - `BBHash `_, a fast, minimal-memory perfect hash function. +- `rax `_, an ANSI C radix tree implementation. License ------- diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt new file mode 100644 index 00000000..e5b9efa2 --- /dev/null +++ b/modules/llm-cache/CMakeLists.txt @@ -0,0 +1,17 @@ +file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" + "ds/*.cc" + "ds/*.h" + "radix-tree/*.cc" + "radix-tree/*.h" + "${PROJECT_SOURCE_DIR}/thirdparty/rax/*.cc" + "${PROJECT_SOURCE_DIR}/thirdparty/rax/*.h" +) + +add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS}) +target_link_libraries(vineyard_llm_cache PUBLIC vineyard_client vineyard_basic) + +install_export_vineyard_target(vineyard_llm_cache) +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/ds/") +install_vineyard_headers("${CMAKE_CURRENT_SOURCE_DIR}/radix-tree/") + +add_subdirectory(tests) diff --git a/modules/llm-cache/README.rst b/modules/llm-cache/README.rst new file mode 100644 index 00000000..c5e48f97 --- /dev/null +++ b/modules/llm-cache/README.rst @@ -0,0 +1,27 @@ +KV-state cache on Vineyard +============================= + +Run test +-------- + +Build vineyard and vineyard test + +.. code:: bash + mkdir build + cd build + cmake .. -DBUILD_VINEYARD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug + make -j$(nproc) + make vineyard_tests -j$(nproc) + +Start vineyard server + +.. code:: bash + cd build + ./bin/vineyardd --socket=/tmp/vineyard_test.sock # make sure the env VINEYARD_IPC_SOCKET is set properly + +Run test + +.. code:: bash + cd build + export VINEYARD_IPC_SOCKET=/tmp/vineyard_test.sock + ./bin/kv_state_cache_test diff --git a/modules/llm-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc new file mode 100644 index 00000000..4c1615f8 --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -0,0 +1,377 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include + +#include "client/client.h" +#include "common/util/base64.h" +#include "common/util/logging.h" +#include "common/util/status.h" +#include "llm-cache/ds/kv_state_cache.h" +#include "llm-cache/radix-tree/radix-tree.h" + +#include "rax/radix.h" + +namespace vineyard { + +void KVStateCache::Construct(const ObjectMeta& meta) { + Object::Construct(meta); + Resolve(); +} + +void KVStateCache::Resolve() { + std::string typeName = type_name(); + + VINEYARD_ASSERT(this->meta_.GetTypeName() == typeName, + "Expect typename '" + typeName + "', but got '" + + this->meta_.GetTypeName() + "'"); + + // 1. construct the radix tree + this->rootTree = RadixTree::Deserialize( + base64_decode(this->meta_.GetKeyValue("radix_tree"))); + // raxShow(this->rootTree->GetRootTree()); + + // 2. construct the kvStateCacheBlockBuilder list + size_t numBlocks = this->meta_.GetKeyValue("numBlocks"); + for (size_t i = 0; i < numBlocks; i++) { + std::shared_ptr kvStateCacheBlockObject = this->meta_.GetMember( + "kv_state_cache_block_builder_" + std::to_string(i)); + this->kvStateCacheBlockList.push_back( + std::dynamic_pointer_cast(kvStateCacheBlockObject)); + } + + // 3. construct the member field + this->dimension = this->meta_.GetKeyValue("dimension"); + this->version = this->meta_.GetKeyValue("version"); + this->layer = this->meta_.GetKeyValue("layer"); + VLOG(100) << "construct the member field success, with dimension:" + << this->dimension << " version:" << this->version + << " layer:" << this->layer; +} + +KVStateCache::~KVStateCache() {} + +KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension, + int cacheCapacity, int layer, + int blockSize) { + this->dimension = dimension; + this->version = 0; + this->layer = layer; + KVStateCacheBlockBuilder* builder = + new KVStateCacheBlockBuilder(client, this->dimension, layer, blockSize); + + this->rootTree = std::make_shared(cacheCapacity); + + TreeData* treeData = new TreeData(); + treeData->kvStateCacheBlockBuilder = builder; + treeData->isPtr = true; + + std::shared_ptr rootTreeHeader = this->rootTree->GetRootNode(); + rootTreeHeader->treeData->data = treeData; + rootTreeHeader->treeData->dataLength = sizeof(TreeData); + this->rootTree->SetSubtreeData(treeData); +} + +KVStateCacheBuilder::KVStateCacheBuilder(Client& client, + std::shared_ptr cache) { + this->dimension = cache->GetDimension(); + this->version = cache->GetVersion(); + this->layer = cache->GetLayer(); + // 1. create block builder from block + std::vector> kvStateCacheBlockList = + cache->GetKVStateCacheBlockList(); + this->rootTree = cache->GetRootTree(); + std::set subTreeData = cache->rootTree->GetSubTreeDataSet(); + + for (auto iter = subTreeData.begin(); iter != subTreeData.end(); ++iter) { + TreeData* treeData = reinterpret_cast(*iter); + VINEYARD_ASSERT(treeData->isPtr == false); + std::shared_ptr kvStateCacheBlock = + kvStateCacheBlockList[treeData->builderObjectID]; + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + new KVStateCacheBlockBuilder(client, kvStateCacheBlock); + + treeData->kvStateCacheBlockBuilder = kvStateCacheBlockBuilder; + treeData->isPtr = true; + } +} + +KVStateCacheBlockBuilder* KVStateCacheBuilder::Split( + Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + std::vector> nodeDataList) { + // Split the tree if the list of kvState is full. + VINEYARD_ASSERT(nodeDataList.size() > 0); + KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder = + new KVStateCacheBlockBuilder(client, this->dimension, this->layer, + kvStateCacheBlockBuilder->GetBlockSize()); + for (size_t i = 0; i < nodeDataList.size(); i++) { + OffsetData* data = + reinterpret_cast(nodeDataList[i]->nodeData->data); + if (data == nullptr) + continue; + int index = data->offset; + + // Transfer the data from this builder to the child builder. + data->offset = + kvStateCacheBlockBuilder->Split(childKVStateCacheBlockBuilder, index); + } + VLOG(100) << "builder:" << kvStateCacheBlockBuilder + << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); + VLOG(100) << "child_builder:" << childKVStateCacheBlockBuilder + << " bitmap:" << childKVStateCacheBlockBuilder->GetBitmapStr(); + return childKVStateCacheBlockBuilder; +} + +void KVStateCacheBuilder::Update(Client& client, + const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + std::vector tokenListCopy = tokenList; + tokenListCopy.push_back(nextToken); + + // Create a empty node of tokens from radix tree. + std::shared_ptr evictedNodeData = nullptr; + std::shared_ptr nodeData = + this->rootTree->Insert(tokenListCopy, evictedNodeData); + if (nodeData == nullptr) { + return; + } + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + reinterpret_cast( + (reinterpret_cast(nodeData->treeData->data)) + ->kvStateCacheBlockBuilder); + if (evictedNodeData != nullptr) { + Delete(evictedNodeData); + } + + if (kvStateCacheBlockBuilder->IsFull()) { + /** + * If the kv-state cache of the tree is full, trigger split. Delete the + * empty node from the radix tree and split the tree. Then, kv-state cache + * split according to the new tree. + */ + VLOG(100) << "trigger splits"; + std::shared_ptr evictedNodeData = nullptr; + this->rootTree->Delete(tokenListCopy, evictedNodeData); + + std::shared_ptr subTreeHeader; + std::vector> nodeDataList = + rootTree->Split(tokenListCopy, subTreeHeader); + KVStateCacheBlockBuilder* newKVStateCacheBlockBuilder = + Split(client, kvStateCacheBlockBuilder, nodeDataList); + + TreeData* newTreeData = new TreeData(); + newTreeData->kvStateCacheBlockBuilder = newKVStateCacheBlockBuilder; + newTreeData->isPtr = true; + + subTreeHeader->treeData->data = newTreeData; + subTreeHeader->treeData->dataLength = sizeof(TreeData); + rootTree->SetSubtreeData(newTreeData); + VLOG(100) << "block split success"; + + // kv_state_cache_builder->UnLock(); + Update(client, tokenList, nextToken, kvState); + } else { + // Update the kv-state cache. + OffsetData* data = new OffsetData(); + kvStateCacheBlockBuilder->Update(kvState, data); + nodeData->nodeData->data = data; + nodeData->nodeData->dataLength = sizeof(OffsetData); + } + + VLOG(100) << "builder:" << kvStateCacheBlockBuilder + << " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr(); +} + +int KVStateCacheBuilder::Query(Client& client, + const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState) { + std::vector tokenListCopy = tokenList; + tokenListCopy.push_back(token); + + std::shared_ptr nodeData = this->rootTree->Query(tokenListCopy); + + if (nodeData != nullptr) { + OffsetData* data = reinterpret_cast(nodeData->nodeData->data); + int offset = data->offset; + + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + reinterpret_cast( + (reinterpret_cast(nodeData->treeData->data)) + ->kvStateCacheBlockBuilder); + + return kvStateCacheBlockBuilder->Query(client, offset, kvState); + } + return -1; +} + +void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { + TreeData* treeData = + reinterpret_cast(evictedNodeData->treeData->data); + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + reinterpret_cast( + treeData->kvStateCacheBlockBuilder); + OffsetData* data = + reinterpret_cast(evictedNodeData->nodeData->data); + kvStateCacheBlockBuilder->DeleteKVCache(data->offset); + delete data; + // TBD + // Refactor this code. The data should be deleted by the RadixTree + // delete (DataWrapper*) evictedNodeData->nodeData; + if (evictedNodeData->cleanTreeData) { + this->rootTree->ClearSubtreeData(treeData); + delete kvStateCacheBlockBuilder; + } + evictedNodeData->RecycleSource(); +} + +void KVStateCacheBuilder::Merge(Client& client, + std::shared_ptr kvStateCache) { + if (kvStateCache == nullptr) { + return; + } + + std::shared_ptr globalCacheBuilder = + std::make_shared(client, kvStateCache); + std::shared_ptr globalCacheTree = kvStateCache->GetRootTree(); + + std::set> insertTokenList; + std::vector> evicted_token_list; + RadixTree::MergeTree(this->rootTree, globalCacheTree, evicted_token_list, + insertTokenList); + + VLOG(100) << "insert token list size:" << insertTokenList.size() + << " evicted token list size:" << evicted_token_list.size(); + for (size_t i = 0; i < evicted_token_list.size(); i++) { + std::vector tokenList = + evicted_token_list[evicted_token_list.size() - i - 1]; + std::shared_ptr evictedNodeData; + this->rootTree->Delete(tokenList, evictedNodeData); + Delete(evictedNodeData); + } + + /** + * Set use lexicographical order to insert the token list, so the insert token + * list is sorted and will not cause insert failed.(Radix tree will reject a + * insert operation if the prefix of the insert token list is not in the + * tree.) + */ + for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) { + std::vector tokenList = + std::vector((*it).begin(), (*it).end() - 1); + KV_STATE_WITH_LAYER kvState; + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(this->dimension * sizeof(double)); + key_state.length = this->dimension * sizeof(double); + value_state.data = malloc(this->dimension * sizeof(double)); + value_state.length = this->dimension * sizeof(double); + + kvState.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } + globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState); + this->Update(client, tokenList, (*it).back(), kvState); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + K_STATE key_state = kvState[currentLayer].first; + V_STATE value_state = kvState[currentLayer].second; + free(key_state.data); + free(value_state.data); + } + } + + this->version = globalCacheBuilder->GetVersion(); + return; +} + +Status KVStateCacheBuilder::Build(Client& client) { return Status::OK(); } + +std::shared_ptr KVStateCacheBuilder::_Seal(Client& client) { + this->Build(client); + + std::shared_ptr kvStateCache = std::make_shared(); + + // 1. store the member variables to cache object meta + kvStateCache->meta_.AddKeyValue("dimension", this->dimension); + kvStateCache->meta_.AddKeyValue("version", this->version); + kvStateCache->meta_.AddKeyValue("layer", this->layer); + + // 2. seal all the block and put object id to cache object and + // change the tree data from pointer to object id + + int count = 0; + std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); + for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); + ++iter) { + TreeData* treeData = reinterpret_cast(*iter); + VINEYARD_ASSERT(treeData != nullptr); + VINEYARD_ASSERT(treeData->isPtr == true); + + KVStateCacheBlockBuilder* kvStateCacheBlockBuilder = + reinterpret_cast( + treeData->kvStateCacheBlockBuilder); + std::shared_ptr kvStateCacheBlock = + kvStateCacheBlockBuilder->_Seal(client); + kvStateCache->meta_.AddMember( + "kv_state_cache_block_builder_" + std::to_string(count), + kvStateCacheBlock); + treeData->builderObjectID = count; + treeData->isPtr = false; + count++; + } + + kvStateCache->meta_.AddKeyValue("numBlocks", count); + + // 3. put the serialized sequence radix tree to cache object meta + kvStateCache->meta_.AddKeyValue("radix_tree", + base64_encode(this->rootTree->Serialize())); + + // 4. put the object type to the meta + kvStateCache->meta_.SetTypeName(type_name()); + + VINEYARD_CHECK_OK( + client.CreateMetaData(kvStateCache->meta_, kvStateCache->id_)); + VLOG(100) << "KVStateCacheBuilder::_Seal: " << kvStateCache->id_; + return kvStateCache; +} + +KVStateCacheBuilder::~KVStateCacheBuilder() { + // get all subtree data and node data + std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); + std::set nodeDataSet = rootTree->GetAllNodeData(); + // 2. delete all subtree data and node data + for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); + ++iter) { + TreeData* treeData = reinterpret_cast(*iter); + if (treeData->isPtr == true) { + delete reinterpret_cast( + treeData->kvStateCacheBlockBuilder); + delete treeData; + } + } + for (auto iter = nodeDataSet.begin(); iter != nodeDataSet.end(); ++iter) { + OffsetData* data = reinterpret_cast(*iter); + if (data != nullptr) { + delete data; + } + } +} + +} // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache.h b/modules/llm-cache/ds/kv_state_cache.h new file mode 100644 index 00000000..82e6a76c --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache.h @@ -0,0 +1,123 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "common/util/status.h" +#include "llm-cache/ds/kv_state_cache_block.h" +#include "llm-cache/radix-tree/radix-tree.h" + +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ + +namespace vineyard { + +struct TreeData { + union { + void* kvStateCacheBlockBuilder; + uint64_t builderObjectID; + }; + bool isPtr = true; +}; + +class KVStateCache : public vineyard::Registered { + private: + std::vector> kvStateCacheBlockList; + std::shared_ptr rootTree; + int dimension; + int cacheCapacity; + int layer; + uint64_t version; + + public: + static std::unique_ptr Create() __attribute__((used)) { + return std::static_pointer_cast( + std::unique_ptr{new KVStateCache()}); + } + + void Construct(const ObjectMeta& meta) override; + + void Resolve(); + + // for test + std::vector> GetKVStateCacheBlockList() { + return this->kvStateCacheBlockList; + } + + int GetDimension() { return this->dimension; } + + int GetCacheCapacity() { return this->cacheCapacity; } + + uint64_t GetVersion() { return this->version; } + + std::shared_ptr GetRootTree() { return this->rootTree; } + + int GetLayer() { return this->layer; } + + ~KVStateCache(); + + friend class KVStateCacheBuilder; +}; + +class KVStateCacheBuilder : public vineyard::ObjectBuilder { + std::shared_ptr rootTree; + int dimension; + int layer; + uint64_t version; + + public: + KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity, + int layer, int blockSize = DEFAULT_BLOCK_SIZE); + + KVStateCacheBuilder(Client& client, std::shared_ptr cache); + + KVStateCacheBlockBuilder* Split( + Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder, + std::vector> nodeDataList); + + void Update(Client& client, const std::vector& token_list, + int next_token, const KV_STATE_WITH_LAYER& kv_state); + + int Query(Client& client, const std::vector& token_list, int token, + KV_STATE_WITH_LAYER& kv_state); + + void Delete(std::shared_ptr evicted_node); + + void Merge(Client& client, std::shared_ptr kv_state_cache); + + uint64_t GetVersion() { return this->version; } + + void UpdateVersion() { this->version++; } + + Status Build(Client& client) override; + + std::shared_ptr _Seal(Client& client) override; + + uint64_t GetDimension() { return this->dimension; } + + std::shared_ptr GetRootTree() { return this->rootTree; } + + int GetLayer() { return this->layer; } + + ~KVStateCacheBuilder(); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_H_ diff --git a/modules/llm-cache/ds/kv_state_cache_block.cc b/modules/llm-cache/ds/kv_state_cache_block.cc new file mode 100644 index 00000000..17477143 --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_block.cc @@ -0,0 +1,287 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "llm-cache/ds/kv_state_cache_block.h" + +namespace vineyard { + +// this function will be removed in the future +std::string KVStateCacheBlock::GetBitmapStr() { + std::string result; + const int bits = 8 * sizeof(uint64_t); + for (int i = 0; i < this->bitmapSize; i++) { + for (int j = bits - 1; j >= 0; --j) { + result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; + } + } + return result; +} + +std::string KVStateCacheBlockBuilder::GetBitmapStr() { + std::string result; + const int bits = 8 * sizeof(uint64_t); + for (int i = 0; i < this->bitmapSize; i++) { + for (int j = bits - 1; j >= 0; --j) { + result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0'; + } + } + return result; +} + +void KVStateCacheBlock::Construct(const ObjectMeta& meta) { + Object::Construct(meta); + + std::string typeName = type_name(); + + VINEYARD_ASSERT(meta.GetTypeName() == typeName, + "Expect typename '" + typeName + "', but got '" + + meta.GetTypeName() + "'"); + + // TBD + // 1. construct the keyStateTensorBuilder and valueStateTensorBuilder + this->layer = this->meta_.GetKeyValue("layer"); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + this->keyStateTensorList.push_back( + std::dynamic_pointer_cast>(this->meta_.GetMember( + "keyStateTensorBuilder_" + std::to_string(currentLayer)))); + this->valueStateTensorList.push_back( + std::dynamic_pointer_cast>(this->meta_.GetMember( + "valueStateTensorBuilder_" + std::to_string(currentLayer)))); + } + // 2. construct the member field + this->bitmapSize = this->meta_.GetKeyValue("bitmap_size"); + VLOG(100) << "construct bitmap size:" << this->bitmapSize; + this->bitmap = new uint64_t[this->bitmapSize]; + for (int i = 0; i < this->bitmapSize; i++) { + this->bitmap[i] = + this->meta_.GetKeyValue("bitmap_" + std::to_string(i)); + } + this->dimension = this->meta_.GetKeyValue("dimension"); + this->blockSize = this->meta_.GetKeyValue("block_size"); +} + +KVStateCacheBlock::~KVStateCacheBlock() { delete this->bitmap; } + +KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client, + int dimension, int layer, + int blockSize) { + this->blockSize = blockSize; + this->bitmapSize = (blockSize + 63) / 64; + this->bitmap = new uint64_t[this->bitmapSize]; + memset(this->bitmap, UINT8_MAX, this->bitmapSize * sizeof(uint64_t)); + std::vector shape = {(int64_t)(blockSize), dimension}; + for (int i = 0; i < layer; i++) { + this->keyStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + this->valueStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + } + this->dimension = dimension; + this->layer = layer; +} + +KVStateCacheBlockBuilder::KVStateCacheBlockBuilder( + Client& client, std::shared_ptr kvStateCacheBlock) { + this->bitmapSize = kvStateCacheBlock->bitmapSize; + this->blockSize = kvStateCacheBlock->blockSize; + VLOG(100) << "create builder from block object, bitmap size:" + << this->bitmapSize << " block size:" << blockSize; + this->bitmap = new uint64_t[this->bitmapSize]; + for (int i = 0; i < this->bitmapSize; i++) { + this->bitmap[i] = kvStateCacheBlock->bitmap[i]; + } + this->dimension = kvStateCacheBlock->dimension; + this->layer = kvStateCacheBlock->layer; + std::vector shape = {(int64_t)(blockSize), dimension}; + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + this->keyStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + this->valueStateTensorBuilderList.push_back( + std::make_shared>(client, shape)); + } + + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + memcpy(this->keyStateTensorBuilderList[currentLayer]->data(), + kvStateCacheBlock->keyStateTensorList[currentLayer]->data(), + (int64_t)(blockSize) * this->dimension * sizeof(double)); + memcpy(this->valueStateTensorBuilderList[currentLayer]->data(), + kvStateCacheBlock->valueStateTensorList[currentLayer]->data(), + (int64_t)(blockSize) * this->dimension * sizeof(double)); + } +} + +// current we do not consider the layer. +int KVStateCacheBlockBuilder::Query(Client& client, int index, + KV_STATE_WITH_LAYER& kvState) { + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + memcpy((kvState.find(currentLayer)->second).first.data, + keyStateTensorBuilderList[currentLayer]->data() + index * dimension, + dimension * sizeof(double)); + memcpy( + (kvState.find(currentLayer)->second).second.data, + valueStateTensorBuilderList[currentLayer]->data() + index * dimension, + dimension * sizeof(double)); + } + return 0; +} + +int KVStateCacheBlockBuilder::FindEmptySlot() { + for (int i = 0; i < this->bitmapSize; i++) { + if (this->bitmap[i] != 0) { + int index = ffsll(this->bitmap[i]) - 1; + return index + i * 64; + } + } + return -1; +} + +bool KVStateCacheBlockBuilder::IsFull() { + int left = this->blockSize; + for (int i = 0; i < this->bitmapSize; i++) { + if (this->bitmap[i] != 0 && ffsll(this->bitmap[i]) - 1 < left) { + return false; + } + left -= sizeof(uint64_t) * 8; + } + return true; +} + +void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState, + OffsetData* data) { + int index = this->FindEmptySlot(); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + K_STATE keyState = (kvState.find(currentLayer)->second).first; + V_STATE valueState = (kvState.find(currentLayer)->second).second; + VINEYARD_ASSERT(keyState.length == + (size_t) this->dimension * sizeof(double)); + VINEYARD_ASSERT(valueState.length == + (size_t) this->dimension * sizeof(double)); + + double* keyData = keyStateTensorBuilderList[currentLayer]->data(); + double* valueData = valueStateTensorBuilderList[currentLayer]->data(); + memcpy(keyData + index * this->dimension, keyState.data, + this->dimension * sizeof(double)); + memcpy(valueData + index * this->dimension, valueState.data, + this->dimension * sizeof(double)); + } + data->offset = index; + + ACQUIRE_BIT_RESOURCE(this->bitmap[index / 64], index % 64); +} + +int16_t KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child, + int index) { + // TBD + VINEYARD_ASSERT(this->layer == child->layer); + int childIndex = child->FindEmptySlot(); + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + std::shared_ptr> keyStateTensorBuilder = + keyStateTensorBuilderList[currentLayer]; + std::shared_ptr> valueStateTensorBuilder = + valueStateTensorBuilderList[currentLayer]; + std::shared_ptr> childKeyStateTensorBuilder = + child->keyStateTensorBuilderList[currentLayer]; + std::shared_ptr> childValueStateTensorBuilder = + child->valueStateTensorBuilderList[currentLayer]; + + double* keyState = keyStateTensorBuilder->data() + index * this->dimension; + double* valueState = + valueStateTensorBuilder->data() + index * this->dimension; + double* childKeyState = + childKeyStateTensorBuilder->data() + childIndex * this->dimension; + double* childValueState = + childValueStateTensorBuilder->data() + childIndex * this->dimension; + + memcpy(childKeyState, keyState, this->dimension * sizeof(double)); + memcpy(childValueState, valueState, this->dimension * sizeof(double)); + } + ACQUIRE_BIT_RESOURCE(child->bitmap[childIndex / 64], childIndex % 64); + FREE_BIT_RESOURCE(this->bitmap[index / 64], index % 64); + return childIndex; +} + +Status KVStateCacheBlockBuilder::Build(Client& client) { return Status::OK(); } + +std::shared_ptr KVStateCacheBlockBuilder::_Seal(Client& client) { + this->Build(client); + + std::shared_ptr kvStateCacheBlock = + std::make_shared(); + + // 1. seal keyStateTensorBuilder and valueStateTensorBuilder + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + kvStateCacheBlock->meta_.AddMember( + "keyStateTensorBuilder_" + std::to_string(currentLayer), + keyStateTensorBuilderList[currentLayer]->Seal(client)); + kvStateCacheBlock->meta_.AddMember( + "valueStateTensorBuilder_" + std::to_string(currentLayer), + valueStateTensorBuilderList[currentLayer]->Seal(client)); + } + + // 2. store the member field to meta + kvStateCacheBlock->meta_.AddKeyValue("bitmap_size", this->bitmapSize); + for (int i = 0; i < this->bitmapSize; i++) { + kvStateCacheBlock->meta_.AddKeyValue("bitmap_" + std::to_string(i), + this->bitmap[i]); + } + + kvStateCacheBlock->meta_.AddKeyValue("block_size", this->blockSize); + kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension); + kvStateCacheBlock->meta_.AddKeyValue("layer", this->layer); + // 3. set the object type to meta + kvStateCacheBlock->meta_.SetTypeName(type_name()); + + VINEYARD_CHECK_OK( + client.CreateMetaData(kvStateCacheBlock->meta_, kvStateCacheBlock->id_)); + return kvStateCacheBlock; +} + +void KVStateCacheBlockBuilder::PrintKVStateCacheBlock() { + LOG(INFO) << "builder:" << this; + for (int i = 0; i < this->blockSize; i++) { + LOG(INFO) << "index:" << i << " bitmap:" << this->GetBitmapStr(); + } + + for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) { + LOG(INFO) << "layer:" << currentLayer; + for (int i = 0; i < this->blockSize; i++) { + LOG(INFO) << "index:" << i; + std::string keyState = ""; + std::string valueState = ""; + for (int j = 0; j < this->dimension; j++) { + keyState += std::to_string((keyStateTensorBuilderList[currentLayer] + ->data())[i * dimension + j]) + + " "; + valueState += std::to_string((valueStateTensorBuilderList[currentLayer] + ->data())[i * dimension + j]) + + " "; + } + LOG(INFO) << "keyState:" << keyState; + LOG(INFO) << "valueState:" << valueState; + } + } + + LOG(INFO) << "=========================="; +} + +KVStateCacheBlockBuilder::~KVStateCacheBlockBuilder() { delete this->bitmap; } + +} // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache_block.h b/modules/llm-cache/ds/kv_state_cache_block.h new file mode 100644 index 00000000..5e0a7262 --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_block.h @@ -0,0 +1,210 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "basic/ds/tensor.h" +#include "client/ds/blob.h" +#include "client/ds/i_object.h" +#include "llm-cache/radix-tree/radix-tree.h" + +struct State { + void* data; + size_t length; +}; + +using K_STATE = State; +using V_STATE = State; + +using KV_STATE_WITH_LAYER = std::map>; +using LIST_KV_STATE_WITH_LAYER = + std::vector>>; +using KV_STATE = std::vector>; +using LIST_KV_STATE = std::vector>; + +// Set the bit to 1, which means the resource is not being used +#define FREE_BIT_RESOURCE(value, bit) ((value) |= (((uint64_t) 1) << (bit))) + +// Set the bit to 0, which means the resource is being used +#define ACQUIRE_BIT_RESOURCE(value, bit) \ + ((value) &= (~(((uint64_t) 1) << (bit)))) + +struct OffsetData { + int16_t offset; +}; +namespace vineyard { + +#define DEFAULT_BLOCK_SIZE 64 + +/** + * @brief KVStateCacheBlock is a cache for kv-cache of LLM. When a new prompt + * comes, LLM can query KVStateCacheBlock to get the state of the kv-cache to + * avoid calculate the kv-cache again if the new prompt is similar to the + * previous one. + * + * KVStateCacheBlock is stored in vineyard as a vineyard object which contains a + * radix tree. The token sequence is the key of the radix tree and the value + * point out the offset of the kv-cache in the tensor list. + * + * KVStateCacheBlock can be shared by multiple machines. + */ + +class KVStateCacheBlock : public vineyard::Registered { + private: + std::vector>> keyStateTensorList; + std::vector>> valueStateTensorList; + uint64_t* bitmap; + int blockSize; + int bitmapSize; + ObjectID id; + int layer; + int dimension; + + public: + static std::unique_ptr Create() __attribute__((used)) { + return std::static_pointer_cast( + std::unique_ptr{new KVStateCacheBlock()}); + } + + void Construct(const ObjectMeta& meta) override; + + std::string GetBitmapStr(); + + uint64_t GetDimension() { return this->dimension; } + + uint64_t* GetBitmap() { return this->bitmap; } + + int GetBlockSize() { return this->blockSize; } + + std::shared_ptr> GetKeyTensor(int layer) { + return this->keyStateTensorList[layer]; + } + + std::shared_ptr> GetValueTensor(int layer) { + return this->valueStateTensorList[layer]; + } + + std::vector>> GetKeyTensorList() { + return this->keyStateTensorList; + } + + std::vector>> GetValueTensorList() { + return this->valueStateTensorList; + } + + ~KVStateCacheBlock(); + + friend class KVStateCacheBlockBuilder; +}; + +class KVStateCacheBlockBuilder : public ObjectBuilder { + private: + std::vector>> keyStateTensorBuilderList; + std::vector>> + valueStateTensorBuilderList; + // TBD + // support more than 64 kv-state cache slots + uint64_t* bitmap; + int blockSize; + int bitmapSize; + int dimension; + int layer; + + int FindEmptySlot(); + + public: + KVStateCacheBlockBuilder(Client& client, int dimension, int layer, + int blockSize); + + KVStateCacheBlockBuilder( + Client& client, std::shared_ptr kv_state_cache_block); + + /** + * @brief Update the kv-state using next token. + * + * @param client The vineyard client. + * @param kv_state The kv-state of the prompt. A LLM inference can contain + * multiple kv-states for each layer. + */ + void Update(const KV_STATE_WITH_LAYER& kv_state, OffsetData* data); + + void Update(double* keyState, double* valueState, uint64_t dataLength, + OffsetData* data); + + /** + * @brief Query the kv-state using the whole token list. + * + * @param token_list The token list of the prompt. + * @param token The token of the prompt. + * @param kv_state The kv-state of the prompt returned by radix-tree. If the + * kv-state is not found, the data of kv-state is invalid. + */ + int Query(Client& client, int index, KV_STATE_WITH_LAYER& kv_state); + + bool IsFull(); + + Status Build(Client& client) override; + + std::shared_ptr _Seal(Client& client) override; + + int16_t Split(KVStateCacheBlockBuilder* child, int index); + + const std::shared_ptr> GetKeyStateBuilder(int layer) { + return keyStateTensorBuilderList[layer]; + } + + const std::shared_ptr> GetValueStateBuilder(int layer) { + return valueStateTensorBuilderList[layer]; + } + + const std::vector>> + GetKeyStateBuilderList() { + return keyStateTensorBuilderList; + } + + const std::vector>> + GetValueStateBuilderList() { + return valueStateTensorBuilderList; + } + + void DeleteKVCache(int bit) { + FREE_BIT_RESOURCE(this->bitmap[bit / 64], bit % 64); + } + + std::string GetBitmapStr(); + + uint64_t* GetBitmap() { return this->bitmap; } + + uint64_t GetDimension() { return this->dimension; } + + int GetBlockSize() { return this->blockSize; } + + void PrintKVStateCacheBlock(); + + ~KVStateCacheBlockBuilder(); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_BLOCK_H_ diff --git a/modules/llm-cache/ds/kv_state_cache_manager.cc b/modules/llm-cache/ds/kv_state_cache_manager.cc new file mode 100644 index 00000000..b13770df --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_manager.cc @@ -0,0 +1,265 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include + +#include "client/client.h" +#include "common/util/logging.h" +#include "llm-cache/ds/kv_state_cache.h" +#include "llm-cache/ds/kv_state_cache_manager.h" + +namespace vineyard { + +KVStateCacheManager::KVStateCacheManager(int dimension, int cacheCapacity, + int layer, int blockSize, + int syncInterval, std::string socket) { + this->syncInterval = syncInterval; + VLOG(100) << "socket:" << socket; + client.Connect(socket); + + // TBD + // try to get cache object + std::string actualKey; + bool result; + while (1) { + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); + if (!result) { + VLOG(100) << "failed to gain the lock, wait for next time."; + sleep(1); + continue; + } else { + break; + } + } + + // sync global cache object with vineyard + ObjectID globalKVStateCacheID; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); + if (status.ok()) { + // if success, pull the cache object + std::shared_ptr globalKVStateCache = + std::dynamic_pointer_cast( + client.FetchAndGetObject(globalKVStateCacheID)); + kvStateCacheBuilder = + std::make_shared(client, globalKVStateCache); + } else { + // if failed, create a new cache object + VLOG(100) << "failed to get the cache object, create a new one."; + kvStateCacheBuilder = std::make_shared( + client, dimension, cacheCapacity, layer, blockSize); + } + + // release the lock + client.TryReleaseLock(actualKey, result); + VINEYARD_ASSERT(result == true); + + // syncThread = new std::thread(threadFunc); + syncThread = new std::thread(SyncThreadFunc, this); + + // TBD + // use lease to prevent the deadlock if the client is down +} + +void KVStateCacheManager::UpdateInternal(const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + kvStateCacheBuilder->Update(client, tokenList, nextToken, kvState); +} + +int KVStateCacheManager::QueryInternal(const std::vector& tokenList, + int token, + KV_STATE_WITH_LAYER& kvState) { + return kvStateCacheBuilder->Query(client, tokenList, token, kvState); +} + +void KVStateCacheManager::Update(const std::vector& tokenList, + int nextToken, + const KV_STATE_WITH_LAYER& kvState) { + if (!syncMutex.try_lock()) { + return; + } + + UpdateInternal(tokenList, nextToken, kvState); + + syncMutex.unlock(); +} + +void KVStateCacheManager::Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState) { + if (!syncMutex.try_lock()) { + return; + } + + std::vector tokenListCopy; + for (size_t i = 0; i < tokenList.size(); i++) { + UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); + tokenListCopy.push_back(tokenList[i]); + } + + syncMutex.unlock(); +} + +int KVStateCacheManager::Query(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState) { + int result = -1; + + if (!syncMutex.try_lock()) { + return result; + } + + result = QueryInternal(tokenList, token, kvState); + syncMutex.unlock(); + + return result; +} + +int KVStateCacheManager::Query(const std::vector& tokenList, + LIST_KV_STATE_WITH_LAYER& listKVState) { + int result = -1; + if (!syncMutex.try_lock()) { + return result; + } + + std::vector tokenListCopy; + for (size_t i = 0; i < tokenList.size(); i++) { + result = QueryInternal(tokenListCopy, tokenList[i], listKVState[i]); + tokenListCopy.push_back(tokenList[i]); + } + + syncMutex.unlock(); + return result; +} + +KVStateCacheManager::~KVStateCacheManager() { + LOG(INFO) << "Wait for sync thread to exit."; + { + std::lock_guard lock(exitMutex); + exitFlag = true; + } + cv.notify_one(); + syncThread->join(); + delete syncThread; + LOG(INFO) << "KVStateCacheManager exit."; +} + +// This function is used for testing +void KVStateCacheManager::Delete(std::vector token) { + std::shared_ptr evictedNode; + kvStateCacheBuilder->GetRootTree()->Delete(token, evictedNode); + kvStateCacheBuilder->Delete(evictedNode); + raxShow(kvStateCacheBuilder->GetRootTree()->tree); +} + +void KVStateCacheManager::Sync() { + // 1. gain the lock + std::string actualKey; + bool result; + client.TryAcquireLock(llmCacheSyncLock, result, actualKey); + if (!result) { + LOG(INFO) << "failed to gain the lock, wait for next time"; + return; + } + // 2. pull the cache object + ObjectID globalKVStateCacheID; + std::vector deleteList; + + std::shared_ptr globalKVStateCache = nullptr; + Status status = client.GetName(llmCacheObjectName, globalKVStateCacheID); + if (status.ok()) { + deleteList.push_back(globalKVStateCacheID); + globalKVStateCache = std::dynamic_pointer_cast( + client.FetchAndGetObject(globalKVStateCacheID)); + } + + // 3. merge the cache object + // only the global cache object with higher version will be merged + VLOG(100) << "Current builder version:" << kvStateCacheBuilder->GetVersion() + << " global version:" + << (globalKVStateCache == nullptr + ? "null" + : std::to_string(globalKVStateCache->GetVersion())); + if (globalKVStateCache != nullptr && + kvStateCacheBuilder->GetVersion() < globalKVStateCache->GetVersion()) { + kvStateCacheBuilder->Merge(client, globalKVStateCache); + } + kvStateCacheBuilder->UpdateVersion(); + + // 4. push the cache object + std::shared_ptr kvStateCache = kvStateCacheBuilder->_Seal(client); + client.Persist(kvStateCache->id()); + + // 5. put the name of the new cache object to the meta server + client.DropName(llmCacheObjectName); + status = client.PutName(kvStateCache->id(), llmCacheObjectName); + if (!status.ok()) { + throw std::runtime_error("Put cache object name failed."); + } + + // 6. delete old cache object + client.DelData(deleteList, true, true); + + // 7. create a global cache object replica + std::dynamic_pointer_cast(kvStateCache)->Resolve(); + kvStateCacheBuilder = std::make_shared( + client, std::dynamic_pointer_cast(kvStateCache)); + + // 8. release the lock + while (1) { + client.TryReleaseLock(actualKey, result); + if (result == true) { + break; + } + sleep(1); + } + + // TBD + // use lease to prevent the deadlock if the client is down +} + +void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) { + uint64_t last_time = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + while (1) { + std::unique_lock lock(manager->exitMutex); + if (manager->cv.wait_for( + lock, std::chrono::seconds(manager->syncInterval), + [&manager, &last_time] { + uint64_t current_time = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + return manager->exitFlag || + static_cast(current_time - last_time) >= + manager->syncInterval; + })) { + if (manager->exitFlag) { + break; + } + manager->syncMutex.lock(); + manager->Sync(); + manager->syncMutex.unlock(); + last_time = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + } + } + LOG(INFO) << "Sync thread exit."; +} + +} // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache_manager.h b/modules/llm-cache/ds/kv_state_cache_manager.h new file mode 100644 index 00000000..408cac8a --- /dev/null +++ b/modules/llm-cache/ds/kv_state_cache_manager.h @@ -0,0 +1,78 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include + +#include "llm-cache/ds/kv_state_cache.h" + +#ifndef MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ +#define MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ + +namespace vineyard { + +class KVStateCacheManager { + private: + Client client; + std::shared_ptr kvStateCacheBuilder = nullptr; + std::string llmCacheSyncLock = "llmCacheSyncLock"; + std::string llmCacheObjectName = "llm_cache_object"; + std::thread* syncThread; + std::mutex syncMutex; + int syncInterval; + bool exitFlag = false; + std::condition_variable cv; + std::mutex exitMutex; + + public: + KVStateCacheManager( + int dimension = 10, int cacheCapacity = 10, int layer = 1, + int blockSize = 5, int syncInterval = 3, + std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET"))); + + void Update(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState); + + void Update(const std::vector& tokenList, + const LIST_KV_STATE_WITH_LAYER& kvState); + + int Query(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState); + + int Query(const std::vector& tokenList, + LIST_KV_STATE_WITH_LAYER& listKVState); + + ~KVStateCacheManager(); + + private: + void UpdateInternal(const std::vector& tokenList, int nextToken, + const KV_STATE_WITH_LAYER& kvState); + + int QueryInternal(const std::vector& tokenList, int token, + KV_STATE_WITH_LAYER& kvState); + + void Delete(std::vector token); + + static void SyncThreadFunc(KVStateCacheManager* manager); + + void Sync(); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_DS_KV_STATE_CACHE_MANAGER_H_ diff --git a/modules/llm-cache/radix-tree/radix-tree.cc b/modules/llm-cache/radix-tree/radix-tree.cc new file mode 100644 index 00000000..93e37a79 --- /dev/null +++ b/modules/llm-cache/radix-tree/radix-tree.cc @@ -0,0 +1,608 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "llm-cache/radix-tree/radix-tree.h" + +#include "common/util/base64.h" +#include "common/util/logging.h" +#include "common/util/status.h" + +#include "zstd/lib/zstd.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +RadixTree::RadixTree(int cacheCapacity) { + this->tree = raxNew(); + // add one to the cache capacity because we insert a root node to the tree. + this->cacheCapacity = cacheCapacity + 1; + this->nodeCount = 0; + + // prepare root node + std::vector rootToken = {INT32_MAX}; + std::shared_ptr evictedNode; + this->InsertInternal(rootToken, evictedNode); + // raxShow(this->tree); + raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL, false); + DataWrapper* data = new DataWrapper(); + data->data = nullptr; + data->dataLength = 0; + dataNode->custom_data = data; + VLOG(100) << "root data wrapper:" << data; + dataNode->issubtree = true; + this->rootToken = rootToken; +} + +RadixTree::~RadixTree() { + VLOG(100) << "~RadixTree"; + // raxShow(this->tree); + + raxNode* dataNode = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL, false); + if (dataNode != nullptr) { + delete reinterpret_cast(dataNode->custom_data); + delete reinterpret_cast(raxGetData(dataNode)); + } + + raxFree(this->tree); +} + +std::shared_ptr RadixTree::Insert( + std::vector tokens, std::shared_ptr& evictedNode) { + tokens.insert(tokens.begin(), INT32_MAX); + return InsertInternal(tokens, evictedNode); +} + +void RadixTree::Delete(std::vector tokens, + std::shared_ptr& evictedNode) { + tokens.insert(tokens.begin(), INT32_MAX); + DeleteInternal(tokens, evictedNode); +} + +std::shared_ptr RadixTree::Query(std::vector key) { + key.insert(key.begin(), INT32_MAX); + return QueryInternal(key); +} + +std::vector> RadixTree::Split( + std::vector tokens, std::shared_ptr& header) { + tokens.insert(tokens.begin(), INT32_MAX); + return SplitInternal(tokens, header); +} + +std::shared_ptr RadixTree::InsertInternal( + std::vector tokens, std::shared_ptr& evictedNode) { + // get the sub vector of the tokens + std::vector rootToken = + std::vector(tokens.begin(), tokens.end() - 1); + if (rootToken.size() > 0 && QueryInternal(rootToken) == nullptr) { + return nullptr; + } + + // insert the token vector to the radix tree + int* insertTokensArray = tokens.data(); + size_t insertTokensArrayLen = tokens.size(); + DataWrapper* dummyData = new DataWrapper(); + DataWrapper* oldData; + raxNode* dataNode = NULL; + int retval = raxInsertAndReturnDataNode( + this->tree, insertTokensArray, insertTokensArrayLen, dummyData, + reinterpret_cast(&dataNode), reinterpret_cast(&oldData)); + if (dataNode == NULL) { + throw std::runtime_error("Insert token list failed"); + return NULL; + } + if (retval == 1) { + VLOG(100) << "node count++:" << this->nodeCount; + nodeCount++; + } + + // raxShow(this->tree); + if (this->nodeCount > this->cacheCapacity) { + VLOG(100) << "cache capacity is full, evict the last recent node"; + VLOG(100) << "cache capacity:" << this->cacheCapacity + << " node count:" << this->nodeCount; + // evict the last recent node (the node with the largest lru index- + std::vector evictedTokensVector; + raxFindLastRecentNode(this->tree->head, evictedTokensVector); + std::string evicted_str = ""; + for (size_t i = 0; i < evictedTokensVector.size(); i++) { + evicted_str += std::to_string(evictedTokensVector[i]); + } + this->DeleteInternal(evictedTokensVector, evictedNode); + } + + raxNode* subTreeNode = nullptr; + dataNode = raxFindAndReturnDataNode( + this->tree, insertTokensArray, insertTokensArrayLen, &subTreeNode, false); + VLOG(100) << "sub tree node:" << subTreeNode << " data node:" << dataNode; + /** + * if the data node is null, it means the evicted node is the same node as + * the inserted node. + */ + if (dataNode == NULL) { + return NULL; + } + + if (subTreeNode == nullptr) { + return std::make_shared(dummyData, nullptr); + } + return std::make_shared( + dummyData, reinterpret_cast(subTreeNode->custom_data)); +} + +void RadixTree::DeleteInternal(std::vector tokens, + std::shared_ptr& evictedNode) { + // remove the token vector from the radix tree + // TBD + // If the evicted node is the root node of sub tree, recycle the tree. + int* deleteTokensArray = tokens.data(); + size_t deleteTokensArrayLen = tokens.size(); + + DataWrapper* oldData; + raxNode* subTreeNode; + std::vector pre; + raxNode* dataNode = raxFindAndReturnDataNode( + this->tree, deleteTokensArray, deleteTokensArrayLen, &subTreeNode, false); + bool nodeIsSubTree = false; + if (dataNode != nullptr && dataNode->issubtree) { + nodeIsSubTree = true; + } + int retval = raxRemove(this->tree, deleteTokensArray, deleteTokensArrayLen, + reinterpret_cast(&oldData), false); + if (retval == 1) { + evictedNode = std::make_shared( + oldData, reinterpret_cast(subTreeNode->custom_data)); + nodeCount--; + if (nodeIsSubTree) { + evictedNode->cleanTreeData = true; + } + } else { + LOG(ERROR) << "remove failed"; + } +} + +std::shared_ptr RadixTree::QueryInternal(std::vector key) { + VLOG(100) << "Query"; + int* tokens = key.data(); + size_t tokensLen = key.size(); + + if (this->tree == nullptr) { + return NULL; + } + + raxNode* subTreeNode; + raxNode* dataNode = + raxFindAndReturnDataNode(this->tree, tokens, tokensLen, &subTreeNode); + VLOG(100) << "query subtree node:" << subTreeNode; + if (dataNode == NULL) { + return NULL; + } + + return std::make_shared( + reinterpret_cast(raxGetData(dataNode)), + reinterpret_cast(subTreeNode->custom_data)); +} + +std::string RadixTree::Serialize() { + VLOG(100) << "Serialize......"; + // raxShow(this->tree); + std::vector> tokenList; + std::vector dataList; + std::vector timestampList; + std::vector> subTreeTokenList; + std::vector subTreeDataList; + raxSerialize(this->tree, tokenList, dataList, timestampList, + &subTreeTokenList, &subTreeDataList); + + // raxShow(this->tree); + std::string serializedStr; + + if (tokenList.size() != dataList.size()) { + throw std::runtime_error( + "The size of token list and data list is not equal"); + } + for (size_t index = 0; index < tokenList.size(); index++) { + for (size_t j = 0; j < tokenList[index].size(); j++) { + serializedStr += std::to_string(tokenList[index][j]); + if (j < tokenList[index].size() - 1) { + serializedStr += ","; + } + } + serializedStr += "|"; + + // convert timestamp(uint64) to hex string + uint64_t timestamp = timestampList[index]; + std::ostringstream timestampOSS; + timestampOSS << std::hex << timestamp; + + serializedStr += timestampOSS.str() + "|"; + + raxNode* node = + raxFindAndReturnDataNode(this->tree, tokenList[index].data(), + tokenList[index].size(), NULL, false); + uint32_t numNodes = node->numnodes; + std::ostringstream subTreeSizeOSS; + subTreeSizeOSS << std::hex << numNodes; + + serializedStr += subTreeSizeOSS.str() + "|"; + + // convert data to hex string + char* bytes = reinterpret_cast( + (reinterpret_cast(dataList[index]))->data); + std::ostringstream dataOSS; + + for (int i = 0; + i < (reinterpret_cast(dataList[index]))->dataLength; + i++) { + dataOSS << std::hex << std::setw(2) << std::setfill('0') + << static_cast(static_cast(bytes[i])); + } + serializedStr += dataOSS.str() + "\n"; + } + + serializedStr += "\t\n"; + + VLOG(100) << "sub tree token list size:" << subTreeTokenList.size(); + for (size_t index = 0; index < subTreeTokenList.size(); index++) { + for (size_t j = 0; j < subTreeTokenList[index].size(); j++) { + serializedStr += std::to_string(subTreeTokenList[index][j]); + if (j < subTreeTokenList[index].size() - 1) { + serializedStr += ","; + } + } + serializedStr += "|"; + + // convert custom data to hex string + char* bytes = reinterpret_cast( + (reinterpret_cast(subTreeDataList[index]))->data); + std::ostringstream dataOSS; + + VLOG(100) + << "data length:" + << (reinterpret_cast(subTreeDataList[index]))->dataLength; + for (int i = 0; + i < + (reinterpret_cast(subTreeDataList[index]))->dataLength; + ++i) { + dataOSS << std::hex << std::setw(2) << std::setfill('0') + << static_cast(static_cast(bytes[i])); + } + VLOG(100) << "data:" + << (reinterpret_cast(subTreeDataList[index]))->data; + VLOG(100) << "data oss:" << dataOSS.str(); + serializedStr += dataOSS.str() + "\n"; + } + VLOG(100) << "serializedStr:" << serializedStr; + + // use ZSTD to compress the serialized string + size_t srcSize = serializedStr.size(); + size_t dstSize = ZSTD_compressBound(srcSize); + std::string compressedStr(dstSize + 1, '\0'); + VLOG(100) << "src size:" << srcSize << " dst size:" << dstSize; + int compressedSize = ZSTD_compress(compressedStr.data(), compressedStr.size(), + serializedStr.c_str(), srcSize, 3); + if (ZSTD_isError(compressedSize)) { + LOG(ERROR) << "ZSTD compression failed: " + << ZSTD_getErrorName(compressedSize); + } + int cacheCapacity = this->cacheCapacity - 1; + + std::string result = + std::string(reinterpret_cast(&compressedSize), sizeof(int)) + + std::string(reinterpret_cast(&cacheCapacity), sizeof(int)) + + std::string(reinterpret_cast(&(this->tree->head->numnodes)), + sizeof(uint32_t)) + + compressedStr; + + return result; +} + +std::shared_ptr RadixTree::Deserialize(std::string data) { + VLOG(100) << "Deserialize......"; + // use LZ4 to decompress the serialized string + int compressedSize = *reinterpret_cast(data.data()); + data.erase(0, sizeof(int)); + int cacheCapacity = *reinterpret_cast(data.data()); + data.erase(0, sizeof(int)); + int rootNumNodes = *reinterpret_cast(data.data()); + data.erase(0, sizeof(uint32_t)); + uint64_t ds = ZSTD_getFrameContentSize(data.c_str(), data.size()); + if (ds == ZSTD_CONTENTSIZE_ERROR) { + LOG(ERROR) << "Error: not a valid compressed frame"; + } else if (ds == ZSTD_CONTENTSIZE_UNKNOWN) { + LOG(ERROR) + << "Error: original size unknown. Use streaming decompression instead."; + } + + std::string decompressedStr(ds + 1, '\0'); + int decompressedSize = + ZSTD_decompress(decompressedStr.data(), ds, data.c_str(), compressedSize); + if (ZSTD_isError(decompressedSize)) { + LOG(ERROR) << "ZSTD decompression failed: " + << ZSTD_getErrorName(decompressedSize); + } + data = decompressedStr.substr(0, decompressedSize); + + std::vector> tokenList; + std::vector dataList; + std::vector dataSizeList; + std::vector timestampList; + std::vector> subTreeTokenList; + std::vector subTreeDataList; + std::vector subTreeDataSizeList; + std::vector subTreeSizeList; + std::istringstream iss(data); + std::string line; + bool isMainTree = true; + + while (std::getline(iss, line)) { + if (!line.empty() && line[0] == '\t') { + isMainTree = false; + line.pop_back(); + continue; + } + VLOG(100) << "data line:" << line << std::endl; + std::istringstream lineStream(line); + std::string tokenListPart, timestampPart, dataPart, subTreeSizePart; + + if (!std::getline(lineStream, tokenListPart, '|')) { + LOG(ERROR) << "Invalid serialized string format in token list part."; + } + if (isMainTree) { + if (!std::getline(lineStream, timestampPart, '|')) { + LOG(ERROR) << "Invalid serialized string format in timestamp part."; + } + if (!std::getline(lineStream, subTreeSizePart, '|')) { + LOG(ERROR) << "Invalid serialized string format in sub tree size part."; + } + } + if (!std::getline(lineStream, dataPart)) { + VLOG(100) << "data length is 0"; + } + + std::istringstream keyStream(tokenListPart); + std::string token; + std::vector keys; + while (std::getline(keyStream, token, ',')) { + keys.push_back(std::stoi(token)); + } + + uint64_t timestamp; + if (isMainTree) { + std::istringstream timestampStream(timestampPart); + if (!(timestampStream >> std::hex >> timestamp)) { + LOG(ERROR) << "Invalid timestamp format."; + } + + std::istringstream subTreeSizeStream(subTreeSizePart); + uint32_t subTreeSize; + if (!(subTreeSizeStream >> std::hex >> subTreeSize)) { + LOG(ERROR) << "Invalid sub tree size format."; + } + VLOG(100) << "Deserialize sub tree size:" << subTreeSize; + subTreeSizeList.push_back(subTreeSize); + } + + size_t dataSize = dataPart.length() / + 2; // Each byte is represented by two hex characters + if (isMainTree) { + dataSizeList.push_back(dataSize); + } else { + subTreeDataSizeList.push_back(dataSize); + } + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + char* data = nullptr; + VLOG(100) << "data size:" << dataSize; + if (dataSize != 0) { + data = new char[dataSize]; + std::istringstream dataStream(dataPart); + for (size_t i = 0; i < dataSize; ++i) { + // Temporary buffer to store two hexadecimal chars + null + char hex[3] = {}; + // Read two characters for one byte + if (!dataStream.read(hex, 2)) { + delete[] data; + LOG(ERROR) << "Invalid data format."; + } + // Convert the two hex characters to one byte + unsigned int byte; + std::istringstream hexStream(hex); + if (!(hexStream >> std::hex >> byte)) { + delete[] data; + LOG(ERROR) << "Invalid data format."; + } + reinterpret_cast(data)[i] = + static_cast(byte); + } + } + if (isMainTree) { + tokenList.push_back(keys); + timestampList.push_back(timestamp); + dataList.push_back(data); + } else { + subTreeTokenList.push_back(keys); + subTreeDataList.push_back(data); + } + } + + // This pointer will be freed by upper layer. Because this data + // is created by upper layer. Here just recover it from serialized + // string. + std::shared_ptr radixTree = + std::make_shared(cacheCapacity); + radixTree->nodeCount = tokenList.size(); + + // raxShow(radixTree->tree); + for (size_t i = 0; i < tokenList.size(); i++) { + std::string token_str = ""; + for (size_t j = 0; j < tokenList[i].size(); j++) { + token_str += std::to_string(tokenList[i][j]); + } + VLOG(100) << "token:" << token_str; + int* insertTokensArray = tokenList[i].data(); + size_t insertTokensArrayLen = tokenList[i].size(); + DataWrapper* data = new DataWrapper(); + data->data = dataList[i]; + data->dataLength = dataSizeList[i]; + raxNode* dataNode = NULL; + + // TBD + // check retval + raxInsertAndReturnDataNode(radixTree->tree, insertTokensArray, + insertTokensArrayLen, data, + reinterpret_cast(&dataNode), NULL); + + if (dataNode == NULL) { + LOG(ERROR) << "Insert token list failed"; + } + dataNode->timestamp = timestampList[i]; + } + + for (size_t i = 0; i < tokenList.size(); i++) { + raxNode* node = raxFindAndReturnDataNode( + radixTree->tree, tokenList[i].data(), tokenList[i].size(), NULL, false); + VLOG(100) << "node:" << node << " sub tree node num:" << subTreeSizeList[i]; + node->numnodes = subTreeSizeList[i]; + } + radixTree->tree->head->numnodes = rootNumNodes; + // raxShow(radixTree->tree); + + VLOG(100) << "start to insert sub tree token list" << std::endl; + for (size_t i = 0; i < subTreeTokenList.size(); i++) { + for (size_t j = 0; j < subTreeTokenList[i].size(); j++) { + VLOG(100) << subTreeTokenList[i][j]; + } + + raxNode* node = nullptr; + VLOG(100) << "stage 1"; + VINEYARD_ASSERT(radixTree->tree != nullptr); + + // TBD refator this code. + node = raxFindAndReturnDataNode(radixTree->tree, subTreeTokenList[i].data(), + subTreeTokenList[i].size(), NULL, false); + VINEYARD_ASSERT(node != nullptr); + VLOG(100) << "stage 2"; + DataWrapper* data = new DataWrapper(); + data->data = subTreeDataList[i]; + VLOG(100) << subTreeDataList[i]; + data->dataLength = subTreeDataSizeList[i]; + + VLOG(100) << "stage 3"; + node->issubtree = true; + raxSetCustomData(node, data); + + radixTree->SetSubtreeData(subTreeDataList[i]); + } + VLOG(100) << "Deserialize success"; + // raxShow(radixTree->tree); + return radixTree; +} + +std::vector> RadixTree::SplitInternal( + std::vector tokens, std::shared_ptr& header) { + std::vector rootToken; + raxNode* subTreeRootNode = + raxSplit(this->tree, tokens.data(), tokens.size(), rootToken); + + // raxShow(this->tree); + subTreeRootNode->issubtree = true; + DataWrapper* treeData = new DataWrapper(); + treeData->data = nullptr; + treeData->dataLength = 0; + subTreeRootNode->custom_data = treeData; + header = std::make_shared( + reinterpret_cast(raxGetData(subTreeRootNode)), treeData); + return TraverseTreeWithoutSubTree(subTreeRootNode); +} + +// Get child node list from this tree. +std::vector> RadixTree::TraverseTreeWithoutSubTree( + raxNode* headNode) { + std::vector> nodes; + if (headNode == NULL) { + VLOG(100) << "traverse failed"; + return nodes; + } + + std::vector dataNodeList; + std::vector pre_tmp; + raxTraverseSubTree(headNode, dataNodeList); + VLOG(100) << "data node list:" << dataNodeList.size(); + for (size_t i = 0; i < dataNodeList.size(); i++) { + nodes.push_back(std::make_shared( + reinterpret_cast(raxGetData(dataNodeList[i])), + reinterpret_cast(dataNodeList[i]->custom_data))); + } + return nodes; +} + +void RadixTree::SetSubtreeData(void* data) { + VLOG(100) << "set subtree data:" << data; + subTreeDataSet.insert(data); +} + +void RadixTree::ClearSubtreeData(void* data) { + VLOG(100) << "clear subtree data:" << data; + subTreeDataSet.erase(data); +} + +std::shared_ptr RadixTree::GetRootNode() { + raxNode* node = raxFindAndReturnDataNode(this->tree, rootToken.data(), + rootToken.size(), NULL); + return std::make_shared( + reinterpret_cast(raxGetData(node)), + reinterpret_cast(node->custom_data)); +} + +void RadixTree::MergeTree(std::shared_ptr tree_1, + std::shared_ptr tree_2, + std::vector>& evicted_tokens, + std::set>& insert_tokens) { + std::set> insert_tokens_set; + std::vector> evicted_tokens_vec; + mergeTree(tree_1->tree, tree_2->tree, evicted_tokens_vec, insert_tokens_set, + tree_1->cacheCapacity); + for (size_t i = 0; i < evicted_tokens_vec.size(); i++) { + VINEYARD_ASSERT(evicted_tokens_vec[i][0] == INT32_MAX); + std::vector tmp(evicted_tokens_vec[i].begin() + 1, + evicted_tokens_vec[i].end()); + evicted_tokens.push_back(tmp); + } + + for (auto& vec : insert_tokens_set) { + VINEYARD_ASSERT(vec[0] == INT32_MAX); + std::vector tmp(vec.begin() + 1, vec.end()); + insert_tokens.insert(tmp); + } +} + +std::set RadixTree::GetAllNodeData() { + raxIterator iter; + raxStart(&iter, this->tree); + raxSeek(&iter, "^", NULL, 0); + std::set nodeDataSet; + while (raxNext(&iter)) { + raxNode* node = iter.node; + if (node->isnull) { + continue; + } + nodeDataSet.insert( + (reinterpret_cast(raxGetData(node)))->data); + } + return nodeDataSet; +} diff --git a/modules/llm-cache/radix-tree/radix-tree.h b/modules/llm-cache/radix-tree/radix-tree.h new file mode 100644 index 00000000..211b48f9 --- /dev/null +++ b/modules/llm-cache/radix-tree/radix-tree.h @@ -0,0 +1,121 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ +#define MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ + +#include "rax/radix.h" + +#include +#include +#include +#include +#include +#include + +#include "common/util/base64.h" +#include "common/util/logging.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +struct DataWrapper { + void* data; + int dataLength; +}; + +struct NodeData { + DataWrapper* nodeData; + DataWrapper* treeData; + bool cleanTreeData = false; + + NodeData(DataWrapper* nodeData, DataWrapper* treeData) { + this->nodeData = nodeData; + this->treeData = treeData; + } + + void RecycleSource() { + if (this->nodeData != nullptr) { + delete this->nodeData; + } + if (cleanTreeData && this->treeData != nullptr) { + delete this->treeData; + } + } +}; + +class RadixTree : public std::enable_shared_from_this { + public: + rax* tree; + int cacheCapacity; + int nodeCount; + std::set subTreeDataSet; + std::vector rootToken; + + private: + std::shared_ptr InsertInternal( + std::vector tokens, std::shared_ptr& evictedNode); + + void DeleteInternal(std::vector tokens, + std::shared_ptr& evictedNode); + + std::shared_ptr QueryInternal(std::vector key); + + std::vector> SplitInternal( + std::vector tokens, std::shared_ptr& header); + + public: + RadixTree(int cacheCapacity); // NOLINT(runtime/explicit) + + ~RadixTree(); + + std::shared_ptr Insert(std::vector tokens, + std::shared_ptr& evictedNode); + + void Delete(std::vector tokens, std::shared_ptr& evictedNode); + + std::shared_ptr Query(std::vector key); + + std::vector> Split( + std::vector tokens, std::shared_ptr& header); + + std::string Serialize(); + + static std::shared_ptr Deserialize(std::string data); + + // Get child node list from this tree. + static std::vector> TraverseTreeWithoutSubTree( + raxNode* headNode); + + void SetSubtreeData(void* data); + + void ClearSubtreeData(void* data); + + rax* GetRootTree() { return this->tree; } + + int GetCacheCapacity() { return cacheCapacity - 1; } + + std::set GetSubTreeDataSet() { return subTreeDataSet; } + + std::shared_ptr GetRootNode(); + + static void MergeTree(std::shared_ptr tree_1, + std::shared_ptr tree_2, + std::vector>& evicted_tokens, + std::set>& insert_tokens); + + std::set GetAllNodeData(); +}; + +#endif // MODULES_LLM_CACHE_RADIX_TREE_RADIX_TREE_H_ diff --git a/modules/llm-cache/tests/CMakeLists.txt b/modules/llm-cache/tests/CMakeLists.txt new file mode 100644 index 00000000..8896f24c --- /dev/null +++ b/modules/llm-cache/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +enable_testing() + +macro(add_test_case testname testfile) + add_executable(${testname} ${testfile}) + + target_include_directories(${testname} PRIVATE ${GLOG_INCLUDE_DIRS}) + target_link_libraries(${testname} PRIVATE vineyard_llm_cache) + + add_test(${testname} ${testname}) + add_dependencies(vineyard_tests ${testname}) +endmacro() + +file(GLOB LLM_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "./*.cc") + +foreach(testfile ${LLM_TEST_FILES}) + string(REGEX MATCH "^(.*)\\.[^.]*$" dummy ${testfile}) + set(testname ${CMAKE_MATCH_1}) + + if(${testname} STREQUAL "gpumalloc_test" AND NOT USE_CUDA) + continue() + endif() + + message(STATUS "Found unit_test - " ${testname}) + add_test_case(${testname} ${testfile}) +endforeach() + diff --git a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc new file mode 100644 index 00000000..feb1166b --- /dev/null +++ b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc @@ -0,0 +1,156 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include + +#include "client/client.h" +#include "client/ds/object_meta.h" +#include "common/util/logging.h" + +#include "llm-cache/ds/kv_state_cache_manager.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +#define DIMENSION 100 +#define CAPACITY 1000 +#define LAYER 64 +#define BLOCK_SIZE 100 + +KVStateCacheManager* manager; + +void init() { + manager = + new KVStateCacheManager(DIMENSION, CAPACITY, LAYER, DEFAULT_BLOCK_SIZE); +} + +std::vector generate_random_tokens(size_t max_length) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist(1, 10000); + + size_t length = dist(gen) % max_length + 1; + std::vector tokens(length); + for (size_t i = 0; i < length; ++i) { + tokens[i] = dist(gen); + } + return tokens; +} + +std::map> generate_kv_state(int token) { + std::map> kv_state; + for (int currentLayer = 0; currentLayer < LAYER; currentLayer++) { + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(DIMENSION * sizeof(double)); + key_state.length = DIMENSION * sizeof(double); + value_state.data = malloc(DIMENSION * sizeof(double)); + value_state.length = DIMENSION * sizeof(double); + + kv_state.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } + return kv_state; +} + +// test the performance of Query and Update function +void benchmark_inference(std::vector>& tokens) { + LOG(INFO) << "inference for benchmark"; + std::map> kv_state; + + std::chrono::steady_clock::time_point start, end; + double token_list_size = 0; + std::chrono::duration update_duration(0); + std::chrono::duration query_duration(0); + double total_update_duration = 0; + double total_query_duration = 0; + + for (size_t i = 0; i < tokens.size(); ++i) { + std::vector inference_tokens; + for (size_t j = 0; j < tokens[i].size(); ++j) { + start = std::chrono::steady_clock::now(); + kv_state = generate_kv_state(tokens[i][j]); + manager->Query(inference_tokens, tokens[i][j], kv_state); + end = std::chrono::steady_clock::now(); + query_duration += end - start; + + if (kv_state.size() == 0) { + start = std::chrono::steady_clock::now(); + manager->Update(inference_tokens, tokens[i][j], kv_state); + end = std::chrono::steady_clock::now(); + update_duration += end - start; + } + inference_tokens.push_back(tokens[i][j]); + token_list_size++; + } + total_update_duration += update_duration.count(); + total_query_duration += query_duration.count(); + } + + LOG(INFO) << "Token list size is " << token_list_size + << "Total Update time is " << total_update_duration << "s " + << "Total Query time is " << total_query_duration << "s " + << "Average update time is " + << token_list_size / total_update_duration << "token/s " + << "Average query time is " + << token_list_size / total_query_duration << "token/s "; +} + +int main(int argc, char** argv) { + if (argc < 2) { + printf("usage ./kv_state_cache_benchmark "); + return 1; + } + std::string ipc_socket = std::string(argv[1]); + + init(); + + std::atomic inference_done(false); + + std::thread memory_monitor([&]() { + Client client; + size_t max_memory_usage = 0; + VINEYARD_CHECK_OK(client.Connect(ipc_socket)); + while (!inference_done) { + sleep(1); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + LOG(INFO) << "status->memory_usage is:" << status->memory_usage; + if (status->memory_usage > max_memory_usage) { + max_memory_usage = status->memory_usage; + } + } + LOG(INFO) << "Max memory usage is " << max_memory_usage; + }); + + std::thread inference([&]() { + const size_t num_lists = 10; + std::vector> all_token_lists; + for (size_t i = 0; i < num_lists; ++i) { + all_token_lists.push_back(generate_random_tokens(2000)); + } + + benchmark_inference(all_token_lists); + inference_done = true; + }); + + memory_monitor.join(); + inference.join(); + delete manager; + return 0; +} diff --git a/modules/llm-cache/tests/kv_state_cache_multi_test.cc b/modules/llm-cache/tests/kv_state_cache_multi_test.cc new file mode 100644 index 00000000..52d4ae42 --- /dev/null +++ b/modules/llm-cache/tests/kv_state_cache_multi_test.cc @@ -0,0 +1,94 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include + +#include "common/util/logging.h" + +char process_name[] = "kv_state_cache_test"; +char arg_0[] = "-s"; +char token_sequence_1[] = "1"; +char token_sequence_2[] = "2"; +char token_sequence_3[] = "3"; +char token_sequence_4[] = "4"; + +const char* program = "./build/bin/kv_state_cache_test"; + +pid_t create_subprocess(char* argv[]) { + pid_t pid = fork(); + if (pid == -1) { + std::cerr << "Failed to fork()" << std::endl; + exit(1); + } else if (pid > 0) { + return pid; + } else { + execv(program, argv); + std::cerr << "Failed to exec()" << std::endl; + exit(1); + } +} + +int main(int argc, char** argv) { + std::string sockets[2]; + std::string rpc_endpoint; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--vineyard-endpoint") == 0) { + rpc_endpoint = std::string(argv[i + 1]); + } else if (strcmp(argv[i], "--vineyard-ipc-sockets") == 0) { + sockets[0] = std::string(argv[i + 1]); + sockets[1] = std::string(argv[i + 2]); + } + } + + char* socket_str[2]; + socket_str[0] = + (char*) malloc(sockets[0].length() + 1); // NOLINT(readability/casting) + socket_str[1] = + (char*) malloc(sockets[1].length() + 1); // NOLINT(readability/casting) + memset(socket_str[0], 0, sockets[0].length() + 1); + memset(socket_str[1], 0, sockets[1].length() + 1); + strcpy(socket_str[0], sockets[0].c_str()); // NOLINT(runtime/printf) + strcpy(socket_str[1], sockets[1].c_str()); // NOLINT(runtime/printf) + + char* args_1[] = {process_name, socket_str[0], arg_0, + token_sequence_1, token_sequence_2, token_sequence_3, + token_sequence_4, nullptr}; + char* args_2[] = {process_name, socket_str[1], arg_0, + token_sequence_1, token_sequence_2, token_sequence_3, + nullptr}; + + std::vector pids; + pids.push_back(create_subprocess(args_1)); + pids.push_back(create_subprocess(args_2)); + for (size_t i = 0; i < pids.size(); i++) { + int status; + waitpid(pids[i], &status, 0); + if (WIFEXITED(status) && WEXITSTATUS(status) != 0) { + free(socket_str[0]); + free(socket_str[1]); + return 1; + } + } + + free(socket_str[0]); + free(socket_str[1]); + + return 0; +} diff --git a/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc new file mode 100644 index 00000000..d3f4ae78 --- /dev/null +++ b/modules/llm-cache/tests/kv_state_cache_radix_tree_test.cc @@ -0,0 +1,191 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include "rax/radix.h" + +#include "common/util/logging.h" +#include "llm-cache/ds/kv_state_cache_manager.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +void print_tokens(const std::vector& tokens) { + std::string tokens_str = ""; + for (size_t i = 0; i < tokens.size(); ++i) { + tokens_str += std::to_string(tokens[i]); + } + LOG(INFO) << "Current tokens: " + tokens_str; +} + +void radix_tree_insert_test() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* insert new token and check whether the old token is evicted */ + tokens.clear(); + for (int i = 1; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data != NULL); + } + + /* insert a token that prefix is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) == NULL); +} + +void radix_tree_delete_test() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* delete a token list*/ + + tokens.clear(); + node_data = NULL; + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data != NULL); + + /* delete a token list that is not in the radix tree */ + tokens.clear(); + node_data = NULL; + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + radix_tree->Delete(tokens, node_data); + VINEYARD_ASSERT(node_data == NULL); +} + +void radix_tree_query_test() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* query a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) != NULL); + + /* query a token list that is not in the radix tree */ + tokens.clear(); + for (int i = 10; i > 0; i--) { + tokens.push_back(i); + } + VINEYARD_ASSERT(radix_tree->Query(tokens) == NULL); +} + +void radix_tree_serialize_and_deserialize() { + std::shared_ptr radix_tree = std::make_shared(10); + + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* serialize radix tree */ + std::string serialized_radix_tree = radix_tree->Serialize(); + + /* deserialize radix tree */ + std::shared_ptr deserialized_radix_tree = + radix_tree->Deserialize(serialized_radix_tree); + + /* query to check whether all token list exist */ + tokens.clear(); + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + print_tokens(tokens); + VINEYARD_ASSERT(deserialized_radix_tree->Query(tokens) != NULL); + } +} + +void radix_tree_split() { + std::shared_ptr radix_tree = std::make_shared(20); + + raxShow(radix_tree->tree); + /* insert a token list*/ + std::vector tokens; + std::shared_ptr node_data; + for (int i = 0; i < 10; i++) { + tokens.push_back(i); + VINEYARD_ASSERT(radix_tree->Insert(tokens, node_data) != NULL); + VINEYARD_ASSERT(node_data == NULL); + } + + /* split a token list*/ + tokens.clear(); + for (int i = 0; i < 5; i++) { + tokens.push_back(i); + } + print_tokens(tokens); + std::shared_ptr subTreeHeader; + std::vector> node_data_list = + radix_tree->Split(tokens, subTreeHeader); + VINEYARD_ASSERT(node_data_list.size() == 7); +} + +int main() { + LOG(INFO) << "Start to test radix tree insert..."; + radix_tree_insert_test(); + LOG(INFO) << "Finish radix tree insert test!"; + LOG(INFO) << "Start to test radix tree delete..."; + radix_tree_delete_test(); + LOG(INFO) << "Finish radix tree delete test!"; + LOG(INFO) << "Start to test radix tree query..."; + radix_tree_query_test(); + LOG(INFO) << "Finish radix tree query test!"; + LOG(INFO) << "Start to test radix tree serialize and deserialize..."; + radix_tree_serialize_and_deserialize(); + LOG(INFO) << "Finish radix tree serialize and deserialize test!"; + LOG(INFO) << "Start to test radix tree split..."; + radix_tree_split(); + LOG(INFO) << "Finish radix tree split test!"; +} diff --git a/modules/llm-cache/tests/kv_state_cache_test.cc b/modules/llm-cache/tests/kv_state_cache_test.cc new file mode 100644 index 00000000..e2d1e98a --- /dev/null +++ b/modules/llm-cache/tests/kv_state_cache_test.cc @@ -0,0 +1,232 @@ +/** Copyright 2020-2023 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include "rax/radix.h" + +#include "common/util/logging.h" +#include "llm-cache/ds/kv_state_cache_manager.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +int dimension = 10; +int capacity = 20; +int layer = 3; +int block_size = 5; + +std::vector round_1_tokens = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69}; +std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14}; +std::vector round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14}; +std::vector round_4_tokens = {1, 2, 3, 4, 5, 6}; + +std::vector> tokens_list; + +KVStateCacheManager* kv_state_cache_manager; + +void init(int dimension, int capacity, int layer, int block_size, + std::string socket) { + kv_state_cache_manager = new KVStateCacheManager(dimension, capacity, layer, + block_size, 3, socket); +} + +void print_current_tokens(const std::vector& prefix, int next_token) { + std::string tokens_str = ""; + for (size_t i = 0; i < prefix.size(); ++i) { + tokens_str += std::to_string(prefix[i]) + " "; + } + tokens_str += std::to_string(next_token); + LOG(INFO) << "Current tokens: " + tokens_str; +} + +void print_kv_state( + const std::map>& kv_state) { + LOG(INFO) << "kv_state: "; + for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { + std::string key_state_str = ""; + std::string value_state_str = ""; + for (int i = 0; i < dimension; ++i) { + key_state_str += + std::to_string( + (reinterpret_cast(iter->second.first.data))[i]) + + " "; + value_state_str += + std::to_string( + (reinterpret_cast(iter->second.second.data))[i]) + + " "; + } + LOG(INFO) << "layer " << iter->first << ":"; + LOG(INFO) << "key_state: " << key_state_str; + LOG(INFO) << "value_state: " << value_state_str; + LOG(INFO) << "---------------------"; + } +} + +// we do not consider the layer. +std::map> generate_kv_state() { + std::map> kv_state; + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { + K_STATE key_state; + V_STATE value_state; + key_state.data = malloc(dimension * sizeof(double)); + value_state.data = malloc(dimension * sizeof(double)); + + key_state.length = dimension * sizeof(double); + value_state.length = dimension * sizeof(double); + + kv_state.insert( + std::make_pair(currentLayer, std::make_pair(key_state, value_state))); + } + return kv_state; +} + +void update_kv_state(std::map>& kvState, + int token) { + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { + K_STATE key_state = kvState[currentLayer].first; + V_STATE value_state = kvState[currentLayer].second; + for (int i = 0; i < dimension; ++i) { + (reinterpret_cast(key_state.data))[i] = + (static_cast(token)) / dimension * (i + 1) + + currentLayer * 10; + (reinterpret_cast(value_state.data))[i] = + (static_cast(token)) / dimension * (i + 1) * 2 + + currentLayer * 10; + } + } +} + +void check_kv_state(const std::map>& kv_state, + int& token) { + VINEYARD_ASSERT(kv_state.size() == (size_t) layer); + for (auto iter = kv_state.begin(); iter != kv_state.end(); ++iter) { + VINEYARD_ASSERT(iter->second.first.length == + (size_t) dimension * sizeof(double)); + VINEYARD_ASSERT(iter->second.second.length == + (size_t) dimension * sizeof(double)); + for (int i = 0; i < dimension; ++i) { + if ((reinterpret_cast(iter->second.first.data))[i] != + (static_cast(token)) / dimension * (i + 1) + + iter->first * 10) { + LOG(INFO) << "token:" << token << " dimension" << dimension + << " layer:" << iter->first; + LOG(INFO) << "key_state[" << i << "]: " + << (reinterpret_cast(iter->second.first.data))[i] + << ". But is should be " + << (static_cast(token)) / dimension * (i + 1) + + iter->first * 10; + throw std::runtime_error("key_state error!"); + } + if ((reinterpret_cast(iter->second.second.data))[i] != + (static_cast(token)) / dimension * (i + 1) * 2 + + iter->first * 10) { + LOG(INFO) << "token:" << token << " dimension" << dimension + << " layer:" << iter->first; + LOG(INFO) << "value_state[" << i << "]: " + << (reinterpret_cast(iter->second.second.data))[i] + << ". But is should be " + << (static_cast(token)) / dimension * (i + 1) * 2 + + iter->first * 10; + throw std::runtime_error("value_state error!"); + } + } + } +} + +void inference(std::vector tokens, bool block = false) { + std::vector inference_tokens; + std::map> kv_state; + kv_state = generate_kv_state(); + for (size_t i = 0; i < tokens.size(); ++i) { + int result = + kv_state_cache_manager->Query(inference_tokens, tokens[i], kv_state); + if (result != 0) { + LOG(INFO) << "Can not find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + LOG(INFO) << "Generate the kv_state and update the cache."; + update_kv_state(kv_state, tokens[i]); + print_kv_state(kv_state); + kv_state_cache_manager->Update(inference_tokens, tokens[i], kv_state); + } else { + LOG(INFO) << "Find the kv_state from cache:"; + print_current_tokens(inference_tokens, tokens[i]); + check_kv_state(kv_state, tokens[i]); + } + LOG(INFO) << "--------------------------------------"; + inference_tokens.push_back(tokens[i]); + } +} + +int main(int argc, char** argv) { + if (argc < 2) { + printf("usage ./kv_state_cache_test "); + return 1; + } + std::string ipc_socket = std::string(argv[1]); + + for (int i = 2; i < argc; i++) { + if (strcmp(argv[i], "-d") == 0) { + dimension = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-c") == 0) { + capacity = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-l") == 0) { + layer = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-b") == 0) { + block_size = atoi(argv[i + 1]); + } + if (strcmp(argv[i], "-s") == 0) { + for (int j = i + 1; j < argc; j++) { + if (strcmp(argv[j], "1") == 0) { + tokens_list.push_back(round_1_tokens); + } else if (strcmp(argv[j], "2") == 0) { + tokens_list.push_back(round_2_tokens); + } else if (strcmp(argv[j], "3") == 0) { + tokens_list.push_back(round_3_tokens); + } else if (strcmp(argv[j], "4") == 0) { + tokens_list.push_back(round_4_tokens); + } else { + break; + } + } + } + } + + LOG(INFO) << "Test KVStateCache with dimension: " << dimension + << ", capacity: " << capacity << ", layer: " << layer + << ", block_size: " << block_size << "."; + + init(dimension, capacity, layer, block_size, ipc_socket); + + for (size_t i = 0; i < tokens_list.size(); i++) { + inference(tokens_list[i]); + } + + sleep(5); + + for (size_t i = 0; i < tokens_list.size(); i++) { + inference(tokens_list[i]); + } + + LOG(INFO) << "inference end"; + delete kv_state_cache_manager; + LOG(INFO) << "Passed KVStateCache tests..."; + return 0; +} diff --git a/src/client/client.cc b/src/client/client.cc index 67dcc258..b949e8f7 100644 --- a/src/client/client.cc +++ b/src/client/client.cc @@ -1156,6 +1156,33 @@ Status Client::IsSpilled(ObjectID const& id, bool& is_spilled) { return Status::OK(); } +Status Client::TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + ENSURE_CONNECTED(this); + + std::string message_out; + WriteTryAcquireLockRequest(key, message_out); + VINEYARD_CHECK_OK(doWrite(message_out)); + + json message_in; + VINEYARD_CHECK_OK(doRead(message_in)); + VINEYARD_CHECK_OK(ReadTryAcquireLockReply(message_in, result, actural_key)); + return Status::OK(); +} + +Status Client::TryReleaseLock(std::string key, bool& result) { + ENSURE_CONNECTED(this); + + std::string message_out; + WriteTryReleaseLockRequest(key, message_out); + VINEYARD_CHECK_OK(doWrite(message_out)); + + json message_in; + VINEYARD_CHECK_OK(doRead(message_in)); + VINEYARD_CHECK_OK(ReadTryReleaseLockReply(message_in, result)); + return Status::OK(); +} + PlasmaClient::~PlasmaClient() {} // dummy implementation diff --git a/src/client/client.h b/src/client/client.h index e6f13b72..1b9d8eba 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -831,6 +831,24 @@ class Client final : public BasicIPCClient, */ Status GetGPUBuffer(const ObjectID id, const bool unsafe, std::shared_ptr& buffer); + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key); + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + Status TryReleaseLock(std::string key, bool& result); protected: /** @@ -1023,6 +1041,15 @@ class PlasmaClient final */ Status Delete(PlasmaID const& id); + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + return Status::NotImplemented(); + } + + Status TryReleaseLock(std::string key, bool& result) { + return Status::NotImplemented(); + } + protected: /** * @brief Required by `UsageTracker`. When reference count reaches zero, send diff --git a/src/client/client_base.h b/src/client/client_base.h index a3c28eda..8e1748e3 100644 --- a/src/client/client_base.h +++ b/src/client/client_base.h @@ -602,6 +602,25 @@ class ClientBase { */ const std::string& Version() const { return server_version_; } + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + virtual Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) = 0; + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + virtual Status TryReleaseLock(std::string key, bool& result) = 0; + /** * @brief Issue a debug request. * diff --git a/src/client/rpc_client.h b/src/client/rpc_client.h index f9cc00eb..88b90ea3 100644 --- a/src/client/rpc_client.h +++ b/src/client/rpc_client.h @@ -359,6 +359,30 @@ class RPCClient final : public ClientBase { Status GetRemoteBlobs( std::set const& ids, const bool unsafe, std::map>& remote_blobs); + /** + * @brief Try to acquire a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the lock process succeeds. + */ + Status TryAcquireLock(std::string key, bool& result, + std::string& actural_key) { + // TBD + return Status::NotImplemented("TryAcquireLock is not implemented yet."); + } + + /** + * @brief Try to release a distributed lock. + * + * @param key The key of the lock. + * + * @return Status that indicates whether the unlock process succeeds. + */ + Status TryReleaseLock(std::string key, bool& result) { + // TBD + return Status::NotImplemented("TryAcquireLock is not implemented yet."); + } private: InstanceID remote_instance_id_; diff --git a/src/common/util/protocols.cc b/src/common/util/protocols.cc index 0d16c68a..c59b726d 100644 --- a/src/common/util/protocols.cc +++ b/src/common/util/protocols.cc @@ -206,6 +206,12 @@ const std::string command_t::SHALLOW_COPY_REPLY = "shallow_copy_reply"; const std::string command_t::DEBUG_REQUEST = "debug_command"; const std::string command_t::DEBUG_REPLY = "debug_reply"; +// distributed lock +const std::string command_t::ACQUIRE_LOCK_REQUEST = "acquire_lock_request"; +const std::string command_t::ACQUIRE_LOCK_REPLY = "acquire_lock_reply"; +const std::string command_t::RELEASE_LOCK_REQUEST = "release_lock_request"; +const std::string command_t::RELEASE_LOCK_REPLY = "release_lock_reply"; + void WriteErrorReply(Status const& status, std::string& msg) { encode_msg(status.ToJSON(), msg); } @@ -2258,4 +2264,60 @@ Status ReadDebugReply(const json& root, json& result) { return Status::OK(); } +void WriteTryAcquireLockRequest(const std::string& key, std::string& msg) { + json root; + root["type"] = command_t::ACQUIRE_LOCK_REQUEST; + root["key"] = key; + encode_msg(root, msg); +} + +Status ReadTryAcquireLockRequest(const json& root, std::string& key) { + CHECK_IPC_ERROR(root, command_t::ACQUIRE_LOCK_REQUEST); + key = root["key"].get(); + return Status::OK(); +} + +void WriteTryAcquireLockReply(const bool result, const std::string actual_key, + std::string& msg) { + json root; + root["type"] = command_t::ACQUIRE_LOCK_REPLY; + root["key"] = actual_key; + root["result"] = result; + encode_msg(root, msg); +} + +Status ReadTryAcquireLockReply(const json& root, bool& result, + std::string& key) { + CHECK_IPC_ERROR(root, command_t::ACQUIRE_LOCK_REPLY); + result = root["result"].get(); + key = root["key"].get(); + return Status::OK(); +} + +void WriteTryReleaseLockRequest(const std::string& key, std::string& msg) { + json root; + root["type"] = command_t::RELEASE_LOCK_REQUEST; + root["key"] = key; + encode_msg(root, msg); +} + +Status ReadTryReleaseLockRequest(const json& root, std::string& key) { + CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REQUEST); + key = root["key"].get(); + return Status::OK(); +} + +void WriteTryReleaseLockReply(const bool result, std::string& msg) { + json root; + root["type"] = command_t::RELEASE_LOCK_REPLY; + root["result"] = result; + encode_msg(root, msg); +} + +Status ReadTryReleaseLockReply(const json& root, bool& result) { + CHECK_IPC_ERROR(root, command_t::RELEASE_LOCK_REPLY); + result = root["result"].get(); + return Status::OK(); +} + } // namespace vineyard diff --git a/src/common/util/protocols.h b/src/common/util/protocols.h index d8a89bea..37fc581f 100644 --- a/src/common/util/protocols.h +++ b/src/common/util/protocols.h @@ -171,6 +171,12 @@ struct command_t { static const std::string SHALLOW_COPY_REPLY; static const std::string DEBUG_REQUEST; static const std::string DEBUG_REPLY; + + // distributed lock + static const std::string ACQUIRE_LOCK_REQUEST; + static const std::string ACQUIRE_LOCK_REPLY; + static const std::string RELEASE_LOCK_REQUEST; + static const std::string RELEASE_LOCK_REPLY; }; enum class StoreType { @@ -820,6 +826,24 @@ void WriteDebugReply(const json& result, std::string& msg); Status ReadDebugReply(const json& root, json& result); +void WriteTryAcquireLockRequest(const std::string& key, std::string& msg); + +Status ReadTryAcquireLockRequest(const json& root, std::string& key); + +void WriteTryAcquireLockReply(const bool result, const std::string actual_key, + std::string& msg); + +Status ReadTryAcquireLockReply(const json& root, bool& result, + std::string& key); + +void WriteTryReleaseLockRequest(const std::string& key, std::string& msg); + +Status ReadTryReleaseLockRequest(const json& root, std::string& key); + +void WriteTryReleaseLockReply(const bool result, std::string& msg); + +Status ReadTryReleaseLockReply(const json& root, bool& result); + } // namespace vineyard #endif // SRC_COMMON_UTIL_PROTOCOLS_H_ diff --git a/src/server/async/socket_server.cc b/src/server/async/socket_server.cc index a0664a78..46d257b8 100644 --- a/src/server/async/socket_server.cc +++ b/src/server/async/socket_server.cc @@ -359,6 +359,10 @@ bool SocketConnection::processMessage(const std::string& message_in) { return doShallowCopy(root); } else if (cmd == command_t::DEBUG_REQUEST) { return doDebug(root); + } else if (cmd == command_t::ACQUIRE_LOCK_REQUEST) { + return doAcquireLock(root); + } else if (cmd == command_t::RELEASE_LOCK_REQUEST) { + return doReleaseLock(root); } else { RESPONSE_ON_ERROR(Status::Invalid("Got unexpected command: " + cmd)); return false; @@ -1103,24 +1107,25 @@ bool SocketConnection::doPersist(const json& root) { auto self(shared_from_this()); ObjectID id; TRY_READ_REQUEST(ReadPersistRequest, root, id); - RESPONSE_ON_ERROR(server_ptr_->Persist(id, [self, id](const Status& status) { - std::string message_out; - if (status.ok()) { - WritePersistReply(message_out); - self->doWrite(message_out); - } else if (status.IsEtcdError()) { - // retry on etcd error: reprocess the message - VLOG(100) << "Warning: " - << "Retry persist on etcd error: " << status.ToString(); - self->server_ptr_->GetIOContext().post( - [self, id]() { self->doPersist(id); }); - } else { - VLOG(100) << "Error: " << status.ToString(); - WriteErrorReply(status, message_out); - self->doWrite(message_out); - } - return Status::OK(); - })); + RESPONSE_ON_ERROR( + server_ptr_->Persist(id, [self, id, root](const Status& status) { + std::string message_out; + if (status.ok()) { + WritePersistReply(message_out); + self->doWrite(message_out); + } else if (status.IsEtcdError()) { + // retry on etcd error: reprocess the message + VLOG(100) << "Warning: " + << "Retry persist on etcd error: " << status.ToString(); + self->server_ptr_->GetIOContext().post( + [self, id, root]() { self->doPersist(root); }); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + self->doWrite(message_out); + } + return Status::OK(); + })); return false; } @@ -1760,6 +1765,46 @@ bool SocketConnection::doDebug(const json& root) { return false; } +bool SocketConnection::doAcquireLock(const json& root) { + auto self(shared_from_this()); + std::string key; + TRY_READ_REQUEST(ReadTryAcquireLockRequest, root, key); + + RESPONSE_ON_ERROR(server_ptr_->TryAcquireLock( + key, [self](const Status& status, bool result, std::string actual_key) { + std::string message_out; + if (status.ok()) { + WriteTryAcquireLockReply(result, actual_key, message_out); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + } + self->doWrite(message_out); + return Status::OK(); + })); + return false; +} + +bool SocketConnection::doReleaseLock(const json& root) { + auto self(shared_from_this()); + std::string key; + TRY_READ_REQUEST(ReadTryReleaseLockRequest, root, key); + + RESPONSE_ON_ERROR(server_ptr_->TryReleaseLock( + key, [self](const Status& status, bool result) { + std::string message_out; + if (status.ok()) { + WriteTryReleaseLockReply(result, message_out); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + } + self->doWrite(message_out); + return Status::OK(); + })); + return false; +} + void SocketConnection::doWrite(const std::string& buf) { std::string to_send; size_t length = buf.size(); diff --git a/src/server/async/socket_server.h b/src/server/async/socket_server.h index d712df15..a5291582 100644 --- a/src/server/async/socket_server.h +++ b/src/server/async/socket_server.h @@ -142,6 +142,9 @@ class SocketConnection : public std::enable_shared_from_this { bool doDebug(json const& root); + bool doAcquireLock(json const& root); + bool doReleaseLock(json const& root); + protected: template Status MoveBuffers(std::map mapping, diff --git a/src/server/server/vineyard_server.cc b/src/server/server/vineyard_server.cc index 043c30b7..ed18935c 100644 --- a/src/server/server/vineyard_server.cc +++ b/src/server/server/vineyard_server.cc @@ -1054,6 +1054,38 @@ Status VineyardServer::MigrateObject(const ObjectID object_id, return Status::OK(); } +Status VineyardServer::TryAcquireLock(std::string& key, + callback_t callback) { + ENSURE_VINEYARDD_READY(); + auto self(shared_from_this()); + meta_service_ptr_->TryAcquireLock( + key, [self, callback](const Status& status, bool result, + std::string actual_key) { + if (status.ok()) { + return callback(status, result, actual_key); + } else { + return callback(status, result, actual_key); + } + }); + + return Status::OK(); +} + +Status VineyardServer::TryReleaseLock(std::string& key, + callback_t callback) { + ENSURE_VINEYARDD_READY(); + auto self(shared_from_this()); + meta_service_ptr_->TryReleaseLock( + key, [self, callback](const Status& status, bool result) { + if (status.ok()) { + return callback(status, result); + } else { + return status; + } + }); + return Status::OK(); +} + Status VineyardServer::LabelObjects(const ObjectID object_id, const std::vector& keys, const std::vector& values, diff --git a/src/server/server/vineyard_server.h b/src/server/server/vineyard_server.h index c9e93663..a51a8323 100644 --- a/src/server/server/vineyard_server.h +++ b/src/server/server/vineyard_server.h @@ -193,6 +193,11 @@ class VineyardServer : public std::enable_shared_from_this { Status Verify(const std::string& username, const std::string& password, callback_t<> callback); + Status TryAcquireLock(std::string& key, + callback_t callback); + + Status TryReleaseLock(std::string& key, callback_t callback); + inline SessionID session_id() const { return session_id_; } inline InstanceID instance_id() { return instance_id_; } inline std::string instance_name() { return instance_name_; } diff --git a/src/server/services/etcd_meta_service.cc b/src/server/services/etcd_meta_service.cc index 161ed9fc..c45b5f79 100644 --- a/src/server/services/etcd_meta_service.cc +++ b/src/server/services/etcd_meta_service.cc @@ -151,6 +151,43 @@ void EtcdMetaService::Stop() { } } +void EtcdMetaService::TryAcquireLock( + std::string key, callback_t callback_after_try_lock) { + auto self(shared_from_base()); + + etcd_->lock(prefix_ + key) + .then([self, callback_after_try_lock]( + pplx::task const& resp_task) { + auto const& resp = resp_task.get(); + if (resp.is_ok()) { + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_lock, Status::OK(), true, + resp.lock_key().substr(self->prefix_.size()))); + } else { + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_lock, Status::OK(), false, "")); + } + }); +} + +void EtcdMetaService::TryReleaseLock( + std::string key, callback_t callback_after_try_unlock) { + auto self(shared_from_base()); + + etcd_->unlock(prefix_ + key) + .then([self, callback_after_try_unlock]( + pplx::task const& resp_task) { + auto const& resp = resp_task.get(); + if (resp.is_ok()) { + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::OK(), true)); + } else { + self->server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::OK(), false)); + } + }); +} + void EtcdMetaService::requestLock( std::string lock_name, callback_t> callback_after_locked) { diff --git a/src/server/services/etcd_meta_service.h b/src/server/services/etcd_meta_service.h index 8fa1403d..4b4624b8 100644 --- a/src/server/services/etcd_meta_service.h +++ b/src/server/services/etcd_meta_service.h @@ -127,6 +127,10 @@ class EtcdMetaService : public IMetaService { ~EtcdMetaService() override {} + void TryAcquireLock(std::string key, callback_t); + + void TryReleaseLock(std::string key, callback_t); + protected: explicit EtcdMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr), diff --git a/src/server/services/local_meta_service.h b/src/server/services/local_meta_service.h index be6c2fd1..6491741d 100644 --- a/src/server/services/local_meta_service.h +++ b/src/server/services/local_meta_service.h @@ -50,6 +50,18 @@ class LocalMetaService : public IMetaService { ~LocalMetaService() override {} + void TryAcquireLock(std::string key, + callback_t callback_after_try_lock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_lock, Status::NotImplemented(), false, "")); + } + + void TryReleaseLock(std::string key, + callback_t callback_after_try_unlock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_unlock, Status::NotImplemented(), false)); + } + protected: explicit LocalMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr) {} diff --git a/src/server/services/meta_service.h b/src/server/services/meta_service.h index 294bcb35..419888e6 100644 --- a/src/server/services/meta_service.h +++ b/src/server/services/meta_service.h @@ -142,6 +142,11 @@ class IMetaService : public std::enable_shared_from_this { bool stopped() const { return this->stopped_.load(); } + virtual void TryAcquireLock(std::string key, + callback_t callback) = 0; + + virtual void TryReleaseLock(std::string key, callback_t callback) = 0; + private: void registerToEtcd(); diff --git a/src/server/services/redis_meta_service.h b/src/server/services/redis_meta_service.h index b3d3b65f..91adc5e6 100644 --- a/src/server/services/redis_meta_service.h +++ b/src/server/services/redis_meta_service.h @@ -177,6 +177,18 @@ class RedisMetaService : public IMetaService { inline void Stop() override; ~RedisMetaService() override {} + void TryAcquireLock(std::string key, + callback_t callback_after_try_lock) { + server_ptr_->GetMetaContext().post(boost::bind( + callback_after_try_lock, Status::NotImplemented, false, "")); + } + + void TryReleaseLock(std::string key, + callback_t callback_after_try_unlock) { + server_ptr_->GetMetaContext().post( + boost::bind(callback_after_try_unlock, Status::NotImplemented, false)); + } + protected: explicit RedisMetaService(std::shared_ptr& server_ptr) : IMetaService(server_ptr), diff --git a/test/runner.py b/test/runner.py index 94c65826..f0a7a7ee 100755 --- a/test/runner.py +++ b/test/runner.py @@ -474,6 +474,7 @@ def run_vineyard_cpp_tests(meta, allocator, endpoints, tests): run_test(tests, 'tensor_test') run_test(tests, 'typename_test') run_test(tests, 'version_test') + run_test(tests, 'kv_state_cache_radix_tree_test') def run_vineyard_spill_tests(meta, allocator, endpoints, tests): @@ -690,6 +691,35 @@ def run_scale_in_out_tests(meta, allocator, endpoints, instance_size=4): time.sleep(5) +def run_llm_tests(meta, allocator, endpoints): + meta_prefix = 'vineyard_test_%s' % time.time() + metadata_settings = make_metadata_settings(meta, endpoints, meta_prefix) + + instance_size = 2 + with start_multiple_vineyardd( + metadata_settings, + ['--allocator', allocator], + default_ipc_socket=VINEYARD_CI_IPC_SOCKET, + instance_size=instance_size, + nowait=False, + ) as instances: # noqa: F841, pylint: disable=unused-variable + vineyard_ipc_socket_1 = '%s.%d' % (VINEYARD_CI_IPC_SOCKET, 0) + vineyard_ipc_socket_2 = '%s.%d' % (VINEYARD_CI_IPC_SOCKET, 1) + + rpc_socket_port = instances[0][1] + subprocess.check_call( + [ + './build/bin/kv_state_cache_multi_test', + '--vineyard-endpoint', + 'localhost:%s' % rpc_socket_port, + '--vineyard-ipc-sockets', + vineyard_ipc_socket_1, + vineyard_ipc_socket_2, + ], + cwd=os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'), + ) + + def run_python_deploy_tests(meta, allocator, endpoints, test_args, with_migration): meta_prefix = 'vineyard_test_%s' % time.time() metadata_settings = make_metadata_settings(meta, endpoints, meta_prefix) @@ -862,6 +892,12 @@ def parse_sys_args(): default=False, help='Whether to run deployment and scaling in/out tests', ) + arg_parser.add_argument( + '--with-llm', + action='store_true', + default=False, + help='Whether to run llm tests', + ) arg_parser.add_argument( '--with-migration', action='store_true', @@ -972,6 +1008,14 @@ def execute_tests(args): with start_metadata_engine(args.meta) as (_, endpoints): run_fuse_test(args.meta, args.allocator, endpoints, python_test_args) + if args.with_llm: + with start_metadata_engine(args.meta) as (_, endpoints): + run_llm_tests( + args.meta, + args.allocator, + endpoints, + ) + def main(): parser, args = parse_sys_args() @@ -987,6 +1031,7 @@ def main(): or args.with_deployment or args.with_io or args.with_fuse + or args.with_llm ): print( 'Error: \n\tat least one of of --with-{cpp,graph,python,io,fuse} needs ' diff --git a/thirdparty/rax/radix.cc b/thirdparty/rax/radix.cc new file mode 100644 index 00000000..82739625 --- /dev/null +++ b/thirdparty/rax/radix.cc @@ -0,0 +1,2869 @@ +/* Rax -- A radix tree implementation. + * + * Version 1.2 -- 7 February 2019 + * + * Copyright (c) 2017-2019, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "radix.h" + +#ifndef RAX_MALLOC_INCLUDE +#define RAX_MALLOC_INCLUDE "rax_malloc.h" +#endif + +#include RAX_MALLOC_INCLUDE + +#include +#include "common/util/logging.h" +using namespace vineyard; + +/* This is a special pointer that is guaranteed to never have the same value + * of a radix tree node. It's used in order to report "not found" error without + * requiring the function to have multiple return values. */ +void *raxNotFound = (void*)"rax-not-found-pointer"; + +/* -------------------------------- Debugging ------------------------------ */ + +void raxDebugShowNode(const char *msg, raxNode *n); + +/* Turn debugging messages on/off by compiling with RAX_DEBUG_MSG macro on. + * When RAX_DEBUG_MSG is defined by default Rax operations will emit a lot + * of debugging info to the standard output, however you can still turn + * debugging on/off in order to enable it only when you suspect there is an + * operation causing a bug using the function raxSetDebugMsg(). */ +#ifdef RAX_DEBUG_MSG +#define debugf(...) \ + if (raxDebugMsg) { \ + printf("%s:%s:%d:\t", __FILE__, __FUNCTION__, __LINE__); \ + printf(__VA_ARGS__); \ + fflush(stdout); \ + } + +#define debugnode(msg,n) raxDebugShowNode(msg,n) +#else +#define debugf(...) +#define debugnode(msg,n) +#endif + +/* By default log debug info if RAX_DEBUG_MSG is defined. */ +static int raxDebugMsg = 1; + +/* When debug messages are enabled, turn them on/off dynamically. By + * default they are enabled. Set the state to 0 to disable, and 1 to + * re-enable. */ +void raxSetDebugMsg(int onoff) { + raxDebugMsg = onoff; +} + +/* ------------------------- raxStack functions -------------------------- + * The raxStack is a simple stack of pointers that is capable of switching + * from using a stack-allocated array to dynamic heap once a given number of + * items are reached. It is used in order to retain the list of parent nodes + * while walking the radix tree in order to implement certain operations that + * need to navigate the tree upward. + * ------------------------------------------------------------------------- */ + +/* Initialize the stack. */ +static inline void raxStackInit(raxStack *ts) { + ts->stack = ts->static_items; + ts->items = 0; + ts->maxitems = RAX_STACK_STATIC_ITEMS; + ts->oom = 0; +} + +/* Push an item into the stack, returns 1 on success, 0 on out of memory. */ +static inline int raxStackPush(raxStack *ts, void *ptr) { + if (ts->items == ts->maxitems) { + if (ts->stack == ts->static_items) { + ts->stack = (void **)rax_malloc(sizeof(void*)*ts->maxitems*2); + if (ts->stack == NULL) { + ts->stack = ts->static_items; + ts->oom = 1; + errno = ENOMEM; + return 0; + } + memcpy(ts->stack,ts->static_items,sizeof(void*)*ts->maxitems); + } else { + void **newalloc = (void **)rax_realloc(ts->stack,sizeof(void*)*ts->maxitems*2); + if (newalloc == NULL) { + ts->oom = 1; + errno = ENOMEM; + return 0; + } + ts->stack = newalloc; + } + ts->maxitems *= 2; + } + ts->stack[ts->items] = ptr; + ts->items++; + return 1; +} + +/* Pop an item from the stack, the function returns NULL if there are no + * items to pop. */ +static inline void *raxStackPop(raxStack *ts) { + if (ts->items == 0) return NULL; + ts->items--; + return ts->stack[ts->items]; +} + +/* Return the stack item at the top of the stack without actually consuming + * it. */ +static inline void *raxStackPeek(raxStack *ts) { + if (ts->items == 0) return NULL; + return ts->stack[ts->items-1]; +} + +/* Free the stack in case we used heap allocation. */ +static inline void raxStackFree(raxStack *ts) { + if (ts->stack != ts->static_items) rax_free(ts->stack); +} + +/* Add the number of nodes in the stack to each node. */ +void raxStackAddNumNodes(raxStack *stack, int num) { + for (size_t i=0; iitems; i++) { + raxNode *node = (raxNode *)stack->stack[stack->items - i - 1]; + node->numnodes+=(num); + if (node->issubtree) { + break; + } + } +} + +/* ---------------------------------------------------------------------------- + * Radix tree implementation + * --------------------------------------------------------------------------*/ + +/* Return the padding needed in the tokens section of a node having size + * 'nodesize'. The padding is needed to store the child pointers to aligned + * addresses. Note that we add 4 to the node size because the node has a four + * bytes header. */ +#define raxPadding(nodesize) ((sizeof(void*) - ((nodesize * sizeof(int) + 4) % sizeof(void*))) & (sizeof(void*)-1)) + +/* Return the pointer to the last child pointer in a node. For the compressed + * nodes this is the only child pointer. */ +#define raxNodeLastChildPtr(n) ((raxNode**) ( \ + ((char*)(n)) + \ + raxNodeCurrentLength(n) - \ + sizeof(raxNode*) - \ + (((n)->iskey && !(n)->isnull) ? sizeof(void*) : 0) \ +)) + +/* Return the pointer to the first child pointer. */ +#define raxNodeFirstChildPtr(n) ((raxNode**) ( \ + (char*)((n)->data) + \ + ((n)->size * sizeof(int)) + \ + raxPadding((n)->size))) + +/* Return the current total size of the node. Note that the second line + * computes the padding after the list of tokens, needed in order to + * save pointers to aligned addresses. */ +#define raxNodeCurrentLength(n) ( \ + sizeof(raxNode) + \ + (n)->size * sizeof(int) + \ + raxPadding((n)->size) + \ + ((n)->iscompr ? sizeof(raxNode*) : sizeof(raxNode*) * (n)->size) + \ + (((n)->iskey && !(n)->isnull) * sizeof(void*)) \ +) + +/* Allocate a new non compressed node with the specified number of children. + * If datafiled is true, the allocation is made large enough to hold the + * associated data pointer. + * Returns the new node pointer. On out of memory NULL is returned. */ +raxNode *raxNewNode(size_t children, int datafield) { + size_t nodesize = sizeof(raxNode) + children * sizeof(int) + + raxPadding(children) + sizeof(raxNode*) * children; + if (datafield) nodesize += sizeof(void*); + raxNode *node = (raxNode *)rax_malloc(nodesize); + if (node == NULL) return NULL; + node->iskey = 0; + node->isnull = 0; + node->iscompr = 0; + node->issubtree = 0; + node->custom_data = nullptr; + node->timestamp = 0; + node->numnodes = 1; + node->size = children; + node->timestamp = 0; + return node; +} + +/* Allocate a new rax and return its pointer. On out of memory the function + * returns NULL. */ +rax *raxNew(void) { + rax *rax = (struct rax *)rax_malloc(sizeof(*rax)); + if (rax == NULL) return NULL; + rax->numele = 0; + rax->numnodes = 1; + rax->head = raxNewNode(0,0); + if (rax->head == NULL) { + rax_free(rax); + return NULL; + } else { + return rax; + } +} + +/* realloc the node to make room for auxiliary data in order + * to store an item in that node. On out of memory NULL is returned. */ +raxNode *raxReallocForData(raxNode *n, void *data) { + if (data == NULL) return n; /* No reallocation needed, setting isnull=1 */ + size_t curlen = raxNodeCurrentLength(n); + return (raxNode *)rax_realloc(n,curlen+sizeof(void*)); +} + +/* Set the node auxiliary data to the specified pointer. */ +void raxSetData(raxNode *n, void *data) { + n->iskey = 1; + if (data != NULL) { + n->isnull = 0; + void **ndata = (void**) + ((char*)n+raxNodeCurrentLength(n)-sizeof(void*)); + memcpy(ndata,&data,sizeof(data)); + } else { + n->isnull = 1; + } +} + +/* Get the node auxiliary data. */ +void *raxGetData(raxNode *n) { + if (n->isnull) return NULL; + void **ndata =(void**)((char*)n+raxNodeCurrentLength(n)-sizeof(void*)); + void *data; + memcpy(&data,ndata,sizeof(data)); + return data; +} + + +/* +* Set the custom data for the root of sub-tree +*/ +void raxSetCustomData(raxNode *n, void *data) { + n->custom_data = data; +} + +/* +* Get the custom data for the root of sub-tree +*/ +void *raxGetCustomData(raxNode *n) { + return n->custom_data; +} + +/* Add a new child to the node 'n' representing the token 'c' and return + * its new pointer, as well as the child pointer by reference. Additionally + * '***parentlink' is populated with the raxNode pointer-to-pointer of where + * the new child was stored, which is useful for the caller to replace the + * child pointer if it gets reallocated. + * + * On success the new parent node pointer is returned (it may change because + * of the realloc, so the caller should discard 'n' and use the new value). + * On out of memory NULL is returned, and the old node is still valid. */ +raxNode *raxAddChild(raxNode *n, int c, raxNode **childptr, raxNode ***parentlink) { + assert(n->iscompr == 0); + + size_t curlen = raxNodeCurrentLength(n); + n->size++; + size_t newlen = raxNodeCurrentLength(n); + n->size--; /* For now restore the original size. We'll update it only on + success at the end. */ + + /* Alloc the new child we will link to 'n'. */ + raxNode *child = raxNewNode(0,0); + child->timestamp = n->timestamp; + if (child == NULL) return NULL; + + int parent_numnodes = n->numnodes; + /* Make space in the original node. */ + raxNode *newn = (raxNode *)rax_realloc(n,newlen); + if (newn == NULL) { + rax_free(child); + return NULL; + } + n = newn; + + /* After the reallocation, we have up to 8/16 (depending on the system + * pointer size, and the required node padding) bytes at the end, that is, + * the additional token in the 'data' section, plus one pointer to the new + * child, plus the padding needed in order to store addresses into aligned + * locations. + * + * So if we start with the following node, having [1,2,4,5] edges. + * + * Note: + * - We assume 4 bytes pointer for simplicity. + * - Each space below corresponds to four byte + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr]|AUXP| + * + * After the reallocation we need: 4 byte for the new edge token + * plus 4 bytes for a new child pointer (assuming 32 bit machine). + * However after adding 4 byte to the edge tokens, the header + the edge + * tokens are no longer aligned, so we also need 4 bytes of padding. + * In total the reallocation will add 4+4 bytes = 8 bytes: + * + * (Blank bytes are represented by ".") + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr]|AUXP|[....][....] + * + * Let's find where to insert the new child in order to make sure + * it is inserted in-place lexicographically. Assuming we are adding + * a child token 3 in our case pos will be = 2 after the end of the following + * loop. */ + int pos; + for (pos = 0; pos < n->size; pos++) { + if (n->data[pos] > c) break; + } + + /* Now, if present, move auxiliary data pointer at the end + * so that we can mess with the other data without overwriting it. + * We will obtain something like that: + * + * [HDR*][1,2][4,5][1ptr][2ptr][4ptr][5ptr][....][....]|AUXP| + */ + char *src, *dst; + if (n->iskey && !n->isnull) { + src = ((char*)n+curlen-sizeof(void*)); + dst = ((char*)n+newlen-sizeof(void*)); + memmove(dst,src,sizeof(void*)); + } + + /* Compute the "shift", that is, how many bytes we need to move the + * pointers section forward because of the addition of the new child + * byte in the string section. Note that if we had no padding, that + * would be always "4", since we are adding a single token in the string + * section of the node (where now there is [1,2,4,5] basically). + * + * However we have padding, so it could be zero, or up to 8. + * + * Another way to think at the shift is, how many bytes we need to + * move child pointers forward *other than* the obvious sizeof(void*) + * needed for the additional pointer itself. */ + size_t shift = newlen - curlen - sizeof(void*); + + /* We said we are adding a node with edge token 3. The insertion + * point is between token 2 and token 4, so the 'pos' variable value is + * the index of the first child pointer that we need to move forward + * to make space for our new pointer. + * + * To start, move all the child pointers after the insertion point + * of shift+sizeof(pointer) bytes on the right, to obtain: + * + * [HDR*][1,2][4,5][1ptr][2ptr][....][....][4ptr][5ptr]|AUXP| + */ + src = (char *)n->data+(n->size)*sizeof(int)+ + raxPadding(n->size)+ + sizeof(raxNode*)*pos; + memmove(src+shift+sizeof(raxNode*),src,sizeof(raxNode*)*(n->size-pos)); + + /* Move the pointers to the left of the insertion position as well. Often + * we don't need to do anything if there was already some padding to use. In + * that case the final destination of the pointers will be the same, however + * in our example there was no pre-existing padding, so we added 4 byte + * plus 4 bytes of padding. After the next memmove() things will look + * like that: + * + * [HDR*][1,2][4,5][....][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + */ + if (shift) { + src = (char*) raxNodeFirstChildPtr(n); + memmove(src+shift,src,sizeof(raxNode*)*pos); + } + + /* Now make the space for the additional char in the data section, + * but also move the pointers before the insertion point to the right + * by shift bytes, in order to obtain the following: + * + * [HDR*][1,2][.4][5.][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + */ + src = (char*)n->data+pos*sizeof(int); + memmove(src+4,src,(n->size-pos)*sizeof(int)); + + /* We can now set the character and its child node pointer to get: + * + * [HDR*][1,2][3,4][5.][1ptr][2ptr][....][4ptr][5ptr]|AUXP| + * [HDR*][1,2][3,4][5.][1ptr][2ptr][3ptr][4ptr][5ptr]|AUXP| + */ + n->data[pos] = c; + n->numnodes = parent_numnodes + 1; + n->size++; + src = (char*) raxNodeFirstChildPtr(n); + raxNode **childfield = (raxNode**)(src+sizeof(raxNode*)*pos); + memcpy(childfield,&child,sizeof(child)); + *childptr = child; + *parentlink = childfield; + return n; +} + +/* Turn the node 'n', that must be a node without any children, into a + * compressed node representing a set of nodes linked one after the other + * and having exactly one child each. The node can be a key or not: this + * property and the associated value if any will be preserved. + * + * The function also returns a child node, since the last node of the + * compressed chain cannot be part of the chain: it has zero children while + * we can only compress inner nodes with exactly one child each. */ +raxNode *raxCompressNode(raxNode *n, int *s, size_t len, raxNode **child) { + assert(n->size == 0 && n->iscompr == 0); + void *data = NULL; /* Initialized only to avoid warnings. */ + size_t newsize; + + debugf("Compress node: %.*d\n", (int)len,s); + + /* Allocate the child to link to this node. */ + *child = raxNewNode(0,0); + (*child)->timestamp = n->timestamp; + if (*child == NULL) return NULL; + + /* Make space in the parent node. */ + // update + newsize = sizeof(raxNode)+len*sizeof(int)+raxPadding(len)+sizeof(raxNode*); + if (n->iskey) { + data = raxGetData(n); /* To restore it later. */ + if (!n->isnull) newsize += sizeof(void*); + } + raxNode *newn = (raxNode *)rax_realloc(n,newsize); + if (newn == NULL) { + rax_free(*child); + return NULL; + } + n = newn; + + n->iscompr = 1; + n->numnodes++; + n->size = len; + memcpy(n->data,s,len * sizeof(int)); + if (n->iskey) raxSetData(n,data); + raxNode **childfield = raxNodeLastChildPtr(n); + memcpy(childfield,child,sizeof(*child)); + return n; +} + +/* Low level function that walks the tree looking for the token list + * s of 'len' bytes. The function returns the number of tokens + * of the key that was possible to process: if the returned integer + * is the same as 'len', then it means that the node corresponding to the + * token list was found (however it may not be a key in case the node->iskey is + * zero or if simply we stopped in the middle of a compressed node, so that + * 'splitpos' is non zero). + * + * Otherwise if the returned integer is not the same as 'len', there was an + * early stop during the tree walk because of a token mismatch. + * + * The node where the search ended (because the full token list was processed + * or because there was an early stop) is returned by reference as + * '*stopnode' if the passed pointer is not NULL. This node link in the + * parent's node is returned as '*plink' if not NULL. Finally, if the + * search stopped in a compressed node, '*splitpos' returns the index + * inside the compressed node where the search ended. This is useful to + * know where to split the node for insertion. + * + * Note that when we stop in the middle of a compressed node with + * a perfect match, this function will return a length equal to the + * 'len' argument (all the key matched), and will return a *splitpos which is + * always positive (that will represent the index of the character immediately + * *after* the last match in the current compressed node). + * + * When instead we stop at a compressed node and *splitpos is zero, it + * means that the current node represents the key (that is, none of the + * compressed node tokens list are needed to represent the key, just all + * its parents nodes). */ +static inline size_t raxLowWalk(rax *rax, const int *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts, bool set_timestamp = true) { + raxNode *h = rax->head; + raxNode **parentlink = &rax->head; + + size_t i = 0; /* Position in the string. */ + size_t j = 0; /* Position in the node children (or bytes if compressed).*/ + + // std::chrono::milliseconds ms = std::chrono::duration_cast< std::chrono::milliseconds >( + // std::chrono::system_clock::now().time_since_epoch()); + // int64_t timestamp = ms.count(); + auto now = std::chrono::high_resolution_clock::now(); + + // auto micros = std::chrono::time_point_cast(now).time_since_epoch().count(); + auto nanos = std::chrono::duration_cast(now.time_since_epoch()).count(); + int64_t timestamp = nanos; + + while(h->size && i < len) { + debugnode("Lookup current node",h); + int *v = h->data; + + if (h->iscompr) { + for (j = 0; j < h->size && i < len; j++, i++) { + if (v[j] != s[i]) { + break; + } + } + if (j != h->size) { + break; + } + } else { + /* Even when h->size is large, linear scan provides good + * performances compared to other approaches that are in theory + * more sounding, like performing a binary search. */ + for (j = 0; j < h->size; j++) { + if (v[j] == s[i]) { + break; + } + } + if (j == h->size) { + break; + } + i++; + } + + /* Save timestamp. */ + h->timestamp = timestamp; + + if (ts) raxStackPush(ts,h); /* Save stack of parent nodes. */ + raxNode **children = raxNodeFirstChildPtr(h); + if (h->iscompr) j = 0; /* Compressed node only child is at index 0. */ + memcpy(&h,children+j,sizeof(h)); + parentlink = children+j; + j = 0; /* If the new node is compressed and we do not + iterate again (since i == l) set the split + position to 0 to signal this node represents + the searched key. */ + } + if (set_timestamp) { + h->timestamp = timestamp; + } + debugnode("Lookup stop node is",h); + if (stopnode) *stopnode = h; + if (plink) *plink = parentlink; + if (splitpos && h->iscompr) *splitpos = j; + return i; +} + +int handleOutOfMemory(rax *rax, raxNode *h, int *s, size_t len, void **old){ + /* This code path handles out of memory after part of the sub-tree was + * already modified. Set the node as a key, and then remove it. However we + * do that only if the node is a terminal node, otherwise if the OOM + * happened reallocating a node in the middle, we don't need to free + * anything. */ + if (h->size == 0) { + h->isnull = 1; + h->iskey = 1; + rax->numele++; /* Compensate the next remove. */ + assert(raxRemove(rax,s,len,NULL) != 0); + } + errno = ENOMEM; + return 0; +} + +/* Insert the token list 's' of size 'len', setting as auxiliary data + * the pointer 'data'. If the element is already present, the associated + * data is updated (only if 'overwrite' is set to 1), and 0 is returned, + * otherwise the element is inserted and 1 is returned. On out of memory the + * function returns 0 as well but sets errno to ENOMEM, otherwise errno will + * be set to 0. + */ +int raxGenericInsert(rax *rax, int *s, size_t len, void *data, void **old, int overwrite, void **dataNode, bool set_timestamp = true) { + size_t i; + int j = 0; /* Split position. If raxLowWalk() stops in a compressed + node, the index 'j' represents the char we stopped within the + compressed node, that is, the position where to split the + node for insertion. */ + raxNode *h, **parentlink; + + raxStack lowWalkStack, splitStack; + raxStackInit(&lowWalkStack); + raxStackInit(&splitStack); + int all_added_node = 0; + + i = raxLowWalk(rax,s,len,&h,&parentlink,&j,&lowWalkStack, set_timestamp); + debugf("######## after raxLowWalk##########"); + /* If i == len we walked following the whole string. If we are not + * in the middle of a compressed node, the string is either already + * inserted or this middle node is currently not a key, but can represent + * our key. We have just to reallocate the node and make space for the + * data pointer. */ + if (i == len && (!h->iscompr || j == 0 /* not in the middle if j is 0 */)) { + debugf("### Insert: node representing key exists\n"); + /* Make space for the value pointer if needed. */ + if (!h->iskey || (h->isnull && overwrite)) { + printf("#############raxReallocForData1 ############\n"); + h = raxReallocForData(h,data); + if (h) memcpy(parentlink,&h,sizeof(h)); + } + if (h == NULL) { + errno = ENOMEM; + return 0; + } + + /* Update the existing key if there is already one. */ + if (h->iskey) { + if (old) *old = raxGetData(h); + if (overwrite) raxSetData(h,data); + *dataNode = h; + errno = 0; + return 0; /* Element already exists. */ + } + + /* Otherwise set the node as a key. Note that raxSetData() + * will set h->iskey. */ + raxSetData(h,data); + rax->numele++; + *dataNode = h; + return 1; /* Element inserted. */ + } + + /* If the node we stopped at is a compressed node, we need to + * split it before to continue. + * + * Splitting a compressed node have a few possible cases. + * Imagine that the node 'h' we are currently at is a compressed + * node containing the token list [1,2,3,4,5,6,7] (it means that it represents + * nodes 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 with the only child + * pointer of this node pointing at the 7 node, because remember that + * we have tokens at the edges of the graph, not inside the nodes + * themselves. + * + * In order to show a real case imagine our node to also point to + * another compressed node, that finally points at the node without + * children, representing 'O': + * + * [1,2,3,4,5,6,7] -> [2,4,5] -> [] + * + * When inserting we may face the following cases. Note that all the cases + * require the insertion of a non compressed node with exactly two + * children, except for the last case which just requires splitting a + * compressed node. + * + * 1) Inserting [1,2,3,6,7,8] + * + * |4| -> [5,6,7] -> [2,4,5] -> [] + * [1,2,3] -> |-| + * |6| -> (... continue algo ...) [7,8] -> [] + * + * 2) Inserting [1,2,3,4,5,6,8] + * + * |7| -> [2,4,5] -> [] + * [1,2,3,4,5,6] -> |-| + * |8| -> (... continue algo ...) [] + * + * 3) Inserting [1,3,4] (Like case 1, but set iscompr = 0 into original node) + * + * |2| -> [3,4,5,6,7] -> [2,4,5] -> [] + * |1| -> |-| + * |3| -> (... continue algo ...) |4| -> [] + * + * 4) Inserting [2,3,4] + * + * |1| -> [2,3,4,5,6,7] -> [2,4,5] -> [] + * |-| + * |2| -> (... continue algo ...) [3,4] -> [] + * + * 5) Inserting [1,2,3] + * + * [1,2,3] -> [4,5,6,7] -> [2,4,5] -> [] + * + * The final algorithm for insertion covering all the above cases is as + * follows. + * + * ============================= ALGO 1 ============================= + * + * For the above cases 1 to 4, that is, all cases where we stopped in + * the middle of a compressed node for a character mismatch, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of characters, we found the mismatching + * character. For example if the node contains [1,2,3,4,5,6,7] and we add + * [1,2,3,6,7,8] the $SPLITPOS is 3, that is, the index at which the + * mismatching token is found. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create "split node" having as child the non common token + * at the compressed node. The other non common token (at the key) + * will be added later as we continue the normal insertion algorithm + * at step "6". + * + * 3a. IF $SPLITPOS == 0: + * Replace the old node with the split node, by copying the auxiliary + * data if any. Fix parent's reference. Free old node eventually + * (we still need its data for the next steps of the algorithm). + * + * 3b. IF $SPLITPOS != 0: + * Trim the compressed node (reallocating it as well) in order to + * contain $splitpos tokens. Change chilid pointer in order to link + * to the split node. If new compressed node len is just 1, set + * iscompr to 0 (layout is the same). Fix parent's reference. + * + * 4a. IF the postfix len (the length of the remaining token list of the + * original compressed node after the split token) is non zero, + * create a "postfix node". If the postfix node has just one token + * set iscompr to 0, otherwise iscompr to 1. Set the postfix node + * child pointer to $NEXT. + * + * 4b. IF the postfix len is zero, just use $NEXT as postfix pointer. + * + * 5. Set child[0] of split node to postfix node. + * + * 6. Set the split node as the current node, set current index at child[1] + * and continue insertion algorithm as usually. + * + * ============================= ALGO 2 ============================= + * + * For case 5, that is, if we stopped in the middle of a compressed + * node but no mismatch was found, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of token list, we stopped iterating because + * there were no more keys character to match. So in the example of + * the node [1,2,3,4,5,6,7], adding the token list [1,2,3,4], the $SPLITPOS is 4. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create a "postfix node" containing all the tokens from $SPLITPOS + * to the end. Use $NEXT as the postfix node child pointer. + * If the postfix node length is 1, set iscompr to 0. + * Set the node as a key with the associated value of the new + * inserted key. + * + * 3. Trim the current node to contain the first $SPLITPOS tokens. + * As usually if the new node length is just 1, set iscompr to 0. + * Take the iskey / associated value as it was in the original node. + * Fix the parent's reference. + * + * 4. Set the postfix node as the only child pointer of the trimmed + * node created at step 1. + */ + + /* ------------------------- ALGORITHM 1 --------------------------- */ + if (h->iscompr && i != len) { + debugf("ALGO 1: Stopped at compressed node (%p)\n",(void*)h); + debugf("Still to insert: "); + debugf("Splitting at %d: '%d'\n", j, ((int*)h->data)[j]); + debugf("Other key is: %d ", s[i]); + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next,childfield,sizeof(next)); + debugf("Next is %p\n", (void*)next); + debugf("iskey %d\n", h->iskey); + if (h->iskey) { + debugf("key value is %p\n", raxGetData(h)); + } + debugf("get next data %p\n", raxGetData(next)); + + /* Set the length of the additional nodes we will need. */ + size_t trimmedlen = j; + size_t postfixlen = h->size - j - 1; + int split_node_is_key = !trimmedlen && h->iskey && !h->isnull; + size_t nodesize; + + /* 2: Create the split node. Also allocate the other nodes we'll need + * ASAP, so that it will be simpler to handle OOM. */ + raxNode *splitnode = raxNewNode(1, split_node_is_key); + splitnode->timestamp = h->timestamp; + raxNode *trimmed = NULL; + raxNode *postfix = NULL; + + if (trimmedlen) { + nodesize = sizeof(raxNode)+trimmedlen * sizeof(int)+raxPadding(trimmedlen)+ + sizeof(raxNode*); + if (h->iskey && !h->isnull) nodesize += sizeof(void*); + trimmed = (raxNode *)rax_malloc(nodesize); + } + + if (postfixlen) { + nodesize = sizeof(raxNode)+postfixlen * sizeof(int)+raxPadding(postfixlen)+ + sizeof(raxNode*); + postfix = (raxNode *)rax_malloc(nodesize); + } + + /* OOM? Abort now that the tree is untouched. */ + if (splitnode == NULL || + (trimmedlen && trimmed == NULL) || + (postfixlen && postfix == NULL)) + { + rax_free(splitnode); + rax_free(trimmed); + rax_free(postfix); + errno = ENOMEM; + return 0; + } + splitnode->data[0] = h->data[j]; + splitnode->numnodes = h->numnodes; + debugf("split node is %p, data is %d\n", &splitnode, splitnode->data[0]); + + if (j == 0) { + /* 3a: Replace the old node with the split node. */ + if (h->iskey) { + void *ndata = raxGetData(h); + raxSetData(splitnode,ndata); + } + memcpy(parentlink,&splitnode,sizeof(splitnode)); + } else { + /* 3b: Trim the compressed node. */ + trimmed->size = j; + memcpy(trimmed->data,h->data,j*sizeof(int)); + trimmed->iscompr = j > 1 ? 1 : 0; + trimmed->numnodes = h->numnodes; + trimmed->iskey = h->iskey; + trimmed->isnull = h->isnull; + trimmed->timestamp = h->timestamp; + if (h->iskey && !h->isnull) { + void *ndata = raxGetData(h); + raxSetData(trimmed,ndata); + } + raxNode **cp = raxNodeLastChildPtr(trimmed); + memcpy(cp,&splitnode,sizeof(splitnode)); + memcpy(parentlink,&trimmed,sizeof(trimmed)); + parentlink = cp; /* Set parentlink to splitnode parent. */ + all_added_node++; + rax->numnodes++; + } + + /* 4: Create the postfix node: what remains of the original + * compressed node after the split. */ + if (postfixlen) { + /* 4a: create a postfix node. */ + postfix->iskey = 0; + postfix->isnull = 0; + postfix->numnodes = h->numnodes; + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + memcpy(postfix->data,h->data+j+1,postfixlen*sizeof(int)); + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp,&next,sizeof(next)); + all_added_node++; + rax->numnodes++; + } else { + /* 4b: just use next as postfix node. */ + postfix = next; + postfix->numnodes = next->numnodes; + } + + /* 5: Set splitnode first child as the postfix node. */ + raxNode **splitchild = raxNodeLastChildPtr(splitnode); + memcpy(splitchild,&postfix,sizeof(postfix)); + + /* 6. Continue insertion: this will cause the splitnode to + * get a new child (the non common token at the currently + * inserted token list). */ + rax_free(h); + h = splitnode; + if (trimmedlen != 0 && postfixlen == 0) { + trimmed->numnodes++; + raxStackPush(&splitStack, trimmed); + } else if (trimmedlen == 0 && postfixlen != 0) { + splitnode->numnodes++; + raxStackPush(&splitStack, splitnode); + } else { + trimmed->numnodes+=2; + splitnode->numnodes++; + raxStackPush(&splitStack, trimmed); + raxStackPush(&splitStack, splitnode); + } + } else if (h->iscompr && i == len) { + /* ------------------------- ALGORITHM 2 --------------------------- */ + debugf("ALGO 2: Stopped at compressed node %d (%p) j = %d\n", + ((int*)h->data)[j], (void*)h, j); + + /* Allocate postfix & trimmed nodes ASAP to fail for OOM gracefully. */ + size_t postfixlen = h->size - j; + size_t nodesize = sizeof(raxNode)+postfixlen*sizeof(int)+raxPadding(postfixlen)+ + sizeof(raxNode*); + if (data != NULL) nodesize += sizeof(void*); + raxNode *postfix = (raxNode *)rax_malloc(nodesize); + + nodesize = sizeof(raxNode)+j*sizeof(int)+raxPadding(j)+sizeof(raxNode*); + if (h->iskey && !h->isnull) nodesize += sizeof(void*); + raxNode *trimmed = (raxNode *)rax_malloc(nodesize); + + if (postfix == NULL || trimmed == NULL) { + rax_free(postfix); + rax_free(trimmed); + errno = ENOMEM; + return 0; + } + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next,childfield,sizeof(next)); + + /* 2: Create the postfix node. */ + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + postfix->numnodes = h->numnodes; + postfix->iskey = 1; + postfix->isnull = 0; + postfix->timestamp = h->timestamp; + memcpy(postfix->data,h->data+j,postfixlen*sizeof(int)); + raxSetData(postfix,data); + *dataNode = postfix; + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp,&next,sizeof(next)); + rax->numnodes++; + all_added_node++; + + /* 3: Trim the compressed node. */ + trimmed->size = j; + trimmed->iscompr = j > 1; + trimmed->numnodes = h->numnodes+1; + trimmed->iskey = 0; + trimmed->isnull = 0; + trimmed->timestamp = h->timestamp; + memcpy(trimmed->data,h->data,j*sizeof(int)); + memcpy(parentlink,&trimmed,sizeof(trimmed)); + if (h->iskey) { + void *aux = raxGetData(h); + raxSetData(trimmed,aux); + } + + /* Fix the trimmed node child pointer to point to + * the postfix node. */ + cp = raxNodeLastChildPtr(trimmed); + memcpy(cp,&postfix,sizeof(postfix)); + + raxStackAddNumNodes(&lowWalkStack, all_added_node); + raxStackFree(&lowWalkStack); + /* Finish! We don't need to continue with the insertion + * algorithm for ALGO 2. The key is already inserted. */ + rax->numele++; + rax_free(h); + return 1; /* Key inserted. */ + } + + raxNode *prev_node = NULL; + int insert_new_node = 0; + /* We walked the radix tree as far as we could, but still there are left + * tokens in our string. We need to insert the missing nodes. */ + while(i < len) { + raxNode *child; + // error + + /* If this node is going to have a single child, and there + * are other tokens, so that that would result in a chain + * of single-childed nodes, turn it into a compressed node. */ + if (h->size == 0 && len-i > 1) { + debugf("Inserting compressed node: [%d, \n", s[i]); + size_t comprsize = len-i; + if (comprsize > RAX_NODE_MAX_SIZE) + comprsize = RAX_NODE_MAX_SIZE; + raxNode *newh = raxCompressNode(h,s+i,comprsize,&child); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + memcpy(parentlink,&h,sizeof(h)); + parentlink = raxNodeLastChildPtr(h); + i += comprsize; + } else { + debugf("Inserting normal node %d\n", s[i]); + raxNode **new_parentlink; + raxNode *newh = raxAddChild(h,s[i],&child,&new_parentlink); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + memcpy(parentlink,&h,sizeof(h)); + parentlink = new_parentlink; + i++; + } + if (prev_node != NULL) { + prev_node->numnodes++; + } else { + prev_node = h; + } + all_added_node++; + insert_new_node++; + rax->numnodes++; + h = child; + } + raxStackAddNumNodes(&lowWalkStack, all_added_node); + raxStackAddNumNodes(&splitStack, insert_new_node); + raxStackFree(&lowWalkStack); + raxStackFree(&splitStack); + raxNode *newh = raxReallocForData(h,data); + // printf("#############raxReallocForData2 ############\n"); + if (newh == NULL) { + return handleOutOfMemory(rax, h, (int *)s, i, old); + } + h = newh; + if (!h->iskey) rax->numele++; + raxSetData(h,data); + memcpy(parentlink,&h,sizeof(h)); + *dataNode = h; + return 1; /* Element inserted. */ +} + +/* Overwriting insert. Just a wrapper for raxGenericInsert() that will + * update the element if there is already one for the same key. */ +int raxInsert(rax *rax, int *s, size_t len, void *data, void **old, bool set_timestamp) { + void *dataNode = NULL; + return raxGenericInsert(rax,s,len,data,old,1,&dataNode, set_timestamp); +} + +/* Non overwriting insert function: this if an element with the same key + * exists, the value is not updated and the function returns 0. + * This is a just a wrapper for raxGenericInsert(). */ +int raxTryInsert(rax *rax, int *s, size_t len, void *data, void **old) { + void *dataNode = NULL; + return raxGenericInsert(rax,s,len,data,old,0,&dataNode); +} + +/* +Overwriting insert. Return the raxNode that contains the key. +*/ +int raxInsertAndReturnDataNode(rax *rax, int *s, size_t len, void *data, void **node, void **old) { + return raxGenericInsert(rax,s,len,data,old,1, node); +} + +/* Find a key in the rax, returns raxNotFound special void pointer value + * if the item was not found, otherwise the value associated with the + * item is returned. */ +void *raxFind(rax *rax, int *s, size_t len) { + raxNode *h; + + //debugf("### Lookup: %d\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return raxNotFound; + return raxGetData(h); +} + +/* +** Find a key in the rax, returns the stack +*/ +raxStack raxFindWithStack(rax *rax, int *s, size_t len) { + raxNode *h; + + raxStack ts; + raxStackInit(&ts); + //debugf("### Lookup: %.*s\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return ts; + return ts; +} + +/* +** Find a key in the rax, returns the raxNode that contains the key. +*/ +raxNode *raxFindAndReturnDataNode(rax *rax, int *s, size_t len, raxNode** sub_tree_node, bool set_timestamp) { + raxNode *h; + + raxStack ts; + raxStackInit(&ts); + //debugf("### Lookup: %.*s\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts,set_timestamp); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return NULL; + raxNode *tmp = h; + while(tmp != nullptr && tmp->issubtree == false) { + tmp = (raxNode *)raxStackPop(&ts); + } + if (tmp != nullptr && sub_tree_node != nullptr) { + *sub_tree_node = tmp; + } + + return h; +} + +// raxNode *raxSetSubTreeAndReturnDataNode(rax *rax, int *s, size_t len) { +// raxNode *h; + +// raxStack ts; +// raxStackInit(&ts); +// //debugf("### Lookup: %.*s\n", (int)len, s); +// int splitpos = 0; +// size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts,false); +// if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) +// return NULL; + +// if (h!= nullptr) { +// h->issubtree = true; +// raxStackAddNumNodes(&ts, -h->numnodes); +// } + +// return h; +// } + +int raxFindNode(rax *rax, int *s, size_t len, void **node) { + raxNode *h; + raxStack ts; + raxStackInit(&ts); + + int splitpos = 0; + size_t i = raxLowWalk(rax,s,len, &h, nullptr, nullptr, &ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return 0; + *node = raxStackPeek(&ts); + + return 1; +} + +/* +** Find a key in the rax, returns the current node and its parent link. +*/ +int raxFindNodeWithParent(rax *rax, int *s, size_t len, void **node, void **parent) { + raxNode *h; + + int splitpos = 0; + raxNode **parentlink; + size_t i = raxLowWalk(rax,s,len,&h,&parentlink,&splitpos,NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return 0; + *parent = (void *)parentlink; + *node = (void *)h; + return 1; +} + +/* Return the memory address where the 'parent' node stores the specified + * 'child' pointer, so that the caller can update the pointer with another + * one if needed. The function assumes it will find a match, otherwise the + * operation is an undefined behavior (it will continue scanning the + * memory without any bound checking). */ +raxNode **raxFindParentLink(raxNode *parent, raxNode *child) { + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode *c; + while(1) { + memcpy(&c,cp,sizeof(c)); + if (c == child) break; + cp++; + } + return cp; +} + +/* Low level child removal from node. The new node pointer (after the child + * removal) is returned. Note that this function does not fix the pointer + * of the parent node in its parent, so this task is up to the caller. + * The function never fails for out of memory. */ +raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { + debugnode("raxRemoveChild before", parent); + /* If parent is a compressed node (having a single child, as for definition + * of the data structure), the removal of the child consists into turning + * it into a normal node without children. */ + if (parent->iscompr) { + void *data = NULL; + if (parent->iskey) data = raxGetData(parent); + parent->isnull = 0; + parent->iscompr = 0; + parent->size = 0; + if (parent->iskey) raxSetData(parent,data); + debugnode("raxRemoveChild after", parent); + return parent; + } + + /* + * + * 0. Before remove the child, we need to store the custom + * data if the current node is the root node of subtree + * + */ + + /* Otherwise we need to scan for the child pointer and memmove() + * accordingly. + * + * 1. To start we seek the first element in both the children + * pointers and edge bytes in the node. */ + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode **c = cp; + int *e = parent->data; + + /* 2. Search the child pointer to remove inside the array of children + * pointers. */ + while(1) { + raxNode *aux; + memcpy(&aux,c,sizeof(aux)); + if (aux == child) break; + c++; + e++; + } + + /* 3. Remove the edge and the pointer by memmoving the remaining children + * pointer and edge bytes one position before. */ + int taillen = parent->size - (e - parent->data) - 1; + memmove(e,(int*)e+1,taillen*sizeof(int)); + + /* Compute the shift, that is the amount of bytes we should move our + * child pointers to the left, since the removal of one edge token + * and the corresponding padding change, may change the layout. + * We just check if in the old version of the node there was at the + * end just a single token and all padding: in that case removing one token + * will remove a whole sizeof(void*) word. */ + size_t shift = ((parent->size * sizeof(int)+4) % sizeof(void*)) == 4 ? sizeof(void *) : 0; + + /* Move the children pointers before the deletion point. */ + if (shift) + memmove(((char*)cp)-shift,cp,(parent->size-taillen-1)*sizeof(raxNode**)); + + /* Move the remaining "tail" pointers at the right position as well. */ + size_t valuelen = (parent->iskey && !parent->isnull) ? sizeof(void*) : 0; + memmove(((char*)c)-shift,c+1,taillen*sizeof(raxNode**)+valuelen); + + /* 4. Update size. */ + parent->size--; + + /* realloc the node according to the theoretical memory usage, to free + * data if we are over-allocating right now. */ + raxNode *newnode; + newnode = (raxNode *)rax_realloc(parent,raxNodeCurrentLength(parent)); + if (newnode) { + debugnode("raxRemoveChild after", newnode); + } + /* Note: if rax_realloc() fails we just return the old address, which + * is valid. */ + return newnode ? newnode : parent; +} + +/* Remove the specified item. Returns 1 if the item was found and + * deleted, 0 otherwise. */ +int raxRemove(rax *rax, int *s, size_t len, void **old, bool set_timestamp) { + raxNode *h; + raxStack ts; + + debugf("### Delete: "); + raxStackInit(&ts); + int splitpos = 0; + int all_added_node = 0; + size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,&ts, set_timestamp); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { + raxStackFree(&ts); + return 0; + } + + + if (old) *old = raxGetData(h); + h->iskey = 0; + rax->numele--; + + /* If this node has no children, the deletion needs to reclaim the + * no longer used nodes. This is an iterative process that needs to + * walk the three upward, deleting all the nodes with just one child + * that are not keys, until the head of the rax is reached or the first + * node with more than one child is found. */ + + int trycompress = 0; /* Will be set to 1 if we should try to optimize the + tree resulting from the deletion. */ + + if (h->size == 0) { + debugf("Key deleted in node without children. Cleanup needed.\n"); + raxNode *child = NULL; + int added_nodes = 0; + while(h != rax->head) { + child = h; + debugf("Freeing child %p, ", (void*)child); + debugf(" key:%d\n", child->iskey) + rax_free(child); + rax->numnodes--; + added_nodes--; + h = (raxNode *)raxStackPop(&ts); + h->numnodes += added_nodes; + /* If this node has more then one child, or actually holds + * a key, stop here. */ + if (h->iskey || (!h->iscompr && h->size != 1)) break; + } + raxStackAddNumNodes(&ts, added_nodes); + if (child) { + debugf("Unlinking child %p from parent %p\n", + (void*)child, (void*)h); + raxNode *newNode = (raxNode *)raxRemoveChild(h,child); + if (newNode != h) { + raxNode *parent = (raxNode *)raxStackPeek(&ts); + raxNode **parentlink; + if (parent == NULL) { + parentlink = &rax->head; + } else { + parentlink = raxFindParentLink(parent,h); + } + memcpy(parentlink,&newNode,sizeof(newNode)); + } + + /* If after the removal the node has just a single child + * and is not a key, we need to try to compress it. */ + if (newNode->size == 1 && newNode->iskey == 0) { + trycompress = 1; + h = newNode; + } + } + } else if (h->size == 1) { + /* If the node had just one child, after the removal of the key + * further compression with adjacent nodes is pontentially possible. */ + trycompress = 1; + } + + /* Don't try node compression if our nodes pointers stack is not + * complete because of OOM while executing raxLowWalk() */ + if (trycompress && ts.oom) trycompress = 0; + + /* Recompression: if trycompress is true, 'h' points to a radix tree node + * that changed in a way that could allow to compress nodes in this + * sub-branch. Compressed nodes represent chains of nodes that are not + * keys and have a single child, so there are two deletion events that + * may alter the tree so that further compression is needed: + * + * 1) A node with a single child was a key and now no longer is a key. + * 2) A node with two children now has just one child. + * + * We try to navigate upward till there are other nodes that can be + * compressed, when we reach the upper node which is not a key and has + * a single child, we scan the chain of children to collect the + * compressible part of the tree, and replace the current node with the + * new one, fixing the child pointer to reference the first non + * compressible node. + * + * Example of case "1". A tree stores the keys [1,2,3] = 1 and + * [1,2,3,4,5] = 2: + * + * + * [1,2,3] -> [4,5] -> [] (2) + * (1) + * + * After the removal of [1,2,3] the tree can be compressed as: + * + * [1,2,3,4,5] -> [] (2) + * + * + * Example of case "2". A tree stores the keys [1,2,3,4,5] = 1 and + * [1,2,3,5,6] = 2: + * + * |4| -> [5] -> [] (1) + * [1,2,3] -> |-| + * |5| -> [6] -> [] (2) + * + * After the removal of [1,2,3,5,6] the resulting tree is: + * + * [1,2,3] -> |4| -> [5] -> [] (1) + * + * That can be compressed into: + * + * [1,2,3,5,6] -> [] (1) + */ + if (trycompress) { + debugf("After removing %d len: %d:\n", s[0], (int)len); + debugnode("Compression may be needed",h); + debugf("Seek start node\n"); + + /* Try to reach the upper node that is compressible. + * At the end of the loop 'h' will point to the first node we + * can try to compress and 'parent' to its parent. */ + raxNode *parent; + while(1) { + parent = (raxNode *)raxStackPop(&ts); + if (!parent || parent->iskey || + (!parent->iscompr && parent->size != 1)) break; + h = parent; + debugnode("Going up to",h); + } + raxNode *start = h; /* Compression starting node. */ + int start_num_nodes = start->numnodes; + + /* Scan chain of nodes we can compress. */ + size_t comprsize = h->size; + int nodes = 1; + while(h->size != 0) { + raxNode **cp = raxNodeLastChildPtr(h); + memcpy(&h,cp,sizeof(h)); + if (h->iskey || (!h->iscompr && h->size != 1)) break; + /* Stop here if going to the next node would result into + * a compressed node larger than h->size can hold. */ + if (comprsize + h->size > RAX_NODE_MAX_SIZE) break; + nodes++; + comprsize += h->size; + } + if (nodes > 1) { + /* If we can compress, create the new node and populate it. */ + size_t nodesize = + sizeof(raxNode)+comprsize*sizeof(int)+raxPadding(comprsize)+sizeof(raxNode*); + raxNode *newNode = (raxNode *)rax_malloc(nodesize); + /* An out of memory here just means we cannot optimize this + * node, but the tree is left in a consistent state. */ + if (newNode == NULL) { + raxStackFree(&ts); + return 1; + } + newNode->iskey = 0; + newNode->isnull = 0; + newNode->iscompr = 1; + newNode->size = comprsize; + newNode->numnodes = h->numnodes+1; + newNode->timestamp = h->timestamp; + all_added_node++; + rax->numnodes++; + + /* Scan again, this time to populate the new node content and + * to fix the new node child pointer. At the same time we free + * all the nodes that we'll no longer use. */ + comprsize = 0; + h = start; + while(h->size != 0) { + memcpy((int*)newNode->data+comprsize,h->data,h->size*sizeof(int)); + comprsize += h->size; + raxNode **cp = raxNodeLastChildPtr(h); + raxNode *tofree = h; + memcpy(&h,cp,sizeof(h)); + rax_free(tofree); rax->numnodes--; + all_added_node--; + if (h->iskey || (!h->iscompr && h->size != 1)) break; + } + newNode->numnodes = start_num_nodes+all_added_node; + debugnode("New node",new); + + /* Now 'h' points to the first node that we still need to use, + * so our new node child pointer will point to it. */ + raxNode **cp = (raxNode **)raxNodeLastChildPtr(newNode); + memcpy(cp,&h,sizeof(h)); + + /* Fix parent link. */ + if (parent) { + parent->numnodes+=all_added_node; + raxNode **parentlink = raxFindParentLink(parent,start); + memcpy(parentlink,&newNode,sizeof(newNode)); + } else { + rax->head = newNode; + } + + debugf("Compressed %d nodes, %d total bytes\n", + nodes, (int)comprsize); + } + } + raxStackAddNumNodes(&ts, all_added_node); + raxStackFree(&ts); + return 1; +} + +/* This is the core of raxFree(): performs a depth-first scan of the + * tree and releases all the nodes found. */ +void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(raxNode *)) { + debugnode("free traversing",n); + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeLastChildPtr(n); + while(numchildren--) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxRecursiveFree(rax,child,free_callback); + cp--; + } + debugnode("free depth-first",n); + // if n is a key node, we need to free the data + // if n is a subtree, we need to free the custom data + if (free_callback && ((n->iskey && !n->isnull))) + free_callback(n); + rax_free(n); + rax->numnodes--; +} + +/* Free a whole radix tree, calling the specified callback in order to + * free the auxiliary data. */ +void raxFreeWithCallback(rax *rax, void (*free_callback)(raxNode *)) { + raxRecursiveFree(rax,rax->head,free_callback); + assert(rax->numnodes == 0); + rax_free(rax); +} + +/* Free a whole radix tree. */ +void raxFree(rax *rax) { + raxFreeWithCallback(rax,NULL); +} + +/* ------------------------------- Iterator --------------------------------- */ + +/* Initialize a Rax iterator. This call should be performed a single time + * to initialize the iterator, and must be followed by a raxSeek() call, + * otherwise the raxPrev()/raxNext() functions will just return EOF. */ +void raxStart(raxIterator *it, rax *rt) { + it->flags = RAX_ITER_EOF; /* No crash if the iterator is not sought. */ + it->rt = rt; + it->key_len = 0; + it->key = it->key_static_tokens; + it->add_to_subtree_list = false; + it->subtree_list = NULL; + it->subtree_data_list = NULL; + it->key_max = RAX_ITER_STATIC_LEN; + it->data = NULL; + it->node_cb = NULL; + raxStackInit(&it->stack); +} + +/* Append token at the current key string of the iterator 'it'. This + * is a low level function used to implement the iterator, not callable by + * the user. Returns 0 on out of memory, otherwise 1 is returned. */ +int raxIteratorAddToken(raxIterator *it, int *s, size_t len) { + if (it->key_max < it->key_len+len) { + int *old = (it->key == it->key_static_tokens) ? NULL : + it->key; + size_t new_max = (it->key_len+len)*2; + // update + it->key = (int *)rax_realloc(old,new_max * sizeof(int)); + if (it->key == NULL) { + it->key = (!old) ? it->key_static_tokens : old; + errno = ENOMEM; + return 0; + } + // update + if (old == NULL) memcpy(it->key,it->key_static_tokens,it->key_len * sizeof(int)); + it->key_max = new_max; + } + /* Use memmove since there could be an overlap between 's' and + * it->key when we use the current key in order to re-seek. */ + // update + memmove(it->key+it->key_len,s,len * sizeof(int)); + it->key_len += len; + return 1; +} + +/* Remove the specified number of chars from the right of the current + * iterator key. */ +void raxIteratorDelChars(raxIterator *it, size_t count) { + it->key_len -= count; +} + +/* Do an iteration step towards the next element. At the end of the step the + * iterator key will represent the (new) current key. If it is not possible + * to step in the specified direction since there are no longer elements, the + * iterator is flagged with RAX_ITER_EOF. + * + * If 'noup' is true the function starts directly scanning for the next + * lexicographically smaller children, and the current node is already assumed + * to be the parent of the last key node, so the first operation to go back to + * the parent will be skipped. This option is used by raxSeek() when + * implementing seeking a non existing element with the ">" or "<" options: + * the starting node is not a key in that particular case, so we start the scan + * from a node that does not represent the key set. + * + * The function returns 1 on success or 0 on out of memory. */ +int raxIteratorNextStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while(1) { + int children = it->node->iscompr ? 1 : it->node->size; + if (!noup && children) { + debugf("GO DEEPER\n"); + /* Seek the lexicographically smaller key in this subtree, which + * is the first one found always going towards the first child + * of every successive node. */ + if (!raxStackPush(&it->stack,it->node)) return 0; + raxNode **cp = raxNodeFirstChildPtr(it->node); + if (!raxIteratorAddToken(it,it->node->data, + it->node->iscompr ? it->node->size : 1)) return 0; + // TBD + // refactor this code with raxSerial() + // if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // std::cout << "first find subtree list is:" << std::endl; + // std::vector token; + // std::string token_str; + // for (size_t i = 0; i < it->key_len - 1; i++) { + // token.push_back(it->key[i]); + // token_str += std::to_string(it->key[i]) + " "; + // } + // LOG(INFO) << "list is:" << token_str; + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } + memcpy(&it->node,cp,sizeof(it->node)); + /* Call the node callback if any, and replace the node pointer + * if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp,&it->node,sizeof(it->node)); + /* For "next" step, stop every time we find a key along the + * way, since the key is lexicograhically smaller compared to + * what follows in the sub-children. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } else { + /* If we finished exporing the previous sub-tree, switch to the + * new one: go upper until a node is found where there are + * children representing keys lexicographically greater than the + * current key. */ + while(1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + // TBD + // refactor this code with raxSerial() + // if (it->node->iskey && it->node->size == 0 && it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // // data node is sub tree + // std::vector token; + // std::string token_str; + // for (size_t i = 0; i < it->key_len - 1; i++) { + // token.push_back(it->key[i]); + // token_str += std::to_string(it->key[i]) + " "; + // } + // LOG(INFO) << "sub tree is:" << token_str; + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } + + /* If there are no children at the current node, try parent's + * next child. */ + int prevchild = it->key[it->key_len-1]; + if (!noup) { + it->node = (raxNode *)raxStackPop(&it->stack); + } else { + noup = 0; + } + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it,todel); + + /* Try visiting the next child if there was at least one + * additional child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeFirstChildPtr(it->node); + int i = 0; + while (i < it->node->size) { + debugf("SCAN NEXT %c\n", it->node->data[i]); + if (it->node->data[i] > prevchild) break; + i++; + cp++; + } + if (i != it->node->size) { + debugf("SCAN found a new node\n"); + raxIteratorAddToken(it,it->node->data+i,1); + if (!raxStackPush(&it->stack,it->node)) return 0; + // if (it->node->issubtree && it->add_to_subtree_list && it->subtree_list != NULL && + // it->subtree_data_list != NULL) { + // std::vector token; + // for (size_t i = 0; i < it->key_len; i++) { + // token.push_back(it->key[i]); + // } + // (*it->subtree_list).push_back(token); + // void *data = raxGetCustomData(it->node); + // if (data == NULL) { + // throw std::runtime_error("custom data is null"); + // } + // (*it->subtree_data_list).push_back(data); + // } + memcpy(&it->node,cp,sizeof(it->node)); + /* Call the node callback if any, and replace the node + * pointer if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp,&it->node,sizeof(it->node)); + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + break; + } + } + } + } + } +} + +/* Seek the greatest key in the subtree at the current node. Return 0 on + * out of memory, otherwise 1. This is an helper function for different + * iteration functions below. */ +int raxSeekGreatest(raxIterator *it) { + while(it->node->size) { + if (it->node->iscompr) { + if (!raxIteratorAddToken(it,it->node->data, + it->node->size)) return 0; + } else { + if (!raxIteratorAddToken(it,it->node->data+it->node->size-1,1)) + return 0; + } + raxNode **cp = raxNodeLastChildPtr(it->node); + if (!raxStackPush(&it->stack,it->node)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + } + return 1; +} + +/* Like raxIteratorNextStep() but implements an iteration step moving + * to the lexicographically previous element. The 'noup' option has a similar + * effect to the one of raxIteratorNextStep(). */ +int raxIteratorPrevStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while(1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + + unsigned char prevchild = it->key[it->key_len-1]; + if (!noup) { + it->node = (raxNode *)raxStackPop(&it->stack); + } else { + noup = 0; + } + + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it,todel); + + /* Try visiting the prev child if there is at least one + * child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeLastChildPtr(it->node); + int i = it->node->size-1; + while (i >= 0) { + debugf("SCAN PREV %c\n", it->node->data[i]); + if (it->node->data[i] < prevchild) break; + i--; + cp--; + } + /* If we found a new subtree to explore in this node, + * go deeper following all the last children in order to + * find the key lexicographically greater. */ + if (i != -1) { + debugf("SCAN found a new node\n"); + /* Enter the node we just found. */ + if (!raxIteratorAddToken(it,it->node->data+i,1)) return 0; + if (!raxStackPush(&it->stack,it->node)) return 0; + memcpy(&it->node,cp,sizeof(it->node)); + /* Seek sub-tree max. */ + if (!raxSeekGreatest(it)) return 0; + } + } + + /* Return the key: this could be the key we found scanning a new + * subtree, or if we did not find a new subtree to explore here, + * before giving up with this node, check if it's a key itself. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } +} + +/* Seek an iterator at the specified element. + * Return 0 if the seek failed for syntax error or out of memory. Otherwise + * 1 is returned. When 0 is returned for out of memory, errno is set to + * the ENOMEM value. */ +int raxSeek(raxIterator *it, const char *op, int *ele, size_t len) { + int eq = 0, lt = 0, gt = 0, first = 0, last = 0; + + it->stack.items = 0; /* Just resetting. Initialized by raxStart(). */ + it->flags |= RAX_ITER_JUST_SEEKED; + it->flags &= ~RAX_ITER_EOF; + it->key_len = 0; + it->node = NULL; + + /* Set flags according to the operator used to perform the seek. */ + if (op[0] == '>') { + gt = 1; + if (op[1] == '=') eq = 1; + } else if (op[0] == '<') { + lt = 1; + if (op[1] == '=') eq = 1; + } else if (op[0] == '=') { + eq = 1; + } else if (op[0] == '^') { + first = 1; + } else if (op[0] == '$') { + last = 1; + } else { + errno = 0; + return 0; /* Error. */ + } + + /* If there are no elements, set the EOF condition immediately and + * return. */ + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 1; + } + + if (first) { + /* Seeking the first key greater or equal to the empty string + * is equivalent to seeking the smaller key available. */ + return raxSeek(it,">=",NULL,0); + } + + if (last) { + /* Find the greatest key taking always the last child till a + * final node is found. */ + it->node = it->rt->head; + if (!raxSeekGreatest(it)) return 0; + assert(it->node->iskey); + it->data = raxGetData(it->node); + return 1; + } + + /* We need to seek the specified key. What we do here is to actually + * perform a lookup, and later invoke the prev/next key code that + * we already use for iteration. */ + int splitpos = 0; + size_t i = raxLowWalk(it->rt,ele,len,&it->node,NULL,&splitpos,&it->stack); + + /* Return OOM on incomplete stack info. */ + if (it->stack.oom) return 0; + + if (eq && i == len && (!it->node->iscompr || splitpos == 0) && + it->node->iskey) + { + /* We found our node, since the key matches and we have an + * "equal" condition. */ + if (!raxIteratorAddToken(it,ele,len)) return 0; /* OOM. */ + it->data = raxGetData(it->node); + } else if (lt || gt) { + /* Exact key not found or eq flag not set. We have to set as current + * key the one represented by the node we stopped at, and perform + * a next/prev operation to seek. To reconstruct the key at this node + * we start from the parent and go to the current node, accumulating + * the characters found along the way. */ + if (!raxStackPush(&it->stack,it->node)) return 0; + for (size_t j = 1; j < it->stack.items; j++) { + raxNode *parent = (raxNode *)it->stack.stack[j-1]; + raxNode *child = (raxNode *)it->stack.stack[j]; + if (parent->iscompr) { + if (!raxIteratorAddToken(it,parent->data,parent->size)) + return 0; + } else { + raxNode **cp = raxNodeFirstChildPtr(parent); + int *p = parent->data; + while(1) { + raxNode *aux; + memcpy(&aux,cp,sizeof(aux)); + if (aux == child) break; + cp++; + p++; + } + if (!raxIteratorAddToken(it,p,1)) return 0; + } + } + raxStackPop(&it->stack); + + /* We need to set the iterator in the correct state to call next/prev + * step in order to seek the desired element. */ + debugf("After initial seek: i=%d len=%d key=%.*s\n", + (int)i, (int)len, (int)it->key_len, it->key); + if (i != len && !it->node->iscompr) { + /* If we stopped in the middle of a normal node because of a + * mismatch, add the mismatching character to the current key + * and call the iterator with the 'noup' flag so that it will try + * to seek the next/prev child in the current node directly based + * on the mismatching character. */ + if (!raxIteratorAddToken(it,ele+i,1)) return 0; + debugf("Seek normal node on mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (lt && !raxIteratorPrevStep(it,1)) return 0; + if (gt && !raxIteratorNextStep(it,1)) return 0; + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else if (i != len && it->node->iscompr) { + debugf("Compressed mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + /* In case of a mismatch within a compressed node. */ + int nodechar = it->node->data[splitpos]; + int keychar = ele[i]; + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (gt) { + /* If the key the compressed node represents is greater + * than our seek element, continue forward, otherwise set the + * state in order to go back to the next sub-tree. */ + if (nodechar > keychar) { + if (!raxIteratorNextStep(it,0)) return 0; + } else { + if (!raxIteratorAddToken(it,it->node->data,it->node->size)) + return 0; + if (!raxIteratorNextStep(it,1)) return 0; + } + } + if (lt) { + /* If the key the compressed node represents is smaller + * than our seek element, seek the greater key in this + * subtree, otherwise set the state in order to go back to + * the previous sub-tree. */ + if (nodechar < keychar) { + if (!raxSeekGreatest(it)) return 0; + it->data = raxGetData(it->node); + } else { + if (!raxIteratorAddToken(it,it->node->data,it->node->size)) + return 0; + if (!raxIteratorPrevStep(it,1)) return 0; + } + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else { + debugf("No mismatch: %.*s\n", + (int)it->key_len, (char*)it->key); + /* If there was no mismatch we are into a node representing the + * key, (but which is not a key or the seek operator does not + * include 'eq'), or we stopped in the middle of a compressed node + * after processing all the key. Continue iterating as this was + * a legitimate key we stopped at. */ + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (it->node->iscompr && it->node->iskey && splitpos && lt) { + /* If we stopped in the middle of a compressed node with + * perfect match, and the condition is to seek a key "<" than + * the specified one, then if this node is a key it already + * represents our match. For instance we may have nodes: + * + * "f" -> "oobar" = 1 -> "" = 2 + * + * Representing keys "f" = 1, "foobar" = 2. A seek for + * the key < "foo" will stop in the middle of the "oobar" + * node, but will be our match, representing the key "f". + * + * So in that case, we don't seek backward. */ + it->data = raxGetData(it->node); + } else { + if (gt && !raxIteratorNextStep(it,0)) return 0; + if (lt && !raxIteratorPrevStep(it,0)) return 0; + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } + } else { + /* If we are here just eq was set but no match was found. */ + it->flags |= RAX_ITER_EOF; + return 1; + } + return 1; +} + +/* Go to the next element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxNext(raxIterator *it) { + if (!raxIteratorNextStep(it,0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Go to the previous element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxPrev(raxIterator *it) { + if (!raxIteratorPrevStep(it,0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Perform a random walk starting in the current position of the iterator. + * Return 0 if the tree is empty or on out of memory. Otherwise 1 is returned + * and the iterator is set to the node reached after doing a random walk + * of 'steps' steps. If the 'steps' argument is 0, the random walk is performed + * using a random number of steps between 1 and two times the logarithm of + * the number of elements. + * + * NOTE: if you use this function to generate random elements from the radix + * tree, expect a disappointing distribution. A random walk produces good + * random elements if the tree is not sparse, however in the case of a radix + * tree certain keys will be reported much more often than others. At least + * this function should be able to expore every possible element eventually. */ +int raxRandomWalk(raxIterator *it, size_t steps) { + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 0; + } + + if (steps == 0) { + size_t file = 1+floor(log(it->rt->numele)); + file *= 2; + steps = 1 + rand() % file; + } + + raxNode *n = it->node; + while(steps > 0 || !n->iskey) { + int numchildren = n->iscompr ? 1 : n->size; + int r = rand() % (numchildren+(n != it->rt->head)); + + if (r == numchildren) { + /* Go up to parent. */ + n = (raxNode *)raxStackPop(&it->stack); + int todel = n->iscompr ? n->size : 1; + raxIteratorDelChars(it,todel); + } else { + /* Select a random child. */ + if (n->iscompr) { + if (!raxIteratorAddToken(it,n->data,n->size)) return 0; + } else { + if (!raxIteratorAddToken(it,n->data+r,1)) return 0; + } + raxNode **cp = raxNodeFirstChildPtr(n)+r; + if (!raxStackPush(&it->stack,n)) return 0; + memcpy(&n,cp,sizeof(n)); + } + if (n->iskey) steps--; + } + it->node = n; + it->data = raxGetData(it->node); + return 1; +} + +/* Compare the key currently pointed by the iterator to the specified + * key according to the specified operator. Returns 1 if the comparison is + * true, otherwise 0 is returned. */ +int raxCompare(raxIterator *iter, const char *op, int *key, size_t key_len) { + int eq = 0, lt = 0, gt = 0; + + if (op[0] == '=' || op[1] == '=') eq = 1; + if (op[0] == '>') gt = 1; + else if (op[0] == '<') lt = 1; + else if (op[1] != '=') return 0; /* Syntax error. */ + + size_t minlen = key_len < iter->key_len ? key_len : iter->key_len; + // update + int cmp = memcmp(iter->key,key,minlen * sizeof(int)); + + /* Handle == */ + if (lt == 0 && gt == 0) return cmp == 0 && key_len == iter->key_len; + + /* Handle >, >=, <, <= */ + if (cmp == 0) { + /* Same prefix: longer wins. */ + if (eq && key_len == iter->key_len) return 1; + else if (lt) return iter->key_len < key_len; + else if (gt) return iter->key_len > key_len; + else return 0; /* Avoid warning, just 'eq' is handled before. */ + } else if (cmp > 0) { + return gt ? 1 : 0; + } else /* (cmp < 0) */ { + return lt ? 1 : 0; + } +} + +/* Free the iterator. */ +void raxStop(raxIterator *it) { + if (it->key != it->key_static_tokens) rax_free(it->key); + raxStackFree(&it->stack); +} + +/* Return if the iterator is in an EOF state. This happens when raxSeek() + * failed to seek an appropriate element, so that raxNext() or raxPrev() + * will return zero, or when an EOF condition was reached while iterating + * with raxNext() and raxPrev(). */ +int raxEOF(raxIterator *it) { + return it->flags & RAX_ITER_EOF; +} + +/* Return the number of elements inside the radix tree. */ +uint64_t raxSize(rax *rax) { + return rax->numele; +} + +/* ----------------------------- Introspection ------------------------------ */ + +/* This function is mostly used for debugging and learning purposes. + * It shows an ASCII representation of a tree on standard output, outling + * all the nodes and the contained keys. + * + * The representation is as follow: + * + * [1,2,3,4] (compressed node) + * [5,6] (normal node with three children) + * [5,6]=0x12345678 (node is a key, pointing to value 0x12345678) + * [] (a normal empty node) + * + * Children are represented in new idented lines, each children prefixed by + * the "`-(x)" string, where "x" is the edge byte. + * + * [1,2,3] + * `-(1) [1,1,1,1] + * `-(2) [3,4] + * `-(3) [] + * + * However when a node has a single child the following representation + * is used instead: + * + * [1,2] -> [1,2,3,4] -> [] + */ + +struct DebugDatawrapper { + void *data; + int length; +}; + +struct DebugTreeData { + union { + void* kvStateCacheBlockBuilder; + uint64_t builderObjectID; + }; + bool isPtr = true; +}; +/* The actual implementation of raxShow(). */ +void raxRecursiveShow(int level, int lpad, raxNode *n) { + char s = n->iscompr ? '"' : '['; + char e = n->iscompr ? '"' : ']'; + + int numchars = printf("%c", s); + for (int i = 0; i < n->size; i++) { + numchars += printf("%d ", n->data[i]); + } + numchars += printf("%c %d ", e, n->numnodes); + + if (n->issubtree) { + numchars += printf("# "); + printf(" %p ", n); + } + if (n->iskey) { + numchars += printf("=%p",raxGetData(n)); + } + numchars += printf(" node:%p time:%ld, data:%p, is_sub_tree:%d is_compr:%d", n, n->timestamp, n->custom_data, n->issubtree, n->iscompr); + if (n->issubtree && n->custom_data != NULL) { + numchars += printf(" cus data:%p" , ((DebugDatawrapper *)(n->custom_data))->data); + DebugTreeData *data = (DebugTreeData *)((DebugDatawrapper *)(n->custom_data))->data; + if (data) { + if (data->isPtr) { + numchars += printf(" builder ptr:%p", data->kvStateCacheBlockBuilder); + } else { + numchars += printf(" builder id:%lu", data->builderObjectID); + } + } + } + + int numchildren = n->iscompr ? 1 : n->size; + /* Note that 7 and 4 magic constants are the string length + * of " `-(x) " and " -> " respectively. */ + if (level) { + lpad += (numchildren > 1) ? 7 : 4; + if (numchildren == 1) lpad += numchars; + } + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + const char *branch = " `-(%d) "; + if (numchildren > 1) { + printf("\n"); + for (int j = 0; j < lpad; j++) putchar(' '); + printf(branch,n->data[i]); + } else { + printf(" -> "); + } + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxRecursiveShow(level+1,lpad,child); + cp++; + } +} + + +/* Show a tree, as outlined in the comment above. */ +void raxShow(rax *rax) { + printf("rax numnode:%lu\n", rax->numele); + raxRecursiveShow(0,0,rax->head); + putchar('\n'); +} + +/* Used by debugnode() macro to show info about a given node. */ +void raxDebugShowNode(const char *msg, raxNode *n) { + if (raxDebugMsg == 0) return; + printf("%s: %p [ ",msg, (void*)n); + + for (int i=0;i<(int)n->size;i++) { + printf("%d,", *(int*)(n->data+i)); + } + printf("],key:%d size:%d children:",n->iskey, n->size); + int numcld = n->iscompr ? 1 : n->size; + raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld-1); + while(numcld--) { + raxNode *child; + memcpy(&child,cldptr,sizeof(child)); + cldptr++; + printf("%p ", (void*)child); + } + printf("\n"); + fflush(stdout); +} + +/* Touch all the nodes of a tree returning a check sum. This is useful + * in order to make Valgrind detect if there is something wrong while + * reading the data structure. + * + * This function was used in order to identify Rax bugs after a big refactoring + * using this technique: + * + * 1. The rax-test is executed using Valgrind, adding a printf() so that for + * the fuzz tester we see what iteration in the loop we are in. + * 2. After every modification of the radix tree made by the fuzz tester + * in rax-test.c, we add a call to raxTouch(). + * 3. Now as soon as an operation will corrupt the tree, raxTouch() will + * detect it (via Valgrind) immediately. We can add more calls to narrow + * the state. + * 4. At this point a good idea is to enable Rax debugging messages immediately + * before the moment the tree is corrupted, to see what happens. + */ +unsigned long raxTouch(raxNode *n) { + debugf("Touching %p\n", (void*)n); + unsigned long sum = 0; + if (n->iskey) { + sum += (unsigned long)raxGetData(n); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + int count = 0; + for (int i = 0; i < numchildren; i++) { + if (numchildren > 1) { + sum += (long)n->data[i]; + } + raxNode *child; + memcpy(&child,cp,sizeof(child)); + if (child == (void*)0x65d1760) count++; + if (count > 1) exit(1); + sum += raxTouch(child); + cp++; + } + return sum; +} + +/* +* Traverse the tree and collect all the nodes that contain data. +* these nodes are stored in the dataNodeList +*/ +void raxTraverse(raxNode *n, std::vector> &dataNodeList) { + if (n->iskey) { + dataNodeList.push_back(std::shared_ptr(n, [](raxNode*){})); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + raxTraverse(child, dataNodeList); + cp++; + } +} + +/* +* Set a node as a subtree root node +*/ +void raxSetSubtree(raxNode *node) { + node->issubtree = 1; +} + +/* +* Check if a node is a subtree root node +*/ +bool raxIsSubtree(raxNode *node) { + if (node->issubtree) { + return true; + } + return false; +} + +/* +* Split the tree into two sub trees, and return the root node of the new sub tree +* +* Input a token list, and split the tree into two sub trees via the token list +* It will find the node that nearest N/2 but not more than N/2, and split the tree +* into two sub trees. If there is no node that has N/2 children, it will split the +* tree from the root node. +* +*/ +raxNode *raxSplit(rax *rax, int *s, size_t len, std::vector& token) { + raxNode *childNode = NULL; + raxNode *splitNode = NULL; + raxStack stack = raxFindWithStack(rax, s, len); + int items = stack.items; + int subtreeNumNodes = 0; + // find the latest subtree root node + int index = stack.items - 1; + while (index >= 0) { + raxNode *node = (raxNode *)stack.stack[index]; + if (node->issubtree) { + subtreeNumNodes = node->numnodes; + break; + } + index--; + } + + // find the node that has N/2 children + while (items > 0) { + raxNode *node = (raxNode *)raxStackPop(&stack); + if (node->numnodes > (uint32_t)subtreeNumNodes/2 || node->issubtree) { + splitNode = childNode; + raxStackPush(&stack, node); + break; + } + childNode = node; + items--; + } + + raxIterator iter; + raxStart(&iter, rax); + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + if (iter.node == splitNode) { + for (size_t i = 0; i < iter.key_len; i++) { + token.push_back(iter.key[i]); + } + } + } + std::string token_str; + for (size_t i = 0; i < token.size(); i++) { + token_str += std::to_string(token[i]); + token_str += " "; + } + VLOG(100) << "split token: " << token_str; + + // if the splitNode is NULL, it means that the tree only has one node + if (splitNode == NULL) { + // if the stack is not empty, it means that top of the stack is the target node + if (stack.items > 0) { + return (raxNode *)raxStackPop(&stack); + } else { + return rax->head; + } + } + + raxSetSubtree(splitNode); + + raxStackAddNumNodes(&stack, -(int)(splitNode->numnodes)); + raxStackFree(&stack); + + return splitNode; +} + +/* +* Traverse the subtree and return all the nodes that contain data under the subtree +* these nodes are stored in the dataNodeList +*/ + +void raxTraverseSubTree(raxNode *n, std::vector &dataNodeList) { + if (n->iskey) { + dataNodeList.push_back(n); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + raxNode *child; + memcpy(&child,cp,sizeof(child)); + if (!child->issubtree) { + raxTraverseSubTree(child, dataNodeList); + } + cp++; + } +} + +void raxSerialize(rax *root, std::vector> &tokenList, std::vector &dataList, std::vector ×tampList, + std::vector> *subtreeList, std::vector *subtreeDataList) { + raxIterator iter; + raxStart(&iter, root); + iter.add_to_subtree_list = 1; + iter.subtree_list = subtreeList; + iter.subtree_data_list = subtreeDataList; + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + std::vector token; + for (size_t i = 0; i < iter.key_len; i++) { + token.push_back(iter.key[i]); + } + tokenList.push_back(token); + dataList.push_back(iter.data); + timestampList.push_back(iter.node->timestamp); + raxNode *data = raxFindAndReturnDataNode(root, iter.key, iter.key_len, nullptr, false); + if (data->issubtree && subtreeList != nullptr) { + subtreeList->push_back(token); + subtreeDataList->push_back(data->custom_data); + } + } + raxStop(&iter); +} + +void raxFindLastRecentNode(raxNode *node, std::vector& key) { + raxNode** childList = raxNodeFirstChildPtr(node); + + // node must have a key. + // assert(node->iskey == 1); + int numChildren = node->iscompr ? 1 : node->size; + if (numChildren == 0) { + // has no children, return + return; + } + + raxNode *chosenChild = childList[0]; + int chosenChildIndex = 0; + for (int i = 1; i < numChildren; i++) { + if (childList[i]->timestamp == chosenChild->timestamp) { + if (childList[i]->numnodes > chosenChild->numnodes) { + VLOG(100) << "childList[i]->numnodes > chossenChild->numnodes"; + VLOG(100) << "node1:" << childList[i] << " node:2" << chosenChild; + chosenChild = childList[i]; + chosenChildIndex = i; + } + } else if (childList[i]->timestamp < chosenChild->timestamp) { + chosenChild = childList[i]; + chosenChildIndex = i; + } + } + + if (node->iscompr) { + for (int i = 0; i < node->size; i++) { + key.push_back(node->data[i]); + } + } else { + key.push_back(node->data[chosenChildIndex]); + } + + raxFindLastRecentNode(chosenChild, key); +} + +bool compareKey(int *first_key, int *second_key, int first_key_len, int second_key_len) { + if (first_key_len != second_key_len) { + printf("length not equal, %d : %d\n", first_key_len, second_key_len); + return false; + } + for (int i = 0; i < first_key_len; i++) { + if (first_key[i] != second_key[i]) { + printf("key not equal, %d : %d\n", first_key[i], second_key[i]); + return false; + } + } + return true; +} + +bool compare(raxIterator a, raxIterator b) { + if (a.key_len == b.key_len) { + return a.node->timestamp > b.node->timestamp; + } + return a.key_len < b.key_len; +} + +void sortNode(std::vector &ite_list) { + std::sort(ite_list.begin(), ite_list.end(), compare); +} + +void printVector(int* v, int size) { + printf("token:\n"); + for (int i = 0; i < size; i++) { + printf("%d " , v[i]); + } + printf("\n"); +} + +void freeVector(std::vector &ite_list) { + for (size_t i = 0; i < ite_list.size(); i++) { + free(ite_list[i].key); + } +} + +void mergeTree(rax* first_tree, rax* second_tree, + std::vector>& evicted_tokens, + std::set>& insert_tokens, int max_node) { + printf("merge tree!\n"); + VLOG(100) << "==============tree 1===================="; + //raxShow(first_tree); + VLOG(100) << "==============tree 2===================="; + //raxShow(second_tree); + raxIterator first_tree_iter; + raxIterator second_tree_iter; + rax* tmp = raxNew(); + + raxStart(&first_tree_iter, first_tree); + raxStart(&second_tree_iter, second_tree); + raxSeek(&first_tree_iter, "^", NULL, 0); + raxSeek(&second_tree_iter, "^", NULL, 0); + + std::vector first_tree_iter_list; + std::vector second_tree_iter_list; + while (raxNext(&first_tree_iter)) { + raxIterator tmp_iter = first_tree_iter; + tmp_iter.key = (int*) malloc(sizeof(int) * first_tree_iter.key_len); + memcpy(tmp_iter.key, first_tree_iter.key, + sizeof(int) * first_tree_iter.key_len); + first_tree_iter_list.push_back(tmp_iter); + } + + while (raxNext(&second_tree_iter)) { + raxIterator tmp_iter = second_tree_iter; + tmp_iter.key = (int*) malloc(sizeof(int) * second_tree_iter.key_len); + memcpy(tmp_iter.key, second_tree_iter.key, + sizeof(int) * second_tree_iter.key_len); + second_tree_iter_list.push_back(tmp_iter); + } + + for (size_t i = 0; i < first_tree_iter_list.size(); i++) { + printVector(first_tree_iter_list[i].key, first_tree_iter_list[i].key_len); + } + + for (size_t i = 0; i < second_tree_iter_list.size(); i++) { + printVector(second_tree_iter_list[i].key, second_tree_iter_list[i].key_len); + } + + // Sort by the length of the key, or timestamp if the keys have the same length. + sortNode(first_tree_iter_list); + sortNode(second_tree_iter_list); + + size_t first_tree_index = 0; + size_t second_tree_index = 0; + int nodeCount = 0; + + /** + * We use two structures to store the nodes chosen from the second tree + * and the nodes evicted from the first tree. + */ + while (nodeCount < max_node) { + if (first_tree_index == first_tree_iter_list.size() || + second_tree_index == second_tree_iter_list.size()) { + break; + } + printf("nodeCount: %d\n", nodeCount); + + /** + * If the key is the same, use the larger timestamp to refresh + * the timestamp of the key in the first tree. If the key is not + * the same, choose the key with the larger timestamp. + */ + if (compareKey(first_tree_iter_list[first_tree_index].key, + second_tree_iter_list[second_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + second_tree_iter_list[second_tree_index].key_len)) { + // same key + printf("same key\n"); + first_tree_iter_list[first_tree_index].node->timestamp = + std::max(first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + first_tree_iter_list[first_tree_index].node->timestamp = + std::max(first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + first_tree_index++; + second_tree_index++; + nodeCount++; + } else if (first_tree_iter_list[first_tree_index].node->timestamp > + second_tree_iter_list[second_tree_index].node->timestamp) { + /** + * Choose first tree node. + * If the key is in the record tree, it means that there exist a same key in + * the second tree and has been chosen in the past. So we just need to remove + * the key from the insert_tokens and update the timestamp of the key in the + * first tree. + * If the key is not in the record tree, it means that the key has not been + * chosen in the past. So we need to insert the key into the record tree. + */ + printf("choose first key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == raxNotFound) { + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + nodeCount++; + } else { + std::vector token = std::vector(first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, NULL, false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } else if (first_tree_iter_list[first_tree_index].node->timestamp < + second_tree_iter_list[second_tree_index].node->timestamp) { + /** + * Choose second tree node. + * If the key is in the record tree, it means that there exist a same key in + * the first tree and has been chosen in the past. So we need do nothing. + * If the key is not in the record tree, it means that the key has not been + * chosen in the past. So we need to insert the key into the record tree. + * and insert the key into the insert_tokens. + */ + printf("choose second key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose second key + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + nodeCount++; + } + second_tree_index++; + } else { + /** + * If the key is not same and the timestamp is the same, we choose the key + * with the smaller children number. + */ + if (first_tree_iter_list[first_tree_index].node->numnodes <= + second_tree_iter_list[second_tree_index].node->numnodes) { + printf("choose first key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose first key + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == raxNotFound) { + raxInsert(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + first_tree_iter_list[first_tree_index].data, NULL); + nodeCount++; + } else { + std::vector token = std::vector(first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } else { + printf("choose second key %ld : %ld\n", + first_tree_iter_list[first_tree_index].node->timestamp, + second_tree_iter_list[second_tree_index].node->timestamp); + // choose second key + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + nodeCount++; + } + second_tree_index++; + } + } + } + + if (nodeCount == max_node) { + printf("insert evicted tokens\n"); + int evicted_node_count = 0; + while (first_tree_index < first_tree_iter_list.size()) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + std::vector evicted_token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + evicted_tokens.push_back(evicted_token); + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + evicted_node_count++; + first_tree_index++; + } + printf("evicted_node_count: %d\n", evicted_node_count); + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); + return; + } + + // first_ret and second_ret both are not 0 is the case that nodeCount == + // max_node + if (first_tree_index >= first_tree_iter_list.size() && + second_tree_index >= second_tree_iter_list.size()) { + // both tree are empty + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); + return; + } else if (first_tree_index >= first_tree_iter_list.size()) { + // first tree is empty + while (second_tree_index < second_tree_iter_list.size() && + nodeCount < max_node) { + if (raxFind(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len) == raxNotFound) { + std::vector insert_token( + second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key + + second_tree_iter_list[second_tree_index].key_len); + insert_tokens.insert(insert_token); + + raxInsert(tmp, second_tree_iter_list[second_tree_index].key, + second_tree_iter_list[second_tree_index].key_len, + second_tree_iter_list[second_tree_index].data, NULL); + + nodeCount++; + } + second_tree_index++; + } + } else if (second_tree_index >= second_tree_iter_list.size()) { + // second tree is empty + //raxShow(tmp); + printf("nodeCount:%d\n", nodeCount); + printf("first_tree_index:%ld\n", first_tree_index); + while (first_tree_index < first_tree_iter_list.size() && + nodeCount < max_node) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + nodeCount++; + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } + while (first_tree_index < first_tree_iter_list.size()) { + if (raxFind(tmp, first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len) == + raxNotFound) { + std::vector evicted_token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + evicted_tokens.push_back(evicted_token); + } else { + std::vector token( + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key + + first_tree_iter_list[first_tree_index].key_len); + insert_tokens.erase(token); + raxNode* node = raxFindAndReturnDataNode(second_tree, + first_tree_iter_list[first_tree_index].key, + first_tree_iter_list[first_tree_index].key_len, + NULL, + false); + first_tree_iter_list[first_tree_index].node->timestamp = node->timestamp; + } + first_tree_index++; + } + } + freeVector(first_tree_iter_list); + freeVector(second_tree_iter_list); + raxFree(tmp); +} + +void testIteRax(rax *tree) { + raxIterator iter; + raxStart(&iter, tree); + raxSeek(&iter, "^", NULL, 0); + while (raxNext(&iter)) { + printf("key: "); + for (size_t i = 0; i < iter.key_len; i++) { + printf("%d ", iter.key[i]); + } + printf("\n"); + // printf("data: %p\n", iter.data); + } + raxStop(&iter); +} + +raxNode* raxGetFirstChildPtr(raxNode* node) { + return raxGetFirstChildPtr(node); +} + +// 1 2 3 +// query subtree node:0x55f87076f760 +// I0129 16:44:25.626318 280948 kv_state_cache.cc:223] offset:0 +// I0129 16:44:25.626322 280948 kv_state_cache.cc:224] kvStateCacheBlockBuilder:0x55f870767bc0 \ No newline at end of file diff --git a/thirdparty/rax/radix.h b/thirdparty/rax/radix.h new file mode 100644 index 00000000..3590fb77 --- /dev/null +++ b/thirdparty/rax/radix.h @@ -0,0 +1,267 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017-2018, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef RADIX_H +#define RADIX_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Representation of a radix tree as implemented in this file, that contains + * the token lists [1, 2, 3], [1, 2, 3, 4, 5, 6] and [1, 2, 3, 6, 7, 8] after + * the insertion of each token list. When the node represents a key inside + * the radix tree, we write it between [], otherwise it is written between (). + * + * This is the vanilla representation: + * + * (1) [] + * \ + * (2) [1] + * \ + * (3) [1,3] + * \ + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] (5) (7) [1,2,3,6] + * / \ + * [1,2,3,4,5] (6) (8) [1,2,3,6,7] + * / \ + *[1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * However, this implementation implements a very common optimization where + * successive nodes having a single child are "compressed" into the node + * itself as a list of tokens, each representing a next-level child, + * and only the link to the node representing the last token node is + * provided inside the representation. So the above representation is turned + * into: + * + * ([1,2,3]) [] + * | + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] ([5,6]) ([7,8]) [1,2,3,6] + * / \ + * [1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * However this optimization makes the implementation a bit more complex. + * For instance if a token list [1,1,2] is added in the above radix tree, a + * "node splitting" operation is needed, since the [1,2,3] prefix is no longer + * composed of nodes having a single child one after the other. This is the + * above tree and the resulting node splitting after this event happens: + * + * + * (1) [] + * / \ + * [1] ([1,2) ([2,3]) [1] + * \ + * [4 6] [1,2,3] + * / \ + * [1,2,3,4] ([5,6]) ([7,8]) [1,2,3,6] + * / \ + * [1,2,3,4,5,6] [] [] [1,2,3,6,7,8] + * + * + * Similarly after deletion, if a new chain of nodes having a single child + * is created (the chain must also not include nodes that represent keys), + * it must be compressed back into a single node. + * + */ + +#define RAX_NODE_MAX_SIZE 1024 +typedef struct raxNode { + uint32_t iskey : 1; /* Does this node contain a key? */ + uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ + uint32_t iscompr : 1; /* Node is compressed. */ + uint32_t issubtree : 1; /* Node is the root node of a sub tree */ + uint32_t size : 26; /* Number of children, or compressed string len. */ + uint32_t numnodes; /* Number of the child nodes */ + uint32_t numele; /* Number of elements inside this node. */ + uint64_t timestamp; /* Timestamps of the node */ + uint32_t sub_tree_size; /* Number of nodes in the sub tree */ + void* custom_data; + /* Data layout is as follows: + * + * If node is not compressed we have 'size' bytes, one for each children + * token, and 'size' raxNode pointers, point to each child node. + * Note how the character is not stored in the children but in the + * edge of the parents: + * + * [header iscompr=0][1,2,3][1-ptr][2-ptr][3-ptr](value-ptr?) + * + * if node is compressed (iscompr bit is 1) the node has 1 children. + * In that case the 'size' bytes of the string stored immediately at + * the start of the data section, represent a sequence of successive + * nodes linked one after the other, for which only the last one in + * the sequence is actually represented as a node, and pointed to by + * the current compressed node. + * + * [header iscompr=1][1,2,3][3-ptr](value-ptr?) + * + * Both compressed and not compressed nodes can represent a key + * with associated data in the radix tree at any level (not just terminal + * nodes). + * + * If the node has an associated key (iskey=1) and is not NULL + * (isnull=0), then after the raxNode pointers pointing to the + * children, an additional value pointer is present (as you can see + * in the representation above as "value-ptr" field). + */ + int data[]; +} raxNode; + +typedef struct rax { + raxNode* head; + raxNode* headDataNode; + uint64_t numele; + uint64_t numnodes; +} rax; + +/* Stack data structure used by raxLowWalk() in order to, optionally, return + * a list of parent nodes to the caller. The nodes do not have a "parent" + * field for space concerns, so we use the auxiliary stack when needed. */ +#define RAX_STACK_STATIC_ITEMS 32 +typedef struct raxStack { + void** stack; /* Points to static_items or an heap allocated array. */ + size_t items, maxitems; /* Number of items contained and total space. */ + /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap + * and use this static array of pointers instead. */ + void* static_items[RAX_STACK_STATIC_ITEMS]; + int oom; /* True if pushing into this stack failed for OOM at some point. */ +} raxStack; + +/* Optional callback used for iterators and be notified on each rax node, + * including nodes not representing keys. If the callback returns true + * the callback changed the node pointer in the iterator structure, and the + * iterator implementation will have to replace the pointer in the radix tree + * internals. This allows the callback to reallocate the node to perform + * very special operations, normally not needed by normal applications. + * + * This callback is used to perform very low level analysis of the radix tree + * structure, scanning each possible node (but the root node), or in order to + * reallocate the nodes to reduce the allocation fragmentation (this is the + * Redis application for this callback). + * + * This is currently only supported in forward iterations (raxNext) */ +typedef int (*raxNodeCallback)(raxNode** noderef); + +/* Radix tree iterator state is encapsulated into this data structure. */ +#define RAX_ITER_STATIC_LEN 128 +#define RAX_ITER_JUST_SEEKED \ + (1 << 0) /* Iterator was just seeked. Return current \ + element for the first iteration and \ + clear the flag. */ +#define RAX_ITER_EOF (1 << 1) /* End of iteration reached. */ +#define RAX_ITER_SAFE \ + (1 << 2) /* Safe iterator, allows operations while \ + iterating. But it is slower. */ +typedef struct raxIterator { + int flags; + rax* rt; /* Radix tree we are iterating. */ + int* key; /* The current string. */ + void* data; /* Data associated to this key. */ + size_t key_len; /* Current key length. */ + size_t key_max; /* Max key len the current key buffer can hold. */ + int key_static_tokens[RAX_ITER_STATIC_LEN]; + bool add_to_subtree_list; /* Whether to add the current node to the subtree + list. */ + std::vector>* subtree_list; /* List of subtrees. */ + std::vector* subtree_data_list; /* List of subtrees' data. */ + raxNode* node; /* Current node. Only for unsafe iteration. */ + raxStack stack; /* Stack used for unsafe iteration. */ + raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ +} raxIterator; + +/* A special pointer returned for not found items. */ +extern void* raxNotFound; + +/* Exported API. */ +rax* raxNew(void); +int raxInsert(rax* rax, int* s, size_t len, void* data, void** old, + bool set_timestamp = true); +int raxTryInsert(rax* rax, int* s, size_t len, void* data, void** old); +int raxInsertAndReturnDataNode(rax* rax, int* s, size_t len, void* data, + void** node, void** old); +int raxRemove(rax* rax, int* s, size_t len, void** old, + bool set_timestamp = true); +void* raxFind(rax* rax, int* s, size_t len); +raxNode* raxFindAndReturnDataNode(rax* rax, int* s, size_t len, + raxNode** sub_tree_node = NULL, + bool set_timestamp = true); +void raxSetSubtree(raxNode* n); +void raxSetSubtreeAllocated(raxNode* node); +void raxSetSubtreeNotNull(raxNode* node); +int raxFindNodeWithParent(rax* rax, int* s, size_t len, void** node, + void** parent); +void raxFree(rax* rax); +void raxFreeWithCallback(rax* rax, void (*free_callback)(raxNode*)); +void raxStart(raxIterator* it, rax* rt); +int raxSeek(raxIterator* it, const char* op, int* ele, size_t len); +int raxNext(raxIterator* it); +int raxPrev(raxIterator* it); +int raxRandomWalk(raxIterator* it, size_t steps); +int raxCompare(raxIterator* iter, const char* op, int* key, size_t key_len); +void raxStop(raxIterator* it); +int raxEOF(raxIterator* it); +void raxShow(rax* rax); +uint64_t raxSize(rax* rax); +void raxSetCustomData(raxNode* n, void* data); +void* raxGetCustomData(raxNode* n); +unsigned long raxTouch(raxNode* n); +void raxSetDebugMsg(int onoff); +void raxTraverse(raxNode* rax, + std::vector>& dataNodeList); +void raxTraverseSubTree(raxNode* n, std::vector& dataNodeList); +raxNode* raxSplit(rax* rax, int* s, size_t len, std::vector& key); +void raxSerialize(rax* root, std::vector>& tokenList, + std::vector& dataList, + std::vector& timestampsList, + std::vector>* subtreeList, + std::vector* subtreeNodeList); + +/* Internal API. May be used by the node callback in order to access rax nodes + * in a low level way, so this function is exported as well. */ +void raxSetData(raxNode* n, void* data); +void* raxGetData(raxNode* n); +int raxFindNode(rax* rax, int* s, size_t len, void** node); +void raxFindLastRecentNode(raxNode* node, std::vector& key); +void mergeTree(rax* first_tree, rax* second_tree, + std::vector>& evicted_tokens, + std::set>& insert_tokens, int max_node); +void testIteRax(rax* tree); +raxNode* raxGetFirstChildPtr(raxNode* node); +// raxNode *raxSetSubTreeAndReturnDataNode(rax *rax, int *s, size_t len); +#endif diff --git a/thirdparty/rax/rax_malloc.h b/thirdparty/rax/rax_malloc.h new file mode 100644 index 00000000..fdd2430b --- /dev/null +++ b/thirdparty/rax/rax_malloc.h @@ -0,0 +1,44 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/* Allocator selection. + * + * This file is used in order to change the Rax allocator at compile time. + * Just define the following defines to what you want to use. Also add + * the include of your alternate allocator if needed (not needed in order + * to use the default libc allocator). */ + +#ifndef MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#define MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_ +#define rax_malloc malloc +#define rax_realloc realloc +#define rax_free free + +#endif // MODULES_LLM_CACHE_RADIX_TREE_RAX_MALLOC_H_