From 672885119e40a4be05e03fdef8c9a363a883edef Mon Sep 17 00:00:00 2001 From: thxCode Date: Tue, 11 Feb 2025 22:12:58 +0800 Subject: [PATCH] refactor: estimate partial offloading Signed-off-by: thxCode --- file_estimate__llamacpp.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/file_estimate__llamacpp.go b/file_estimate__llamacpp.go index 2e5c6d3..3bd26ad 100644 --- a/file_estimate__llamacpp.go +++ b/file_estimate__llamacpp.go @@ -568,6 +568,12 @@ func (gf *GGUFFile) estimateLLaMACppRunInModel(o *_GGUFRunEstimateOptions, a *GG } default: loadAttnInc, offloadAttnInc := uint64(0), uint64(0) + { + rs := o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV}) + loadAttnInc = rs // k-? + rs = o.LMCCacheValueType.RowSizeOf([]uint64{uint64(a.AttentionValueLength), nKV, a.AttentionHeadCountKV}) + loadAttnInc += rs // v-? + } if o.FlashAttention { // https://github.com/ggerganov/llama.cpp/blob/172c8256840ffd882ab9992ecedbb587d9b21f15/llama.cpp#L7387. offloadAttnInc = GGMLTypeF16.RowSizeOf([]uint64{nKV, nTokens}) @@ -597,32 +603,24 @@ func (gf *GGUFFile) estimateLLaMACppRunInModel(o *_GGUFRunEstimateOptions, a *GG case strings.HasSuffix(l.Name, ".attn_q.weight"): rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens}) offloadAttnInc += rs * 2 // Qcur. - loadAttnInc = rs // Vcur. rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount}) offloadAttnInc += rs // kq. if !zeroOffload && !fullOffload { - rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV}) - offloadAttnInc += rs * 2 // k-?, v-?. + offloadAttnInc += loadAttnInc } case strings.HasSuffix(l.Name, ".attn_qkv.weight"): rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens}) offloadAttnInc += rs * 2 // Qcur. - loadAttnInc = rs // Vcur. rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount}) offloadAttnInc += rs // kq. rs = GGMLTypeF32.RowSizeOf([]uint64{a.EmbeddingLength, a.EmbeddingLength * 3}) offloadAttnInc += rs // wqkv. if !zeroOffload && !fullOffload { - rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV}) - offloadAttnInc += rs * 2 // k-?, v-?. + offloadAttnInc += loadAttnInc } case strings.HasSuffix(l.Name, ".attn_q_b.weight"): rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens}) offloadAttnInc += rs * 2 // q-? - rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV}) - loadAttnInc = rs // k-? - rs = o.LMCCacheValueType.RowSizeOf([]uint64{uint64(a.AttentionValueLength), nKV, a.AttentionHeadCountKV}) - loadAttnInc += rs // v-? rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount}) offloadAttnInc += rs // kq. }