@@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
314
314
bool logits_all) {
315
315
GGML_UNUSED (embd_pooled);
316
316
317
- auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
317
+ do {
318
+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
318
319
319
- std::vector<llama_ubatch> ubatches;
320
- while (sbatch.n_tokens > 0 ) {
321
- ubatches.push_back (sbatch.split_simple (n_ubatch));
322
- }
320
+ std::vector<llama_ubatch> ubatches;
321
+ while (sbatch.n_tokens > 0 ) {
322
+ ubatches.push_back (sbatch.split_simple (n_ubatch));
323
+ }
323
324
324
- auto heads = prepare (ubatches);
325
- if (heads.empty ()) {
326
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE) ;
327
- }
325
+ auto heads = prepare (ubatches);
326
+ if (heads.empty ()) {
327
+ break ;
328
+ }
328
329
329
- return std::make_unique<llama_kv_cache_unified_state>(
330
- this , std::move (sbatch), std::move (heads), std::move (ubatches));
330
+ return std::make_unique<llama_kv_cache_unified_state>(
331
+ this , std::move (sbatch), std::move (heads), std::move (ubatches));
332
+ } while (false );
333
+
334
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
331
335
}
332
336
333
337
llama_memory_state_ptr llama_kv_cache_unified::init_full () {
@@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
521
525
}
522
526
523
527
if (debug > 0 ) {
524
- LLAMA_LOG_CONT (" \n " );
525
528
LLAMA_LOG_DEBUG (" %s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n " , __func__, cells.used_max_p1 (), cells.get_used (), head, get_size (), n_swa);
526
529
527
530
if ((debug == 2 && n_swa > 0 ) || debug > 2 ) {
@@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
530
533
if (cells.is_empty (i)) {
531
534
ss += ' .' ;
532
535
} else {
533
- ss += std::to_string (cells.seq_get (i));
536
+ assert (cells.seq_count (i) >= 1 );
537
+
538
+ if (cells.seq_count (i) == 1 ) {
539
+ ss += std::to_string (cells.seq_get (i));
540
+ } else {
541
+ ss += ' M' ;
542
+ }
534
543
}
535
544
if (i%256 == 255 ) {
536
545
ss += " *" ;
@@ -636,29 +645,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
636
645
}
637
646
638
647
void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
648
+ if (debug > 0 ) {
649
+ LLAMA_LOG_DEBUG (" %s: ubatch info:\n " , __func__);
650
+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d, equal_seqs = %d\n " , __func__, ubatch.n_tokens , ubatch.equal_seqs );
651
+ LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d, n_seqs = %d\n " , __func__, ubatch.n_seq_tokens , ubatch.n_seqs );
652
+ }
653
+
639
654
// keep track of the max sequence position that we would overwrite with this ubatch
640
655
// for non-SWA cache, this would be always empty
641
656
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642
657
for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643
658
seq_pos_max_rm[s] = -1 ;
644
659
}
645
660
646
- for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i ) {
647
- if (!cells. is_empty (head_cur + i) ) {
648
- assert (cells. seq_count (head_cur + i) == 1 ) ;
661
+ for (uint32_t s = 0 ; s < ubatch.n_seqs ; ++s ) {
662
+ for ( uint32_t j = 0 ; j < ubatch. n_seq_tokens ; ++j ) {
663
+ const uint32_t idx = s*ubatch. n_seq_tokens + j ;
649
664
650
- const llama_seq_id seq_id = cells.seq_get (head_cur + i);
651
- const llama_pos pos = cells.pos_get (head_cur + i );
665
+ if (! cells.is_empty (head_cur + idx)) {
666
+ assert ( cells.seq_count (head_cur + idx) == 1 );
652
667
653
- seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
668
+ const llama_seq_id seq_id = cells.seq_get (head_cur + idx);
669
+ const llama_pos pos = cells.pos_get (head_cur + idx);
654
670
655
- cells.rm (head_cur + i);
656
- }
671
+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
672
+
673
+ cells.rm (head_cur + idx);
674
+ }
657
675
658
- cells.pos_set (head_cur + i , ubatch.pos [i ]);
676
+ cells.pos_set (head_cur + idx , ubatch.pos [idx ]);
659
677
660
- for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
661
- cells.seq_add (head_cur + i, ubatch.seq_id [i][j]);
678
+ for (int32_t i = 0 ; i < ubatch.n_seq_id [s]; i++) {
679
+ cells.seq_add (head_cur + idx, ubatch.seq_id [s][i]);
680
+ }
662
681
}
663
682
}
664
683
@@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
677
696
seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
678
697
}
679
698
}
680
-
681
699
// move the head at the end of the slot
682
700
head = head_cur + ubatch.n_tokens ;
683
701
}
@@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
774
792
}
775
793
776
794
void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
777
- const int64_t n_tokens = ubatch->n_tokens ;
778
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
779
- const int64_t n_seqs = ubatch->n_seqs ;
795
+ const uint32_t n_tokens = ubatch->n_tokens ;
796
+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens ;
797
+ const uint32_t n_seqs = ubatch->n_seqs ;
780
798
781
799
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
782
800
float * data = (float *) dst->data ;
783
801
784
- const auto n_kv = dst->ne [0 ];
802
+ const int64_t n_kv = dst->ne [0 ];
785
803
786
804
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
787
805
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
795
813
// xxxxx-----
796
814
// xxxxx-----
797
815
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
798
- for (int h = 0 ; h < 1 ; ++h) {
799
- for (int s = 0 ; s < n_seqs; ++s) {
816
+ for (uint32_t h = 0 ; h < 1 ; ++h) {
817
+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
800
818
const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
801
819
802
- for (int j = 0 ; j < n_seq_tokens; ++j) {
803
- const llama_pos p1 = ubatch->pos [s*n_seq_tokens + j];
820
+ for (uint32_t j = 0 ; j < n_seq_tokens; ++j) {
821
+ const uint32_t idx = s*n_seq_tokens + j;
822
+
823
+ const llama_pos p1 = ubatch->pos [idx];
804
824
805
825
for (uint32_t i = 0 ; i < n_kv; ++i) {
806
826
float f = 0 .0f ;
@@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
830
850
f = -INFINITY;
831
851
}
832
852
833
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j *n_kv + i] = f;
853
+ data[h*(n_kv*n_tokens) + idx *n_kv + i] = f;
834
854
}
835
855
}
836
856
}
837
857
838
858
// mask padded tokens
839
859
if (data) {
840
- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i ) {
841
- for (uint32_t j = 0 ; j < n_kv; ++j ) {
842
- data[h*(n_kv*n_tokens) + i *n_kv + j ] = -INFINITY;
860
+ for (uint32_t j = n_tokens; j < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++j ) {
861
+ for (uint32_t i = 0 ; i < n_kv; ++i ) {
862
+ data[h*(n_kv*n_tokens) + j *n_kv + i ] = -INFINITY;
843
863
}
844
864
}
845
865
}
@@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1490
1510
seq_rm (dest_seq_id, -1 , -1 );
1491
1511
1492
1512
llama_sbatch sbatch;
1493
- llama_ubatch batch = sbatch.reserve_ubatch (cell_count, /* has_embd */ false );
1513
+ llama_ubatch ubatch = sbatch.reserve_ubatch (cell_count, /* has_embd */ false );
1494
1514
1495
- batch.n_tokens = cell_count;
1515
+ ubatch.n_tokens = cell_count;
1516
+ ubatch.n_seq_tokens = cell_count;
1517
+ ubatch.n_seqs = 1 ;
1496
1518
1497
1519
for (uint32_t i = 0 ; i < cell_count; ++i) {
1498
1520
llama_pos pos;
@@ -1512,27 +1534,27 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1512
1534
io.read_to (&seq_id, sizeof (seq_id));
1513
1535
}
1514
1536
1515
- batch .pos [i] = pos;
1516
- batch .n_seq_id [i] = n_seq_id;
1517
- batch .seq_id [i] = &dest_seq_id;
1537
+ ubatch .pos [i] = pos;
1538
+ ubatch .n_seq_id [i] = n_seq_id;
1539
+ ubatch .seq_id [i] = &dest_seq_id;
1518
1540
}
1519
1541
1520
- const auto head_cur = find_slot (batch );
1542
+ const auto head_cur = find_slot (ubatch );
1521
1543
if (head_cur < 0 ) {
1522
1544
LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
1523
1545
return false ;
1524
1546
}
1525
1547
1526
- apply_ubatch (head_cur, batch );
1548
+ apply_ubatch (head_cur, ubatch );
1527
1549
1528
1550
// keep the head at the old position because we will read the KV data into it in state_read_data()
1529
1551
head = head_cur;
1530
1552
1531
1553
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1532
1554
// Assume that this is one contiguous block of cells
1533
1555
GGML_ASSERT (head_cur + cell_count <= cells.size ());
1534
- GGML_ASSERT (cells.pos_get (head_cur) == batch .pos [0 ]);
1535
- GGML_ASSERT (cells.pos_get (head_cur + cell_count - 1 ) == batch .pos [cell_count - 1 ]);
1556
+ GGML_ASSERT (cells.pos_get (head_cur) == ubatch .pos [0 ]);
1557
+ GGML_ASSERT (cells.pos_get (head_cur + cell_count - 1 ) == ubatch .pos [cell_count - 1 ]);
1536
1558
GGML_ASSERT (cells.seq_has (head_cur, dest_seq_id));
1537
1559
GGML_ASSERT (cells.seq_has (head_cur + cell_count - 1 , dest_seq_id));
1538
1560
} else {
0 commit comments