Skip to content

Commit

Permalink
feat: sync llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Feb 23, 2024
1 parent 40bea05 commit 7bdb22d
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 55 deletions.
27 changes: 20 additions & 7 deletions cpp/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,23 @@ extern "C" {
//
#include <arm_neon.h>

#define LM_GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
#define LM_GGML_COMPUTE_FP32_TO_FP16(x) (x)
#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x)
#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x)

#define LM_GGML_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x)

static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) {
__fp16 tmp;
memcpy(&tmp, &h, sizeof(lm_ggml_fp16_t));
return (float)tmp;
}

#define LM_GGML_FP16_TO_FP32(x) ((float) (x))
#define LM_GGML_FP32_TO_FP16(x) (x)
static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) {
lm_ggml_fp16_t res;
__fp16 tmp = f;
memcpy(&res, &tmp, sizeof(lm_ggml_fp16_t));
return res;
}

#else

Expand Down Expand Up @@ -214,17 +226,18 @@ extern float lm_ggml_table_f32_f16[1 << 16];
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into lm_ggml_lookup_fp16_to_fp32,
// so we define LM_GGML_FP16_TO_FP32 and LM_GGML_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
#if !defined(LM_GGML_FP16_TO_FP32) || !defined(LM_GGML_FP32_TO_FP16)

#if !defined(LM_GGML_FP16_TO_FP32)
inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) {
uint16_t s;
memcpy(&s, &f, sizeof(uint16_t));
return lm_ggml_table_f32_f16[s];
}

#define LM_GGML_FP16_TO_FP32(x) lm_ggml_lookup_fp16_to_fp32(x)
#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x)
#endif

#if !defined(LM_GGML_FP32_TO_FP16)
#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x)
#endif

#define LM_GGML_HASHTABLE_FULL ((size_t)-1)
Expand Down
65 changes: 45 additions & 20 deletions cpp/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,30 @@ inline static lm_ggml_int8x16x4_t lm_ggml_vld1q_s8_x4(const int8_t * ptr) {
return res;
}

// NOTE: not tested
inline static int8x16_t lm_ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
int8x16_t res;

res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];

return res;
}

#else

#define lm_ggml_int16x8x2_t int16x8x2_t
Expand All @@ -451,6 +475,7 @@ inline static lm_ggml_int8x16x4_t lm_ggml_vld1q_s8_x4(const int8_t * ptr) {
#define lm_ggml_vld1q_u8_x4 vld1q_u8_x4
#define lm_ggml_vld1q_s8_x2 vld1q_s8_x2
#define lm_ggml_vld1q_s8_x4 vld1q_s8_x4
#define lm_ggml_vqtbl1q_s8 vqtbl1q_s8

#endif

Expand Down Expand Up @@ -5629,8 +5654,8 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);

const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
Expand Down Expand Up @@ -5779,8 +5804,8 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);

const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
Expand Down Expand Up @@ -6433,7 +6458,7 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void

int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);

const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);

const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
Expand Down Expand Up @@ -6635,7 +6660,7 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void

int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);

const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);

vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);

Expand Down Expand Up @@ -7138,9 +7163,9 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
aux16[1] = (a[0] >> 4) & 0x0f0f;

const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
sum_mins += y[i].d * (float)x[i].d[1] * summi;
sum_mins += y[i].d * LM_GGML_FP16_TO_FP32(x[i].d[1]) * summi;

const float d = y[i].d * (float)x[i].d[0];
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d[0]);

const lm_ggml_uint8x16x2_t q4bits = lm_ggml_vld1q_u8_x2(q4);

Expand Down Expand Up @@ -7798,7 +7823,7 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
const int8_t * sc = x[i].scales;

const uint8_t * restrict q5 = x[i].qs;
Expand Down Expand Up @@ -7940,7 +7965,7 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
const int8_t * sc = x[i].scales;

const uint8_t * restrict q5 = x[i].qs;
Expand Down Expand Up @@ -8508,7 +8533,7 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d_all = (float)x[i].d;
const float d_all = LM_GGML_FP16_TO_FP32(x[i].d);

const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
Expand Down Expand Up @@ -8679,7 +8704,7 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void

for (int i = 0; i < nb; ++i) {

const float d_all = (float)x[i].d;
const float d_all = LM_GGML_FP16_TO_FP32(x[i].d);

const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
Expand Down Expand Up @@ -9333,7 +9358,7 @@ void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs,
uint16_t gindex[8];
uint16x8x2_t vindex;
int8x16x4_t q1b;
int8x16x4_t q8b;
lm_ggml_int8x16x4_t q8b;
uint16x8x4_t scales;
int32x4x2_t sumi;
int32x4x2_t dotq;
Expand Down Expand Up @@ -9498,24 +9523,24 @@ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const voi
float sumf = 0;

for (int ib = 0; ib < nb; ib += 2) {

q4bits.val[0] = vld1q_u8(x[ib+0].qs);
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
q8b.val[0] = vld1q_s8(y[ib+0].qs);
q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
q8b.val[2] = vld1q_s8(y[ib+1].qs);
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);

q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b));
q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b));
q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));

prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);

sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2);

sumf +=
LM_GGML_FP16_TO_FP32(x[ib+0].d) * LM_GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
LM_GGML_FP16_TO_FP32(x[ib+1].d) * LM_GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
}

*s = sumf;
Expand Down
6 changes: 3 additions & 3 deletions cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ float lm_ggml_table_f32_f16[1 << 16];
// note: do not use these inside ggml.c
// these are meant to be used via the ggml.h API
float lm_ggml_fp16_to_fp32(lm_ggml_fp16_t x) {
return (float) LM_GGML_FP16_TO_FP32(x);
return LM_GGML_FP16_TO_FP32(x);
}

lm_ggml_fp16_t lm_ggml_fp32_to_fp16(float x) {
Expand Down Expand Up @@ -798,7 +798,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define LM_GGML_F16x8 float16x8_t
#define LM_GGML_F16x8_ZERO vdupq_n_f16(0.0f)
#define LM_GGML_F16x8_SET1(x) vdupq_n_f16(x)
#define LM_GGML_F16x8_LOAD vld1q_f16
#define LM_GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
#define LM_GGML_F16x8_STORE vst1q_f16
#define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
#define LM_GGML_F16x8_ADD vaddq_f16
Expand Down Expand Up @@ -841,7 +841,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define LM_GGML_F32Cx4 float32x4_t
#define LM_GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
#define LM_GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
#define LM_GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
#define LM_GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
#define LM_GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
#define LM_GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define LM_GGML_F32Cx4_ADD vaddq_f32
Expand Down
6 changes: 0 additions & 6 deletions cpp/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,7 @@
extern "C" {
#endif

#if defined(__ARM_NEON) && defined(__CUDACC__)
typedef half lm_ggml_fp16_t;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
typedef __fp16 lm_ggml_fp16_t;
#else
typedef uint16_t lm_ggml_fp16_t;
#endif

// convert FP16 <-> FP32
LM_GGML_API float lm_ggml_fp16_to_fp32(lm_ggml_fp16_t x);
Expand Down
Loading

0 comments on commit 7bdb22d

Please sign in to comment.