Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
yirongjie authored Aug 6, 2024
2 parents 08453c9 + a1b451e commit 2dba414
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 102 deletions.
6 changes: 3 additions & 3 deletions src/backends/cpu/compute/GEMM_AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,7 @@ void _mllm_gemm_q4_0_4x4_q8_0_bias(int n, float *__restrict s, size_t bs, const
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 *b_ptr = (const block_q4_0x4 *)vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[(y * 4 + m) * bs + x * ncols_interleaved + j];
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[x * ncols_interleaved + j];
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
Expand Down Expand Up @@ -3047,7 +3047,7 @@ void _mllm_gemm_q4_0_4x8_q8_0_bias(int n, float *__restrict s, size_t bs, const
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 *b_ptr = (const block_q4_0x4 *)vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[(y * 4 + m) * bs + x * ncols_interleaved + j];
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[x * ncols_interleaved + j];
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
Expand Down Expand Up @@ -4009,7 +4009,7 @@ void _mllm_gemm_q4_0_8x8_q8_0_bias(int n, float *__restrict s, size_t bs, const
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x8 *b_ptr = (const block_q4_0x8 *)vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[(y * 4 + m) * bs + x * ncols_interleaved + j];
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = bias_ptr[x * ncols_interleaved + j];
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
Expand Down
22 changes: 7 additions & 15 deletions src/backends/cpu/compute/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
auto src0_type_size = type_size(src0->dtype());
auto src0_blck_size = blck_size(src0->dtype());
#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && !support_bias) {
if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype())) {
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand All @@ -210,7 +210,9 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
id, thread_count,
src1->dtype(),
src0->dtype(),
dst->dtype());
dst->dtype(),
/*bias=*/support_bias ? bias->hostPtr<float>() : nullptr,
/*BiasType=*/support_bias ? bias->dtype() : DataType::MLLM_TYPE_F32);
}
}
}
Expand Down Expand Up @@ -276,19 +278,9 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
id, thread_count,
src1->dtype(),
src0->dtype(),
dst->dtype());
}
}
}
if (support_bias) {
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int b = 0; b < dst->batch(); b++) {
for (int h = 0; h < dst->head(); h++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
*dst->ptrAt<float>(b, h, m, n) += bias->dataAt<float>(0, 0, 0, n);
}
}
dst->dtype(),
/*bias=*/support_bias ? bias->hostPtr<float>() + bias->offset(b, h, 0, 0) : nullptr,
/*BiasType=*/support_bias ? bias->dtype() : DataType::MLLM_TYPE_F32);
}
}
}
Expand Down
Loading

0 comments on commit 2dba414

Please sign in to comment.