Skip to content

Commit

Permalink
Support store more than 64 entries for kv cache block.
Browse files Browse the repository at this point in the history
Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Feb 26, 2024
1 parent 779348e commit df8825d
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 40 deletions.
8 changes: 5 additions & 3 deletions modules/kv-state-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ KVStateCache::~KVStateCache() {
}

KVStateCacheBuilder::KVStateCacheBuilder(Client& client, int dimension,
int cacheCapacity, int layer) {
int cacheCapacity, int layer,
int blockSize) {
this->dimension = dimension;
this->version = 0;
this->layer = layer;
KVStateCacheBlockBuilder* builder =
new KVStateCacheBlockBuilder(client, this->dimension, layer);
new KVStateCacheBlockBuilder(client, this->dimension, layer, blockSize);

this->rootTree = std::make_shared<RadixTree>(cacheCapacity);

Expand Down Expand Up @@ -126,7 +127,8 @@ KVStateCacheBlockBuilder* KVStateCacheBuilder::Split(
// Split the tree if the list of kvState is full.
VINEYARD_ASSERT(nodeDataList.size() > 0);
KVStateCacheBlockBuilder* childKVStateCacheBlockBuilder =
new KVStateCacheBlockBuilder(client, this->dimension, this->layer);
new KVStateCacheBlockBuilder(client, this->dimension, this->layer,
kvStateCacheBlockBuilder->GetBlockSize());
for (size_t i = 0; i < nodeDataList.size(); i++) {
OffsetData* data = (OffsetData*) nodeDataList[i]->nodeData->data;
if (data == nullptr)
Expand Down
4 changes: 2 additions & 2 deletions modules/kv-state-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ class KVStateCache : public vineyard::Registered<KVStateCache> {
class KVStateCacheBuilder : public vineyard::ObjectBuilder {
std::shared_ptr<RadixTree> rootTree;
int dimension;
int layer = 1;
int layer;
uint64_t version;

public:
KVStateCacheBuilder(Client& client, int dimension, int cacheCapacity,
int layer);
int layer, int blockSize = DEFAULT_BLOCK_SIZE);

KVStateCacheBuilder(Client& client, std::shared_ptr<KVStateCache> cache);

Expand Down
85 changes: 64 additions & 21 deletions modules/kv-state-cache/ds/kv_state_cache_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@ namespace vineyard {
std::string KVStateCacheBlock::GetBitmapStr() {
std::string result;
const int bits = 8 * sizeof(unsigned long long);
for (int i = bits - 1; i >= 0; --i) {
result += ((this->bitmap >> i) & 1) ? '1' : '0';
for (int i = 0; i < this->bitmapSize; i++) {
for (int j = bits - 1; j >= 0; --j) {
result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0';
}
}
return result;
}

std::string KVStateCacheBlockBuilder::GetBitmapStr() {
std::string result;
const int bits = 8 * sizeof(unsigned long long);
for (int i = bits - 1; i >= 0; --i) {
result += ((this->bitmap >> i) & 1) ? '1' : '0';
for (int i = 0; i < this->bitmapSize; i++) {
for (int j = bits - 1; j >= 0; --j) {
result += (((this->bitmap[i]) >> j) & 1) ? '1' : '0';
}
}
return result;
}
Expand All @@ -59,14 +63,27 @@ void KVStateCacheBlock::Construct(const ObjectMeta& meta) {
"valueStateTensorBuilder_" + std::to_string(currentLayer))));
}
// 2. construct the member field
this->bitmap = this->meta_.GetKeyValue<unsigned long long>("bitmap");
this->bitmapSize = this->meta_.GetKeyValue<int>("bitmap_size");
LOG(INFO) << "construct bitmap size:" << this->bitmapSize;
this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t));
for (int i = 0; i < this->bitmapSize; i++) {
this->bitmap[i] =
this->meta_.GetKeyValue<uint64_t>("bitmap_" + std::to_string(i));
}
this->dimension = this->meta_.GetKeyValue<int>("dimension");
this->blockSize = this->meta_.GetKeyValue<int>("block_size");
}

KVStateCacheBlock::~KVStateCacheBlock() { free(this->bitmap); }

KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client,
int dimension, int layer) {
this->bitmap = UINT64_MAX;
std::vector<int64_t> shape = {LIST_SIZE, dimension};
int dimension, int layer,
int blockSize) {
this->blockSize = blockSize;
this->bitmapSize = (blockSize + 63) / 64;
this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t));
memset((void*) this->bitmap, UINT8_MAX, this->bitmapSize * sizeof(uint64_t));
std::vector<int64_t> shape = {(int64_t)(blockSize), dimension};
for (int i = 0; i < layer; i++) {
this->keyStateTensorBuilderList.push_back(
std::make_shared<TensorBuilder<double>>(client, shape));
Expand All @@ -79,10 +96,17 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(Client& client,

KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
Client& client, std::shared_ptr<KVStateCacheBlock> kvStateCacheBlock) {
this->bitmap = kvStateCacheBlock->bitmap;
this->bitmapSize = kvStateCacheBlock->bitmapSize;
this->blockSize = kvStateCacheBlock->blockSize;
LOG(INFO) << "create builder from block object, bitmap size:"
<< this->bitmapSize << " block size:" << blockSize;
this->bitmap = (uint64_t*) malloc(this->bitmapSize * sizeof(uint64_t));
for (int i = 0; i < this->bitmapSize; i++) {
this->bitmap[i] = kvStateCacheBlock->bitmap[i];
}
this->dimension = kvStateCacheBlock->dimension;
this->layer = kvStateCacheBlock->layer;
std::vector<int64_t> shape = {LIST_SIZE, dimension};
std::vector<int64_t> shape = {(int64_t)(blockSize), dimension};
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
this->keyStateTensorBuilderList.push_back(
std::make_shared<TensorBuilder<double>>(client, shape));
Expand All @@ -93,10 +117,10 @@ KVStateCacheBlockBuilder::KVStateCacheBlockBuilder(
for (int currentLayer = 0; currentLayer < this->layer; currentLayer++) {
memcpy(this->keyStateTensorBuilderList[currentLayer]->data(),
kvStateCacheBlock->keyStateTensorList[currentLayer]->data(),
LIST_SIZE * this->dimension * sizeof(double));
(int64_t)(blockSize) * this->dimension * sizeof(double));
memcpy(this->valueStateTensorBuilderList[currentLayer]->data(),
kvStateCacheBlock->valueStateTensorList[currentLayer]->data(),
LIST_SIZE * this->dimension * sizeof(double));
(int64_t)(blockSize) * this->dimension * sizeof(double));
}
}

Expand Down Expand Up @@ -126,14 +150,24 @@ Status KVStateCacheBlockBuilder::Query(Client& client, int index,
}

int KVStateCacheBlockBuilder::FindEmptySlot() {
int index = ffsll(this->bitmap) - 1;
VINEYARD_ASSERT(index >= 0 && index < LIST_SIZE);
return index;
for (int i = 0; i < this->bitmapSize; i++) {
if (this->bitmap[i] != 0) {
int index = ffsll(this->bitmap[i]) - 1;
return index + i * 64;
}
}
return -1;
}

bool KVStateCacheBlockBuilder::IsFull() {
int index = ffsll(this->bitmap) - 1;
return index < 0 || index >= LIST_SIZE;
int left = this->blockSize;
for (int i = 0; i < this->bitmapSize; i++) {
if (this->bitmap[i] != 0 && ffsll(this->bitmap[i]) - 1 < left) {
return false;
}
left -= sizeof(unsigned long long) * 8;
}
return true;
}

void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState,
Expand Down Expand Up @@ -161,7 +195,7 @@ void KVStateCacheBlockBuilder::Update(const KV_STATE_WITH_LAYER& kvState,
}
data->offset = index;

ACQUIRE_BIT_RESOURCE(this->bitmap, index);
ACQUIRE_BIT_RESOURCE(this->bitmap[index / 64], index % 64);
}

short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child,
Expand Down Expand Up @@ -191,8 +225,8 @@ short KVStateCacheBlockBuilder::Split(KVStateCacheBlockBuilder* child,
memcpy(childKeyState, keyState, this->dimension * sizeof(double));
memcpy(childValueState, valueState, this->dimension * sizeof(double));
}
ACQUIRE_BIT_RESOURCE(child->bitmap, childIndex);
FREE_BIT_RESOURCE(this->bitmap, index);
ACQUIRE_BIT_RESOURCE(child->bitmap[childIndex / 64], childIndex % 64);
FREE_BIT_RESOURCE(this->bitmap[index / 64], index % 64);
return childIndex;
}

Expand All @@ -216,7 +250,14 @@ std::shared_ptr<Object> KVStateCacheBlockBuilder::_Seal(Client& client) {
}

// 2. store the member field to meta
kvStateCacheBlock->meta_.AddKeyValue("bitmap", this->bitmap);
kvStateCacheBlock->meta_.AddKeyValue("bitmap_size", this->bitmapSize);
for (int i = 0; i < this->bitmapSize; i++) {
kvStateCacheBlock->meta_.AddKeyValue("bitmap_" + std::to_string(i),
this->bitmap[i]);
}
LOG(INFO) << "seal bitmap:" << this->GetBitmapStr();

kvStateCacheBlock->meta_.AddKeyValue("block_size", this->blockSize);
kvStateCacheBlock->meta_.AddKeyValue("dimension", this->dimension);
kvStateCacheBlock->meta_.AddKeyValue("layer", this->layer);
// 3. set the object type to meta
Expand All @@ -227,4 +268,6 @@ std::shared_ptr<Object> KVStateCacheBlockBuilder::_Seal(Client& client) {
return kvStateCacheBlock;
}

KVStateCacheBlockBuilder::~KVStateCacheBlockBuilder() { free(this->bitmap); }

} // namespace vineyard
29 changes: 22 additions & 7 deletions modules/kv-state-cache/ds/kv_state_cache_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct OffsetData {
};
namespace vineyard {

#define LIST_SIZE 5
#define DEFAULT_BLOCK_SIZE 64

/**
* @brief KVStateCacheBlock is a cache for kv-cache of LLM. When a new prompt
Expand All @@ -67,7 +67,9 @@ class KVStateCacheBlock : public vineyard::Registered<KVStateCacheBlock> {
private:
std::vector<std::shared_ptr<Tensor<double>>> keyStateTensorList;
std::vector<std::shared_ptr<Tensor<double>>> valueStateTensorList;
uint64_t bitmap;
uint64_t* bitmap;
int blockSize;
int bitmapSize;
ObjectID id;
int layer;
int dimension;
Expand All @@ -84,7 +86,9 @@ class KVStateCacheBlock : public vineyard::Registered<KVStateCacheBlock> {

uint64_t GetDimension() { return this->dimension; }

uint64_t GetBitmap() { return this->bitmap; }
uint64_t* GetBitmap() { return this->bitmap; }

int GetBlockSize() { return this->blockSize; }

std::shared_ptr<const Tensor<double>> GetKeyTensor(int layer) {
return this->keyStateTensorList[layer];
Expand All @@ -102,6 +106,8 @@ class KVStateCacheBlock : public vineyard::Registered<KVStateCacheBlock> {
return this->valueStateTensorList;
}

~KVStateCacheBlock();

friend class KVStateCacheBlockBuilder;
};

Expand All @@ -112,14 +118,17 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
valueStateTensorBuilderList;
// TBD
// support more than 64 kv-state cache slots
uint64_t bitmap;
uint64_t* bitmap;
int blockSize;
int bitmapSize;
int dimension;
int layer;

int FindEmptySlot();

public:
KVStateCacheBlockBuilder(Client& client, int dimension, int layer);
KVStateCacheBlockBuilder(Client& client, int dimension, int layer,
int blockSize);

KVStateCacheBlockBuilder(
Client& client, std::shared_ptr<KVStateCacheBlock> kv_state_cache_block);
Expand Down Expand Up @@ -172,13 +181,19 @@ class KVStateCacheBlockBuilder : public ObjectBuilder {
return valueStateTensorBuilderList;
}

void DeleteKVCache(int bit) { FREE_BIT_RESOURCE(this->bitmap, bit); }
void DeleteKVCache(int bit) {
FREE_BIT_RESOURCE(this->bitmap[bit / 64], bit % 64);
}

std::string GetBitmapStr();

uint64_t GetBitmap() { return this->bitmap; }
uint64_t* GetBitmap() { return this->bitmap; }

uint64_t GetDimension() { return this->dimension; }

int GetBlockSize() { return this->blockSize; }

~KVStateCacheBlockBuilder();
};

} // namespace vineyard
Expand Down
6 changes: 3 additions & 3 deletions modules/kv-state-cache/utils/kv_state_cache_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void signalHandler(int signum) {
exit(signum);
}

void InitKVStateCache(int dimension, int cacheCapacity, int layer) {
void InitKVStateCache(int dimension, int cacheCapacity, int layer,
int blockSize) {
if (kvStateCacheBuilder == nullptr) {
std::string socket = std::string(getenv("VINEYARD_IPC_SOCKET"));
LOG(INFO) << "socket:" << socket;
Expand Down Expand Up @@ -87,14 +88,13 @@ void InitKVStateCache(int dimension, int cacheCapacity, int layer) {
std::shared_ptr<KVStateCache> globalKVStateCache =
std::dynamic_pointer_cast<KVStateCache>(
client.GetObject(globalKVStateCacheID));
// TBD cache stragety
kvStateCacheBuilder =
std::make_shared<KVStateCacheBuilder>(client, globalKVStateCache);
} else {
// if failed, create a new cache object
LOG(INFO) << "failed to get the cache object, create a new one";
kvStateCacheBuilder = std::make_shared<KVStateCacheBuilder>(
client, dimension, cacheCapacity, layer);
client, dimension, cacheCapacity, layer, blockSize);
}

// // release the lock
Expand Down
4 changes: 2 additions & 2 deletions modules/kv-state-cache/utils/kv_state_cache_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ limitations under the License.
#ifndef MODULES_KV_STATE_CACHE_UTILS_H_
#define MODULES_KV_STATE_CACHE_UTILS_H_

void InitKVStateCache(int dimension = 10, int cacheCapacity = 10,
int layer = 1);
void InitKVStateCache(int dimension = 10, int cacheCapacity = 10, int layer = 1,
int blockSize = 5);

void Update(const std::vector<int>& tokenList, int nextToken,
const KV_STATE_WITH_LAYER& kvState);
Expand Down
11 changes: 9 additions & 2 deletions test/kv_state_cache_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ using namespace vineyard;
#define DEMENSION 10
#define CAPACITY 20
#define LAYER 3
#define BLOCK_SIZE 5

void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER); }
void init() { InitKVStateCache(DEMENSION, CAPACITY, LAYER, BLOCK_SIZE); }

void print_current_tokens(const std::vector<int>& prefix, int next_token) {
std::string tokens_str = "";
Expand Down Expand Up @@ -106,7 +107,11 @@ void inference(std::vector<int> tokens, bool block = false) {

int main() {
init();
std::vector<int> round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
std::vector<int> round_1_tokens = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69};
std::vector<int> round_2_tokens = {1, 2, 3, 4, 5, 7, 8,
9, 10, 11, 12, 13, 14};
std::vector<int> round_3_tokens = {1, 2, 3, 9, 10, 11, 12, 13, 14};
Expand All @@ -119,6 +124,8 @@ int main() {
inference(round_1_tokens);
inference(round_2_tokens);
inference(round_3_tokens);
sleep(5);
inference(round_3_tokens);
// inference(round_4_tokens);
// sleep(5);
// Delete(std::vector<int>(round_4_tokens.begin(), round_4_tokens.begin() +
Expand Down

0 comments on commit df8825d

Please sign in to comment.