Skip to content

Commit

Permalink
Simplify block manager (#812)
Browse files Browse the repository at this point in the history
* simplify block manager

* fix lint
  • Loading branch information
lzhangzz authored Dec 11, 2023
1 parent 2d5f5b3 commit a54b16a
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 140 deletions.
144 changes: 83 additions & 61 deletions src/turbomind/models/llama/BlockManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/string_utils.h"
#include <algorithm>
#include <iterator>
#include <stdexcept>
Expand Down Expand Up @@ -70,7 +71,6 @@ bool BlockManager::Malloc()
for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
auto& block = blocks_.emplace_back();
block.use_count = 0;
block.ref_count = 0;
block.id = (int)blocks_.size() - 1;
block.timestamp = 0;
block.data = ptr;
Expand All @@ -91,47 +91,54 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio)

void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
{
FT_CHECK(src.size() >= delta.size());
std::vector<int> src1(src.size() - delta.size());
std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
{
auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
FT_CHECK(end == src1.end());
}
src.swap(src1);

std::vector<int> dst1(dst.size() + delta.size());
std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
{
auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
FT_CHECK(end == dst1.end());
}
dst.swap(dst1);
}

std::vector<const Block*> BlockManager::Allocate(int count)
auto BlockManager::Allocate(int count) -> std::pair<BlockIds, UniqueIds>
{
while (free_ids_.size() < count) {
if (!Malloc()) {
throw std::runtime_error("out of memory");
}
}

std::vector<const Block*> ret;

std::vector<int> idxs(count);
BlockIds block_ids(count);
UniqueIds unique_ids(count);

for (int i = 0; i < count; ++i) {
int idx = free_ids_[i];
idxs[i] = idx;
auto& block = blocks_[idx];
FT_CHECK(is_free(block));
block.ref_count = 1;
block.use_count = 1;
block.unique_id = unique_id_++;
ret.push_back(&block);
int idx = free_ids_[i];
auto& b = blocks_[idx];
FT_CHECK(is_free(b)); // pre-condition: uc == 0 && ts == 0
b.use_count = 1;
b.unique_id = unique_id_++;
FT_CHECK(is_active(b)); // post-condition
block_ids[i] = idx;
unique_ids[i] = b.unique_id;
}

Move(free_ids_, idxs, active_ids_);
Move(free_ids_, block_ids, active_ids_);

dbg(free_ids_, active_ids_);

return ret;
return {block_ids, unique_ids};
}

void BlockManager::Evict(int count)
{
FT_CHECK(count <= cached_ids_.size());
std::vector<int> idxs(cached_ids_);
// get first `count` cached ids according to timestamp
std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
Expand All @@ -146,89 +153,104 @@ void BlockManager::Evict(int count)
for (const auto& idx : idxs) {
auto& b = blocks_[idx];
FT_CHECK(is_cached(b));
b.ref_count = 0;
b.unique_id = 0;
b.timestamp = 0;
FT_CHECK(is_free(b));
}

Move(cached_ids_, idxs, free_ids_);

dbg(cached_ids_, free_ids_);
}

int BlockManager::Free(const std::vector<const Block*>& bs)
void BlockManager::Free(BlockIds ids)
{
std::vector<int> idxs;
std::sort(ids.begin(), ids.end());

for (const auto& p : bs) {
auto& b = blocks_[p->id];
FT_CHECK(is_cached(b));
if (--b.ref_count == 0) {
b.unique_id = 0;
b.timestamp = 0;
idxs.push_back(b.id);
}
for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_cached(b)); // uc == 0 && ts != 0
b.unique_id = 0;
b.timestamp = 0;
FT_CHECK(is_free(b));
}

std::sort(idxs.begin(), idxs.end());

Move(cached_ids_, idxs, free_ids_);

dbg(cached_ids_, free_ids_);

return idxs.size();
Move(cached_ids_, ids, free_ids_);
}

int BlockManager::Unlock(const std::vector<const Block*>& bs)
int BlockManager::Unlock(const BlockIds& ids)
{
std::vector<int> idxs;

for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_active(block));
if (--block.use_count == 0) {
idxs.push_back(block.id);
BlockIds unlock;
unlock.reserve(ids.size());

for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_active(b)); // pre-condition: uc > 0
if (--b.use_count == 0) {
unlock.push_back(b.id);
FT_CHECK(is_cached(b)); // post-condition
}
}

std::sort(idxs.begin(), idxs.end());
std::sort(unlock.begin(), unlock.end());

Move(active_ids_, idxs, cached_ids_);
Move(active_ids_, unlock, cached_ids_);

dbg(active_ids_, cached_ids_);

return idxs.size();
return unlock.size();
}

int BlockManager::Lock(const std::vector<const Block*>& bs)
int BlockManager::Lock(const BlockIds& ids)
{
std::vector<int> idxs;
BlockIds lock;
lock.reserve(ids.size());

for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_cached(block));
if (++block.use_count == 1) {
idxs.push_back(p->id);
for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_cached(b));
if (++b.use_count == 1) {
lock.push_back(i);
FT_CHECK(is_active(b));
}
}

std::sort(idxs.begin(), idxs.end());
std::sort(lock.begin(), lock.end());

Move(cached_ids_, idxs, active_ids_);
Move(cached_ids_, lock, active_ids_);

// dbg(cached_ids_, active_ids_);

return idxs.size();
return lock.size();
}

void BlockManager::Touch(const std::vector<const Block*>& bs)
void BlockManager::Touch(const BlockIds& ids)
{
std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) {
FT_CHECK(is_active(*p));
const_cast<Block*>(p)->timestamp = timestamp_++;
std::for_each(ids.crbegin(), ids.crend(), [this](int i) {
FT_CHECK(is_active(blocks_[i]));
blocks_[i].timestamp = timestamp_++;
});
}

int BlockManager::Verify(const std::vector<int>& block_ids, const std::vector<uint64_t>& unique_ids)
{
FT_CHECK(block_ids.size() == unique_ids.size());
int valid = block_ids.size();
for (int i = 0; i < block_ids.size(); ++i) {
if (unique_id(block_ids[i]) != unique_ids[i]) {
valid = i;
break;
}
}
int miss = 0;
for (int i = valid; i < block_ids.size(); ++i) {
miss += (unique_id(block_ids[i]) != unique_ids[i]);
}
// All later blocks should have been invalidated
FT_CHECK_WITH_INFO(miss == (int)block_ids.size() - valid,
fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss));
return valid;
}

Snapshot BlockManager::TakeSnapshot()
{
std::vector<int> use_count(blocks_.size());
Expand Down
49 changes: 35 additions & 14 deletions src/turbomind/models/llama/BlockManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include <unordered_map>
#include <vector>

Expand All @@ -22,28 +23,37 @@ namespace turbomind {

struct Block {
int id; // fixed linear id in the pool
int ref_count; // all sequences referencing the block
int use_count; // active sequences using the block
uint64_t unique_id; // unique for every block allocation
uint64_t timestamp;
void* data;

friend std::ostream& operator<<(std::ostream& os, const Block& block);
friend std::string to_string(const Block& b)
{
std::stringstream ss;
ss << b;
return ss.str();
}
};

using BlockIds = std::vector<int>;
using UniqueIds = std::vector<uint64_t>;

inline bool is_active(const Block& block)
{
return block.ref_count > 0 && block.use_count > 0;
// timestamp may be 0 for newly allocated block that has not been written
return block.use_count > 0;
}

inline bool is_cached(const Block& block)
{
return block.ref_count > 0 && block.use_count == 0;
return block.use_count == 0 && block.timestamp != 0;
}

inline bool is_free(const Block& block)
{
return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0;
return block.use_count == 0 && block.timestamp == 0;
}

struct Snapshot {
Expand All @@ -60,22 +70,24 @@ class BlockManager {
~BlockManager();

// free -> active (use_count = 1, ref_count = 1)
[[nodiscard]] std::vector<const Block*> Allocate(int count);
[[nodiscard]] std::pair<BlockIds, UniqueIds> Allocate(int count);

// cached -> active (use_count += 1)
[[maybe_unused]] int Lock(const std::vector<const Block*>& bs);
[[maybe_unused]] int Lock(const BlockIds& ids);

// active -> cached (use_count -= 1)
[[maybe_unused]] int Unlock(const std::vector<const Block*>& bs);
[[maybe_unused]] int Unlock(const BlockIds& ids);

// cached -> free (ref_count = 0)
void Evict(int count);

// cached -> free (ref_count -= 1)
[[maybe_unused]] int Free(const std::vector<const Block*>& bs);
void Free(BlockIds bs);

// increase timestamp in reversed order
void Touch(const std::vector<const Block*>& bs);
void Touch(const BlockIds& bs);

[[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);

Snapshot TakeSnapshot();

Expand All @@ -99,13 +111,23 @@ class BlockManager {
return (max_block_count_ - blocks_.size()) + free_ids_.size();
}

Block& block(int idx)
{
return blocks_[idx];
}

int unique_id(int idx)
{
return blocks_[idx].unique_id;
}

friend std::ostream& operator<<(std::ostream& os, const BlockManager&);

private:
static size_t GetBlockCount(size_t block_size, double ratio);

// move indices between sets
static void Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst);
static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);

// allocate a chunk of blocks
bool Malloc();
Expand All @@ -118,13 +140,12 @@ class BlockManager {

std::vector<void*> chunks_;

std::vector<int> active_ids_;
std::vector<int> cached_ids_;
std::vector<int> free_ids_;
BlockIds active_ids_;
BlockIds cached_ids_;
BlockIds free_ids_;

std::vector<Block> blocks_; // < 100k

// uint64_t unique_id_{1UL << 63};
uint64_t unique_id_{1};
uint64_t timestamp_{1};
};
Expand Down
8 changes: 4 additions & 4 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ void LlamaBatch<T>::Initialize(GenerationState& g)
FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
std::to_string(h_cu_block_counts_[i + 1]));

k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->GetKeyPtr(block_id));
});
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->GetValPtr(block_id));
});
}

Expand Down
Loading

0 comments on commit a54b16a

Please sign in to comment.