From db64188c662299da7278422f10be4a73ae1e9f71 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Thu, 15 Aug 2024 10:58:03 +0800 Subject: [PATCH] KV cache module optimization and bug fix for rdma. (#1984) Signed-off-by: vegetableysm --- modules/llm-cache/ds/vineyard_file.cc | 4 +++ modules/llm-cache/storage/file_storage.cc | 18 ++++++++++ modules/llm-cache/storage/file_storage.h | 2 ++ .../llm-cache/storage/vineyard_file_storage.h | 4 +-- src/client/rpc_client.cc | 2 +- src/common/rdma/util.h | 3 +- src/server/async/rpc_server.cc | 36 ++++++++++++++++--- src/server/async/rpc_server.h | 2 ++ src/server/util/remote.cc | 2 +- 9 files changed, 64 insertions(+), 9 deletions(-) diff --git a/modules/llm-cache/ds/vineyard_file.cc b/modules/llm-cache/ds/vineyard_file.cc index c2b4a7ab5..47de9f9b7 100644 --- a/modules/llm-cache/ds/vineyard_file.cc +++ b/modules/llm-cache/ds/vineyard_file.cc @@ -273,6 +273,7 @@ std::shared_ptr VineyardFileBuilder::SealAndPersist( ObjectMeta blob_meta; if (ipc_client.Connected()) { std::shared_ptr object; + writer_->Shrink(ipc_client, writer_->size()); writer_->Seal(ipc_client, object); blob_meta = object->meta(); ipc_client.Persist(blob_meta.GetId()); @@ -282,6 +283,7 @@ std::shared_ptr VineyardFileBuilder::SealAndPersist( } vineyardFile->meta_.AddMember("buffer", blob_meta); vineyardFile->meta_.AddKeyValue("path", path_); + vineyardFile->meta_.AddKeyValue("size", Size()); vineyardFile->meta_.SetTypeName(type_name()); auto access_time = std::chrono::system_clock::now().time_since_epoch(); @@ -312,6 +314,7 @@ std::vector> VineyardFileBuilder::BatchedSealAndPersist( if (ipc_client.Connected()) { for (auto builder : builders) { std::shared_ptr object; + builder->writer_->Shrink(ipc_client, builder->writer_->size()); builder->writer_->Seal(ipc_client, object); blob_metas.push_back(object->meta()); } @@ -334,6 +337,7 @@ std::vector> VineyardFileBuilder::BatchedSealAndPersist( } vineyard_file->meta_.AddMember("buffer", blob_metas[i]); vineyard_file->meta_.AddKeyValue("path", builders[i]->path_); + vineyard_file->meta_.AddKeyValue("size", builders[i]->Size()); vineyard_file->meta_.SetTypeName(type_name()); auto access_time = std::chrono::system_clock::now().time_since_epoch(); diff --git a/modules/llm-cache/storage/file_storage.cc b/modules/llm-cache/storage/file_storage.cc index 2ce9b570a..28168705f 100644 --- a/modules/llm-cache/storage/file_storage.cc +++ b/modules/llm-cache/storage/file_storage.cc @@ -95,6 +95,12 @@ Status FileStorage::Update( " should be multiple of batch size " + std::to_string(chunkSize) + "!"); } + if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) { + LOG(WARNING) + << "The token list size is larger than the maximum cache token " + "length. This token list will be ignored!"; + return Status::OK(); + } std::vector pathList; std::set createFileSet; @@ -281,6 +287,12 @@ Status FileStorage::Update( " should be multiple of batch size " + std::to_string(chunkSize) + "!"); } + if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) { + LOG(WARNING) + << "The token list size is larger than the maximum cache token " + "length. This token list will be ignored!"; + return Status::OK(); + } std::vector pathList; std::set createFileSet; @@ -427,6 +439,12 @@ Status FileStorage::BatchedUpdate( " should be multiple of batch size " + std::to_string(chunkSize) + "!"); } + if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) { + LOG(WARNING) + << "The token list size is larger than the maximum cache token " + "length. This token list will be ignored!"; + return Status::OK(); + } std::vector pathList; std::set createFileSet; diff --git a/modules/llm-cache/storage/file_storage.h b/modules/llm-cache/storage/file_storage.h index dd74f42ef..1a8cc89cf 100644 --- a/modules/llm-cache/storage/file_storage.h +++ b/modules/llm-cache/storage/file_storage.h @@ -34,6 +34,8 @@ limitations under the License. #define SECOND_TO_MICROSECOND 1000000 #define SECOND_TO_NANOSECOND 1000000000 +#define MAX_CACHE_TOKEN_LENGTH 65536 + namespace vineyard { struct FileDescriptor {}; diff --git a/modules/llm-cache/storage/vineyard_file_storage.h b/modules/llm-cache/storage/vineyard_file_storage.h index 5911c5c6d..b967cbc8a 100644 --- a/modules/llm-cache/storage/vineyard_file_storage.h +++ b/modules/llm-cache/storage/vineyard_file_storage.h @@ -62,8 +62,8 @@ class VineyardFileStorage : public FileStorage { this->globalGCInterval = std::chrono::seconds(globalGCInterval); this->globalFileTTL = std::chrono::seconds(globalTTL); this->enableGlobalGC = enableGlobalGC; - this->max_file_size_ = - tensorNBytes * 2 * layer * chunkSize + 65536 * sizeof(int); + this->max_file_size_ = tensorNBytes * 2 * layer * chunkSize + + MAX_CACHE_TOKEN_LENGTH * sizeof(int); } ~VineyardFileStorage() = default; diff --git a/src/client/rpc_client.cc b/src/client/rpc_client.cc index 8ca751cc7..0ed089a8a 100644 --- a/src/client/rpc_client.cc +++ b/src/client/rpc_client.cc @@ -258,7 +258,7 @@ Status RPCClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) { void* buffer; RETURN_ON_ERROR(this->rdma_client_->GetTXFreeMsgBuffer(buffer)); VineyardMsg* msg = reinterpret_cast(buffer); - msg->type = VINEYARD_RELEASE_MEM; + msg->type = VINEYARD_MSG_RELEASE_MEM; msg->remoteMemInfo.remote_address = (uint64_t) remote_info.address; msg->remoteMemInfo.len = remote_info.size; msg->remoteMemInfo.mr_desc = remote_info.mr_desc; diff --git a/src/common/rdma/util.h b/src/common/rdma/util.h index 6a3ef20ba..f6247a3ac 100644 --- a/src/common/rdma/util.h +++ b/src/common/rdma/util.h @@ -51,10 +51,11 @@ namespace vineyard { #endif // defined(__linux__) enum VINEYARD_MSG_OPT { + VINEYARD_MSG_EMPTY = 1, VINEYARD_MSG_CONNECT, VINEYARD_MSG_EXCHANGE_KEY, VINEYARD_MSG_REQUEST_MEM, - VINEYARD_RELEASE_MEM, + VINEYARD_MSG_RELEASE_MEM, VINEYARD_MSG_CLOSE, }; diff --git a/src/server/async/rpc_server.cc b/src/server/async/rpc_server.cc index 90fe1fbe3..9b4af00d2 100644 --- a/src/server/async/rpc_server.cc +++ b/src/server/async/rpc_server.cc @@ -245,11 +245,11 @@ void RPCServer::doVineyardRequestMemory(VineyardRecvContext* recv_context, send_msg->remoteMemInfo.len = 0; VineyardSendContext* send_context = new VineyardSendContext(); - memset(&send_context->attr, 0, sizeof(send_context->attr)); + memset(send_context, 0, sizeof(VineyardSendContext)); send_context->attr.msg_buffer = msg; rdma_server_->Send(recv_context->rdma_conn_id, recv_context->attr.msg_buffer, sizeof(VineyardMsg), - recv_context); + send_context); return; } @@ -272,7 +272,7 @@ void RPCServer::doVineyardRequestMemory(VineyardRecvContext* recv_context, send_msg->remoteMemInfo.mr_desc = remote_request_mem_info.mr_desc; VineyardSendContext* send_context = new VineyardSendContext(); - memset(&send_context->attr, 0, sizeof(send_context->attr)); + memset(send_context, 0, sizeof(VineyardSendContext)); send_context->attr.msg_buffer = msg; std::lock_guard scope_lock(this->rdma_mutex_); @@ -358,6 +358,24 @@ void RPCServer::doPrepareRecv(uint64_t rdma_conn_id) { rdma_server_->Recv(rdma_conn_id, msg, sizeof(VineyardMsg), context); } +void RPCServer::doNothing(VineyardRecvContext* recv_context) { + void* msg = nullptr; + rdma_server_->GetTXFreeMsgBuffer(msg); + VineyardMsg* send_msg = reinterpret_cast(msg); + send_msg->type = VINEYARD_MSG_REQUEST_MEM; + + send_msg->remoteMemInfo.remote_address = 0; + send_msg->remoteMemInfo.key = 0; + send_msg->remoteMemInfo.len = 0; + send_msg->remoteMemInfo.mr_desc = 0; + + VineyardSendContext* send_context = new VineyardSendContext(); + memset(send_context, 0, sizeof(VineyardSendContext)); + send_context->attr.msg_buffer = msg; + rdma_server_->Send(recv_context->rdma_conn_id, msg, sizeof(VineyardMsg), + send_context); +} + void RPCServer::doRDMARecv() { while (1) { void* context = nullptr; @@ -415,7 +433,7 @@ void RPCServer::doRDMARecv() { rdma_server_->Recv( recv_context->rdma_conn_id, reinterpret_cast(recv_msg), sizeof(VineyardMsg), reinterpret_cast(recv_context)); - } else if (recv_msg->type == VINEYARD_RELEASE_MEM) { + } else if (recv_msg->type == VINEYARD_MSG_RELEASE_MEM) { boost::asio::post( vs_ptr_->GetIOContext(), [this, recv_context_tmp, recv_msg_tmp] { doVineyardReleaseMemory(recv_context_tmp, recv_msg_tmp); @@ -425,6 +443,16 @@ void RPCServer::doRDMARecv() { rdma_server_->Recv( recv_context->rdma_conn_id, reinterpret_cast(recv_msg), sizeof(VineyardMsg), reinterpret_cast(recv_context)); + } else if (recv_msg->type == VINEYARD_MSG_EMPTY) { + boost::asio::post(vs_ptr_->GetIOContext(), + [this, recv_context_tmp, recv_msg_tmp] { + doNothing(recv_context_tmp); + delete recv_msg_tmp; + delete recv_context_tmp; + }); + rdma_server_->Recv( + recv_context->rdma_conn_id, reinterpret_cast(recv_msg), + sizeof(VineyardMsg), reinterpret_cast(recv_context)); } else { LOG(ERROR) << "Unknown message type: " << recv_msg->type; rdma_server_->Recv( diff --git a/src/server/async/rpc_server.h b/src/server/async/rpc_server.h index 1254acf93..567fbb56d 100644 --- a/src/server/async/rpc_server.h +++ b/src/server/async/rpc_server.h @@ -85,6 +85,8 @@ class RPCServer : public SocketServer, void doPrepareRecv(uint64_t rdma_conn_id); + void doNothing(VineyardRecvContext* recv_context); + const json rpc_spec_; asio::ip::tcp::acceptor acceptor_; asio::ip::tcp::socket socket_; diff --git a/src/server/util/remote.cc b/src/server/util/remote.cc index 295df3c8b..ee69f0059 100644 --- a/src/server/util/remote.cc +++ b/src/server/util/remote.cc @@ -144,7 +144,7 @@ Status RemoteClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) { void* buffer; this->rdma_client_->GetTXFreeMsgBuffer(buffer); VineyardMsg* msg = reinterpret_cast(buffer); - msg->type = VINEYARD_RELEASE_MEM; + msg->type = VINEYARD_MSG_RELEASE_MEM; msg->remoteMemInfo.remote_address = (uint64_t) remote_info.address; msg->remoteMemInfo.len = remote_info.size; VLOG(100) << "Send remote addr: "