Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the query API of llm cache and use vector<uint8_t> as payload object. #1797

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions modules/basic/ds/dataframe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ const std::shared_ptr<arrow::RecordBatch> DataFrame::AsBatch(bool copy) const {
} else if (auto tensor =
std::dynamic_pointer_cast<Tensor<std::string>>(df_col)) {
num_rows = tensor->shape()[0];
} else if (auto tensor =
std::dynamic_pointer_cast<Tensor<uint8_t>>(df_col)) {
num_rows = tensor->shape()[0];
dashanji marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto tensor =
std::dynamic_pointer_cast<Tensor<int8_t>>(df_col)) {
num_rows = tensor->shape()[0];
}

std::vector<std::shared_ptr<arrow::Buffer>> buffer{
Expand Down
55 changes: 28 additions & 27 deletions modules/llm-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ void KVStateCache::Resolve() {
// 1. construct the radix tree
this->rootTree = RadixTree::Deserialize(
base64_decode(this->meta_.GetKeyValue<std::string>("radix_tree")));
// raxShow(this->rootTree->GetRootTree());
if (VLOG_IS_ON(100)) {
VLOG(100) << raxShow(this->rootTree->GetRootTree());
}

dashanji marked this conversation as resolved.
Show resolved Hide resolved
// 2. construct the kvStateCacheBlockBuilder list
size_t numBlocks = this->meta_.GetKeyValue<size_t>("numBlocks");
Expand All @@ -57,24 +59,24 @@ void KVStateCache::Resolve() {
}

// 3. construct the member field
this->dimension = this->meta_.GetKeyValue<int>("dimension");
this->tensorBytes = this->meta_.GetKeyValue<int>("tensorBytes");
this->version = this->meta_.GetKeyValue<uint64_t>("version");
this->layer = this->meta_.GetKeyValue<int>("layer");
VLOG(100) << "construct the member field success, with dimension:"
<< this->dimension << " version:" << this->version
VLOG(100) << "construct the member field success, with tensorBytes:"
<< this->tensorBytes << " version:" << this->version
<< " layer:" << this->layer;
}

KVStateCache::~KVStateCache() {}

KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension,
KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int tensorBytes,
int cacheCapacity, int layer,
int blockSize) {
this->dimension = dimension;
this->tensorBytes = tensorBytes;
this->version = 0;
this->layer = layer;
KVStateCacheBlockBuilder* builder =
new KVStateCacheBlockBuilder(client, this->dimension, layer, blockSize);
new KVStateCacheBlockBuilder(client, this->tensorBytes, layer, blockSize);

this->rootTree = std::make_shared<RadixTree>(cacheCapacity);

Expand All @@ -90,7 +92,7 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension,

KVStateCacheBuilder::KVStateCacheBuilder(Client& client,
std::shared_ptr<KVStateCache> cache) {
this->dimension = cache->GetDimension();
this->tensorBytes = cache->GetTensorBytes();
this->version = cache->GetVersion();
this->layer = cache->GetLayer();
// 1. create block builder from block
Expand All @@ -114,11 +116,11 @@ KVStateCacheBuilder::KVStateCacheBuilder(Client& client,

KVStateCacheBlockBuilder* KVStateCacheBuilder::Split(
Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
std::vector<std::shared_ptr<NodeData>> nodeDataList) {
std::vector<std::shared_ptr<NodeData>>& nodeDataList) {
// Split the tree if the list of kvState is full.
VINEYARD_ASSERT(nodeDataList.size() > 0);
KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder =
new KVStateCacheBlockBuilder(client, this->dimension, this->layer,
new KVStateCacheBlockBuilder(client, this->tensorBytes, this->layer,
kvStateCacheBlockBuilder->GetBlockSize());
for (size_t i = 0; i < nodeDataList.size(); i++) {
OffsetData* data =
Expand All @@ -138,10 +140,9 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split(
return childKVStateCacheBlockBuilder;
}

void KVStateCacheBuilder::Update(Client& client,
const std::vector<int>& tokenList,
int nextToken,
const KV_STATE_WITH_LAYER& kvState) {
void KVStateCacheBuilder::Update(
Client& client, const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(nextToken);

Expand Down Expand Up @@ -199,9 +200,9 @@ void KVStateCacheBuilder::Update(Client& client,
<< " bitmap:" << kvStateCacheBlockBuilder->GetBitmapStr();
}

int KVStateCacheBuilder::Query(Client& client,
const std::vector<int>& tokenList, int token,
KV_STATE_WITH_LAYER& kvState) {
int KVStateCacheBuilder::Query(
Client& client, const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
std::vector<int> tokenListCopy = tokenList;
tokenListCopy.push_back(token);

Expand Down Expand Up @@ -275,23 +276,23 @@ void KVStateCacheBuilder::Merge(Client& client,
for (auto it = insertTokenList.begin(); it != insertTokenList.end(); ++it) {
std::vector<int> tokenList =
std::vector<int>((*it).begin(), (*it).end() - 1);
KV_STATE_WITH_LAYER kvState;
std::map<int, std::pair<LLMKV, LLMKV>> kvState;
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
K_STATE key_state;
V_STATE value_state;
key_state.data = malloc(this->dimension * sizeof(double));
key_state.length = this->dimension * sizeof(double);
value_state.data = malloc(this->dimension * sizeof(double));
value_state.length = this->dimension * sizeof(double);
LLMKV key_state;
LLMKV value_state;
key_state.data = malloc(this->tensorBytes);
key_state.length = this->tensorBytes;
value_state.data = malloc(this->tensorBytes);
value_state.length = this->tensorBytes;

kvState.insert(
std::make_pair(currentLayer, std::make_pair(key_state, value_state)));
}
globalCacheBuilder->Query(client, tokenList, (*it).back(), kvState);
this->Update(client, tokenList, (*it).back(), kvState);
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
K_STATE key_state = kvState[currentLayer].first;
V_STATE value_state = kvState[currentLayer].second;
LLMKV key_state = kvState[currentLayer].first;
LLMKV value_state = kvState[currentLayer].second;
free(key_state.data);
free(value_state.data);
}
Expand All @@ -309,7 +310,7 @@ std::shared_ptr<Object> KVStateCacheBuilder::_Seal(Client& client) {
std::shared_ptr<KVStateCache> kvStateCache = std::make_shared<KVStateCache>();

// 1. store the member variables to cache object meta
kvStateCache->meta_.AddKeyValue("dimension", this->dimension);
kvStateCache->meta_.AddKeyValue("tensorBytes", this->tensorBytes);
kvStateCache->meta_.AddKeyValue("version", this->version);
kvStateCache->meta_.AddKeyValue("layer", this->layer);

Expand Down
20 changes: 11 additions & 9 deletions modules/llm-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <map>
#include <memory>
#include <utility>
#include <vector>

#include "client/client.h"
Expand All @@ -40,7 +41,7 @@ class KVStateCache : public vineyard::Registered<KVStateCache> {
private:
std::vector<std::shared_ptr<KVStateCacheBlock>> kvStateCacheBlockList;
std::shared_ptr<RadixTree> rootTree;
int dimension;
int tensorBytes;
int cacheCapacity;
int layer;
uint64_t version;
Expand All @@ -56,11 +57,11 @@ class KVStateCache : public vineyard::Registered<KVStateCache> {
void Resolve();

// for test
std::vector<std::shared_ptr<KVStateCacheBlock>> GetKVStateCacheBlockList() {
std::vector<std::shared_ptr<KVStateCacheBlock>>& GetKVStateCacheBlockList() {
return this->kvStateCacheBlockList;
}

int GetDimension() { return this->dimension; }
int GetTensorBytes() { return this->tensorBytes; }

int GetCacheCapacity() { return this->cacheCapacity; }

Expand All @@ -77,25 +78,26 @@ class KVStateCache : public vineyard::Registered<KVStateCache> {

class KVStateCacheBuilder : public vineyard::ObjectBuilder {
std::shared_ptr<RadixTree> rootTree;
int dimension;
int tensorBytes;
int layer;
uint64_t version;

public:
KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity,
KVStateCacheBuilder(Client& client, int tensorBytes, int cacheCapacity,
int layer, int blockSize = DEFAULT_BLOCK_SIZE);

KVStateCacheBuilder(Client& client, std::shared_ptr<KVStateCache> cache);

KVStateCacheBlockBuilder* Split(
Client& client, KVStateCacheBlockBuilder* kvStateCacheBlockBuilder,
std::vector<std::shared_ptr<NodeData>> nodeDataList);
std::vector<std::shared_ptr<NodeData>>& nodeDataList);

void Update(Client& client, const std::vector<int>& token_list,
int next_token, const KV_STATE_WITH_LAYER& kv_state);
int next_token,
const std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

int Query(Client& client, const std::vector<int>& token_list, int token,
KV_STATE_WITH_LAYER& kv_state);
std::map<int, std::pair<LLMKV, LLMKV>>& kv_state);

void Delete(std::shared_ptr<NodeData> evicted_node);

Expand All @@ -109,7 +111,7 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder {

std::shared_ptr<Object> _Seal(Client& client) override;

uint64_t GetDimension() { return this->dimension; }
uint64_t GetTensorBytes() { return this->tensorBytes; }

std::shared_ptr<RadixTree> GetRootTree() { return this->rootTree; }

Expand Down
Loading
Loading