From a57171e9d54e456aef2c3dfe12e53eedca732972 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 19 Dec 2023 08:20:16 +0800 Subject: [PATCH] feat: sync llama.cpp --- cpp/common.cpp | 13 + cpp/common.h | 11 - cpp/ggml-alloc.h | 2 +- cpp/ggml-metal-llama.metal | 1679 ++++++++++++++++++++++++++++++++---- cpp/ggml-metal.m | 589 +++++++++++-- cpp/ggml-quants.c | 4 +- cpp/ggml.c | 701 +++++++++++---- cpp/ggml.h | 59 +- cpp/llama.cpp | 723 ++++++++++++---- cpp/llama.h | 3 +- cpp/log.h | 8 +- llama.cpp | 2 +- scripts/bootstrap.sh | 1 + scripts/common.cpp.patch | 24 +- scripts/common.h.patch | 20 + scripts/ggml-metal.m.patch | 6 +- scripts/llama.cpp.patch | 8 +- 17 files changed, 3257 insertions(+), 596 deletions(-) create mode 100644 scripts/common.h.patch diff --git a/cpp/common.cpp b/cpp/common.cpp index dcde2be..f8f9e02 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -42,6 +42,12 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +// build info +int LLAMA_BUILD_NUMBER = 0; +char const *LLAMA_COMMIT = "unknown"; +char const *LLAMA_COMPILER = "unknown"; +char const *LLAMA_BUILD_TARGET = "unknown"; + int32_t get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores @@ -656,6 +662,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } else if (arg == "-h" || arg == "--help") { return false; + } else if (arg == "--version") { + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; } else if (arg == "--in-prefix-bos") { @@ -794,6 +804,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf("\n"); printf("options:\n"); printf(" -h, --help show this help message and exit\n"); + printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); @@ -1385,6 +1396,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { const llama_sampling_params & sparams = params.sparams; + fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); + fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false"); fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false"); fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false"); diff --git a/cpp/common.h b/cpp/common.h index e87ce11..254df73 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -26,17 +26,6 @@ #define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) -#define print_build_info() do { \ - fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ - fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ -} while(0) - -// build info -extern int LLAMA_BUILD_NUMBER; -extern char const *LLAMA_COMMIT; -extern char const *LLAMA_COMPILER; -extern char const *LLAMA_BUILD_TARGET; - // // CLI argument parsing // diff --git a/cpp/ggml-alloc.h b/cpp/ggml-alloc.h index 8f5f880..fa3cd61 100644 --- a/cpp/ggml-alloc.h +++ b/cpp/ggml-alloc.h @@ -43,7 +43,7 @@ LM_GGML_API size_t lm_ggml_allocr_alloc_graph(lm_ggml_allocr_t alloc, struct lm_ // ggml-backend v2 API // -// Seperate tensor and graph allocator objects +// Separate tensor and graph allocator objects // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators // The original API is kept as a wrapper around the new API diff --git a/cpp/ggml-metal-llama.metal b/cpp/ggml-metal-llama.metal index 2f8ea22..d5b54e1 100644 --- a/cpp/ggml-metal-llama.metal +++ b/cpp/ggml-metal-llama.metal @@ -79,6 +79,7 @@ kernel void kernel_add( constant int64_t & nb1, constant int64_t & nb2, constant int64_t & nb3, + constant int64_t & offs, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -90,9 +91,9 @@ kernel void kernel_add( const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { const int i10 = i0 % ne10; @@ -204,7 +205,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -213,7 +214,7 @@ kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig % nb]; } @@ -222,7 +223,7 @@ kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -243,19 +244,53 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } -kernel void kernel_silu( - device const float4 * src0, - device float4 * dst, +kernel void kernel_relu( + device const float * src0, + device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); + dst[tpig] = max(0.0f, src0[tpig]); } -kernel void kernel_relu( +kernel void kernel_tanh( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); } kernel void kernel_sqr( @@ -313,22 +348,6 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } -constant float GELU_COEF_A = 0.044715f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - kernel void kernel_soft_max( device const float * src0, device const float * src1, @@ -347,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max float lmax = -INFINITY; @@ -385,7 +404,12 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + float sum = simd_sum(lsum); + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; @@ -428,9 +452,9 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float4 lmax4 = -INFINITY; @@ -468,7 +492,13 @@ kernel void kernel_soft_max_4( } const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + float sum = simd_sum(lsum); + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; @@ -639,6 +669,94 @@ kernel void kernel_rms_norm( } } +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // giard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. template -void mul_vec_q_n_f32( +void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, @@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } #define NB_Q8_0 8 -kernel void kernel_mul_mv_q8_0_f32( +void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32( } } +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + #define N_F32_F32 4 -kernel void kernel_mul_mv_f32_f32( +void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, @@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32( } } +[[host_name("kernel_mul_mv_f32_f32")]] +kernel void kernel_mul_mv_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + #define N_F16_F16 4 kernel void kernel_mul_mv_f16_f16( @@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16( } } -kernel void kernel_mul_mv_f16_f32_1row( +void kernel_mul_mv_f16_f32_1row_impl( device const char * src0, device const char * src1, device float * dst, @@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row( dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } +} +[[host_name("kernel_mul_mv_f16_f32_1row")]] +kernel void kernel_mul_mv_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } #define N_F16_F32 4 -kernel void kernel_mul_mv_f16_f32( +void kernel_mul_mv_f16_f32_impl( device const char * src0, device const char * src1, device float * dst, @@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32( } } +[[host_name("kernel_mul_mv_f16_f32")]] +kernel void kernel_mul_mv_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + // Assumes row size (ne00) is a multiple of 4 kernel void kernel_mul_mv_f16_f32_l4( device const char * src0, @@ -1487,8 +1702,9 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; @@ -1507,6 +1723,14 @@ kernel void kernel_rope( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -1548,21 +1772,112 @@ kernel void kernel_im2col_f16( } } -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & sf, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1/sf; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = src0_ptr[i0/sf]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; @@ -1600,9 +1915,17 @@ kernel void kernel_argsort_f32_i32( template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, + device const half * src0, + device half * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -1641,6 +1964,47 @@ kernel void kernel_cpy_f16_f16( } } +kernel void kernel_cpy_f16_f32( + device const half * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_f32_f16( device const float * src0, device half * dst, @@ -1917,9 +2281,9 @@ kernel void kernel_cpy_f32_q4_1( } kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -1956,7 +2320,7 @@ kernel void kernel_concat( const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; @@ -2064,19 +2428,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //====================================== dot products ========================= -kernel void kernel_mul_mv_q2_K_f32( +void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2214,8 +2578,8 @@ kernel void kernel_mul_mv_q2_K_f32( } } -#if QK_K == 256 -kernel void kernel_mul_mv_q3_K_f32( +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2229,8 +2593,29 @@ kernel void kernel_mul_mv_q3_K_f32( constant uint & r2 [[buffer(17)]], constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; @@ -2373,19 +2758,19 @@ kernel void kernel_mul_mv_q3_K_f32( } } #else -kernel void kernel_mul_mv_q3_K_f32( +void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2450,20 +2835,41 @@ kernel void kernel_mul_mv_q3_K_f32( } #endif +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + #if QK_K == 256 -kernel void kernel_mul_mv_q4_K_f32( +void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01 [[buffer(4)]], - constant int64_t & ne02 [[buffer(5)]], - constant int64_t & ne10 [[buffer(9)]], - constant int64_t & ne12 [[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2564,19 +2970,19 @@ kernel void kernel_mul_mv_q4_K_f32( } } #else -kernel void kernel_mul_mv_q4_K_f32( +void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2660,7 +3066,8 @@ kernel void kernel_mul_mv_q4_K_f32( } #endif -kernel void kernel_mul_mv_q5_K_f32( +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2677,6 +3084,26 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -2836,10 +3263,10 @@ kernel void kernel_mul_mv_q5_K_f32( dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } - } -kernel void kernel_mul_mv_q6_K_f32( +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2853,21 +3280,41 @@ kernel void kernel_mul_mv_q6_K_f32( constant uint & r2 [[buffer(17)]], constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int row = 2 * r0 + sgitg; + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2 * r0 + sgitg; const uint i12 = im%ne12; const uint i13 = im/ne12; @@ -2945,6 +3392,27 @@ kernel void kernel_mul_mv_q6_K_f32( } } +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -3062,10 +3530,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const half d = xb->d; - const half min = xb->dmin; + const float d = xb->d; + const float min = xb->dmin; device const uint8_t * q = (device const uint8_t *)xb->qs; - half dl, ml; + float dl, ml; uint8_t sc = xb->scales[il]; #if QK_K == 256 @@ -3135,10 +3603,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg q = q + (il/4) * 32 + 16 * (il&1); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; #else q = q + 16 * (il&1); device const uint8_t * s = xb->scales; @@ -3165,13 +3633,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; - const half qh_val = il<2 ? 16.h : 256.h; + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; } @@ -3219,22 +3687,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg template kernel void kernel_get_rows( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg[[threads_per_threadgroup]]) { - const int i = tgpig; - const int r = ((device int32_t *) src1)[i]; + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - for (int ind = tiitg; ind < ne00/16; ind += tptg) { + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { float4x4 temp; dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +kernel void kernel_get_rows_f32( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; } } @@ -3426,19 +3962,22 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -3456,10 +3995,16 @@ kernel void kernel_mul_mm_id( uint sgitg[[simdgroup_index_in_threadgroup]]) { device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + kernel_mul_mm_impl( - src0[ids[idx]], - src1, - dst, + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), ne00, ne02, nb01, @@ -3484,17 +4029,26 @@ kernel void kernel_mul_mm_id( #define QK_NL 4 #endif +// +// get rows +// + typedef void (get_rows_t)( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint, uint, uint); + constant uint64_t & nb2, + uint3, uint, uint3); -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; @@ -3506,6 +4060,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +// +// matrix-matrix multiplication +// + typedef void (mat_mm_t)( device const uchar * src0, device const uchar * src1, @@ -3538,20 +4096,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +// +// indirect matrix-matrix multiplication +// + typedef void (mat_mm_id_t)( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -3578,3 +4143,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +[[host_name("kernel_mul_mv_id_f32_f32")]] +kernel void kernel_mul_mv_id_f32_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f32_f32_impl( + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_f16_f32")]] +kernel void kernel_mul_mv_id_f16_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f16_f32_impl( + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_q8_0_f32")]] +kernel void kernel_mul_mv_id_q8_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q8_0_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_0_f32")]] +kernel void kernel_mul_mv_id_q4_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_1_f32")]] +kernel void kernel_mul_mv_id_q4_1_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_0_f32")]] +kernel void kernel_mul_mv_id_q5_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_1_f32")]] +kernel void kernel_mul_mv_id_q5_1_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q2_K_f32")]] +kernel void kernel_mul_mv_id_q2_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q2_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q3_K_f32")]] +kernel void kernel_mul_mv_id_q3_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q3_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_K_f32")]] +kernel void kernel_mul_mv_id_q4_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q4_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_K_f32")]] +kernel void kernel_mul_mv_id_q5_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q5_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q6_K_f32")]] +kernel void kernel_mul_mv_id_q6_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q6_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 32c1019..d4c6e0d 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -66,9 +66,11 @@ LM_GGML_METAL_DECL_KERNEL(div_row); LM_GGML_METAL_DECL_KERNEL(scale); LM_GGML_METAL_DECL_KERNEL(scale_4); - LM_GGML_METAL_DECL_KERNEL(silu); + LM_GGML_METAL_DECL_KERNEL(tanh); LM_GGML_METAL_DECL_KERNEL(relu); LM_GGML_METAL_DECL_KERNEL(gelu); + LM_GGML_METAL_DECL_KERNEL(gelu_quick); + LM_GGML_METAL_DECL_KERNEL(silu); LM_GGML_METAL_DECL_KERNEL(soft_max); LM_GGML_METAL_DECL_KERNEL(soft_max_4); LM_GGML_METAL_DECL_KERNEL(diag_mask_inf); @@ -86,6 +88,7 @@ LM_GGML_METAL_DECL_KERNEL(get_rows_q5_K); LM_GGML_METAL_DECL_KERNEL(get_rows_q6_K); LM_GGML_METAL_DECL_KERNEL(rms_norm); + LM_GGML_METAL_DECL_KERNEL(group_norm); LM_GGML_METAL_DECL_KERNEL(norm); LM_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); @@ -102,6 +105,21 @@ LM_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); + //LM_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); + //LM_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row); + //LM_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -130,8 +148,11 @@ LM_GGML_METAL_DECL_KERNEL(rope_f16); LM_GGML_METAL_DECL_KERNEL(alibi_f32); LM_GGML_METAL_DECL_KERNEL(im2col_f16); + LM_GGML_METAL_DECL_KERNEL(upscale_f32); + LM_GGML_METAL_DECL_KERNEL(pad_f32); LM_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc); LM_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc); + LM_GGML_METAL_DECL_KERNEL(leaky_relu_f32); LM_GGML_METAL_DECL_KERNEL(cpy_f32_f16); LM_GGML_METAL_DECL_KERNEL(cpy_f32_f32); LM_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); @@ -140,6 +161,7 @@ //LM_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); //LM_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); LM_GGML_METAL_DECL_KERNEL(cpy_f16_f16); + LM_GGML_METAL_DECL_KERNEL(cpy_f16_f32); LM_GGML_METAL_DECL_KERNEL(concat); LM_GGML_METAL_DECL_KERNEL(sqr); LM_GGML_METAL_DECL_KERNEL(sum_rows); @@ -177,6 +199,8 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, lm_ggml_metal_log_callback(level, buffer, lm_ggml_metal_log_user_data); } else { char* buffer2 = malloc(len+1); + va_end(args); + va_start(args, format); vsnprintf(buffer2, len+1, format, args); buffer2[len] = 0; lm_ggml_metal_log_callback(level, buffer2, lm_ggml_metal_log_user_data); @@ -316,9 +340,11 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(div_row); LM_GGML_METAL_ADD_KERNEL(scale); LM_GGML_METAL_ADD_KERNEL(scale_4); - LM_GGML_METAL_ADD_KERNEL(silu); + LM_GGML_METAL_ADD_KERNEL(tanh); LM_GGML_METAL_ADD_KERNEL(relu); LM_GGML_METAL_ADD_KERNEL(gelu); + LM_GGML_METAL_ADD_KERNEL(gelu_quick); + LM_GGML_METAL_ADD_KERNEL(silu); LM_GGML_METAL_ADD_KERNEL(soft_max); LM_GGML_METAL_ADD_KERNEL(soft_max_4); LM_GGML_METAL_ADD_KERNEL(diag_mask_inf); @@ -336,6 +362,7 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(get_rows_q5_K); LM_GGML_METAL_ADD_KERNEL(get_rows_q6_K); LM_GGML_METAL_ADD_KERNEL(rms_norm); + LM_GGML_METAL_ADD_KERNEL(group_norm); LM_GGML_METAL_ADD_KERNEL(norm); LM_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); @@ -352,6 +379,21 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); + //LM_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); + //LM_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row); + //LM_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { LM_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); @@ -382,8 +424,11 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, LM_GGML_METAL_ADD_KERNEL(rope_f16); LM_GGML_METAL_ADD_KERNEL(alibi_f32); LM_GGML_METAL_ADD_KERNEL(im2col_f16); + LM_GGML_METAL_ADD_KERNEL(upscale_f32); + LM_GGML_METAL_ADD_KERNEL(pad_f32); LM_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc); LM_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc); + LM_GGML_METAL_ADD_KERNEL(leaky_relu_f32); LM_GGML_METAL_ADD_KERNEL(cpy_f32_f16); LM_GGML_METAL_ADD_KERNEL(cpy_f32_f32); LM_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); @@ -392,6 +437,7 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, //LM_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); //LM_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); LM_GGML_METAL_ADD_KERNEL(cpy_f16_f16); + LM_GGML_METAL_ADD_KERNEL(cpy_f16_f32); LM_GGML_METAL_ADD_KERNEL(concat); LM_GGML_METAL_ADD_KERNEL(sqr); LM_GGML_METAL_ADD_KERNEL(sum_rows); @@ -416,9 +462,11 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(div_row); LM_GGML_METAL_DEL_KERNEL(scale); LM_GGML_METAL_DEL_KERNEL(scale_4); - LM_GGML_METAL_DEL_KERNEL(silu); + LM_GGML_METAL_DEL_KERNEL(tanh); LM_GGML_METAL_DEL_KERNEL(relu); LM_GGML_METAL_DEL_KERNEL(gelu); + LM_GGML_METAL_DEL_KERNEL(gelu_quick); + LM_GGML_METAL_DEL_KERNEL(silu); LM_GGML_METAL_DEL_KERNEL(soft_max); LM_GGML_METAL_DEL_KERNEL(soft_max_4); LM_GGML_METAL_DEL_KERNEL(diag_mask_inf); @@ -436,6 +484,7 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(get_rows_q5_K); LM_GGML_METAL_DEL_KERNEL(get_rows_q6_K); LM_GGML_METAL_DEL_KERNEL(rms_norm); + LM_GGML_METAL_DEL_KERNEL(group_norm); LM_GGML_METAL_DEL_KERNEL(norm); LM_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); @@ -452,6 +501,21 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); + //LM_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); + //LM_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row); + //LM_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { LM_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); @@ -482,8 +546,11 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(rope_f16); LM_GGML_METAL_DEL_KERNEL(alibi_f32); LM_GGML_METAL_DEL_KERNEL(im2col_f16); + LM_GGML_METAL_DEL_KERNEL(upscale_f32); + LM_GGML_METAL_DEL_KERNEL(pad_f32); LM_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc); LM_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc); + LM_GGML_METAL_DEL_KERNEL(leaky_relu_f32); LM_GGML_METAL_DEL_KERNEL(cpy_f32_f16); LM_GGML_METAL_DEL_KERNEL(cpy_f32_f32); LM_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); @@ -492,6 +559,7 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { //LM_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); //LM_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); LM_GGML_METAL_DEL_KERNEL(cpy_f16_f16); + LM_GGML_METAL_DEL_KERNEL(cpy_f16_f32); LM_GGML_METAL_DEL_KERNEL(concat); LM_GGML_METAL_DEL_KERNEL(sqr); LM_GGML_METAL_DEL_KERNEL(sum_rows); @@ -793,9 +861,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_tensor * op) { switch (op->op) { case LM_GGML_OP_UNARY: switch (lm_ggml_get_unary_op(op)) { - case LM_GGML_UNARY_OP_SILU: + case LM_GGML_UNARY_OP_TANH: case LM_GGML_UNARY_OP_RELU: case LM_GGML_UNARY_OP_GELU: + case LM_GGML_UNARY_OP_GELU_QUICK: + case LM_GGML_UNARY_OP_SILU: return true; default: return false; @@ -807,6 +877,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_tensor * op) { case LM_GGML_OP_PERMUTE: case LM_GGML_OP_CONCAT: case LM_GGML_OP_ADD: + case LM_GGML_OP_ACC: case LM_GGML_OP_MUL: case LM_GGML_OP_DIV: case LM_GGML_OP_SCALE: @@ -814,21 +885,50 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_tensor * op) { case LM_GGML_OP_SUM_ROWS: case LM_GGML_OP_SOFT_MAX: case LM_GGML_OP_RMS_NORM: + case LM_GGML_OP_GROUP_NORM: case LM_GGML_OP_NORM: case LM_GGML_OP_ALIBI: case LM_GGML_OP_ROPE: case LM_GGML_OP_IM2COL: + case LM_GGML_OP_UPSCALE: + case LM_GGML_OP_PAD: case LM_GGML_OP_ARGSORT: - case LM_GGML_OP_DUP: - case LM_GGML_OP_CPY: - case LM_GGML_OP_CONT: + case LM_GGML_OP_LEAKY_RELU: case LM_GGML_OP_MUL_MAT: case LM_GGML_OP_MUL_MAT_ID: return true; + case LM_GGML_OP_CPY: + case LM_GGML_OP_DUP: + case LM_GGML_OP_CONT: + { + switch (op->src[0]->type) { + case LM_GGML_TYPE_F32: + switch (op->type) { + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + return true; + default: + return false; + } + case LM_GGML_TYPE_F16: + switch (op->type) { + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_F32: + return true; + default: + return false; + } + default: + return false; + }; + } case LM_GGML_OP_DIAG_MASK_INF: case LM_GGML_OP_GET_ROWS: { - return op->ne[0] % 4 == 0; + return op->ne[3] == 1; } default: return false; @@ -904,7 +1004,10 @@ void lm_ggml_metal_graph_compute( } break; } - LM_GGML_ASSERT(lm_ggml_metal_supports_op(dst)); + if (!lm_ggml_metal_supports_op(dst)) { + LM_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, lm_ggml_op_desc(dst)); + LM_GGML_ASSERT(!"unsupported op"); + } const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; @@ -1001,34 +1104,39 @@ void lm_ggml_metal_graph_compute( case LM_GGML_OP_MUL: case LM_GGML_OP_DIV: { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + const size_t offs = 0; bool bcast_row = false; int64_t nb = ne00; - if (lm_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { + id pipeline = nil; + + if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + // src1 is a row LM_GGML_ASSERT(ne11 == 1); nb = ne00 / 4; switch (dst->op) { - case LM_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break; - case LM_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break; - case LM_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break; + case LM_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break; + case LM_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break; + case LM_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break; default: LM_GGML_ASSERT(false); } bcast_row = true; } else { switch (dst->op) { - case LM_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break; - case LM_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break; - case LM_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break; + case LM_GGML_OP_ADD: pipeline = ctx->pipeline_add; break; + case LM_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break; + case LM_GGML_OP_DIV: pipeline = ctx->pipeline_div; break; default: LM_GGML_ASSERT(false); } } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1056,18 +1164,99 @@ void lm_ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; if (bcast_row) { const int64_t n = lm_ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } else { - const int nth = MIN(1024, ne0); + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case LM_GGML_OP_ACC: + { + LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(dstt == LM_GGML_TYPE_F32); + + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const int nth = MIN(1024, ne00); + + [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + [encoder setComputePipelineState:ctx->pipeline_add]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case LM_GGML_OP_SCALE: { LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); @@ -1091,16 +1280,15 @@ void lm_ggml_metal_graph_compute( } break; case LM_GGML_OP_UNARY: switch (lm_ggml_get_unary_op(gf->nodes[i])) { - case LM_GGML_UNARY_OP_SILU: + case LM_GGML_UNARY_OP_TANH: { - [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setComputePipelineState:ctx->pipeline_tanh]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = lm_ggml_nelements(dst); - LM_GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case LM_GGML_UNARY_OP_RELU: { @@ -1121,6 +1309,28 @@ void lm_ggml_metal_graph_compute( const int64_t n = lm_ggml_nelements(dst); LM_GGML_ASSERT(n % 4 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_GELU_QUICK: + { + [encoder setComputePipelineState:ctx->pipeline_gelu_quick]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = lm_ggml_nelements(dst); + LM_GGML_ASSERT(n % 4 == 0); + + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case LM_GGML_UNARY_OP_SILU: + { + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = lm_ggml_nelements(dst); + LM_GGML_ASSERT(n % 4 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: @@ -1193,7 +1403,11 @@ void lm_ggml_metal_graph_compute( const float scale = ((float *) dst->op_params)[0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; @@ -1444,7 +1658,7 @@ void lm_ggml_metal_graph_compute( else if (src0t == LM_GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + nrows - 1)/nrows; + const int64_t ny = (ne11 + nrows - 1)/nrows; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } @@ -1456,7 +1670,7 @@ void lm_ggml_metal_graph_compute( LM_GGML_ASSERT(src0t == LM_GGML_TYPE_I32); - const int n_as = ne00; + const int n_as = ((int32_t *) dst->op_params)[1]; // TODO: make this more general LM_GGML_ASSERT(n_as <= 8); @@ -1488,14 +1702,22 @@ void lm_ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 0; + int ne11_mm_min = 1; const int idx = ((int32_t *) dst->op_params)[0]; + // batch size + LM_GGML_ASSERT(ne01 == ne11); + + const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne11 > ne11_mm_min) { + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { switch (src2->type) { case LM_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; @@ -1514,19 +1736,22 @@ void lm_ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; // TODO: how to make this an array? read Metal docs for (int j = 0; j < n_as; ++j) { struct lm_ggml_tensor * src_cur = dst->src[2 + j]; @@ -1534,11 +1759,157 @@ void lm_ggml_metal_graph_compute( size_t offs_src_cur = 0; id id_src_cur = lm_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j]; + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; } [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + // TODO: processing one row at a time (ne11 -> 1) is not efficient + [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // use custom matrix x vector kernel + switch (src2t) { + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32]; + } break; + case LM_GGML_TYPE_F16: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32]; + } break; + case LM_GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32]; + } break; + case LM_GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32]; + } break; + case LM_GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32]; + } break; + case LM_GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32]; + } break; + case LM_GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32]; + } break; + case LM_GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32]; + } break; + case LM_GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32]; + } break; + case LM_GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32]; + } break; + case LM_GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32]; + } break; + case LM_GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32]; + } break; + default: + { + LM_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + LM_GGML_ASSERT(false && "not implemented"); + } + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; + // TODO: how to make this an array? read Metal docs + for (int j = 0; j < n_as; ++j) { + struct lm_ggml_tensor * src_cur = dst->src[2 + j]; + + size_t offs_src_cur = 0; + id id_src_cur = lm_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); + + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; + } + + if (src2t == LM_GGML_TYPE_Q4_0 || src2t == LM_GGML_TYPE_Q4_1 || + src2t == LM_GGML_TYPE_Q5_0 || src2t == LM_GGML_TYPE_Q5_1 || src2t == LM_GGML_TYPE_Q8_0 || + src2t == LM_GGML_TYPE_Q2_K) { // || src2t == LM_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == LM_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == LM_GGML_TYPE_Q3_K) { +#ifdef LM_GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src2t == LM_GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == LM_GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } } } break; case LM_GGML_OP_GET_ROWS: @@ -1559,16 +1930,19 @@ void lm_ggml_metal_graph_compute( default: LM_GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; - - const int64_t n = lm_ggml_nelements(src1); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case LM_GGML_OP_RMS_NORM: { @@ -1595,6 +1969,38 @@ void lm_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case LM_GGML_OP_GROUP_NORM: + { + LM_GGML_ASSERT(ne00 % 4 == 0); + + //float eps; + //memcpy(&eps, dst->op_params, sizeof(float)); + + const float eps = 1e-6f; // TODO: temporarily hardcoded + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + [encoder setComputePipelineState:ctx->pipeline_group_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case LM_GGML_OP_NORM: { float eps; @@ -1764,6 +2170,65 @@ void lm_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } break; + case LM_GGML_OP_UPSCALE: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + const int sf = dst->op_params[0]; + + [encoder setComputePipelineState:ctx->pipeline_upscale_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_PAD: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + [encoder setComputePipelineState:ctx->pipeline_pad_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case LM_GGML_OP_ARGSORT: { LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); @@ -1785,6 +2250,22 @@ void lm_ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; } break; + case LM_GGML_OP_LEAKY_RELU: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = lm_ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case LM_GGML_OP_DUP: case LM_GGML_OP_CPY: case LM_GGML_OP_CONT: @@ -1813,7 +2294,7 @@ void lm_ggml_metal_graph_compute( { switch (dstt) { case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case LM_GGML_TYPE_F32: LM_GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + case LM_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break; default: LM_GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 30bebbf..260fca0 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -3114,7 +3114,7 @@ void lm_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * res size_t vl = __riscv_vsetvl_e8m1(qk/2); - // These tempory registers are for masking and shift operations + // These temporary registers are for masking and shift operations vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); @@ -4757,7 +4757,7 @@ void lm_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * res vl = 16; - // retreive lane to multiply with scale + // retrieve lane to multiply with scale vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); diff --git a/cpp/ggml.c b/cpp/ggml.c index 2dda9e7..dbc6be8 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1,4 +1,4 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC #include "ggml-impl.h" @@ -33,7 +33,7 @@ // we should just be careful :) #pragma warning(disable: 4244 4267) -// disable POSIX deprecation warnigns +// disable POSIX deprecation warnings // these functions are never going away, anyway #pragma warning(disable: 4996) #endif @@ -1395,7 +1395,7 @@ inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x inline static void lm_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void lm_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void lm_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -inline static void lm_ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; } +inline static void lm_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -1623,7 +1623,9 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "POOL_1D", "POOL_2D", "UPSCALE", + "PAD", "ARGSORT", + "LEAKY_RELU", "FLASH_ATTN", "FLASH_FF", @@ -1650,7 +1652,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(LM_GGML_OP_COUNT == 70, "LM_GGML_OP_COUNT != 70"); +static_assert(LM_GGML_OP_COUNT == 72, "LM_GGML_OP_COUNT != 72"); static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "none", @@ -1707,7 +1709,9 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "pool_1d(x)", "pool_2d(x)", "upscale(x)", + "pad(x)", "argsort(x)", + "leaky_relu(x)", "flash_attn(x)", "flash_ff(x)", @@ -1734,7 +1738,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(LM_GGML_OP_COUNT == 70, "LM_GGML_OP_COUNT != 70"); +static_assert(LM_GGML_OP_COUNT == 72, "LM_GGML_OP_COUNT != 72"); static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); @@ -1750,17 +1754,16 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = { "GELU", "GELU_QUICK", "SILU", - "LEAKY", }; -static_assert(LM_GGML_UNARY_OP_COUNT == 11, "LM_GGML_UNARY_OP_COUNT != 11"); +static_assert(LM_GGML_UNARY_OP_COUNT == 10, "LM_GGML_UNARY_OP_COUNT != 10"); static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN"); static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN"); // WARN: -// Mis-confguration can lead to problem that's hard to reason about: +// Mis-configuration can lead to problem that's hard to reason about: // * At best it crash or talks nosense. // * At worst it talks slightly difference but hard to perceive. // @@ -1994,12 +1997,6 @@ size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) { return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN); } -size_t lm_ggml_nbytes_split(const struct lm_ggml_tensor * tensor, int nrows_split) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - - return (nrows_split*tensor->ne[0]*lm_ggml_type_size(tensor->type))/lm_ggml_blck_size(tensor->type); -} - int lm_ggml_blck_size(enum lm_ggml_type type) { return type_traits[type].blck_size; } @@ -2008,8 +2005,13 @@ size_t lm_ggml_type_size(enum lm_ggml_type type) { return type_traits[type].type_size; } -float lm_ggml_type_sizef(enum lm_ggml_type type) { - return ((float)(type_traits[type].type_size))/type_traits[type].blck_size; +size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) { + assert(ne % lm_ggml_blck_size(type) == 0); + return lm_ggml_type_size(type)*ne/lm_ggml_blck_size(type); +} + +double lm_ggml_type_sizef(enum lm_ggml_type type) { + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; } const char * lm_ggml_type_name(enum lm_ggml_type type) { @@ -2046,24 +2048,37 @@ size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { return lm_ggml_type_size(tensor->type); } -static inline bool lm_ggml_is_scalar(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_scalar(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } -static inline bool lm_ggml_is_vector(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_vector(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } -static inline bool lm_ggml_is_matrix(const struct lm_ggml_tensor * tensor) { +bool lm_ggml_is_matrix(const struct lm_ggml_tensor * tensor) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[2] == 1 && tensor->ne[3] == 1; } +bool lm_ggml_is_3d(const struct lm_ggml_tensor * tensor) { + return tensor->ne[3] == 1; +} + +int lm_ggml_n_dims(const struct lm_ggml_tensor * tensor) { + for (int i = LM_GGML_MAX_DIMS - 1; i >= 1; --i) { + if (tensor->ne[i] > 1) { + return i + 1; + } + } + return 1; +} + static inline bool lm_ggml_can_mul_mat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); @@ -2470,7 +2485,7 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( view_src = view_src->view_src; } - size_t data_size = lm_ggml_type_size(type)*(ne[0]/lm_ggml_blck_size(type)); + size_t data_size = lm_ggml_row_size(type, ne[0]); for (int i = 1; i < n_dims; i++) { data_size *= ne[i]; } @@ -2513,7 +2528,6 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( /*.type =*/ type, /*.backend =*/ LM_GGML_BACKEND_CPU, /*.buffer =*/ NULL, - /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ LM_GGML_OP_NONE, @@ -2620,7 +2634,7 @@ struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float valu } struct lm_ggml_tensor * lm_ggml_dup_tensor(struct lm_ggml_context * ctx, const struct lm_ggml_tensor * src) { - return lm_ggml_new_tensor(ctx, src->type, src->n_dims, src->ne); + return lm_ggml_new_tensor(ctx, src->type, LM_GGML_MAX_DIMS, src->ne); } static void lm_ggml_set_op_params(struct lm_ggml_tensor * tensor, const void * params, size_t params_size) { @@ -3069,7 +3083,7 @@ struct lm_ggml_tensor * lm_ggml_format_name(struct lm_ggml_tensor * tensor, cons struct lm_ggml_tensor * lm_ggml_view_tensor( struct lm_ggml_context * ctx, struct lm_ggml_tensor * src) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, src->type, LM_GGML_MAX_DIMS, src->ne, src, 0); lm_ggml_format_name(result, "%s (view)", src->name); for (int i = 0; i < LM_GGML_MAX_DIMS; i++) { @@ -3227,10 +3241,10 @@ static struct lm_ggml_tensor * lm_ggml_add_cast_impl( is_node = true; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, a->n_dims, a->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); result->op = LM_GGML_OP_ADD; - result->grad = is_node ? lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, a->n_dims, a->ne) : NULL; + result->grad = is_node ? lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, a->ne) : NULL; result->src[0] = a; result->src[1] = b; @@ -3599,12 +3613,12 @@ struct lm_ggml_tensor * lm_ggml_sum_rows( is_node = true; } - int64_t ne[4] = {1,1,1,1}; - for (int i=1; in_dims; ++i) { + int64_t ne[LM_GGML_MAX_DIMS] = { 1 }; + for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { ne[i] = a->ne[i]; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, a->n_dims, ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); result->op = LM_GGML_OP_SUM_ROWS; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -3625,8 +3639,8 @@ struct lm_ggml_tensor * lm_ggml_mean( is_node = true; } - int64_t ne[LM_GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, a->n_dims, ne); + int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); result->op = LM_GGML_OP_MEAN; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -3648,8 +3662,7 @@ struct lm_ggml_tensor * lm_ggml_argmax( is_node = true; } - int64_t ne[LM_GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, a->n_dims, ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]); result->op = LM_GGML_OP_ARGMAX; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -3672,7 +3685,7 @@ struct lm_ggml_tensor * lm_ggml_repeat( is_node = true; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); result->op = LM_GGML_OP_REPEAT; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -3699,7 +3712,7 @@ struct lm_ggml_tensor * lm_ggml_repeat_back( return a; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); result->op = LM_GGML_OP_REPEAT_BACK; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -3830,12 +3843,25 @@ struct lm_ggml_tensor * lm_ggml_relu_inplace( return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_RELU); } -// lm_ggml_leaky +// lm_ggml_leaky_relu -struct lm_ggml_tensor * lm_ggml_leaky( +struct lm_ggml_tensor * lm_ggml_leaky_relu( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_LEAKY); + struct lm_ggml_tensor * a, float negative_slope, bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + lm_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); + + result->op = LM_GGML_OP_LEAKY_RELU; + result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; } // lm_ggml_gelu @@ -4022,8 +4048,9 @@ static struct lm_ggml_tensor * lm_ggml_group_norm_impl( struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_GROUP_NORM; result->op_params[0] = n_groups; + + result->op = LM_GGML_OP_GROUP_NORM; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = NULL; // TODO: maybe store epsilon here? @@ -4061,7 +4088,7 @@ struct lm_ggml_tensor * lm_ggml_mul_mat( } const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); result->op = LM_GGML_OP_MUL_MAT; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -4071,21 +4098,30 @@ struct lm_ggml_tensor * lm_ggml_mul_mat( return result; } +void lm_ggml_mul_mat_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec) { + const int32_t prec_i32 = (int32_t) prec; + + lm_ggml_set_op_params_i32(a, 0, prec_i32); +} + // lm_ggml_mul_mat_id struct lm_ggml_tensor * lm_ggml_mul_mat_id( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * as[], + struct lm_ggml_tensor * const as[], + int n_as, struct lm_ggml_tensor * ids, int id, struct lm_ggml_tensor * b) { - int64_t n_as = ids->ne[0]; - LM_GGML_ASSERT(ids->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(lm_ggml_is_vector(ids)); + LM_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); + LM_GGML_ASSERT(ids->ne[1] == b->ne[1]); + LM_GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); LM_GGML_ASSERT(n_as > 0 && n_as <= LM_GGML_MAX_SRC - 2); - LM_GGML_ASSERT(id >= 0 && id < n_as); + LM_GGML_ASSERT(id >= 0 && id < ids->ne[0]); bool is_node = false; @@ -4094,16 +4130,17 @@ struct lm_ggml_tensor * lm_ggml_mul_mat_id( } const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); lm_ggml_set_op_params_i32(result, 0, id); + lm_ggml_set_op_params_i32(result, 1, n_as); result->op = LM_GGML_OP_MUL_MAT_ID; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = ids; result->src[1] = b; - for (int64_t i = 0; i < n_as; i++) { + for (int i = 0; i < n_as; i++) { struct lm_ggml_tensor * a = as[i]; LM_GGML_ASSERT(lm_ggml_are_same_shape(as[0], a)); LM_GGML_ASSERT(lm_ggml_can_mul_mat(a, b)); @@ -4131,7 +4168,7 @@ struct lm_ggml_tensor * lm_ggml_out_prod( // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); result->op = LM_GGML_OP_OUT_PROD; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -4416,7 +4453,7 @@ struct lm_ggml_tensor * lm_ggml_reshape( //LM_GGML_ASSERT(false); } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, LM_GGML_MAX_DIMS, b->ne, a, 0); lm_ggml_format_name(result, "%s (reshaped)", a->name); result->op = LM_GGML_OP_RESHAPE; @@ -4731,7 +4768,9 @@ struct lm_ggml_tensor * lm_ggml_get_rows( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_is_matrix(a) && lm_ggml_is_vector(b) && b->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(a->ne[2] == b->ne[1]); + LM_GGML_ASSERT(b->ne[3] == 1); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); bool is_node = false; @@ -4741,7 +4780,7 @@ struct lm_ggml_tensor * lm_ggml_get_rows( // TODO: implement non F32 return //struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, a->ne[0], b->ne[0]); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = LM_GGML_OP_GET_ROWS; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -4792,7 +4831,7 @@ struct lm_ggml_tensor * lm_ggml_diag( } const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, 4, ne); result->op = LM_GGML_OP_DIAG; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -5439,7 +5478,7 @@ struct lm_ggml_tensor * lm_ggml_pool_1d( is_node = true; } - const int64_t ne[3] = { + const int64_t ne[2] = { lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], }; @@ -5519,6 +5558,30 @@ static struct lm_ggml_tensor * lm_ggml_upscale_impl( return result; } +struct lm_ggml_tensor * lm_ggml_pad( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, int p1, int p2, int p3) { + bool is_node = false; + + if (a->grad) { + LM_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0, + a->ne[1] + p1, + a->ne[2] + p2, + a->ne[3] + p3); + + result->op = LM_GGML_OP_PAD; + result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + struct lm_ggml_tensor * lm_ggml_upscale( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -5534,7 +5597,7 @@ struct lm_ggml_tensor * lm_ggml_argsort( enum lm_ggml_sort_order order) { bool is_node = false; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, a->n_dims, a->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne); lm_ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5581,7 +5644,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn( } //struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, q); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, q->n_dims, q->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, q->ne); int32_t t = masked ? 1 : 0; lm_ggml_set_op_params(result, &t, sizeof(t)); @@ -5614,7 +5677,7 @@ struct lm_ggml_tensor * lm_ggml_flash_ff( } //struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, a->n_dims, a->ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, a->ne); result->op = LM_GGML_OP_FLASH_FF; result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; @@ -5730,7 +5793,6 @@ struct lm_ggml_tensor * lm_ggml_win_part( const int np = npx*npy; const int64_t ne[4] = { a->ne[0], w, w, np, }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); int32_t params[] = { npx, npy, w }; @@ -7520,7 +7582,7 @@ static void lm_ggml_compute_forward_acc_f32( LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); // view src0 and dst with these strides and data offset inbytes during acc - // nb0 is implicitely element_size because src0 and dst are contiguous + // nb0 is implicitly element_size because src0 and dst are contiguous size_t nb1 = ((int32_t *) dst->op_params)[0]; size_t nb2 = ((int32_t *) dst->op_params)[1]; size_t nb3 = ((int32_t *) dst->op_params)[2]; @@ -7716,6 +7778,8 @@ static void lm_ggml_compute_forward_mul_f32( #ifdef LM_GGML_USE_CLBLAST if (src1->backend == LM_GGML_BACKEND_GPU) { + // TODO: OpenCL kernel support full broadcast + LM_GGML_ASSERT(lm_ggml_can_repeat_rows(src1, src0)); if (ith == 0) { lm_ggml_cl_mul(src0, src1, dst); } @@ -8981,10 +9045,9 @@ static void lm_ggml_compute_forward_silu( } break; } } +// lm_ggml_compute_forward_leaky_relu -// lm_ggml_compute_forward_leaky - -static void lm_ggml_compute_forward_leaky_f32( +static void lm_ggml_compute_forward_leaky_relu_f32( const struct lm_ggml_compute_params * params, const struct lm_ggml_tensor * src0, struct lm_ggml_tensor * dst) { @@ -8998,24 +9061,27 @@ static void lm_ggml_compute_forward_leaky_f32( const int n = lm_ggml_nrows(src0); const int nc = src0->ne[0]; + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + assert(dst->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float)); for (int i = 0; i < n; i++) { - lm_ggml_vec_leaky_f32(nc, + lm_ggml_vec_leaky_relu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); } } -static void lm_ggml_compute_forward_leaky( +static void lm_ggml_compute_forward_leaky_relu( const struct lm_ggml_compute_params * params, const struct lm_ggml_tensor * src0, struct lm_ggml_tensor * dst) { switch (src0->type) { case LM_GGML_TYPE_F32: { - lm_ggml_compute_forward_leaky_f32(params, src0, dst); + lm_ggml_compute_forward_leaky_relu_f32(params, src0, dst); } break; default: { @@ -9110,6 +9176,8 @@ static void lm_ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + LM_GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9179,6 +9247,8 @@ static void lm_ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + LM_GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9504,8 +9574,11 @@ static bool lm_ggml_compute_forward_mul_mat_use_blas( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; + // NOTE: with LM_GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float) + // all the experts for each batch element and the processing would become incredibly slow // TODO: find the optimal values for these - if (lm_ggml_is_contiguous(src0) && + if (dst->op != LM_GGML_OP_MUL_MAT_ID && + lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(src1) && //src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32 && @@ -9593,8 +9666,7 @@ static void lm_ggml_compute_forward_mul_mat( const void * x = (char *) src0->data + i02*nb02 + i03*nb03; const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); if (type != LM_GGML_TYPE_F32) { float * const wdata = params->wdata; @@ -9611,10 +9683,10 @@ static void lm_ggml_compute_forward_mul_mat( } cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); + ne1, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); } } @@ -9627,9 +9699,10 @@ static void lm_ggml_compute_forward_mul_mat( if (params->type == LM_GGML_TASK_INIT) { if (src1->type != vec_dot_type) { char * wdata = params->wdata; - const size_t row_size = ne10*lm_ggml_type_size(vec_dot_type)/lm_ggml_blck_size(vec_dot_type); + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == LM_GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -9649,10 +9722,10 @@ static void lm_ggml_compute_forward_mul_mat( } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ne10*lm_ggml_type_size(vec_dot_type)/lm_ggml_blck_size(vec_dot_type); + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = ne11*ne12*ne13; // src1 rows + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne1*ne12*ne13; // src1 rows //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); @@ -9694,9 +9767,9 @@ static void lm_ggml_compute_forward_mul_mat( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*ne11)); - const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; - const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); // broadcast src0 into src1 const int64_t i03 = i13/r3; @@ -9736,20 +9809,191 @@ static void lm_ggml_compute_forward_mul_mat( static void lm_ggml_compute_forward_mul_mat_id( const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * ids, + const struct lm_ggml_tensor * src1, struct lm_ggml_tensor * dst) { - const struct lm_ggml_tensor * ids = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * src0 = dst->src[2]; // only for LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum lm_ggml_type type = src0->type; + + const bool src1_cont = lm_ggml_is_contiguous(src1); + + lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + lm_ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + LM_GGML_ASSERT(ne0 == ne01); + LM_GGML_ASSERT(ne1 == ne11); + LM_GGML_ASSERT(ne2 == ne12); + LM_GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // row groups + const int id = lm_ggml_get_op_params_i32(dst, 0); + const int n_as = lm_ggml_get_op_params_i32(dst, 1); + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + LM_GGML_PAD(lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)), sizeof(int64_t)); + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + + #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + + if (params->type == LM_GGML_TASK_INIT) { + char * wdata = params->wdata; + if (src1->type != vec_dot_type) { + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); + + assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == LM_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + + // initialize matrix_row_counts + LM_GGML_ASSERT(wdata == wdata_src1_end); + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); + + LM_GGML_ASSERT(row_id >= 0 && row_id < n_as); + MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + + return; + } + + if (params->type == LM_GGML_TASK_FINALIZE) { + return; + } + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const struct lm_ggml_tensor * src0_cur = dst->src[cur_a + 2]; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - const int id = lm_ggml_get_op_params_i32(dst, 0); + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1*ne12*ne13; // src1 rows - const int a_id = ((int32_t *)ids->data)[id]; + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); - LM_GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); + // distribute the thread work across the inner or outer loop based on which one is larger - const struct lm_ggml_tensor * src0 = dst->src[a_id + 2]; + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - lm_ggml_compute_forward_mul_mat(params, src0, src1, dst); + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + continue; + } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix + const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; + const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); + const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); + } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + + #undef MMID_MATRIX_ROW } // lm_ggml_compute_forward_out_prod @@ -10161,7 +10405,7 @@ static void lm_ggml_compute_forward_set_f32( LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); // view src0 and dst with these strides and data offset inbytes during set - // nb0 is implicitely element_size because src0 and dst are contiguous + // nb0 is implicitly element_size because src0 and dst are contiguous size_t nb1 = ((int32_t *) dst->op_params)[0]; size_t nb2 = ((int32_t *) dst->op_params)[1]; size_t nb3 = ((int32_t *) dst->op_params)[2]; @@ -10325,21 +10569,30 @@ static void lm_ggml_compute_forward_get_rows_q( return; } - const int nc = src0->ne[0]; - const int nr = lm_ggml_nelements(src1); + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); LM_GGML_UNUSED(nr); + const enum lm_ggml_type type = src0->type; lm_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == lm_ggml_type_size(type)); + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == lm_ggml_type_size(type)); + assert(lm_ggml_nrows(dst) == nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } + } } } @@ -10354,19 +10607,26 @@ static void lm_ggml_compute_forward_get_rows_f16( return; } - const int nc = src0->ne[0]; - const int nr = lm_ggml_nelements(src1); + LM_GGML_TENSOR_BINARY_OP_LOCALS - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(lm_ggml_fp16_t)); + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); LM_GGML_UNUSED(nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(lm_ggml_fp16_t)); + assert(lm_ggml_nrows(dst) == nr); - for (int j = 0; j < nc; ++j) { - lm_ggml_fp16_t v = ((lm_ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = LM_GGML_FP16_TO_FP32(v); + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + lm_ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } } } } @@ -10382,19 +10642,27 @@ static void lm_ggml_compute_forward_get_rows_f32( return; } - const int nc = src0->ne[0]; - const int nr = lm_ggml_nelements(src1); + LM_GGML_TENSOR_BINARY_OP_LOCALS - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(float)); + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); LM_GGML_UNUSED(nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(lm_ggml_nrows(dst) == nr); - lm_ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + lm_ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } + } } } @@ -11306,10 +11574,13 @@ static void lm_ggml_compute_forward_rope_f32( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11332,6 +11603,14 @@ static void lm_ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -11459,10 +11738,13 @@ static void lm_ggml_compute_forward_rope_f16( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11485,6 +11767,14 @@ static void lm_ggml_compute_forward_rope_f16( dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } else { + const int64_t i0 = ic; + + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -12114,6 +12404,7 @@ static void lm_ggml_compute_forward_upscale_f32( LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; + const int nth = params->nth; LM_GGML_TENSOR_UNARY_OP_LOCALS @@ -12121,16 +12412,17 @@ static void lm_ggml_compute_forward_upscale_f32( // TODO: optimize - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = ith; i02 < ne02; i02++) { - for (int m = 0; m < dst->ne[1]; m++) { - int i01 = m / scale_factor; - for (int n = 0; n < dst->ne[0]; n++) { - int i00 = n / scale_factor; - - const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / scale_factor; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / scale_factor; - float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]); + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); *y = *x; } @@ -12155,6 +12447,64 @@ static void lm_ggml_compute_forward_upscale( } } +// lm_ggml_compute_forward_pad + +static void lm_ggml_compute_forward_pad_f32( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * src0, + struct lm_ggml_tensor * dst) { + + if (params->type == LM_GGML_TASK_INIT || params->type == LM_GGML_TASK_FINALIZE) { + return; + } + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } + } + } + } + } +} + +static void lm_ggml_compute_forward_pad( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * src0, + struct lm_ggml_tensor * dst) { + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_pad_f32(params, src0, dst); + } break; + default: + { + LM_GGML_ASSERT(false); + } break; + } +} + // lm_ggml_compute_forward_argsort static void lm_ggml_compute_forward_argsort_f32( @@ -13362,10 +13712,6 @@ static void lm_ggml_compute_forward_unary( { lm_ggml_compute_forward_silu(params, src0, dst); } break; - case LM_GGML_UNARY_OP_LEAKY: - { - lm_ggml_compute_forward_leaky(params, src0, dst); - } break; default: { LM_GGML_ASSERT(false); @@ -14041,7 +14387,7 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru } break; case LM_GGML_OP_MUL_MAT_ID: { - lm_ggml_compute_forward_mul_mat_id(params, tensor); + lm_ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor); } break; case LM_GGML_OP_OUT_PROD: { @@ -14147,10 +14493,18 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_upscale(params, tensor->src[0], tensor); } break; + case LM_GGML_OP_PAD: + { + lm_ggml_compute_forward_pad(params, tensor->src[0], tensor); + } break; case LM_GGML_OP_ARGSORT: { lm_ggml_compute_forward_argsort(params, tensor->src[0], tensor); } break; + case LM_GGML_OP_LEAKY_RELU: + { + lm_ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor); + } break; case LM_GGML_OP_FLASH_ATTN: { const int32_t t = lm_ggml_get_op_params_i32(tensor, 0); @@ -14405,7 +14759,7 @@ static struct lm_ggml_tensor * lm_ggml_recompute_graph_node( return replacements->vals[i]; } - struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, LM_GGML_MAX_DIMS, node->ne); // insert clone into replacements LM_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite @@ -14475,7 +14829,7 @@ void lm_ggml_build_backward_gradient_checkpointing( // insert new tensors recomputing src, reusing already made replacements, // remember replacements: remember new tensors with mapping from corresponding gf nodes // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacments (like checkpoints) + // unless (i.e. terminating when) input tensors are replacements (like checkpoints) node->src[k] = lm_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); } // insert rewritten backward node with replacements made into resulting backward graph gb @@ -15143,10 +15497,18 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm { LM_GGML_ASSERT(false); // TODO: not implemented } break; + case LM_GGML_OP_PAD: + { + LM_GGML_ASSERT(false); // TODO: not implemented + } break; case LM_GGML_OP_ARGSORT: { LM_GGML_ASSERT(false); // TODO: not implemented } break; + case LM_GGML_OP_LEAKY_RELU: + { + LM_GGML_ASSERT(false); // TODO: not implemented + } break; case LM_GGML_OP_FLASH_ATTN: { struct lm_ggml_tensor * flash_grad = NULL; @@ -15752,6 +16114,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { case LM_GGML_OP_ARGMAX: case LM_GGML_OP_REPEAT: case LM_GGML_OP_REPEAT_BACK: + case LM_GGML_OP_LEAKY_RELU: { n_tasks = 1; } break; @@ -15764,7 +16127,6 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { case LM_GGML_UNARY_OP_TANH: case LM_GGML_UNARY_OP_ELU: case LM_GGML_UNARY_OP_RELU: - case LM_GGML_UNARY_OP_LEAKY: { n_tasks = 1; } break; @@ -15821,7 +16183,6 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { } break; case LM_GGML_OP_MUL_MAT_ID: { - // FIXME: blas n_tasks = n_threads; } break; case LM_GGML_OP_OUT_PROD: @@ -15883,6 +16244,10 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case LM_GGML_OP_PAD: + { + n_tasks = n_threads; + } break; case LM_GGML_OP_ARGSORT: { n_tasks = n_threads; @@ -16146,25 +16511,21 @@ struct lm_ggml_cplan lm_ggml_graph_plan(struct lm_ggml_cgraph * cgraph, int n_th } else #endif if (node->src[1]->type != vec_dot_type) { - cur = lm_ggml_type_size(vec_dot_type)*lm_ggml_nelements(node->src[1])/lm_ggml_blck_size(vec_dot_type); + cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); } } break; case LM_GGML_OP_MUL_MAT_ID: { - const struct lm_ggml_tensor * a = node->src[2]; - const struct lm_ggml_tensor * b = node->src[1]; - const enum lm_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type; -#if defined(LM_GGML_USE_ACCELERATE) || defined(LM_GGML_USE_OPENBLAS) - if (lm_ggml_compute_forward_mul_mat_use_blas(a, b, node)) { - if (a->type != LM_GGML_TYPE_F32) { - // here we need memory just for single 2D matrix from src0 - cur = lm_ggml_type_size(LM_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]); - } - } else -#endif - if (b->type != vec_dot_type) { - cur = lm_ggml_type_size(vec_dot_type)*lm_ggml_nelements(b)/lm_ggml_blck_size(vec_dot_type); + const struct lm_ggml_tensor * src0 = node->src[2]; + const struct lm_ggml_tensor * src1 = node->src[1]; + const enum lm_ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)); } + const int n_as = lm_ggml_get_op_params_i32(node, 1); + cur = LM_GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows } break; case LM_GGML_OP_OUT_PROD: { @@ -16394,7 +16755,7 @@ static void lm_ggml_graph_export_leaf(const struct lm_ggml_tensor * tensor, FILE fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", lm_ggml_type_name(tensor->type), lm_ggml_op_name (tensor->op), - tensor->n_dims, + lm_ggml_n_dims(tensor), ne[0], ne[1], ne[2], ne[3], nb[0], nb[1], nb[2], nb[3], tensor->data, @@ -16409,7 +16770,7 @@ static void lm_ggml_graph_export_node(const struct lm_ggml_tensor * tensor, cons arg, lm_ggml_type_name(tensor->type), lm_ggml_op_name (tensor->op), - tensor->n_dims, + lm_ggml_n_dims(tensor), ne[0], ne[1], ne[2], ne[3], nb[0], nb[1], nb[2], nb[3], tensor->data, @@ -16499,11 +16860,9 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna const uint32_t type = tensor->type; const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -16533,11 +16892,9 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna const uint32_t type = tensor->type; const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -16709,12 +17066,10 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ { uint32_t type; uint32_t op; - uint32_t n_dims; for (uint32_t i = 0; i < n_leafs; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); int64_t ne[LM_GGML_MAX_DIMS]; size_t nb[LM_GGML_MAX_DIMS]; @@ -16730,7 +17085,7 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ nb[j] = nb_cur; } - struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, n_dims, ne); + struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne); tensor->op = (enum lm_ggml_op) op; @@ -16747,7 +17102,7 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ ptr += lm_ggml_nbytes(tensor); - fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, lm_ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); } } @@ -16757,12 +17112,10 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ { uint32_t type; uint32_t op; - uint32_t n_dims; for (uint32_t i = 0; i < n_nodes; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); enum lm_ggml_op eop = (enum lm_ggml_op) op; @@ -16833,7 +17186,7 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ } break; default: { - tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, n_dims, ne); + tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne); tensor->op = eop; } break; @@ -16852,7 +17205,7 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_ result->nodes[i] = tensor; - fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, lm_ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); } } } @@ -16990,7 +17343,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg fprintf(fp, "(%s)|", lm_ggml_type_name(node->type)); } - if (node->n_dims == 2) { + if (lm_ggml_is_matrix(node)) { fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], lm_ggml_op_symbol(node->op)); } else { fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_symbol(node->op)); @@ -17257,7 +17610,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam( int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = lm_ggml_nelements(ps[p]); - const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; + const float p_decay = ((lm_ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { float x = lm_ggml_get_f32_1d(ps[p], j); float g_ = g[i]*gnorm; @@ -18531,7 +18884,7 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg return NULL; } - const size_t size_cur = (ne*lm_ggml_type_size(info->type))/lm_ggml_blck_size(info->type); + const size_t size_cur = lm_ggml_row_size(info->type, ne); ctx->size += LM_GGML_PAD(size_cur, ctx->alignment); } @@ -19035,8 +19388,8 @@ void lm_gguf_add_tensor( ctx->infos[idx].ne[i] = 1; } - ctx->infos[idx].n_dims = tensor->n_dims; - for (int i = 0; i < tensor->n_dims; i++) { + ctx->infos[idx].n_dims = lm_ggml_n_dims(tensor); + for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) { ctx->infos[idx].ne[i] = tensor->ne[i]; } diff --git a/cpp/ggml.h b/cpp/ggml.h index 2000b8f..e9f15cb 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -215,9 +215,9 @@ #define LM_GGML_QNT_VERSION_FACTOR 1000 // do not change this #define LM_GGML_MAX_DIMS 4 -#define LM_GGML_MAX_PARAMS 1024 +#define LM_GGML_MAX_PARAMS 2048 #define LM_GGML_MAX_CONTEXTS 64 -#define LM_GGML_MAX_SRC 6 +#define LM_GGML_MAX_SRC 10 #define LM_GGML_MAX_NAME 64 #define LM_GGML_MAX_OP_PARAMS 64 #define LM_GGML_DEFAULT_N_THREADS 4 @@ -343,6 +343,12 @@ extern "C" { LM_GGML_TYPE_COUNT, }; + // precision + enum lm_ggml_prec { + LM_GGML_PREC_DEFAULT, + LM_GGML_PREC_F32, + }; + enum lm_ggml_backend_type { LM_GGML_BACKEND_CPU = 0, LM_GGML_BACKEND_GPU = 10, @@ -423,7 +429,9 @@ extern "C" { LM_GGML_OP_POOL_1D, LM_GGML_OP_POOL_2D, LM_GGML_OP_UPSCALE, // nearest interpolate + LM_GGML_OP_PAD, LM_GGML_OP_ARGSORT, + LM_GGML_OP_LEAKY_RELU, LM_GGML_OP_FLASH_ATTN, LM_GGML_OP_FLASH_FF, @@ -463,7 +471,6 @@ extern "C" { LM_GGML_UNARY_OP_GELU, LM_GGML_UNARY_OP_GELU_QUICK, LM_GGML_UNARY_OP_SILU, - LM_GGML_UNARY_OP_LEAKY, LM_GGML_UNARY_OP_COUNT, }; @@ -501,7 +508,6 @@ extern "C" { struct lm_ggml_backend_buffer * buffer; - int n_dims; int64_t ne[LM_GGML_MAX_DIMS]; // number of elements size_t nb[LM_GGML_MAX_DIMS]; // stride in bytes: // nb[0] = lm_ggml_type_size(type) @@ -533,7 +539,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[12]; + char padding[8]; }; static const size_t LM_GGML_TENSOR_SIZE = sizeof(struct lm_ggml_tensor); @@ -638,11 +644,14 @@ extern "C" { LM_GGML_API int64_t lm_ggml_nrows (const struct lm_ggml_tensor * tensor); LM_GGML_API size_t lm_ggml_nbytes (const struct lm_ggml_tensor * tensor); LM_GGML_API size_t lm_ggml_nbytes_pad (const struct lm_ggml_tensor * tensor); // same as lm_ggml_nbytes() but padded to LM_GGML_MEM_ALIGN - LM_GGML_API size_t lm_ggml_nbytes_split(const struct lm_ggml_tensor * tensor, int nrows_split); - LM_GGML_API int lm_ggml_blck_size (enum lm_ggml_type type); - LM_GGML_API size_t lm_ggml_type_size (enum lm_ggml_type type); // size in bytes for all elements in a block - LM_GGML_API float lm_ggml_type_sizef(enum lm_ggml_type type); // lm_ggml_type_size()/lm_ggml_blck_size() as float + LM_GGML_API int lm_ggml_blck_size(enum lm_ggml_type type); + LM_GGML_API size_t lm_ggml_type_size(enum lm_ggml_type type); // size in bytes for all elements in a block + LM_GGML_API size_t lm_ggml_row_size (enum lm_ggml_type type, int64_t ne); // size in bytes for all elements in a row + + LM_GGML_DEPRECATED( + LM_GGML_API double lm_ggml_type_sizef(enum lm_ggml_type type), // lm_ggml_type_size()/lm_ggml_blck_size() as float + "use lm_ggml_row_size() instead"); LM_GGML_API const char * lm_ggml_type_name(enum lm_ggml_type type); LM_GGML_API const char * lm_ggml_op_name (enum lm_ggml_op op); @@ -661,6 +670,11 @@ extern "C" { LM_GGML_API bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor); LM_GGML_API bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor); LM_GGML_API bool lm_ggml_is_permuted (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_scalar (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_vector (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_matrix (const struct lm_ggml_tensor * tensor); + LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor); + LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars LM_GGML_API bool lm_ggml_are_same_shape(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1); @@ -793,6 +807,9 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b); + // dst = a + // view(dst, nb1, nb2, nb3, offset) += b + // return dst LM_GGML_API struct lm_ggml_tensor * lm_ggml_acc( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -957,15 +974,14 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_leaky( + LM_GGML_API struct lm_ggml_tensor * lm_ggml_leaky_relu( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a); + struct lm_ggml_tensor * a, float negative_slope, bool inplace); LM_GGML_API struct lm_ggml_tensor * lm_ggml_relu_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); - // TODO: double-check this computation is correct LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); @@ -1047,11 +1063,18 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b); + // change the precision of a matrix multiplication + // set to LM_GGML_PREC_F32 for higher precision (useful for phi-2) + LM_GGML_API void lm_ggml_mul_mat_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec); + // indirect matrix multiplication // lm_ggml_mul_mat_id(ctx, as, ids, id, b) ~= lm_ggml_mul_mat(as[ids[id]], b) LM_GGML_API struct lm_ggml_tensor * lm_ggml_mul_mat_id( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * as[], + struct lm_ggml_tensor * const as[], + int n_as, struct lm_ggml_tensor * ids, int id, struct lm_ggml_tensor * b); @@ -1263,6 +1286,7 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a); + // supports 3D: a->ne[2] == b->ne[1] LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_rows( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -1549,6 +1573,15 @@ extern "C" { struct lm_ggml_tensor * a, int scale_factor); + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] + LM_GGML_API struct lm_ggml_tensor * lm_ggml_pad( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + // sort rows enum lm_ggml_sort_order { LM_GGML_SORT_ASC, diff --git a/cpp/llama.cpp b/cpp/llama.cpp index f1de553..db1151f 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -91,7 +91,8 @@ #define LLAMA_ATTRIBUTE_FORMAT(...) #endif -#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_EXPERTS 8 // // logging @@ -205,6 +206,7 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, + LLM_ARCH_PHI2, LLM_ARCH_UNKNOWN, }; @@ -222,6 +224,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_PHI2, "phi2" }, }; enum llm_kv { @@ -242,6 +245,8 @@ enum llm_kv { LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -292,6 +297,8 @@ static std::map LLM_KV_NAMES = { { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -349,10 +356,14 @@ enum llm_tensor { LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, }; @@ -371,10 +382,14 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, }, }, { @@ -548,6 +563,19 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHI2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -596,6 +624,10 @@ struct LLM_TN { std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix; + } }; // @@ -1175,6 +1207,8 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_ff; + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; float f_norm_eps; float f_norm_rms_eps; @@ -1189,15 +1223,18 @@ struct llama_hparams { float f_max_alibi_bias; bool operator!=(const llama_hparams & other) const { - if (this->vocab_only != other.vocab_only) return true; - if (this->n_vocab != other.n_vocab) return true; - if (this->n_ctx_train != other.n_ctx_train) return true; - if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; - if (this->n_layer != other.n_layer) return true; - if (this->n_rot != other.n_rot) return true; - if (this->n_ff != other.n_ff) return true; + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + if (this->n_expert != other.n_expert) return true; + if (this->n_expert_used != other.n_expert_used) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1279,6 +1316,12 @@ struct llama_layer { struct lm_ggml_tensor * ffn_down; // w2 struct lm_ggml_tensor * ffn_up; // w3 + // ff MoE + struct lm_ggml_tensor * ffn_gate_inp; + struct lm_ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS]; + struct lm_ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS]; + struct lm_ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS]; + // ff bias struct lm_ggml_tensor * ffn_down_b; // b2 struct lm_ggml_tensor * ffn_up_b; // b3 @@ -1403,6 +1446,7 @@ struct llama_model { struct lm_ggml_tensor * output_norm; struct lm_ggml_tensor * output_norm_b; struct lm_ggml_tensor * output; + struct lm_ggml_tensor * output_b; std::vector layers; @@ -1488,6 +1532,10 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; +#ifndef NDEBUG + // guard against access to unset logits + std::vector logits_valid; +#endif bool logits_all = false; // input embedding (1-dimensional array: [n_embd]) @@ -1538,7 +1586,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); - cache.buf.resize(n_elements*(lm_ggml_type_sizef(ktype) + lm_ggml_type_sizef(vtype)) + 2u*n_layer*lm_ggml_tensor_overhead()); + cache.buf.resize(lm_ggml_row_size(ktype, n_elements) + lm_ggml_row_size(vtype, n_elements) + 2u*n_layer*lm_ggml_tensor_overhead()); memset(cache.buf.data, 0, cache.buf.size); struct lm_ggml_init_params params; @@ -1916,7 +1964,7 @@ namespace GGUFMeta { target = override->bool_value; return true; } - return true; + return false; } template @@ -2376,25 +2424,25 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "mostly Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + return "Q4_1, some F16"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "mostly Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; default: return "unknown, may not work"; } @@ -2451,6 +2499,16 @@ static void llm_load_hparams( ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + LM_GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + LM_GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + LM_GGML_ASSERT(hparams.n_expert_used > 0); + } else { + LM_GGML_ASSERT(hparams.n_expert_used == 0); + } // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; @@ -2502,6 +2560,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = e_model::MODEL_13B; break; @@ -2603,6 +2662,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PHI2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -2769,7 +2837,7 @@ static void llm_load_vocab( // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer // are special tokens. - // From testing, this appears to corelate 1:1 with special tokens. + // From testing, this appears to correlate 1:1 with special tokens. // // Counting special tokens and verifying in only one direction @@ -2882,6 +2950,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); @@ -2953,7 +3023,7 @@ static void llm_load_tensors( (void) main_gpu; - enum lm_ggml_backend_type llama_backend_offload = LM_GGML_BACKEND_CPU; + enum lm_ggml_backend_type llama_backend_offload = LM_GGML_BACKEND_CPU; enum lm_ggml_backend_type llama_backend_offload_split = LM_GGML_BACKEND_CPU; #ifdef LM_GGML_USE_CUBLAS @@ -3036,9 +3106,26 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false); + + if (layer.ffn_gate_inp == nullptr) { + LM_GGML_ASSERT(hparams.n_expert == 0); + LM_GGML_ASSERT(hparams.n_expert_used == 0); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + } else { + LM_GGML_ASSERT(hparams.n_expert > 0); + LM_GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split); + layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + } + } if (backend == LM_GGML_BACKEND_GPU) { vram_weights += @@ -3048,8 +3135,18 @@ static void llm_load_tensors( (layer.bk ? lm_ggml_nbytes(layer.bk) : 0) + (layer.bv ? lm_ggml_nbytes(layer.bv) : 0) + (layer.bo ? lm_ggml_nbytes(layer.bo) : 0) + - lm_ggml_nbytes(layer.ffn_norm) + lm_ggml_nbytes(layer.ffn_gate) + - lm_ggml_nbytes(layer.ffn_down) + lm_ggml_nbytes(layer.ffn_up); + lm_ggml_nbytes(layer.ffn_norm); + + if (layer.ffn_gate_inp == nullptr) { + vram_weights += + lm_ggml_nbytes(layer.ffn_gate) + lm_ggml_nbytes(layer.ffn_down) + lm_ggml_nbytes(layer.ffn_up); + } else { + vram_weights += lm_ggml_nbytes(layer.ffn_gate_inp); + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + vram_weights += + lm_ggml_nbytes(layer.ffn_gate_exp[x]) + lm_ggml_nbytes(layer.ffn_down_exp[x]) + lm_ggml_nbytes(layer.ffn_up_exp[x]); + } + } } } } break; @@ -3569,7 +3666,73 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_PHI2: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, LM_GGML_BACKEND_CPU); + + // output + { + lm_ggml_backend_type backend_norm; + lm_ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload; + } else { + backend_norm = LM_GGML_BACKEND_CPU; + backend_output = LM_GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); + + if (backend_norm == LM_GGML_BACKEND_GPU) { + vram_weights += lm_ggml_nbytes(model.output_norm); + vram_weights += lm_ggml_nbytes(model.output_norm_b); + vram_weights += lm_ggml_nbytes(model.output); + vram_weights += lm_ggml_nbytes(model.output_b); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const lm_ggml_backend_type backend = int(i) < i_gpu_start ? LM_GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const lm_ggml_backend_type backend_split = int(i) < i_gpu_start ? LM_GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + + if (backend == LM_GGML_BACKEND_GPU) { + vram_weights += + lm_ggml_nbytes(layer.attn_norm) + lm_ggml_nbytes(layer.attn_norm_b) + + lm_ggml_nbytes(layer.wqkv) + lm_ggml_nbytes(layer.bqkv) + + lm_ggml_nbytes(layer.wo) + lm_ggml_nbytes(layer.bo) + + lm_ggml_nbytes(layer.ffn_up) + lm_ggml_nbytes(layer.ffn_up_b) + + lm_ggml_nbytes(layer.ffn_down) + lm_ggml_nbytes(layer.ffn_down_b); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3766,8 +3929,8 @@ static void llm_build_k_shift( lm_ggml_rope_custom_inplace(ctx, lm_ggml_view_3d(ctx, kv.k_l[il], n_embd_head, n_head_kv, n_ctx, - lm_ggml_type_sizef(kv.k_l[il]->type)*n_embd_head, - lm_ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa, + lm_ggml_row_size(kv.k_l[il]->type, n_embd_head), + lm_ggml_row_size(kv.k_l[il]->type, n_embd_gqa), 0), K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -3796,7 +3959,7 @@ static void llm_build_kv_store( cb(v_cur_t, "v_cur_t", il); struct lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa, - (lm_ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head); + (lm_ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); struct lm_ggml_tensor * v_cache_view = lm_ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa, @@ -3930,6 +4093,7 @@ static struct lm_ggml_tensor * llm_build_ffn( // if max_alibi_bias > 0 then apply ALiBi static struct lm_ggml_tensor * llm_build_kqv( struct lm_ggml_context * ctx, + const llama_model & model, const llama_hparams & hparams, const llama_kv_cache & kv, struct lm_ggml_tensor * wo, @@ -3941,6 +4105,7 @@ static struct lm_ggml_tensor * llm_build_kqv( int32_t n_tokens, int32_t n_kv, float max_alibi_bias, + float scale, const llm_build_cb & cb, int il) { const int64_t n_embd = hparams.n_embd; @@ -3955,14 +4120,20 @@ static struct lm_ggml_tensor * llm_build_kqv( struct lm_ggml_tensor * k = lm_ggml_view_3d(ctx, kv.k_l[il], n_embd_head, n_kv, n_head_kv, - lm_ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa, - lm_ggml_type_sizef(kv.k_l[il]->type)*n_embd_head, + lm_ggml_row_size(kv.k_l[il]->type, n_embd_gqa), + lm_ggml_row_size(kv.k_l[il]->type, n_embd_head), 0); cb(k, "k", il); struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + lm_ggml_mul_mat_set_prec(kq, LM_GGML_PREC_F32); + } + if (max_alibi_bias > 0.0f) { // temporary branch until we figure out how to handle lm_ggml_alibi through lm_ggml_add kq = lm_ggml_scale(ctx, kq, kq_scale); @@ -3982,7 +4153,7 @@ static struct lm_ggml_tensor * llm_build_kqv( kq = lm_ggml_soft_max(ctx, kq); cb(kq, "kq_soft_max", il); } else { - kq = lm_ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); + kq = lm_ggml_soft_max_ext(ctx, kq, kq_mask, scale); cb(kq, "kq_soft_max_ext", il); } @@ -4030,6 +4201,8 @@ struct llm_build_context { const int64_t n_head_kv; const int64_t n_embd_head; const int64_t n_embd_gqa; + const int64_t n_expert; + const int64_t n_expert_used; const float freq_base; const float freq_scale; @@ -4071,6 +4244,8 @@ struct llm_build_context { n_head_kv (hparams.n_head_kv), n_embd_head (hparams.n_embd_head()), n_embd_gqa (hparams.n_embd_gqa()), + n_expert (hparams.n_expert), + n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -4185,9 +4360,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4195,7 +4370,7 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - { + if (model.layers[il].ffn_gate_inp == nullptr) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -4207,6 +4382,69 @@ struct llm_build_context { model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + lm_ggml_tensor * logits = lm_ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + + lm_ggml_tensor * probs = lm_ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); + + // select experts + lm_ggml_tensor * selected_experts = lm_ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + + lm_ggml_tensor * weights = lm_ggml_get_rows(ctx0, + lm_ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = lm_ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + + lm_ggml_tensor * weights_sum = lm_ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = lm_ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + + // compute expert outputs + lm_ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + lm_ggml_tensor * cur_expert; + + lm_ggml_tensor * cur_up = lm_ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + cb(cur_up, "ffn_moe_up", il); + + lm_ggml_tensor * cur_gate = lm_ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); + cb(cur_gate, "ffn_moe_gate", il); + + cur_gate = lm_ggml_silu(ctx0, cur_gate); + cb(cur_gate, "ffn_moe_silu", il); + + cur_expert = lm_ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_gate_par", il); + + cur_expert = lm_ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_down", il); + + cur_expert = lm_ggml_mul(ctx0, cur_expert, + lm_ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = lm_ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); + } + } + + cur = moe_out; } cur = lm_ggml_add(ctx0, cur, ffn_inp); @@ -4305,9 +4543,9 @@ struct llm_build_context { // apply ALiBi for 13B model const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f; - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4429,9 +4667,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4529,9 +4767,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4738,9 +4976,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); // TODO: not tested, could be broken - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4829,9 +5067,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4926,9 +5164,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5020,9 +5258,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5133,9 +5371,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5192,15 +5430,15 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // inp_pos - contains the positions - struct lm_ggml_tensor * inp_pos= lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens); + struct lm_ggml_tensor * inp_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); // KQ_scale - struct lm_ggml_tensor * KQ_scale= lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, 1); + struct lm_ggml_tensor * KQ_scale = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct lm_ggml_tensor * KQ_mask= lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, n_kv, n_tokens, 1); + struct lm_ggml_tensor * KQ_mask = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5250,9 +5488,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5294,6 +5532,122 @@ struct llm_build_context { lm_ggml_build_forward_expand(gf, cur); + return gf; + } + struct lm_ggml_cgraph * build_phi2() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * attn_norm_output; + struct lm_ggml_tensor * ffn_output; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct lm_ggml_tensor * inp_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // Q_scale + struct lm_ggml_tensor * Q_scale = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, 1); + cb(Q_scale, "Q_scale", -1); + + // KQ_scale + struct lm_ggml_tensor * KQ_scale = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct lm_ggml_tensor * KQ_mask = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(attn_norm_output, "attn_norm", il); + + // self-attention + { + cur = lm_ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + cur = lm_ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct lm_ggml_tensor * Qcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct lm_ggml_tensor * Kcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct lm_ggml_tensor * Vcur = lm_ggml_cont(ctx0, lm_ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = lm_ggml_rope_custom( + ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Qcur = lm_ggml_scale(ctx0, Qcur, Q_scale); + cb(Qcur, "Qcur", il); + + Kcur = lm_ggml_rope_custom( + ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + // FF + { + ffn_output = llm_build_ffn(ctx0, attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(ffn_output, "ffn_out", il); + } + + cur = lm_ggml_add(ctx0, cur, ffn_output); + cb(cur, "l_out", il); + + cur = lm_ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = lm_ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output_no_bias", -1); + + cur = lm_ggml_add(ctx0, cur, model.output_b); + cb(cur, "result_output", -1); + + lm_ggml_build_forward_expand(gf, cur); + return gf; } }; @@ -5309,7 +5663,7 @@ enum llm_offload_func_e { OFFLOAD_FUNC_FRC, // force offload OFFLOAD_FUNC_KQV, OFFLOAD_FUNC_NR, - OFFLOAD_FUNC_EMB, + OFFLOAD_FUNC_EMB, // embeddings OFFLOAD_FUNC_OUT, }; @@ -5394,6 +5748,7 @@ static const std::unordered_map k_offload_map { "pos_embd", OFFLOAD_FUNC_NR }, { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) + { "Q_scale", OFFLOAD_FUNC_FRC }, { "KQ_scale", OFFLOAD_FUNC_FRC }, { "KQ_mask", OFFLOAD_FUNC_FRC }, { "K_shift", OFFLOAD_FUNC_FRC }, @@ -5461,9 +5816,24 @@ static const std::unordered_map k_offload_map { "ffn_relu", OFFLOAD_FUNC }, { "ffn_sqr(relu)", OFFLOAD_FUNC }, + { "ffn_moe_logits", OFFLOAD_FUNC }, + { "ffn_moe_probs", OFFLOAD_FUNC }, + { "ffn_moe_argsort", OFFLOAD_FUNC }, + { "ffn_moe_weights", OFFLOAD_FUNC }, + { "ffn_moe_weights_sum", OFFLOAD_FUNC }, + { "ffn_moe_weights_norm", OFFLOAD_FUNC }, + { "ffn_moe_weighted", OFFLOAD_FUNC }, + { "ffn_moe_up", OFFLOAD_FUNC }, + { "ffn_moe_gate", OFFLOAD_FUNC }, + { "ffn_moe_silu", OFFLOAD_FUNC }, + { "ffn_moe_gate_par", OFFLOAD_FUNC }, + { "ffn_moe_down", OFFLOAD_FUNC }, + { "ffn_moe_out", OFFLOAD_FUNC }, + { "l_out", OFFLOAD_FUNC }, { "result_norm", OFFLOAD_FUNC_EMB }, + { "result_output_no_bias", OFFLOAD_FUNC_EMB }, { "result_output", OFFLOAD_FUNC_OUT }, }; @@ -5481,6 +5851,7 @@ static struct lm_ggml_cgraph * llama_build_graph( bool alloc_inp_tokens = false; bool alloc_inp_embd = false; bool alloc_inp_pos = false; + bool alloc_inp_Q_scale = false; bool alloc_inp_KQ_scale = false; bool alloc_inp_KQ_mask = false; bool alloc_inp_K_shift = false; @@ -5548,7 +5919,7 @@ static struct lm_ggml_cgraph * llama_build_graph( alloc_inp_pos = true; } - if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { + if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) { lm_ggml_allocr_alloc(lctx.alloc, cur); if (!lm_ggml_allocr_is_measure(lctx.alloc)) { @@ -5556,6 +5927,23 @@ static struct lm_ggml_cgraph * llama_build_graph( lm_ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); } + alloc_inp_Q_scale = true; + } + + if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { + lm_ggml_allocr_alloc(lctx.alloc, cur); + + if (!lm_ggml_allocr_is_measure(lctx.alloc)) { + const int64_t n_embd_head = model.hparams.n_embd_head(); + if (model.arch == LLM_ARCH_PHI2) { + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + lm_ggml_set_f32(cur, 1.0f); + } else { + lm_ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); + } + } + alloc_inp_KQ_scale = true; } @@ -5780,6 +6168,10 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; + case LLM_ARCH_PHI2: + { + result = llm.build_phi2(); + } break; default: LM_GGML_ASSERT(false); } @@ -5857,7 +6249,7 @@ static int llama_decode_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; - // helpers for smoother batch API transistion + // helpers for smoother batch API transition // after deprecating the llama_eval calls, these will be removed std::vector pos; @@ -5913,12 +6305,16 @@ static int llama_decode_internal( lm_ggml_allocr_alloc_graph(lctx.alloc, gf); - struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct lm_ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0); - LM_GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + // the output is always the last tensor in the graph + struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0); + // the embeddings could be the second to last tensor, or the third to last tensor + struct lm_ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + LM_GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } #ifdef LM_GGML_USE_CUBLAS for (int i = 0; i < gf->n_leafs; i++) { @@ -6013,6 +6409,14 @@ static int llama_decode_internal( { auto & logits_out = lctx.logits; +#ifndef NDEBUG + auto & logits_valid = lctx.logits_valid; + logits_valid.clear(); + logits_valid.resize(n_tokens); + + logits_out.clear(); +#endif + if (batch.logits) { logits_out.resize(n_vocab * n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -6020,13 +6424,22 @@ static int llama_decode_internal( continue; } memcpy(logits_out.data() + (n_vocab*i), (float *) lm_ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[i] = true; +#endif } } else if (lctx.logits_all) { logits_out.resize(n_vocab * n_tokens); memcpy(logits_out.data(), (float *) lm_ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); +#ifndef NDEBUG + std::fill(logits_valid.begin(), logits_valid.end(), true); +#endif } else { logits_out.resize(n_vocab); memcpy(logits_out.data(), (float *) lm_ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[0] = true; +#endif } } @@ -6636,12 +7049,12 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // loop over the text while (true) { - // find the first occurence of a given special token in this fragment + // find the first occurrence of a given special token in this fragment // passing offset argument only limit the "search area" but match coordinates // are still relative to the source full raw_text auto match = raw_text->find(special_token, raw_text_base_offset); - // no occurences found, stop processing this fragment for a given special token + // no occurrences found, stop processing this fragment for a given special token if (match == std::string::npos) break; // check if match is within bounds of offset <-> length @@ -7840,7 +8253,7 @@ struct llama_beam_search_data { } // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). - // The repetative patterns below reflect the 2 stages of heaps: + // The repetitive patterns below reflect the 2 stages of heaps: // * Gather elements until the vector is full, then call std::make_heap() on it. // * If the heap is full and a new element is found that should be included, pop the // least element to the back(), replace it with the new, then push it into the heap. @@ -8078,11 +8491,9 @@ static void llama_convert_tensor_internal( workers.clear(); } -static lm_ggml_type get_k_quant_type( - quantize_state_internal & qs, - lm_ggml_type new_type, const lm_ggml_tensor * tensor, llama_ftype ftype -) { +static lm_ggml_type get_k_quant_type(quantize_state_internal & qs, lm_ggml_type new_type, const lm_ggml_tensor * tensor, llama_ftype ftype) { const std::string name = lm_ggml_get_name(tensor); + // TODO: avoid hardcoded tensor names - use the TN_* constants const llm_arch arch = qs.model.arch; const auto tn = LLM_TN(arch); @@ -8116,7 +8527,18 @@ static lm_ggml_type get_k_quant_type( // nearly negligible increase in model size by quantizing this tensor with more bits: if (new_type == LM_GGML_TYPE_Q3_K || new_type == LM_GGML_TYPE_Q4_K) new_type = LM_GGML_TYPE_Q5_K; } + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = LM_GGML_TYPE_Q8_0; + } ++qs.i_attention_wv; + } else if (name.find("attn_k.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = LM_GGML_TYPE_Q8_0; + } } else if (name.find("ffn_down.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = LM_GGML_TYPE_Q3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { @@ -8325,10 +8747,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? // quantize only 2D tensors - quantize &= (tensor->n_dims == 2); + quantize &= (lm_ggml_n_dims(tensor) == 2); quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= !params->only_copy; + // do not quantize expert gating tensors + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + enum lm_ggml_type new_type; void * new_data; size_t new_size; @@ -8477,53 +8902,60 @@ static int llama_apply_lora_from_file_internal( const int64_t t_start_lora_us = lm_ggml_time_us(); - auto fin = std::ifstream(path_lora, std::ios::binary); - if (!fin) { - LLAMA_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_lora); - return 1; - } + llama_file fin(path_lora, "rb"); // verify magic and version { - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - uint32_t format_version; - fin.read((char *) &format_version, sizeof(format_version)); + uint32_t magic = fin.read_u32(); + if (magic != LLAMA_FILE_MAGIC_GGLA) { + LLAMA_LOG_ERROR("%s: bad file magic\n", __func__); + return 1; + } + uint32_t format_version = fin.read_u32(); if (format_version != 1) { LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); return 1; } } - int32_t lora_r; - int32_t lora_alpha; - fin.read((char *) &lora_r, sizeof(lora_r)); - fin.read((char *) &lora_alpha, sizeof(lora_alpha)); + int32_t lora_r = fin.read_u32(); + int32_t lora_alpha = fin.read_u32(); float scaling = scale * (float)lora_alpha / (float)lora_r; LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + // create a name -> tensor map of the model to accelerate lookups + // find the max tensor size to estimate the required temporary buffer size + size_t max_tensor_size = 0; + std::unordered_map model_tensors; + for (const auto & kv : model.tensors_by_name) { + model_tensors.insert(kv); + size_t f32_size = lm_ggml_nelements(kv.second) * sizeof(float); + max_tensor_size = std::max(max_tensor_size, f32_size); + } + // create a temporary ggml context to store the lora tensors - // todo: calculate size from biggest possible tensor - std::vector lora_buf(1024ull * 1024ull * 1024ull); + // TODO: use ggml-alloc + size_t lora_ctx_size = max_tensor_size * 3; + LLAMA_LOG_INFO("%s: allocating %.f MB for lora temporary buffer\n", __func__, lora_ctx_size / 1024.0 / 1024.0); + std::vector lora_buf(lora_ctx_size); + struct lm_ggml_init_params params; params.mem_size = lora_buf.size(); params.mem_buffer = lora_buf.data(); params.no_alloc = false; - lm_ggml_context * lora_ctx = lm_ggml_init(params); - std::unordered_map lora_tensors; + using unique_context = std::unique_ptr; - // create a name -> tensor map of the model to accelerate lookups - std::unordered_map model_tensors; - for (const auto & kv : model.tensors_by_name) { - model_tensors.insert(kv); - } + unique_context lora_ctx(nullptr, lm_ggml_free); + lora_ctx.reset(lm_ggml_init(params)); + std::unordered_map lora_tensors; // load base model std::unique_ptr ml; - lm_ggml_context * base_ctx = NULL; + + unique_context base_ctx(nullptr, lm_ggml_free); std::vector base_buf; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); @@ -8532,6 +8964,7 @@ static int llama_apply_lora_from_file_internal( size_t ctx_size; size_t mmapped_size; ml->calc_sizes(ctx_size, mmapped_size); + base_buf.resize(ctx_size); lm_ggml_init_params base_params; @@ -8539,9 +8972,9 @@ static int llama_apply_lora_from_file_internal( base_params.mem_buffer = base_buf.data(); base_params.no_alloc = ml->use_mmap; - base_ctx = lm_ggml_init(base_params); + base_ctx.reset(lm_ggml_init(base_params)); - // maybe this should in llama_model_loader + // maybe this should be in llama_model_loader if (ml->use_mmap) { ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, lm_ggml_is_numa())); } @@ -8554,27 +8987,35 @@ static int llama_apply_lora_from_file_internal( std::vector work_buffer; while (true) { + if (fin.tell() == fin.size) { + // eof + break; + } + int32_t n_dims; - int32_t length; + int32_t name_len; int32_t ftype; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - if (fin.eof()) { - break; + fin.read_raw(&n_dims, sizeof(n_dims)); + fin.read_raw(&name_len, sizeof(name_len)); + fin.read_raw(&ftype, sizeof(ftype)); + + if (n_dims != 1 && n_dims != 2) { + LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; } int32_t ne[2] = { 1, 1 }; for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + fin.read_raw(&ne[i], sizeof(ne[i])); } std::string name; { + LM_GGML_ASSERT(name_len <= 1024); char buf[1024]; - fin.read(buf, length); - name = std::string(buf, length); + fin.read_raw(buf, name_len); + name = std::string(buf, name_len); } // check for lora suffix and get the type of tensor @@ -8588,7 +9029,7 @@ static int llama_apply_lora_from_file_internal( std::string lora_type = name.substr(pos + lora_suffix.length()); std::string base_name = name; base_name.erase(pos); - // LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); + // LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(), base_name.c_str(), lora_type.c_str()); if (model_tensors.find(base_name) == model_tensors.end()) { LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); @@ -8607,22 +9048,15 @@ static int llama_apply_lora_from_file_internal( return false; } } - lm_ggml_tensor * lora_tensor; - if (n_dims == 2) { - lora_tensor = lm_ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); - } - else { - LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); - return 1; - } - lm_ggml_set_name(lora_tensor, "lora_tensor"); + lm_ggml_tensor * lora_tensor = lm_ggml_new_tensor_2d(lora_ctx.get(), wtype, ne[0], ne[1]); + lm_ggml_set_name(lora_tensor, name.c_str()); // load tensor data - size_t offset = fin.tellg(); + size_t offset = fin.tell(); size_t tensor_data_size = lm_ggml_nbytes(lora_tensor); offset = (offset + 31) & -32; - fin.seekg(offset); - fin.read((char*)lora_tensor->data, tensor_data_size); + fin.seek(offset, SEEK_SET); + fin.read_raw(lora_tensor->data, tensor_data_size); lora_tensors[name] = lora_tensor; @@ -8652,13 +9086,11 @@ static int llama_apply_lora_from_file_internal( // load from base model if (lm_gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { - // TODO: throw LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); return 1; } - // TODO: not tested!! maybe not working! - base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, LM_GGML_BACKEND_CPU); + base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, LM_GGML_BACKEND_CPU); ml->load_data_for(base_t); } else { base_t = dest_t; @@ -8687,43 +9119,45 @@ static int llama_apply_lora_from_file_internal( } // w = w + BA*s - lm_ggml_tensor * BA = lm_ggml_mul_mat(lora_ctx, loraA, loraB); + lm_ggml_tensor * BA = lm_ggml_mul_mat(lora_ctx.get(), loraA, loraB); offload_func(BA); lm_ggml_set_name(BA, "BA"); if (scaling != 1.0f) { - lm_ggml_tensor * scale_tensor = lm_ggml_new_f32(lora_ctx, scaling); + lm_ggml_tensor * scale_tensor = lm_ggml_new_f32(lora_ctx.get(), scaling); lm_ggml_set_name(scale_tensor, "scale_tensor"); - BA = lm_ggml_scale_inplace(lora_ctx, BA, scale_tensor); + BA = lm_ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor); offload_func(BA); lm_ggml_set_name(BA, "BA_scaled"); } lm_ggml_tensor * r; if (base_t == dest_t) { - r = lm_ggml_add_inplace(lora_ctx, dest_t, BA); + r = lm_ggml_add_inplace(lora_ctx.get(), dest_t, BA); offload_func_force_inplace(r); lm_ggml_set_name(r, "r_add_inplace"); } else { - r = lm_ggml_add(lora_ctx, base_t, BA); + r = lm_ggml_add(lora_ctx.get(), base_t, BA); offload_func(r); lm_ggml_set_name(r, "r_add"); - r = lm_ggml_cpy(lora_ctx, r, dest_t); + r = lm_ggml_cpy(lora_ctx.get(), r, dest_t); offload_func(r); lm_ggml_set_name(r, "r_cpy"); } - struct lm_ggml_cgraph * gf = lm_ggml_new_graph(lora_ctx); + struct lm_ggml_cgraph * gf = lm_ggml_new_graph(lora_ctx.get()); lm_ggml_build_forward_expand(gf, r); lm_ggml_graph_compute_helper(work_buffer, gf, n_threads); + // the tensors in the adapter must be sorted such that loraA and loraB of the same tensor are next to each other + LM_GGML_ASSERT(lora_tensors.size() == 2); + // we won't need these tensors again, reset the context to save memory - lm_ggml_free(lora_ctx); - lora_ctx = lm_ggml_init(params); + lora_ctx.reset(lm_ggml_init(params)); lora_tensors.clear(); n_tensors++; @@ -8733,12 +9167,6 @@ static int llama_apply_lora_from_file_internal( } } - // TODO: this should be in a destructor, it will leak on failure - lm_ggml_free(lora_ctx); - if (base_ctx) { - lm_ggml_free(base_ctx); - } - const int64_t t_lora_us = lm_ggml_time_us() - t_start_lora_us; LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); @@ -9903,6 +10331,7 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + assert(ctx->logits_valid.at(i)); return ctx->logits.data() + i*ctx->model.hparams.n_vocab; } diff --git a/cpp/llama.h b/cpp/llama.h index df8ff8c..d574f5d 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -39,6 +39,7 @@ #define LLAMA_MAX_RNG_STATE (64*1024) +#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -216,7 +217,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) - bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embedding; // embedding mode only bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU }; diff --git a/cpp/log.h b/cpp/log.h index 374404e..5b25135 100644 --- a/cpp/log.h +++ b/cpp/log.h @@ -61,13 +61,13 @@ // #define LOG_TARGET stderr // #include "log.h" // -// The log target can also be redirected to a diffrent function +// The log target can also be redirected to a different function // like so: // -// #define LOG_TARGET log_handler_diffrent() +// #define LOG_TARGET log_handler_different() // #include "log.h" // -// FILE* log_handler_diffrent() +// FILE* log_handler_different() // { // return stderr; // } @@ -434,7 +434,7 @@ inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriS // Disables logs entirely at runtime. // Makes LOG() and LOG_TEE() produce no output, -// untill enabled back. +// until enabled back. #define log_disable() log_disable_impl() // INTERNAL, DO NOT USE diff --git a/llama.cpp b/llama.cpp index 8a7b2fa..a7aee47 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 8a7b2fa528f130631a5f43648481596ab320ed5a +Subproject commit a7aee47b98e45539d491071b25778b833b77e387 diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 8d2a965..15a8467 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -70,6 +70,7 @@ echo "Replacement completed successfully!" yarn example # Apply patch +patch -p0 -d ./cpp < ./scripts/common.h.patch patch -p0 -d ./cpp < ./scripts/common.cpp.patch patch -p0 -d ./cpp < ./scripts/log.h.patch patch -p0 -d ./cpp < ./scripts/llama.cpp.patch diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index b1bfd37..168fb35 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,11 +1,15 @@ ---- common.cpp.orig 2023-12-12 10:50:18 -+++ common.cpp 2023-12-12 10:50:19 -@@ -1385,8 +1385,6 @@ - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; +--- common.cpp.orig 2023-12-19 08:18:55 ++++ common.cpp 2023-12-19 08:18:26 +@@ -41,6 +41,12 @@ + #if defined(_MSC_VER) + #pragma warning(disable: 4244 4267) // possible loss of data + #endif ++ ++// build info ++int LLAMA_BUILD_NUMBER = 0; ++char const *LLAMA_COMMIT = "unknown"; ++char const *LLAMA_COMPILER = "unknown"; ++char const *LLAMA_BUILD_TARGET = "unknown"; -- fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); -- fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false"); + int32_t get_num_physical_cores() { + #ifdef __linux__ diff --git a/scripts/common.h.patch b/scripts/common.h.patch new file mode 100644 index 0000000..013695f --- /dev/null +++ b/scripts/common.h.patch @@ -0,0 +1,20 @@ +--- common.h.orig 2023-12-19 08:16:26 ++++ common.h 2023-12-19 08:17:16 +@@ -26,17 +26,6 @@ + #define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) + #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) + +-#define print_build_info() do { \ +- fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ +- fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ +-} while(0) +- +-// build info +-extern int LLAMA_BUILD_NUMBER; +-extern char const *LLAMA_COMMIT; +-extern char const *LLAMA_COMPILER; +-extern char const *LLAMA_BUILD_TARGET; +- + // + // CLI argument parsing + // diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 368b05e..caeaa8a 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,6 +1,6 @@ ---- ggml-metal.m.orig 2023-12-12 10:46:04 -+++ ggml-metal.m 2023-12-12 10:46:43 -@@ -241,7 +241,7 @@ +--- ggml-metal.m.orig 2023-12-19 07:48:34 ++++ ggml-metal.m 2023-12-19 07:48:35 +@@ -265,7 +265,7 @@ if (ggmlMetalPathResources) { sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; } else { diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch index f88e643..36d3c5d 100644 --- a/scripts/llama.cpp.patch +++ b/scripts/llama.cpp.patch @@ -1,6 +1,6 @@ ---- llama.cpp.orig 2023-12-12 10:46:04 -+++ llama.cpp 2023-12-12 10:46:05 -@@ -105,6 +105,17 @@ +--- llama.cpp.orig 2023-12-19 07:48:34 ++++ llama.cpp 2023-12-19 07:48:35 +@@ -106,6 +106,17 @@ #define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) @@ -18,7 +18,7 @@ // // helpers // -@@ -863,16 +874,16 @@ +@@ -895,16 +906,16 @@ if (prefetch > 0) { // Advise the kernel to preload the mapped memory