Skip to content

Commit

Permalink
Add batched update and batched query API.
Browse files Browse the repository at this point in the history
Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Aug 7, 2024
1 parent 4de2cf1 commit e9fadd6
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 26 deletions.
74 changes: 56 additions & 18 deletions modules/llm-cache/ds/vineyard_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,16 @@ void VineyardFile::Construct(const ObjectMeta& meta) {
}
this->path_ = meta_.GetKeyValue("path");
this->access_time_ = meta_.GetKeyValue<uint64_t>("access_time");
ObjectID blob_id = meta_.GetMember("buffer")->id();
if (meta.GetClient()->IsIPC()) {
std::shared_ptr<Blob> blob;
Client* client = reinterpret_cast<Client*>(meta.GetClient());
client->GetBlob(blob_id, blob);
buffer_ = blob->Buffer();
} else {
meta.GetBuffer(blob_id, buffer_);
}
ObjectMeta blob_meta;
meta_.GetMemberMeta("buffer", blob_meta);
ObjectID blob_id = blob_meta.GetId();
meta.GetBuffer(blob_id, buffer_);
}

Status VineyardFile::Read(void* buffer, size_t size, size_t offset) {
if (buffer == nullptr) {
return Status::Invalid("Buffer is nullptr");
}
if (static_cast<int64_t>(offset + size) > buffer_->size()) {
return Status::Invalid("Read out of range");
}
Expand All @@ -69,7 +67,7 @@ Status VineyardFile::Make(std::shared_ptr<VineyardFile>& file,
return Status::IOError("File " + path + " is not exist.");
}
// RETURN_ON_ERROR(ipc_client.GetObject(file_id, object));
ipc_client.GetMetaData(file_id, meta);
ipc_client.GetMetaData(file_id, meta, true);
if (meta.GetInstanceId() == ipc_client.instance_id()) {
object = ipc_client.GetObject(file_id);
file = std::dynamic_pointer_cast<VineyardFile>(object);
Expand All @@ -89,7 +87,7 @@ Status VineyardFile::Make(std::shared_ptr<VineyardFile>& file,

std::map<InstanceID, json> cluster_info;
rpc_client.ClusterInfo(cluster_info);
if (object_meta.GetInstanceId() == rpc_client.instance_id()) {
if (object_meta.GetInstanceId() == rpc_client.remote_instance_id()) {
object = rpc_client.GetObject(file_id);
} else {
std::string rpc_endpoint =
Expand Down Expand Up @@ -121,15 +119,31 @@ Status VineyardFile::BatchedGetObjects(
std::unordered_map<ObjectID, std::shared_ptr<VineyardFile>>& id_to_files) {
std::map<InstanceID, json> cluster_info;
rpc_client.ClusterInfo(cluster_info);
for (const auto& instance_to_meta : instance_to_metas) {
for (auto& instance_to_meta : instance_to_metas) {
std::vector<std::shared_ptr<Object>> file_objects;
if (client.Connected() && instance_to_meta.first == client.instance_id()) {
std::vector<ObjectID> ids(instance_to_meta.second.size());
for (size_t i = 0; i < instance_to_meta.second.size(); ++i) {
ids[i] = instance_to_meta.second[i].GetId();
}
instance_to_meta.second.clear();
client.GetMetaData(ids, instance_to_meta.second, false);
file_objects = client.GetObjects(instance_to_meta.second);
} else {
if (rpc_client.instance_id() == instance_to_meta.first) {
if (rpc_client.remote_instance_id() == instance_to_meta.first) {
std::vector<ObjectID> ids(instance_to_meta.second.size());
for (size_t i = 0; i < instance_to_meta.second.size(); ++i) {
ids[i] = instance_to_meta.second[i].GetId();
}
instance_to_meta.second.clear();
rpc_client.GetMetaData(ids, instance_to_meta.second, false);
RETURN_ON_ERROR(rpc_client.BatchedGetObjects(instance_to_meta.second,
file_objects));
} else {
std::vector<ObjectID> ids(instance_to_meta.second.size());
for (size_t i = 0; i < instance_to_meta.second.size(); ++i) {
ids[i] = instance_to_meta.second[i].GetId();
}
std::string rpc_endpoint =
cluster_info[instance_to_meta.first]["rpc_endpoint"]
.get<std::string>();
Expand All @@ -139,6 +153,13 @@ Status VineyardFile::BatchedGetObjects(
RPCClient remote_rpc_client;
RETURN_ON_ERROR(
remote_rpc_client.Connect(rpc_endpoint, "", "", rdma_endpoint));

/*
* Because the GetMeta will not set buffer that is not created by the
* caller rpc_client, so we need to get meta again.
*/
instance_to_meta.second.clear();
remote_rpc_client.GetMetaData(ids, instance_to_meta.second, false);
RETURN_ON_ERROR(remote_rpc_client.BatchedGetObjects(
instance_to_meta.second, file_objects));
}
Expand Down Expand Up @@ -284,9 +305,17 @@ std::vector<std::shared_ptr<Object>> VineyardFileBuilder::BatchedSealAndPersist(
}

for (size_t i = 0; i < blob_metas.size(); i++) {
rpc_client.Persist(blob_metas[i].GetId());
std::shared_ptr<VineyardFile> vineyard_file =
std::make_shared<VineyardFile>();
if (ipc_client.Connected()) {
ipc_client.Persist(blob_metas[i].GetId());
// vineyard_file->meta_.SetBuffer(blob_metas[i].GetId(),
// builders[i]->writer_->Buffer());
} else {
rpc_client.Persist(blob_metas[i].GetId());
// vineyard_file->meta_.SetBuffer(blob_metas[i].GetId(),
// builders[i]->remote_writer_->Buffer());
}
vineyard_file->meta_.AddMember("buffer", blob_metas[i]);
vineyard_file->meta_.AddKeyValue("path", builders[i]->path_);
vineyard_file->meta_.SetTypeName(type_name<VineyardFile>());
Expand All @@ -296,10 +325,19 @@ std::vector<std::shared_ptr<Object>> VineyardFileBuilder::BatchedSealAndPersist(
"access_time",
std::chrono::duration_cast<std::chrono::nanoseconds>(access_time)
.count());
VINEYARD_CHECK_OK(
rpc_client.CreateMetaData(vineyard_file->meta_, vineyard_file->id_));
rpc_client.Persist(vineyard_file->id_);
Status status = rpc_client.PutName(vineyard_file->id_, builders[i]->path_);
if (ipc_client.Connected()) {
VINEYARD_CHECK_OK(
ipc_client.CreateMetaData(vineyard_file->meta_, vineyard_file->id_));
VINEYARD_CHECK_OK(ipc_client.Persist(vineyard_file->id_));
Status status =
ipc_client.PutName(vineyard_file->id_, builders[i]->path_);
} else {
VINEYARD_CHECK_OK(
rpc_client.CreateMetaData(vineyard_file->meta_, vineyard_file->id_));
VINEYARD_CHECK_OK(rpc_client.Persist(vineyard_file->id_));
Status status =
rpc_client.PutName(vineyard_file->id_, builders[i]->path_);
}
}

return vineyard_file_objects;
Expand Down
193 changes: 193 additions & 0 deletions modules/llm-cache/storage/file_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,129 @@ Status FileStorage::Update(
return Status::NotImplemented();
}

Status FileStorage::BatchedUpdate(
const std::vector<int>& tokenList,
const std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& updated) {
if (this->exitFlag) {
return Status::Invalid("The file storage has been closed!");
}
if (tokenList.size() % chunkSize != 0) {
return Status::Invalid("Tokens size " + std::to_string(tokenList.size()) +
" should be multiple of batch size " +
std::to_string(chunkSize) + "!");
}

std::vector<std::string> pathList;
std::set<std::string> createFileSet;
std::mutex createFileSetMutex;
RETURN_ON_ERROR(hasher->computePathForTokens(tokenList, chunkSize,
hashChunkSize, pathList));
if (pathList.size() == 0) {
return Status::OK();
}

std::vector<std::shared_ptr<FileDescriptor>> read_fd_list;
RETURN_ON_ERROR(BatchedOpen(pathList, read_fd_list, FileOperationType::READ));

auto read_fn = [this, &read_fd_list, &tokenList](int i) -> Status {
int tokenLength = (i + 1) * chunkSize;
RETURN_ON_ERROR(Read(read_fd_list[i], &tokenLength, sizeof(int)));
std::vector<int> tokens;
tokens.resize(tokenLength);
RETURN_ON_ERROR(
Read(read_fd_list[i], tokens.data(), tokenLength * sizeof(int)));
if (!CompareTokenList(tokenList, tokens, tokenLength)) {
// Token list not match
VINEYARD_DISCARD(Close(read_fd_list[i]));
return Status::ObjectExists("File exists for another token sequence");
}
// Skip this kv state
VINEYARD_DISCARD(Close(read_fd_list[i]));
return Status::OK();
};

int lower_bound = 0;
if (read_fd_list.size() > 0) {
parallel::ThreadGroup tg(std::min(read_fd_list.size(), (size_t) 1));
std::vector<parallel::ThreadGroup::tid_t> tids(read_fd_list.size());
for (size_t i = 0; i < read_fd_list.size(); ++i) {
tids[i] = tg.AddTask(read_fn, i);
}
std::vector<Status> taskResults(read_fd_list.size(), Status::OK());
for (size_t i = 0; i < read_fd_list.size(); ++i) {
taskResults[i] = tg.TaskResult(tids[i]);
}

for (size_t i = 0; i < taskResults.size(); i++) {
if (taskResults[i].ok()) {
lower_bound += 1;
} else {
// File exists for another token sequence
break;
}
}
}

BatchedClose(read_fd_list);

std::vector<std::shared_ptr<FileDescriptor>> write_fd_list;
std::vector<std::string> left_path(pathList.begin() + lower_bound,
pathList.end());
RETURN_ON_ERROR(
BatchedOpen(left_path, write_fd_list, FileOperationType::WRITE));
auto fn = [this, &write_fd_list, &tokenList, &kvCacheList,
lower_bound](int i) -> Status {
int tokenLength = (i + 1 + lower_bound) * chunkSize;

RETURN_ON_ERROR(Write(write_fd_list[i], &tokenLength, sizeof(int)));
RETURN_ON_ERROR(
Write(write_fd_list[i], tokenList.data(), tokenLength * sizeof(int)));
for (int currentTokenIndex = (i + lower_bound) * chunkSize;
currentTokenIndex < (i + lower_bound + 1) * chunkSize;
currentTokenIndex++) {
for (int currentLayer = 0; currentLayer < layer; currentLayer++) {
const LLMKV& k = kvCacheList[currentTokenIndex][currentLayer].first;
const LLMKV& v = kvCacheList[currentTokenIndex][currentLayer].second;
RETURN_ON_ERROR(Write(write_fd_list[i], k.data, k.length));
RETURN_ON_ERROR(Write(write_fd_list[i], v.data, k.length));
}
}

VINEYARD_DISCARD(Flush(write_fd_list[i]));
return Status::OK();
};

if (write_fd_list.size() > 0) {
parallel::ThreadGroup tg_write(std::min(write_fd_list.size(), (size_t) 1));
std::vector<parallel::ThreadGroup::tid_t> tids_write(write_fd_list.size());
for (size_t i = 0; i < write_fd_list.size(); ++i) {
tids_write[i] = tg_write.AddTask(fn, i);
}
std::vector<Status> taskResults_write(write_fd_list.size(), Status::OK());
for (size_t i = 0; i < write_fd_list.size(); ++i) {
taskResults_write[i] = tg_write.TaskResult(tids_write[i]);
}

size_t upper_bound = 0;
for (size_t i = 0; i < write_fd_list.size(); i++) {
if (taskResults_write[i].ok()) {
upper_bound += 1;
} else {
break;
}
}

for (size_t i = upper_bound; i < write_fd_list.size(); i++) {
VINEYARD_SUPPRESS(Delete(this->rootPath + pathList[i + lower_bound]));
}
updated = upper_bound * chunkSize;

RETURN_ON_ERROR(BatchedClose(write_fd_list));
}
return Status::OK();
}

/**
* @brief Query the kv state with the given token list in the file storage.
*
Expand Down Expand Up @@ -654,6 +777,76 @@ Status FileStorage::Query(
return Status::OK();
}

Status FileStorage::BatchedQuery(
const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) {
if (this->exitFlag) {
return Status::Invalid("The file storage has been closed!");
}

std::vector<std::string> paths;
RETURN_ON_ERROR(
hasher->computePathForTokens(tokenList, chunkSize, hashChunkSize, paths));

std::vector<std::shared_ptr<FileDescriptor>> read_fd_list;
RETURN_ON_ERROR(BatchedOpen(paths, read_fd_list, FileOperationType::READ));
auto read_fn = [this, &read_fd_list, &tokenList, &kvCacheList](
size_t i, size_t matched_start) -> Status {
int tokenLength = 0;
RETURN_ON_ERROR(Read(read_fd_list[i], &tokenLength, sizeof(int)));
std::vector<int> blockTokenList(tokenLength, 0);
RETURN_ON_ERROR(Read(read_fd_list[i], blockTokenList.data(),
tokenLength * sizeof(int)));

if (!CompareTokenList(tokenList, blockTokenList, tokenLength)) {
VINEYARD_DISCARD(Close(read_fd_list[i]));
return Status::ObjectNotExists("Token mismatch");
}

for (int j = 0; j < chunkSize; j++) {
if (matched_start + j >= tokenList.size() ||
matched_start + j >= kvCacheList.size()) {
break;
}
auto& kvState = kvCacheList[matched_start + j];
for (int currentLayer = 0; currentLayer < layer; currentLayer++) {
RETURN_ON_ASSERT(static_cast<int>(kvState.size()) == layer,
"The size of kvState is not equal to layer");
LLMKV& k = kvState[currentLayer].first;
LLMKV& v = kvState[currentLayer].second;
RETURN_ON_ASSERT(
k.length == tensorNBytes && v.length == tensorNBytes,
"The size of kv tensor doesn't match with the tensorNBytes");
RETURN_ON_ERROR(Read(read_fd_list[i], k.data, k.length));
RETURN_ON_ERROR(Read(read_fd_list[i], v.data, v.length));
}
}
VINEYARD_DISCARD(Close(read_fd_list[i]));
return Status::OK();
};

parallel::ThreadGroup tg(std::min(read_fd_list.size(), (size_t) 1));
std::vector<parallel::ThreadGroup::tid_t> tids(read_fd_list.size());
for (size_t i = 0; i < read_fd_list.size(); ++i) {
tids[i] = tg.AddTask(read_fn, i, i * chunkSize);
}
std::vector<Status> taskResults(read_fd_list.size(), Status::OK());
for (size_t i = 0; i < read_fd_list.size(); ++i) {
taskResults[i] = tg.TaskResult(tids[i]);
}

matched = 0;
for (size_t i = 0; i < read_fd_list.size(); i++) {
if (taskResults[i].ok()) {
matched += chunkSize;
} else {
break;
}
}
return Status::OK();
}

bool FileStorage::CompareTokenList(const std::vector<int>& tokenList1,
const std::vector<int>& tokenList2,
size_t length) {
Expand Down
9 changes: 9 additions & 0 deletions modules/llm-cache/storage/file_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class FileStorage : public IStorage,
const std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& updated) override;

Status BatchedUpdate(
const std::vector<int>& tokenList,
const std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& updated) override;

Status Query(const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) override;
Expand All @@ -156,6 +161,10 @@ class FileStorage : public IStorage,
const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) override;
Status BatchedQuery(
const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) override;

void CloseCache() override;

Expand Down
19 changes: 19 additions & 0 deletions modules/llm-cache/storage/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ class IStorage {
const std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& updated) = 0;

/*
* BatchedUpdate is used to update multiple kvCacheList in one batch. It will
* batch open all files or batch close all files to reduce the overhead of
* network IO.
*/
virtual Status BatchedUpdate(
const std::vector<int>& tokenList,
const std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& updated) {
return Status::NotImplemented();
}

virtual Status Query(
const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
Expand All @@ -56,6 +68,13 @@ class IStorage {
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) = 0;

virtual Status BatchedQuery(
const std::vector<int>& tokenList,
std::vector<std::vector<std::pair<LLMKV, LLMKV>>>& kvCacheList,
size_t& matched) {
return Status::NotImplemented();
}

virtual void CloseCache() = 0;

virtual void StartGlobalGCThread() {}
Expand Down
Loading

0 comments on commit e9fadd6

Please sign in to comment.