Skip to content

Commit bf2b8ba

Browse files
committed
kv-cache : fix split_equal handling in unified implementation
ggml-ci
1 parent 89a184f commit bf2b8ba

File tree

3 files changed

+122
-65
lines changed

3 files changed

+122
-65
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,8 @@ int llama_context::encode(llama_batch & inp_batch) {
877877
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
878878

879879
// remember the sequence ids used during the encoding - needed for cross attention later
880+
// TODO: the seuqence indexing here is likely not correct in the general case
881+
// probably works only for split_simple
880882
cross.seq_ids_enc.resize(n_tokens);
881883
for (int32_t i = 0; i < n_tokens; i++) {
882884
cross.seq_ids_enc[i].clear();

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,33 +98,66 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9898
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
9999
GGML_UNUSED(embd_pooled);
100100

101-
// TODO: if we fail with split_simple, we should attempt different splitting strategies
102-
// but to do that properly, we first have to refactor the batches to be more flexible
101+
// first try simple split
102+
do {
103+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
103104

104-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105+
std::vector<llama_ubatch> ubatches;
105106

106-
std::vector<llama_ubatch> ubatches;
107+
while (sbatch.n_tokens > 0) {
108+
auto ubatch = sbatch.split_simple(n_ubatch);
107109

108-
while (sbatch.n_tokens > 0) {
109-
auto ubatch = sbatch.split_simple(n_ubatch);
110+
ubatches.push_back(ubatch);
111+
}
110112

111-
ubatches.push_back(ubatch);
112-
}
113+
auto heads_base = kv_base->prepare(ubatches);
114+
if (heads_base.empty()) {
115+
break;
116+
}
113117

114-
auto heads_base = kv_base->prepare(ubatches);
115-
if (heads_base.empty()) {
116-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117-
}
118+
auto heads_swa = kv_swa->prepare(ubatches);
119+
if (heads_swa.empty()) {
120+
break;
121+
}
118122

119-
auto heads_swa = kv_swa->prepare(ubatches);
120-
if (heads_swa.empty()) {
121-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122-
}
123+
assert(heads_base.size() == heads_swa.size());
124+
125+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127+
} while (false);
128+
129+
// if it fails, try equal split
130+
do {
131+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
132+
133+
std::vector<llama_ubatch> ubatches;
123134

124-
assert(heads_base.size() == heads_swa.size());
135+
while (sbatch.n_tokens > 0) {
136+
auto ubatch = sbatch.split_equal(n_ubatch);
137+
138+
ubatches.push_back(ubatch);
139+
}
140+
141+
auto heads_base = kv_base->prepare(ubatches);
142+
if (heads_base.empty()) {
143+
break;
144+
}
145+
146+
auto heads_swa = kv_swa->prepare(ubatches);
147+
if (heads_swa.empty()) {
148+
break;
149+
}
150+
151+
assert(heads_base.size() == heads_swa.size());
152+
153+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
154+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155+
} while (false);
156+
157+
// TODO: if we fail again, we should attempt different splitting strategies
158+
// but to do that properly, we first have to refactor the batches to be more flexible
125159

126-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
127-
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
160+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
128161
}
129162

130163
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {

src/llama-kv-cache-unified.cpp

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
314314
bool logits_all) {
315315
GGML_UNUSED(embd_pooled);
316316

317-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
317+
do {
318+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
318319

319-
std::vector<llama_ubatch> ubatches;
320-
while (sbatch.n_tokens > 0) {
321-
ubatches.push_back(sbatch.split_simple(n_ubatch));
322-
}
320+
std::vector<llama_ubatch> ubatches;
321+
while (sbatch.n_tokens > 0) {
322+
ubatches.push_back(sbatch.split_simple(n_ubatch));
323+
}
323324

324-
auto heads = prepare(ubatches);
325-
if (heads.empty()) {
326-
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
327-
}
325+
auto heads = prepare(ubatches);
326+
if (heads.empty()) {
327+
break;
328+
}
328329

329-
return std::make_unique<llama_kv_cache_unified_state>(
330-
this, std::move(sbatch), std::move(heads), std::move(ubatches));
330+
return std::make_unique<llama_kv_cache_unified_state>(
331+
this, std::move(sbatch), std::move(heads), std::move(ubatches));
332+
} while (false);
333+
334+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
331335
}
332336

333337
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
@@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
521525
}
522526

523527
if (debug > 0) {
524-
LLAMA_LOG_CONT("\n");
525528
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
526529

527530
if ((debug == 2 && n_swa > 0) || debug > 2) {
@@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
530533
if (cells.is_empty(i)) {
531534
ss += '.';
532535
} else {
533-
ss += std::to_string(cells.seq_get(i));
536+
assert(cells.seq_count(i) >= 1);
537+
538+
if (cells.seq_count(i) == 1) {
539+
ss += std::to_string(cells.seq_get(i));
540+
} else {
541+
ss += 'M';
542+
}
534543
}
535544
if (i%256 == 255) {
536545
ss += " *";
@@ -636,29 +645,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
636645
}
637646

638647
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
648+
if (debug > 0) {
649+
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
650+
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
651+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
652+
}
653+
639654
// keep track of the max sequence position that we would overwrite with this ubatch
640655
// for non-SWA cache, this would be always empty
641656
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642657
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643658
seq_pos_max_rm[s] = -1;
644659
}
645660

646-
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
647-
if (!cells.is_empty(head_cur + i)) {
648-
assert(cells.seq_count(head_cur + i) == 1);
661+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
662+
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
663+
const uint32_t idx = s*ubatch.n_seq_tokens + j;
649664

650-
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
651-
const llama_pos pos = cells.pos_get(head_cur + i);
665+
if (!cells.is_empty(head_cur + idx)) {
666+
assert(cells.seq_count(head_cur + idx) == 1);
652667

653-
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
668+
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
669+
const llama_pos pos = cells.pos_get(head_cur + idx);
654670

655-
cells.rm(head_cur + i);
656-
}
671+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
672+
673+
cells.rm(head_cur + idx);
674+
}
657675

658-
cells.pos_set(head_cur + i, ubatch.pos[i]);
676+
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
659677

660-
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
661-
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
678+
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679+
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680+
}
662681
}
663682
}
664683

@@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
677696
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
678697
}
679698
}
680-
681699
// move the head at the end of the slot
682700
head = head_cur + ubatch.n_tokens;
683701
}
@@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
774792
}
775793

776794
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
777-
const int64_t n_tokens = ubatch->n_tokens;
778-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
779-
const int64_t n_seqs = ubatch->n_seqs;
795+
const uint32_t n_tokens = ubatch->n_tokens;
796+
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797+
const uint32_t n_seqs = ubatch->n_seqs;
780798

781799
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
782800
float * data = (float *) dst->data;
783801

784-
const auto n_kv = dst->ne[0];
802+
const int64_t n_kv = dst->ne[0];
785803

786804
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
787805
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
795813
// xxxxx-----
796814
// xxxxx-----
797815
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
798-
for (int h = 0; h < 1; ++h) {
799-
for (int s = 0; s < n_seqs; ++s) {
816+
for (uint32_t h = 0; h < 1; ++h) {
817+
for (uint32_t s = 0; s < n_seqs; ++s) {
800818
const llama_seq_id seq_id = ubatch->seq_id[s][0];
801819

802-
for (int j = 0; j < n_seq_tokens; ++j) {
803-
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
820+
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821+
const uint32_t idx = s*n_seq_tokens + j;
822+
823+
const llama_pos p1 = ubatch->pos[idx];
804824

805825
for (uint32_t i = 0; i < n_kv; ++i) {
806826
float f = 0.0f;
@@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
830850
f = -INFINITY;
831851
}
832852

833-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
853+
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
834854
}
835855
}
836856
}
837857

838858
// mask padded tokens
839859
if (data) {
840-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
841-
for (uint32_t j = 0; j < n_kv; ++j) {
842-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
860+
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861+
for (uint32_t i = 0; i < n_kv; ++i) {
862+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
843863
}
844864
}
845865
}
@@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
14901510
seq_rm(dest_seq_id, -1, -1);
14911511

14921512
llama_sbatch sbatch;
1493-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1513+
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
14941514

1495-
batch.n_tokens = cell_count;
1515+
ubatch.n_tokens = cell_count;
1516+
ubatch.n_seq_tokens = cell_count;
1517+
ubatch.n_seqs = 1;
14961518

14971519
for (uint32_t i = 0; i < cell_count; ++i) {
14981520
llama_pos pos;
@@ -1512,27 +1534,27 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15121534
io.read_to(&seq_id, sizeof(seq_id));
15131535
}
15141536

1515-
batch.pos[i] = pos;
1516-
batch.n_seq_id[i] = n_seq_id;
1517-
batch.seq_id[i] = &dest_seq_id;
1537+
ubatch.pos[i] = pos;
1538+
ubatch.n_seq_id[i] = n_seq_id;
1539+
ubatch.seq_id[i] = &dest_seq_id;
15181540
}
15191541

1520-
const auto head_cur = find_slot(batch);
1542+
const auto head_cur = find_slot(ubatch);
15211543
if (head_cur < 0) {
15221544
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
15231545
return false;
15241546
}
15251547

1526-
apply_ubatch(head_cur, batch);
1548+
apply_ubatch(head_cur, ubatch);
15271549

15281550
// keep the head at the old position because we will read the KV data into it in state_read_data()
15291551
head = head_cur;
15301552

15311553
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
15321554
// Assume that this is one contiguous block of cells
15331555
GGML_ASSERT(head_cur + cell_count <= cells.size());
1534-
GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
1535-
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
1556+
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1557+
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
15361558
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
15371559
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
15381560
} else {

0 commit comments

Comments
 (0)