Skip to content

Commit

Permalink
KV cache module optimization and bug fix for rdma. (#1984)
Browse files Browse the repository at this point in the history
Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm authored Aug 15, 2024
1 parent bbdb41e commit db64188
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 9 deletions.
4 changes: 4 additions & 0 deletions modules/llm-cache/ds/vineyard_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ std::shared_ptr<Object> VineyardFileBuilder::SealAndPersist(
ObjectMeta blob_meta;
if (ipc_client.Connected()) {
std::shared_ptr<Object> object;
writer_->Shrink(ipc_client, writer_->size());
writer_->Seal(ipc_client, object);
blob_meta = object->meta();
ipc_client.Persist(blob_meta.GetId());
Expand All @@ -282,6 +283,7 @@ std::shared_ptr<Object> VineyardFileBuilder::SealAndPersist(
}
vineyardFile->meta_.AddMember("buffer", blob_meta);
vineyardFile->meta_.AddKeyValue("path", path_);
vineyardFile->meta_.AddKeyValue("size", Size());
vineyardFile->meta_.SetTypeName(type_name<VineyardFile>());

auto access_time = std::chrono::system_clock::now().time_since_epoch();
Expand Down Expand Up @@ -312,6 +314,7 @@ std::vector<std::shared_ptr<Object>> VineyardFileBuilder::BatchedSealAndPersist(
if (ipc_client.Connected()) {
for (auto builder : builders) {
std::shared_ptr<Object> object;
builder->writer_->Shrink(ipc_client, builder->writer_->size());
builder->writer_->Seal(ipc_client, object);
blob_metas.push_back(object->meta());
}
Expand All @@ -334,6 +337,7 @@ std::vector<std::shared_ptr<Object>> VineyardFileBuilder::BatchedSealAndPersist(
}
vineyard_file->meta_.AddMember("buffer", blob_metas[i]);
vineyard_file->meta_.AddKeyValue("path", builders[i]->path_);
vineyard_file->meta_.AddKeyValue("size", builders[i]->Size());
vineyard_file->meta_.SetTypeName(type_name<VineyardFile>());

auto access_time = std::chrono::system_clock::now().time_since_epoch();
Expand Down
18 changes: 18 additions & 0 deletions modules/llm-cache/storage/file_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ Status FileStorage::Update(
" should be multiple of batch size " +
std::to_string(chunkSize) + "!");
}
if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) {
LOG(WARNING)
<< "The token list size is larger than the maximum cache token "
"length. This token list will be ignored!";
return Status::OK();
}

std::vector<std::string> pathList;
std::set<std::string> createFileSet;
Expand Down Expand Up @@ -281,6 +287,12 @@ Status FileStorage::Update(
" should be multiple of batch size " +
std::to_string(chunkSize) + "!");
}
if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) {
LOG(WARNING)
<< "The token list size is larger than the maximum cache token "
"length. This token list will be ignored!";
return Status::OK();
}

std::vector<std::string> pathList;
std::set<std::string> createFileSet;
Expand Down Expand Up @@ -427,6 +439,12 @@ Status FileStorage::BatchedUpdate(
" should be multiple of batch size " +
std::to_string(chunkSize) + "!");
}
if (tokenList.size() > MAX_CACHE_TOKEN_LENGTH) {
LOG(WARNING)
<< "The token list size is larger than the maximum cache token "
"length. This token list will be ignored!";
return Status::OK();
}

std::vector<std::string> pathList;
std::set<std::string> createFileSet;
Expand Down
2 changes: 2 additions & 0 deletions modules/llm-cache/storage/file_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ limitations under the License.
#define SECOND_TO_MICROSECOND 1000000
#define SECOND_TO_NANOSECOND 1000000000

#define MAX_CACHE_TOKEN_LENGTH 65536

namespace vineyard {

struct FileDescriptor {};
Expand Down
4 changes: 2 additions & 2 deletions modules/llm-cache/storage/vineyard_file_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class VineyardFileStorage : public FileStorage {
this->globalGCInterval = std::chrono::seconds(globalGCInterval);
this->globalFileTTL = std::chrono::seconds(globalTTL);
this->enableGlobalGC = enableGlobalGC;
this->max_file_size_ =
tensorNBytes * 2 * layer * chunkSize + 65536 * sizeof(int);
this->max_file_size_ = tensorNBytes * 2 * layer * chunkSize +
MAX_CACHE_TOKEN_LENGTH * sizeof(int);
}

~VineyardFileStorage() = default;
Expand Down
2 changes: 1 addition & 1 deletion src/client/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Status RPCClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) {
void* buffer;
RETURN_ON_ERROR(this->rdma_client_->GetTXFreeMsgBuffer(buffer));
VineyardMsg* msg = reinterpret_cast<VineyardMsg*>(buffer);
msg->type = VINEYARD_RELEASE_MEM;
msg->type = VINEYARD_MSG_RELEASE_MEM;
msg->remoteMemInfo.remote_address = (uint64_t) remote_info.address;
msg->remoteMemInfo.len = remote_info.size;
msg->remoteMemInfo.mr_desc = remote_info.mr_desc;
Expand Down
3 changes: 2 additions & 1 deletion src/common/rdma/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ namespace vineyard {
#endif // defined(__linux__)

enum VINEYARD_MSG_OPT {
VINEYARD_MSG_EMPTY = 1,
VINEYARD_MSG_CONNECT,
VINEYARD_MSG_EXCHANGE_KEY,
VINEYARD_MSG_REQUEST_MEM,
VINEYARD_RELEASE_MEM,
VINEYARD_MSG_RELEASE_MEM,
VINEYARD_MSG_CLOSE,
};

Expand Down
36 changes: 32 additions & 4 deletions src/server/async/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ void RPCServer::doVineyardRequestMemory(VineyardRecvContext* recv_context,
send_msg->remoteMemInfo.len = 0;

VineyardSendContext* send_context = new VineyardSendContext();
memset(&send_context->attr, 0, sizeof(send_context->attr));
memset(send_context, 0, sizeof(VineyardSendContext));
send_context->attr.msg_buffer = msg;
rdma_server_->Send(recv_context->rdma_conn_id,
recv_context->attr.msg_buffer, sizeof(VineyardMsg),
recv_context);
send_context);
return;
}

Expand All @@ -272,7 +272,7 @@ void RPCServer::doVineyardRequestMemory(VineyardRecvContext* recv_context,
send_msg->remoteMemInfo.mr_desc = remote_request_mem_info.mr_desc;

VineyardSendContext* send_context = new VineyardSendContext();
memset(&send_context->attr, 0, sizeof(send_context->attr));
memset(send_context, 0, sizeof(VineyardSendContext));
send_context->attr.msg_buffer = msg;

std::lock_guard<std::recursive_mutex> scope_lock(this->rdma_mutex_);
Expand Down Expand Up @@ -358,6 +358,24 @@ void RPCServer::doPrepareRecv(uint64_t rdma_conn_id) {
rdma_server_->Recv(rdma_conn_id, msg, sizeof(VineyardMsg), context);
}

void RPCServer::doNothing(VineyardRecvContext* recv_context) {
void* msg = nullptr;
rdma_server_->GetTXFreeMsgBuffer(msg);
VineyardMsg* send_msg = reinterpret_cast<VineyardMsg*>(msg);
send_msg->type = VINEYARD_MSG_REQUEST_MEM;

send_msg->remoteMemInfo.remote_address = 0;
send_msg->remoteMemInfo.key = 0;
send_msg->remoteMemInfo.len = 0;
send_msg->remoteMemInfo.mr_desc = 0;

VineyardSendContext* send_context = new VineyardSendContext();
memset(send_context, 0, sizeof(VineyardSendContext));
send_context->attr.msg_buffer = msg;
rdma_server_->Send(recv_context->rdma_conn_id, msg, sizeof(VineyardMsg),
send_context);
}

void RPCServer::doRDMARecv() {
while (1) {
void* context = nullptr;
Expand Down Expand Up @@ -415,7 +433,7 @@ void RPCServer::doRDMARecv() {
rdma_server_->Recv(
recv_context->rdma_conn_id, reinterpret_cast<void*>(recv_msg),
sizeof(VineyardMsg), reinterpret_cast<void*>(recv_context));
} else if (recv_msg->type == VINEYARD_RELEASE_MEM) {
} else if (recv_msg->type == VINEYARD_MSG_RELEASE_MEM) {
boost::asio::post(
vs_ptr_->GetIOContext(), [this, recv_context_tmp, recv_msg_tmp] {
doVineyardReleaseMemory(recv_context_tmp, recv_msg_tmp);
Expand All @@ -425,6 +443,16 @@ void RPCServer::doRDMARecv() {
rdma_server_->Recv(
recv_context->rdma_conn_id, reinterpret_cast<void*>(recv_msg),
sizeof(VineyardMsg), reinterpret_cast<void*>(recv_context));
} else if (recv_msg->type == VINEYARD_MSG_EMPTY) {
boost::asio::post(vs_ptr_->GetIOContext(),
[this, recv_context_tmp, recv_msg_tmp] {
doNothing(recv_context_tmp);
delete recv_msg_tmp;
delete recv_context_tmp;
});
rdma_server_->Recv(
recv_context->rdma_conn_id, reinterpret_cast<void*>(recv_msg),
sizeof(VineyardMsg), reinterpret_cast<void*>(recv_context));
} else {
LOG(ERROR) << "Unknown message type: " << recv_msg->type;
rdma_server_->Recv(
Expand Down
2 changes: 2 additions & 0 deletions src/server/async/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class RPCServer : public SocketServer,

void doPrepareRecv(uint64_t rdma_conn_id);

void doNothing(VineyardRecvContext* recv_context);

const json rpc_spec_;
asio::ip::tcp::acceptor acceptor_;
asio::ip::tcp::socket socket_;
Expand Down
2 changes: 1 addition & 1 deletion src/server/util/remote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Status RemoteClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) {
void* buffer;
this->rdma_client_->GetTXFreeMsgBuffer(buffer);
VineyardMsg* msg = reinterpret_cast<VineyardMsg*>(buffer);
msg->type = VINEYARD_RELEASE_MEM;
msg->type = VINEYARD_MSG_RELEASE_MEM;
msg->remoteMemInfo.remote_address = (uint64_t) remote_info.address;
msg->remoteMemInfo.len = remote_info.size;
VLOG(100) << "Send remote addr: "
Expand Down

0 comments on commit db64188

Please sign in to comment.