Skip to content

Commit

Permalink
Merge pull request #207 from bab2min/dev/optim_layout
Browse files Browse the repository at this point in the history
Move `ll` and `gamma` into `Node` structure
  • Loading branch information
bab2min authored Jan 5, 2025
2 parents 63efcad + 2e56b16 commit 600630c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 36 deletions.
4 changes: 1 addition & 3 deletions include/kiwi/Knlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace kiwi
KeyType num_nexts = 0;
DiffType lower = 0;
uint32_t next_offset = 0;
float ll = 0, gamma = 0;
};

class KnLangModelBase
Expand All @@ -56,9 +57,6 @@ namespace kiwi
virtual ptrdiff_t getLowerNode(ptrdiff_t node_idx) const = 0;

virtual size_t nonLeafNodeSize() const = 0;
virtual size_t llSize() const = 0;
virtual const float* getLLBuf() const = 0;
virtual const float* getGammaBuf() const = 0;
virtual const void* getExtraBuf() const = 0;

static std::unique_ptr<KnLangModelBase> create(utils::MemoryObject&& mem, ArchType archType = ArchType::none);
Expand Down
51 changes: 18 additions & 33 deletions src/Knlm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ namespace kiwi
{
static constexpr size_t serialAlignment = 16;


using QCode = qe::QCode<0, 2, 8, 16>;

template<size_t bits>
Expand Down Expand Up @@ -98,11 +97,8 @@ namespace kiwi
std::unique_ptr<DiffType[]> all_value_data;
size_t num_non_leaf_nodes = 0;
DiffType* value_data = nullptr;
const float* ll_data = nullptr;
const float* gamma_data = nullptr;
const KeyType* htx_data = nullptr;
const void* extra_buf = nullptr;
Vector<float> restored_floats;
float unk_ll = 0;
ptrdiff_t bos_node_idx = 0;

Expand Down Expand Up @@ -193,7 +189,9 @@ namespace kiwi
}

// restore ll & gamma data
Vector<float> restored_leaf_ll;
Vector<float> restored_leaf_ll, restored_floats;
const float* ll_data = nullptr;
const float* gamma_data = nullptr;
const float* leaf_ll_data = nullptr;
if (quantized)
{
Expand Down Expand Up @@ -262,6 +260,8 @@ namespace kiwi
}
node.num_nexts = node_sizes[i];
node.next_offset = next_offset;
node.ll = ll_data[non_leaf_idx];
node.gamma = gamma_data[non_leaf_idx];
next_offset += node_sizes[i];
key_ranges.emplace_back(std::array<size_t, 3>{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) });
non_leaf_idx++;
Expand Down Expand Up @@ -343,14 +343,14 @@ namespace kiwi
node->num_nexts, next, v
))
{
return gamma_data[node_idx] + getLL(node_idx + node->lower, next);
return node->gamma + getLL(node_idx + node->lower, next);
}
}

// non-leaf node
if (v > 0)
{
return ll_data[node_idx + v];
return node_data[node_idx + v].ll;
}
// leaf node
else
Expand Down Expand Up @@ -396,7 +396,7 @@ namespace kiwi
node->num_nexts, next, v
))
{
acc += gamma_data[node_idx];
acc += node->gamma;
node_idx += node->lower;
PREFETCH_T0(&key_data[node_data[node_idx].next_offset]);
continue;
Expand All @@ -407,7 +407,7 @@ namespace kiwi
if (v > 0)
{
node_idx += v;
return acc + ll_data[node_idx];
return acc + node_data[node_idx].ll;
}
// leaf node
else
Expand Down Expand Up @@ -456,16 +456,6 @@ namespace kiwi
return bos_node_idx;
}

const float* getLLBuf() const final
{
return ll_data;
}

const float* getGammaBuf() const final
{
return gamma_data;
}

const void* getExtraBuf() const final
{
return extra_buf;
Expand All @@ -481,11 +471,6 @@ namespace kiwi
return num_non_leaf_nodes;
}

size_t llSize() const final
{
return gamma_data - ll_data;
}

std::vector<float> allNextLL(ptrdiff_t node_idx) const final
{
std::vector<float> ret(getHeader().vocab_size, -INFINITY);
Expand All @@ -500,14 +485,14 @@ namespace kiwi
}
else
{
ret[keys[i]] = ll_data[node_idx + values[i]];
ret[keys[i]] = node_data[node_idx + values[i]].ll;
}
}

float acc = 0;
while (node->lower)
{
acc += gamma_data[node - &node_data[0]];
acc += node->gamma;
node += node->lower;
keys = &key_data[node->next_offset];
values = &value_data[node->next_offset];
Expand All @@ -520,7 +505,7 @@ namespace kiwi
}
else
{
ret[keys[i]] = acc + ll_data[node - &node_data[0] + values[i]];
ret[keys[i]] = acc + node[values[i]].ll;
}
}
}
Expand Down Expand Up @@ -550,7 +535,7 @@ namespace kiwi
}
else
{
ret[k] = ll_data[node_idx + v];
ret[k] = node_data[node_idx + v].ll;
}

if (htx_data)
Expand Down Expand Up @@ -590,7 +575,7 @@ namespace kiwi
float acc = 0;
while (node->lower)
{
acc += gamma_data[node - &node_data[0]];
acc += node->gamma;
node += node->lower;
keys = &key_data[node->next_offset];
values = &value_data[node->next_offset];
Expand All @@ -605,7 +590,7 @@ namespace kiwi
}
else
{
ret[k] = acc + ll_data[node - &node_data[0] + v];
ret[k] = acc + node[v].ll;
}

if (htx_data)
Expand Down Expand Up @@ -667,15 +652,15 @@ namespace kiwi
}
else
{
buf.emplace_back(ll_data[node_idx + values[i]], (KeyOut)keys[i]);
buf.emplace_back(node_data[node_idx + values[i]].ll, (KeyOut)keys[i]);
}
}
std::make_heap(buf.begin(), buf.end());

float acc = 0;
while (node->num_nexts < top_n && node->lower)
{
acc += gamma_data[node - &node_data[0]];
acc += node->gamma;
node += node->lower;
keys = &key_data[node->next_offset];
values = &value_data[node->next_offset];
Expand All @@ -687,7 +672,7 @@ namespace kiwi
}
else
{
buf.emplace_back(acc + ll_data[node - &node_data[0] + values[i]], (KeyOut)keys[i]);
buf.emplace_back(acc + node[values[i]].ll, (KeyOut)keys[i]);
}
std::push_heap(buf.begin(), buf.end());
}
Expand Down

0 comments on commit 600630c

Please sign in to comment.