Skip to content

Commit

Permalink
refactor: estimate partial offloading
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Feb 11, 2025
1 parent a0e9c88 commit 6728851
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions file_estimate__llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,12 @@ func (gf *GGUFFile) estimateLLaMACppRunInModel(o *_GGUFRunEstimateOptions, a *GG
}
default:
loadAttnInc, offloadAttnInc := uint64(0), uint64(0)
{
rs := o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
loadAttnInc = rs // k-?
rs = o.LMCCacheValueType.RowSizeOf([]uint64{uint64(a.AttentionValueLength), nKV, a.AttentionHeadCountKV})
loadAttnInc += rs // v-?
}
if o.FlashAttention {
// https://github.com/ggerganov/llama.cpp/blob/172c8256840ffd882ab9992ecedbb587d9b21f15/llama.cpp#L7387.
offloadAttnInc = GGMLTypeF16.RowSizeOf([]uint64{nKV, nTokens})
Expand Down Expand Up @@ -597,32 +603,24 @@ func (gf *GGUFFile) estimateLLaMACppRunInModel(o *_GGUFRunEstimateOptions, a *GG
case strings.HasSuffix(l.Name, ".attn_q.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
offloadAttnInc += rs * 2 // Qcur.
loadAttnInc = rs // Vcur.
rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount})
offloadAttnInc += rs // kq.
if !zeroOffload && !fullOffload {
rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
offloadAttnInc += rs * 2 // k-?, v-?.
offloadAttnInc += loadAttnInc
}
case strings.HasSuffix(l.Name, ".attn_qkv.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
offloadAttnInc += rs * 2 // Qcur.
loadAttnInc = rs // Vcur.
rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount})
offloadAttnInc += rs // kq.
rs = GGMLTypeF32.RowSizeOf([]uint64{a.EmbeddingLength, a.EmbeddingLength * 3})
offloadAttnInc += rs // wqkv.
if !zeroOffload && !fullOffload {
rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
offloadAttnInc += rs * 2 // k-?, v-?.
offloadAttnInc += loadAttnInc
}
case strings.HasSuffix(l.Name, ".attn_q_b.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
offloadAttnInc += rs * 2 // q-?
rs = o.LMCCacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
loadAttnInc = rs // k-?
rs = o.LMCCacheValueType.RowSizeOf([]uint64{uint64(a.AttentionValueLength), nKV, a.AttentionHeadCountKV})
loadAttnInc += rs // v-?
rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount})
offloadAttnInc += rs // kq.
}
Expand Down

0 comments on commit 6728851

Please sign in to comment.