From c717198ffa479a7aea147bc12470f64f7a30ec95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 12 Jun 2025 17:39:56 +0200 Subject: [PATCH 01/17] implement unary REGLU/GEGLU/SWIGLU cpu ops --- ggml/include/ggml.h | 15 ++ ggml/src/ggml-cpu/ggml-cpu.c | 3 + ggml/src/ggml-cpu/ops.cpp | 333 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/vec.cpp | 24 +++ ggml/src/ggml-cpu/vec.h | 54 ++++++ ggml/src/ggml.c | 56 +++++- src/llama-graph.cpp | 29 +-- src/llama-graph.h | 1 + 8 files changed, 493 insertions(+), 22 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1a57f1cd75a31..efbf7a84e91fe 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -537,6 +537,9 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_REGLU, + GGML_UNARY_OP_GEGLU, + GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_COUNT, }; @@ -1085,6 +1088,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ff28bf98bc7df..6a9a97e738902 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2209,6 +2209,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_REGLU: + case GGML_UNARY_OP_GEGLU: + case GGML_UNARY_OP_SWIGLU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 08facb6d03d5e..f7cee3e04c766 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3194,6 +3194,327 @@ void ggml_compute_forward_silu_back( } } +// ggml_compute_forward_reglu + +static void ggml_compute_forward_reglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_reglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_reglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_reglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_reglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_reglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_reglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu + +static void ggml_compute_forward_geglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_geglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_geglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_swiglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -7920,6 +8241,18 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_REGLU: + { + ggml_compute_forward_reglu(params, dst); + } break; + case GGML_UNARY_OP_GEGLU: + { + ggml_compute_forward_geglu(params, dst); + } break; + case GGML_UNARY_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index f7614568ea388..bfb2d5d361512 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * x[i + n]; + } +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 09dbade2179fb..48d13a60d0563 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f; + } +} + +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f); + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i] * x[i + n]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n]; + } + } +} +#else +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]) * x[i + n]; + } +} +#endif + +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float g = GGML_FP16_TO_FP32(x[i + n]); + y[i] = ggml_table_gelu_f16[i16[i]] * g; + } +} + +void ggml_vec_swiglu_f32(const int n, float * y, const float * x); + +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + float g = GGML_FP16_TO_FP32(x[i + n]); + y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 196b7b8f3e2ae..d79e33d0e0afd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1107,9 +1107,12 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "REGLU", + "GEGLU", + "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); +static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -2616,6 +2619,57 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_reglu + +struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + +// ggml_geglu + +struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53b05a..75420f277d92c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -632,32 +632,19 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_SWIGLU: { - // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_silu(ctx0, x0); - cb(cur, "ffn_silu", il); - - cur = ggml_mul(ctx0, x0, x1); - cb(cur, "ffn_mul", il); + cur = ggml_swiglu(ctx0, cur); + cb(cur, "ffn_swiglu", il); } break; case LLM_FFN_GEGLU: { - // Split into two equal parts - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_gelu(ctx0, x0); - cb(x0, "ffn_gelu", il); - - cur = ggml_mul(ctx0, x0, x1); + cur = ggml_geglu(ctx0, cur); cb(cur, "ffn_geglu", il); } break; + case LLM_FFN_REGLU: + { + cur = ggml_reglu(ctx0, cur); + cb(cur, "ffn_reglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 88fb77f1ddc9a..e8725a917610b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -37,6 +37,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, LLM_FFN_GEGLU, + LLM_FFN_REGLU, }; enum llm_ffn_gate_type { From 92943e7e5a50cc556a57c2ce92429c0221f9267b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 12 Jun 2025 23:05:51 +0200 Subject: [PATCH 02/17] relax constraints --- ggml/src/ggml-cpu/ops.cpp | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f7cee3e04c766..af60f2d58dd62 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3208,11 +3208,11 @@ static void ggml_compute_forward_reglu_f32( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3249,11 +3249,11 @@ static void ggml_compute_forward_reglu_f16( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3315,11 +3315,11 @@ static void ggml_compute_forward_geglu_f32( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3356,11 +3356,11 @@ static void ggml_compute_forward_geglu_f16( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3422,11 +3422,11 @@ static void ggml_compute_forward_swiglu_f32( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3463,11 +3463,11 @@ static void ggml_compute_forward_swiglu_f16( const int ith = params->ith; const int nth = params->nth; - const int nc = dst->ne[0]; + const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(src0->ne[0] / 2 == nc); - GGML_ASSERT(ggml_nrows(dst) == nr); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= nr); // rows per thread const int dr = (nr + nth - 1)/nth; From 319c6cbe150c9639e94999c1c9be20491ce23ef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 00:51:53 +0200 Subject: [PATCH 03/17] duplicate shape of source --- ggml/src/ggml.c | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d79e33d0e0afd..db83cce93ba68 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2626,7 +2626,8 @@ struct ggml_tensor * ggml_reglu( struct ggml_tensor * a) { GGML_ASSERT(ggml_is_contiguous_1(a)); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); @@ -2643,7 +2644,8 @@ struct ggml_tensor * ggml_geglu( struct ggml_tensor * a) { GGML_ASSERT(ggml_is_contiguous_1(a)); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU); @@ -2660,7 +2662,8 @@ struct ggml_tensor * ggml_swiglu( struct ggml_tensor * a) { GGML_ASSERT(ggml_is_contiguous_1(a)); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); From 6fe7e070a7290a2bf329faa146c0e470687b3955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 01:04:59 +0200 Subject: [PATCH 04/17] fix ggml_vec_geglu_f16 --- ggml/src/ggml-cpu/vec.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 48d13a60d0563..178629e994216 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -945,7 +945,7 @@ inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_f const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { float g = GGML_FP16_TO_FP32(x[i + n]); - y[i] = ggml_table_gelu_f16[i16[i]] * g; + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g); } } From 7e075bea610aa32d3a4aa5a3c687e4a5d685c7e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 01:07:49 +0200 Subject: [PATCH 05/17] special case gated ops --- tests/test-backend-ops.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 509a4b35f57cb..1b3d4440677ab 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1072,7 +1072,16 @@ struct test_unary : public test_case { ggml_set_name(a, "a"); } - ggml_tensor * out = ggml_unary(ctx, a, op); + ggml_tensor * out; + if (op == GGML_UNARY_OP_REGLU) { + out = ggml_reglu(ctx, a); + } else if (op == GGML_UNARY_OP_GEGLU) { + out = ggml_geglu(ctx, a); + } else if (op == GGML_UNARY_OP_SWIGLU) { + out = ggml_swiglu(ctx, a); + } else { + out = ggml_unary(ctx, a, op); + } ggml_set_name(out, "out"); return out; From 1acd12111de3d7d6475cbb23644520cde87c0d0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 01:11:57 +0200 Subject: [PATCH 06/17] implement unary REGLU/GEGLU/SWIGLU cuda ops --- ggml/src/ggml-cuda/ggml-cuda.cu | 13 ++++++++ ggml/src/ggml-cuda/unary.cu | 56 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 6 ++++ 3 files changed, 75 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0bd2904e1c9d1..8187ddb11101c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2216,6 +2216,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_REGLU: + ggml_cuda_op_reglu(ctx, dst); + break; + case GGML_UNARY_OP_GEGLU: + ggml_cuda_op_geglu(ctx, dst); + break; + case GGML_UNARY_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; default: return false; } @@ -2987,6 +2996,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_REGLU: + case GGML_UNARY_OP_GEGLU: + case GGML_UNARY_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]); default: return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2c0375fbe3cf6..c98564a31f6a7 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -196,6 +196,62 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +/* gated ops */ + +template +static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, const int n, const int o) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op on first half of row and multiply with gate in second half + const int j = (i / n) * o + (i % n); + dst[i] = (T)(op((float)x[j]) * (float)x[j + n]); +} + +template +static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) { + const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; + unary_gated_op_kernel<<>>(x, dst, k, n, o); +} + +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + const int nc = src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0)); + + if (src0->type == GGML_TYPE_F16) { + unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream); + } else { + unary_gated_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream); + } +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6686fc17e9193..d4533d24e25bc 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -57,3 +57,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 5c5819630b0f49df453936b20cdf57206779c168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 09:00:30 +0200 Subject: [PATCH 07/17] tighten constraints again --- ggml/src/ggml-cpu/ops.cpp | 24 ++++++++++++------------ ggml/src/ggml-cuda/unary.cu | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index af60f2d58dd62..53096cccc5c55 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3211,8 +3211,8 @@ static void ggml_compute_forward_reglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3252,8 +3252,8 @@ static void ggml_compute_forward_reglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3318,8 +3318,8 @@ static void ggml_compute_forward_geglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3359,8 +3359,8 @@ static void ggml_compute_forward_geglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3425,8 +3425,8 @@ static void ggml_compute_forward_swiglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3466,8 +3466,8 @@ static void ggml_compute_forward_swiglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c98564a31f6a7..77ef8154578b5 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -230,8 +230,8 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0)); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); if (src0->type == GGML_TYPE_F16) { unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream); From f4be71e3c049f822b3030c458b2462055e439393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 10:14:32 +0200 Subject: [PATCH 08/17] refactor into GGML_GLU_OP --- ggml/include/ggml.h | 23 +++++++++-- ggml/src/ggml-cpu/ggml-cpu.c | 19 +++++++-- ggml/src/ggml-cpu/ops.cpp | 22 ++++++++-- ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 24 ++++++++--- ggml/src/ggml.c | 72 +++++++++++++++++++-------------- tests/test-backend-ops.cpp | 66 +++++++++++++++++++++++++----- 7 files changed, 172 insertions(+), 55 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index efbf7a84e91fe..7e30cb931fd84 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -518,6 +518,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_GLU, + GGML_OP_COUNT, }; @@ -537,13 +539,18 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, - GGML_UNARY_OP_REGLU, - GGML_UNARY_OP_GEGLU, - GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_COUNT, }; + enum ggml_glu_op { + GGML_GLU_OP_REGLU, + GGML_GLU_OP_GEGLU, + GGML_GLU_OP_SWIGLU, + + GGML_GLU_OP_COUNT, + }; + enum ggml_object_type { GGML_OBJECT_TYPE_TENSOR, GGML_OBJECT_TYPE_GRAPH, @@ -659,6 +666,7 @@ extern "C" { GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -760,6 +768,7 @@ extern "C" { GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); @@ -1088,6 +1097,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // gated linear unit ops + // A: n columns, r rows, + // result is n / 2 columns, r rows, + GGML_API struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op); + GGML_API struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6a9a97e738902..e8de7fb2895d7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2006,6 +2006,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; + case GGML_OP_GLU: + { + ggml_compute_forward_glu(params, tensor); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2209,9 +2213,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_REGLU: - case GGML_UNARY_OP_GEGLU: - case GGML_UNARY_OP_SWIGLU: + { + n_tasks = n_threads; + } break; + default: + GGML_ABORT("fatal error"); + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 53096cccc5c55..e9a1d82e38deb 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8241,15 +8241,31 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; - case GGML_UNARY_OP_REGLU: + default: + { + GGML_ABORT("fatal error"); + } + } +} + +//ggml_compute_forward_glu + +void ggml_compute_forward_glu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_glu_op op = ggml_get_glu_op(dst); + + switch (op) { + case GGML_GLU_OP_REGLU: { ggml_compute_forward_reglu(params, dst); } break; - case GGML_UNARY_OP_GEGLU: + case GGML_GLU_OP_GEGLU: { ggml_compute_forward_geglu(params, dst); } break; - case GGML_UNARY_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU: { ggml_compute_forward_swiglu(params, dst); } break; diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index dc081b9e66397..bd75bc43cd4cb 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -92,6 +92,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8187ddb11101c..8c3baf0c31a9d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2216,13 +2216,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; - case GGML_UNARY_OP_REGLU: + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: ggml_cuda_op_reglu(ctx, dst); break; - case GGML_UNARY_OP_GEGLU: + case GGML_GLU_OP_GEGLU: ggml_cuda_op_geglu(ctx, dst); break; - case GGML_UNARY_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; default: @@ -2996,9 +3002,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); - case GGML_UNARY_OP_REGLU: - case GGML_UNARY_OP_GEGLU: - case GGML_UNARY_OP_SWIGLU: + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: return ggml_is_contiguous_1(op->src[0]); default: return false; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index db83cce93ba68..83f5108457699 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -989,9 +989,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + + "GLU", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1084,9 +1086,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + + "glu(x)", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1107,12 +1111,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", +}; + +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); + + +static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); +static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -1217,11 +1227,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) { return GGML_UNARY_OP_NAME[op]; } +const char * ggml_glu_op_name(enum ggml_glu_op op) { + return GGML_GLU_OP_NAME[op]; +} + const char * ggml_op_desc(const struct ggml_tensor * t) { if (t->op == GGML_OP_UNARY) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); } + if (t->op == GGML_OP_GLU) { + enum ggml_glu_op gop = ggml_get_glu_op(t); + return ggml_glu_op_name(gop); + } return ggml_op_name(t->op); } @@ -1740,6 +1758,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); } +enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_GLU); + return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0); +} + const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } @@ -2619,40 +2642,39 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } -// ggml_reglu +// ggml_glu -struct ggml_tensor * ggml_reglu( +struct ggml_tensor * ggml_glu( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a, + enum ggml_glu_op op) { GGML_ASSERT(ggml_is_contiguous_1(a)); int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); + ggml_set_op_params_i32(result, 0, (int32_t) op); - result->op = GGML_OP_UNARY; + result->op = GGML_OP_GLU; result->src[0] = a; return result; } -// ggml_geglu +// ggml_reglu -struct ggml_tensor * ggml_geglu( +struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a) { - GGML_ASSERT(ggml_is_contiguous_1(a)); - - int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU); +} - result->op = GGML_OP_UNARY; - result->src[0] = a; +// ggml_geglu - return result; +struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU); } // ggml_swiglu @@ -2660,17 +2682,7 @@ struct ggml_tensor * ggml_geglu( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - GGML_ASSERT(ggml_is_contiguous_1(a)); - - int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); - - result->op = GGML_OP_UNARY; - result->src[0] = a; - - return result; + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU); } // ggml_norm diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1b3d4440677ab..27093875f8cb9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1072,16 +1072,7 @@ struct test_unary : public test_case { ggml_set_name(a, "a"); } - ggml_tensor * out; - if (op == GGML_UNARY_OP_REGLU) { - out = ggml_reglu(ctx, a); - } else if (op == GGML_UNARY_OP_GEGLU) { - out = ggml_geglu(ctx, a); - } else if (op == GGML_UNARY_OP_SWIGLU) { - out = ggml_swiglu(ctx, a); - } else { - out = ggml_unary(ctx, a, op); - } + ggml_tensor * out = ggml_unary(ctx, a, op); ggml_set_name(out, "out"); return out; @@ -1113,6 +1104,51 @@ struct test_unary : public test_case { }; +// GGML_OP_GLU +struct test_glu : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, v); + } + + test_glu(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + } + + ggml_tensor * out = ggml_glu(ctx, a, op); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -3940,6 +3976,16 @@ static std::vector> make_test_cases_eval() { } } + // glu ops + for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { + for (int v : {0, 1}) { + for (int op = 0; op < GGML_GLU_OP_COUNT; op++) { + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); + } + } + } + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); for (ggml_type type : all_types) { for (int b : {1, 7}) { From 564861d45069d2080d309f8d9b902435840be70d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Jun 2025 16:12:25 +0300 Subject: [PATCH 09/17] metal : add glu kernels ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 6 +++ ggml/src/ggml-metal/ggml-metal.m | 53 +++++++++++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 58 +++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f3ad1..ec9069c52a27c 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -422,6 +422,12 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct{ + int32_t ne00; + uint64_t nb01; + uint64_t nb1; +} ggml_metal_kargs_glu; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index bc93bc633a49b..186a32576d388 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -497,6 +497,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_REGLU, + GGML_METAL_KERNEL_TYPE_GEGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, @@ -1453,6 +1456,9 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); @@ -1626,6 +1632,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -2343,6 +2358,43 @@ static bool ggml_metal_encode_node( GGML_ABORT("fatal error"); } } break; + case GGML_OP_GLU: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + id pipeline = nil; + + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; + break; + case GGML_GLU_OP_GEGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; + break; + case GGML_GLU_OP_SWIGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + break; + default: + GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb1 =*/ nb1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int64_t nrows = ggml_nrows(src0); + + const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -2405,7 +2457,6 @@ static bool ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5d7760217f826..973c142d216e1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -993,6 +993,64 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_reglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + dst_row[i00] = x0*x1*(x0 > 0.0f); + } +} + +kernel void kernel_geglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i00] = gelu*x1; + } +} + +kernel void kernel_swiglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + const float silu = x0 / (1.0f + exp(-x0)); + + dst_row[i00] = silu*x1; + } +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, From 4b7d4dd23d2dcc2e938d3c055ba8825925e9dee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 16:10:03 +0200 Subject: [PATCH 10/17] add CUDA_GLU_BLOCK_SIZE [no ci] --- ggml/src/ggml-cuda/unary.cu | 4 ++-- ggml/src/ggml-cuda/unary.cuh | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 77ef8154578b5..bb048ba4bfcf6 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -213,8 +213,8 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, template static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) { - const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, dst, k, n, o); + const int num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + unary_gated_op_kernel<<>>(x, dst, k, n, o); } template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index d4533d24e25bc..9094f1d0bad37 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -15,6 +15,7 @@ #define CUDA_SQRT_BLOCK_SIZE 256 #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 +#define CUDA_GLU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From d1d3f4f2aa1f06d35ec6734884e78b5be73e6431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 16:34:23 +0200 Subject: [PATCH 11/17] more constraints and use 64bit ints ggml-ci --- ggml/src/ggml-cuda/unary.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index bb048ba4bfcf6..8dd70bc7a43fc 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -199,21 +199,21 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { /* gated ops */ template -static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, const int n, const int o) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } // perform base op on first half of row and multiply with gate in second half - const int j = (i / n) * o + (i % n); + const int64_t j = (i / n) * o + (i % n); dst[i] = (T)(op((float)x[j]) * (float)x[j + n]); } template -static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) { - const int num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; +static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; unary_gated_op_kernel<<>>(x, dst, k, n, o); } @@ -222,10 +222,12 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; - const int nc = src0->ne[0] / 2; + const int64_t nc = src0->ne[0] / 2; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); From e3d2b20a8430554d8e3d923514f06484b857bf3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 17:11:01 +0200 Subject: [PATCH 12/17] 64bit multiplication [no ci] --- ggml/src/ggml-cuda/unary.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8dd70bc7a43fc..31177a099b1c0 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -200,7 +200,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { return; From 39eba35c22c4e15f7e8babcb440e67be1e2f5b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 22:48:53 +0200 Subject: [PATCH 13/17] implement swapped variants (cpu/cuda) --- ggml/include/ggml.h | 16 +++++++++++++++- ggml/src/ggml-cpu/ops.cpp | 30 ++++++++++++++++++++++++------ ggml/src/ggml-cpu/vec.cpp | 12 ++++++------ ggml/src/ggml-cpu/vec.h | 32 ++++++++++++++++---------------- ggml/src/ggml-cuda/unary.cu | 28 ++++++++++++++++++++++------ ggml/src/ggml.c | 28 ++++++++++++++++++++++++---- tests/test-backend-ops.cpp | 16 ++++++++++------ 7 files changed, 117 insertions(+), 45 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e30cb931fd84..40ff1c187a831 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1100,23 +1100,37 @@ extern "C" { // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, + // expects gate in second half of row, unless swapped is true GGML_API struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_glu_op op); + enum ggml_glu_op op, + bool swapped); GGML_API struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index e9a1d82e38deb..8c88bf2e7b880 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index bfb2d5d361512..1956f78e4e743 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -254,27 +254,27 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } -void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) for (; i + 15 < n; i += 16) { - _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i))); } #elif defined(__AVX2__) && defined(__FMA__) for (; i + 7 < n; i += 8) { - _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i))); } #elif defined(__SSE2__) for (; i + 3 < n; i += 4) { - _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { - vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); } #endif for (; i < n; ++i) { - y[i] = ggml_silu_f32(x[i]) * x[i + n]; + y[i] = ggml_silu_f32(x[i]) * g[i]; } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 178629e994216..f9113a0b17953 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -905,57 +905,57 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } -inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) { +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) { for (int i = 0; i < n; ++i) { - y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f; + y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f; } } -inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { for (int i = 0; i < n; ++i) { float v = GGML_FP16_TO_FP32(x[i]); - y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f); } } #ifdef GGML_GELU_FP16 -inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { uint16_t t; for (int i = 0; i < n; ++i) { if (x[i] <= -10.0f) { y[i] = 0.0f; } else if (x[i] >= 10.0f) { - y[i] = x[i] * x[i + n]; + y[i] = x[i] * g[i]; } else { ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n]; + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i]; } } } #else -inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]) * x[i + n]; + y[i] = ggml_gelu_f32(x[i]) * g[i]; } } #endif -inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { - float g = GGML_FP16_TO_FP32(x[i + n]); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g); + float v = GGML_FP16_TO_FP32(g[i]); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v); } } -void ggml_vec_swiglu_f32(const int n, float * y, const float * x); +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g); -inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { for (int i = 0; i < n; ++i) { float v = GGML_FP16_TO_FP32(x[i]); - float g = GGML_FP16_TO_FP32(x[i + n]); - y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g); + float w = GGML_FP16_TO_FP32(g[i]); + y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w); } } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 31177a099b1c0..caab84d525dd7 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -199,7 +199,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { /* gated ops */ template -static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) { +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -208,13 +208,13 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t // perform base op on first half of row and multiply with gate in second half const int64_t j = (i / n) * o + (i % n); - dst[i] = (T)(op((float)x[j]) * (float)x[j + n]); + dst[i] = (T)(op((float)x[j]) * (float)g[j]); } template -static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, dst, k, n, o); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o); } template @@ -235,10 +235,26 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + if (src0->type == GGML_TYPE_F16) { - unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream); + unary_gated_cuda( + (const half *)src0_d + (swapped ? nc : 0), + (const half *)src0_d + (swapped ? 0 : nc), + (half *)dst_d, + ggml_nelements(dst), + nc, + src0->nb[1] / sizeof(half), + stream); } else { - unary_gated_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream); + unary_gated_cuda( + (const float *)src0_d + (swapped ? nc : 0), + (const float *)src0_d + (swapped ? 0 : nc), + (float *)dst_d, + ggml_nelements(dst), + nc, + src0->nb[1] / sizeof(float), + stream); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 83f5108457699..2ae4e511b543b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2647,13 +2647,15 @@ struct ggml_tensor * ggml_exp_inplace( struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_glu_op op) { + enum ggml_glu_op op, + bool swapped) { GGML_ASSERT(ggml_is_contiguous_1(a)); int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) op); + ggml_set_op_params_i32(result, 1, (int32_t) swapped); result->op = GGML_OP_GLU; result->src[0] = a; @@ -2666,7 +2668,13 @@ struct ggml_tensor * ggml_glu( struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false); +} + +struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true); } // ggml_geglu @@ -2674,7 +2682,13 @@ struct ggml_tensor * ggml_reglu( struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false); +} + +struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true); } // ggml_swiglu @@ -2682,7 +2696,13 @@ struct ggml_tensor * ggml_geglu( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false); +} + +struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true); } // ggml_norm diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 27093875f8cb9..0278df1a98d66 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1110,16 +1110,18 @@ struct test_glu : public test_case { const ggml_type type; const std::array ne_a; int v; // view (1 : non-contiguous a) + bool swapped; std::string vars() override { - return VARS_TO_STR3(type, ne_a, v); + return VARS_TO_STR4(type, ne_a, v, swapped); } test_glu(ggml_glu_op op, ggml_type type = GGML_TYPE_F32, std::array ne_a = {128, 2, 2, 2}, - int v = 0) - : op(op), type(type), ne_a(ne_a), v(v) {} + int v = 0, + bool swapped = false) + : op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; @@ -1135,7 +1137,7 @@ struct test_glu : public test_case { ggml_set_name(a, "a"); } - ggml_tensor * out = ggml_glu(ctx, a, op); + ggml_tensor * out = ggml_glu(ctx, a, op, swapped); ggml_set_name(out, "out"); return out; @@ -3980,8 +3982,10 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) { for (int op = 0; op < GGML_GLU_OP_COUNT; op++) { - test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); - test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); + for (bool swapped : {false, true}) { + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped)); + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped)); + } } } } From 98a50196c1688af10f698eae66fa52f1eac52249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 13 Jun 2025 23:08:18 +0200 Subject: [PATCH 14/17] update comment [no ci] ggml-ci --- ggml/src/ggml-cuda/unary.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index caab84d525dd7..c991c1d700174 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -206,7 +206,7 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, return; } - // perform base op on first half of row and multiply with gate in second half + // perform base op on half of the row and multiply with gate in other half const int64_t j = (i / n) * o + (i % n); dst[i] = (T)(op((float)x[j]) * (float)g[j]); } From 8dc1d9f8adfe48c53a7e0d47664277f51e68597c Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 14 Jun 2025 10:06:55 +0000 Subject: [PATCH 15/17] Vulkan: Add GLU ops and shaders --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 93 ++++++++++++++++++- .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 43 +++++++++ .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 36 +++++++ .../ggml-vulkan/vulkan-shaders/swiglu.comp | 38 ++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++ 5 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 32d6407441535..ee59f3a59957e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -431,6 +431,10 @@ struct vk_device_struct { vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; @@ -2728,6 +2732,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(sigmoid) #undef CREATE_UNARY +#define CREATE_GLU(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) +#undef CREATE_GLU + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -6415,6 +6428,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; case GGML_OP_DIAG_MASK_INF: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_diag_mask_inf_f32; @@ -6791,6 +6822,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: case GGML_OP_ARGMAX: + case GGML_OP_GLU: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -7507,6 +7539,14 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + + const uint32_t swapped = (uint32_t)dst->op_params[1]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); @@ -8718,6 +8758,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + break; + default: + return false; + } + break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: @@ -8810,6 +8860,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -8947,6 +8998,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); @@ -9072,8 +9134,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } - else { + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; } } @@ -9152,6 +9215,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + buf = tensor->buffer; + break; + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_FLASH_ATTN_EXT: @@ -9923,6 +9997,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -10637,6 +10724,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } + } else if (tensor->op == GGML_OP_GLU) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 0000000000000..e58ac59d9a860 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + const uint offset = p.KX / 2; + + const bool swapped = p.KY > 0; + + if (!swapped) { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + const float xi = float(data_a[idx]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset])); + } + } else { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + const float xi = float(data_a[idx + offset]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx])); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 0000000000000..034481a1f17ef --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,36 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + const uint offset = p.KX / 2; + + const bool swapped = p.KY > 0; + + if (!swapped) { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset])); + } + } else { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx])); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 0000000000000..e75c1d38aa1ea --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,38 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + const uint offset = p.KX / 2; + + const bool swapped = p.KY > 0; + + if (!swapped) { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + const float xi = float(data_a[idx]); + data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset])); + } + } else { + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.KX + i; + + const float xi = float(data_a[idx + offset]); + data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx])); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c63345ec8b4b6..259b647317332 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -585,6 +585,13 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); From 95e4be03c05271b8ea2cd270bbb281779a41afb2 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Sat, 14 Jun 2025 18:34:21 +0530 Subject: [PATCH 16/17] SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate --- ggml/src/ggml-sycl/element_wise.cpp | 221 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/element_wise.hpp | 8 + ggml/src/ggml-sycl/ggml-sycl.cpp | 25 ++++ 3 files changed, 254 insertions(+) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 5b7c4f0b4f003..7e6b48db7002b 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -1,6 +1,9 @@ #include "common.hpp" +#include "ggml-sycl/presets.hpp" #include "ggml.h" #include "element_wise.hpp" +#include +#include static void acc_f32(const float * x, const float * y, float * dst, const int ne, const int ne10, const int ne11, const int ne12, @@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } +// Fused GLU kernels +template +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { + const int64_t j = ((i / n) * o) + (i % n); + const T x_val = x[j]; + const T gelu_val = x_val * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val))); + + dst[i] = gelu_val * g[j]; + } +} + +template +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { + const int64_t j = ((i / n) * o) + (i % n); + dst[i] = sycl::max((x[j]), static_cast(0)) * g[j]; + } +} + +template +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { + const int64_t j = ((i / n) * o) + (i % n); + dst[i] = (x[j] / (static_cast(1) + sycl::native::exp(-x[j]))) * g[j]; + } +} + static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, const int ne12, const int nb1, const int nb2, @@ -649,6 +680,33 @@ static void clamp_sycl(const T *x, T *dst, const float min, }); } +template +static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1); + }); +} + +template +static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE); + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1); + }); +} + +template +static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE); + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1); + }); +} + inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -1444,6 +1502,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } +inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const int64_t nc = dst->src[0]->ne[0] / 2; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const void * src0_d = dst->src[0]->data; + void * dst_d = dst->data; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), + (const sycl::half *)src0_d + (swapped ? 0 : nc), + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(sycl::half), + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + geglu_sycl((const float *) src0_d + (swapped ? nc : 0), + (const float *)src0_d + (swapped ? 0 : nc), + (float *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(float), + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const int64_t nc = dst->src[0]->ne[0] / 2; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const void * src0_d = dst->src[0]->data; + void * dst_d = dst->data; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), + (const sycl::half *)src0_d + (swapped ? 0 : nc), + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(sycl::half), + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + reglu_sycl((const float *) src0_d + (swapped ? nc : 0), + (const float *)src0_d + (swapped ? 0 : nc), + (float *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(float), + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const int64_t nc = dst->src[0]->ne[0] / 2; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const void * src0_d = dst->src[0]->data; + void * dst_d = dst->data; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), + (const sycl::half *)src0_d + (swapped ? 0 : nc), + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(sycl::half), + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + swiglu_sycl((const float *) src0_d + (swapped ? nc : 0), + (const float *)src0_d + (swapped ? 0 : nc), + (float *) dst_d, + ggml_nelements(dst), + nc, + dst->src[0]->nb[1] / sizeof(float), + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); @@ -1569,3 +1773,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_elu(ctx, dst); } + +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu(ctx, dst); +} + +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_reglu(ctx, dst); +} + +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_swiglu(ctx, dst); +} + + diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index bd40113f09705..f530c9c1e1bdd 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -24,6 +24,9 @@ typed_data cast_data(ggml_tensor * dst) { }; } +const float GELU_QUICK_COEF = -1.702f; + + void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -73,5 +76,10 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index feb30304fc092..8f215f03b0f8a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3685,6 +3685,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + ggml_sycl_reglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU: + ggml_sycl_geglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU: + ggml_sycl_swiglu(ctx, dst); + break; + default: + return false; + } + break; case GGML_OP_NORM: ggml_sycl_norm(ctx, dst); break; @@ -4221,6 +4236,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { From c2af58b06fa446c11ac3e1809482766c0d541ee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 18 Jun 2025 16:11:07 +0200 Subject: [PATCH 17/17] ggml : implement GLU for split up/gate (#14181) * implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m Co-authored-by: Akarshan --- ggml/include/ggml.h | 23 + ggml/src/ggml-cpu/ops.cpp | 150 +- ggml/src/ggml-cuda/unary.cu | 63 +- ggml/src/ggml-sycl/element_wise.cpp | 1739 +++++------------ ggml/src/ggml-sycl/element_wise.hpp | 17 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 38 +- .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 46 +- .../ggml-vulkan/vulkan-shaders/glu_head.comp | 15 + .../ggml-vulkan/vulkan-shaders/glu_main.comp | 31 + .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 37 +- .../ggml-vulkan/vulkan-shaders/swiglu.comp | 39 +- ggml/src/ggml.c | 61 +- src/llama-graph.cpp | 33 +- tests/test-backend-ops.cpp | 57 + 14 files changed, 912 insertions(+), 1437 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 40ff1c187a831..3991d974f4fab 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1131,6 +1131,29 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // A: n columns, r rows, + // B: n columns, r rows, + GGML_API struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op); + + GGML_API struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8c88bf2e7b880..5543addcbdc00 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_reglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_reglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_geglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_geglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_swiglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_swiglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c991c1d700174..ba3c0f13762b0 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -199,30 +199,36 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { /* gated ops */ template -static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) { +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { return; } - // perform base op on half of the row and multiply with gate in other half - const int64_t j = (i / n) * o + (i % n); - dst[i] = (T)(op((float)x[j]) * (float)g[j]); + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template -static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); } template void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const void * src0_d = src0->data; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; void * dst_d = dst->data; - const int64_t nc = src0->ne[0] / 2; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous_1(src0)); @@ -235,26 +241,35 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; if (src0->type == GGML_TYPE_F16) { - unary_gated_cuda( - (const half *)src0_d + (swapped ? nc : 0), - (const half *)src0_d + (swapped ? 0 : nc), - (half *)dst_d, - ggml_nelements(dst), - nc, - src0->nb[1] / sizeof(half), - stream); + half * src0_p = (half *) src0_d; + half * src1_p = (half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); } else { - unary_gated_cuda( - (const float *)src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *)dst_d, - ggml_nelements(dst), - nc, - src0->nb[1] / sizeof(float), - stream); + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); } } diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 7e6b48db7002b..828cea1aa0086 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -2,14 +2,20 @@ #include "ggml-sycl/presets.hpp" #include "ggml.h" #include "element_wise.hpp" -#include -#include + +// --- Helper Macros for Kernel Indexing --- +#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \ + for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0)) + +#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \ + (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX)) + +// --- Original Kernels (non-_sycl) - Modified to use indexing macros and cast literals --- static void acc_f32(const float * x, const float * y, float * dst, const int ne, const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) { + const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0); if (i >= ne) { return; } @@ -25,72 +31,59 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } template -static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = x[i] > static_cast(0.f) ? static_cast(1.f) : ((x[i] < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } } template -static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = sycl::fabs(x[i]); } } template -static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = (x[i] > static_cast(0.f)) ? x[i] : sycl::expm1(x[i]); } } template static void gelu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { + const sycl::nd_item<1> &item_ct1) { const T GELU_COEF_A = static_cast(0.044715f); const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = static_cast(0.5f) * x[i] * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * x[i] * (static_cast(1.0f) + GELU_COEF_A * x[i] * x[i]))); } - - float xi = x[i]; - dst[i] = static_cast(0.5f) * xi * - (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); } template static void silu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } - dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } template static void gelu_quick(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_QUICK_COEF = -1.702f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x[i]))); } - dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } template -static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { +static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { auto x_i = x[i]; dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); } @@ -98,174 +91,121 @@ static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> & template static void tanh(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::tanh((x[i])); } - dst[i] = sycl::tanh((x[i])); } template static void relu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::fmax((x[i]), static_cast(0)); } - dst[i] = sycl::fmax((x[i]), static_cast(0)); } template static void sigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x[i])); } - dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); } template static void sqrt(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::sqrt(x[i]); } - dst[i] = sycl::sqrt(x[i]); } template static void sin(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::sin(x[i]); } - dst[i] = sycl::sin(x[i]); } template static void cos(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::cos(x[i]); } - dst[i] = sycl::cos(x[i]); } template static void hardsigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } - dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } template static void hardswish(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } - dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } template static void exp(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::exp(x[i]); } - dst[i] = sycl::exp(x[i]); } template static void log(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - T xi = x[i]; - if (xi <= 0) { - dst[i] = neg_infinity(); - } else { - dst[i] = sycl::log(xi); + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + T xi = x[i]; + if (xi <= static_cast(0)) { + dst[i] = neg_infinity(); + } else { + dst[i] = sycl::log(xi); + } } } template static void neg(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = -x[i]; } - dst[i] = -x[i]; } template static void step(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = (x[i] > static_cast(0.0f)) ? static_cast(1.0f) : static_cast(0.0f); } - dst[i] = x[i] > static_cast(0.0f); } template static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + T neg_slope_T = static_cast(negative_slope); + dst[i] = sycl::fmax((x[i]), static_cast(0)) + + sycl::fmin((x[i]), static_cast(0.0f)) * neg_slope_T; } - dst[i] = sycl::fmax((x[i]), static_cast(0)) + - sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; } template static void sqr(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * x[i]; } - dst[i] = x[i] * x[i]; } template @@ -284,10 +224,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, int i12 = (index / (ne10 * ne11)) % ne12; int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - int i00 = i10 / sf0; - int i01 = i11 / sf1; - int i02 = i12 / sf2; - int i03 = i13 / sf3; + int i00 = static_cast(i10 / sf0); + int i01 = static_cast(i11 / sf1); + int i02 = static_cast(i12 / sf2); + int i03 = static_cast(i13 / sf3); dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } @@ -295,8 +235,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, template static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); + int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2); if (nidx >= ne0) { return; } @@ -313,337 +252,75 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne } } - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } - - dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } -// Fused GLU kernels template -static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - const T x_val = x[j]; - const T gelu_val = x_val * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val))); +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + const T GELU_COEF_A = static_cast(0.044715f); + const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + const T x_val = x[j0]; + + const T x_cubed_term = static_cast(1.0f) + GELU_COEF_A * x_val * x_val; + const T tanh_input = SQRT_2_OVER_PI * x_val * x_cubed_term; + const T gelu_val = static_cast(0.5f) * x_val * (static_cast(1.0f) + sycl::tanh(tanh_input)); - dst[i] = gelu_val * g[j]; + dst[i] = gelu_val * g[j1]; } } template -static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - dst[i] = sycl::max((x[j]), static_cast(0)) * g[j]; +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = sycl::max((x[j0]), static_cast(0)) * g[j1]; } } template -static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - dst[i] = (x[j] / (static_cast(1) + sycl::native::exp(-x[j]))) * g[j]; +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = (x[j0] / (static_cast(1) + sycl::native::exp(-x[j0]))) * g[j1]; } } +// --- Generic SYCL Kernel Launchers --- +namespace ggml_sycl_detail { +// acc_f32_sycl remains specific static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, const int ne12, const int nb1, const int nb2, const int offset, queue_ptr stream) { - int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + sycl::nd_range<1>(sycl::range<1>(num_blocks) * + sycl::range<1>(SYCL_ACC_BLOCK_SIZE), + sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1); }); } -template -static void gelu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu(x, dst, k, item_ct1); - }); -} - -template -static void silu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - silu(x, dst, k, item_ct1); - }); -} - -template -static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - sgn(x, dst, k, item_ct1); - }); -} - -template -static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - abs_op(x, dst, k, item_ct1); - }); -} - - -template -static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - elu_op(x, dst, k, item_ct1); - }); -} - -template -static void gelu_quick_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_quick(x, dst, k, item_ct1); - }); -} - - -template -static void gelu_erf_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_erf(x, dst, k, item_ct1); - }); -} - -template -static void tanh_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - tanh(x, dst, k, item_ct1); - }); -} - -template -static void relu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - relu(x, dst, k, item_ct1); - }); -} - -template -static void hardsigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardsigmoid(x, dst, k, item_ct1); - }); -} - -template -static void hardswish_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardswish(x, dst, k, item_ct1); - }); -} - -template -static void exp_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - exp(x, dst, k, item_ct1); - }); -} - -template -static void log_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - log(x, dst, k, item_ct1); - }); -} - -template -static void neg_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - neg(x, dst, k, item_ct1); - }); -} - -template -static void step_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - step(x, dst, k, item_ct1); - }); -} - -template -static void sigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sigmoid(x, dst, k, item_ct1); - }); -} - -template -static void sqrt_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqrt(x, dst, k, item_ct1); - }); -} - -template -static void sin_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sin(x, dst, k, item_ct1); - }); -} - -template -static void cos_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cos(x, dst, k, item_ct1); - }); -} - -template -static void leaky_relu_sycl(const T *x, T *dst, const int k, - const float negative_slope, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - leaky_relu(x, dst, k, negative_slope, item_ct1); - }); -} - -template -static void sqr_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqr(x, dst, k, item_ct1); - }); -} - +// upscale_sycl remains specific template static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, queue_ptr stream) { int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); stream->parallel_for( sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), @@ -652,11 +329,12 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, }); } +// pad_sycl remains specific template static void pad_sycl(const T *x, T *dst, const int ne00, const int ne01, const int ne02, const int ne0, const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); sycl::range<3> gridDim(ne2, ne1, num_blocks); stream->parallel_for( sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), @@ -666,52 +344,13 @@ static void pad_sycl(const T *x, T *dst, const int ne00, }); } -template -static void clamp_sycl(const T *x, T *dst, const float min, - const float max, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - clamp(x, dst, min, max, k, item_ct1); - }); -} - -template -static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1); - }); -} - -template -static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1); - }); -} - -template -static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1); - }); -} - -inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Common dispatcher for 1-input, 1-output element-wise ops, handling type switching. +// KernelInvoker is a lambda that takes (const T* src, T* dst, int k, queue_ptr stream, Args...) +template +inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -724,14 +363,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } default: @@ -739,11 +378,12 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } -inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for fused GLU ops, handling specific input pointer setup and type switching. +template +inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -751,19 +391,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } switch (dst->type) { #if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + sycl::half * src0_p = (sycl::half *) src0_d; + sycl::half * src1_p = (sycl::half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + kernel_invoker(src0_p, + src1_p, + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(sycl::half), + src1_o / sizeof(sycl::half), + main_stream, + std::forward(args)...); break; } #endif case GGML_TYPE_F32: { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + kernel_invoker(src0_p, + src1_p, + (float *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(float), + src1_o / sizeof(float), + main_stream, + std::forward(args)...); break; } default: @@ -771,32 +458,42 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } - -inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for upscale +template +inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); #endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; + const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; + const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; + const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; switch (dst->type) { #if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } default: @@ -804,7 +501,9 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for pad +template +inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -813,6 +512,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32); #endif GGML_ASSERT(dst->src[0]->type == dst->type); + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); switch (dst->type) { @@ -820,14 +520,16 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } default: @@ -835,655 +537,321 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } } +} // namespace ggml_sycl_detail + + +// --- Backend Operation Functions (ggml_sycl_op_...) --- + +inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + sgn(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + abs_op(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + elu_op(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + silu(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu_quick(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu_erf(src, dst_ptr, k_elements, item_ct1); + }); + }); } - inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE), + sycl::range<1>(SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + tanh(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + relu(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + hardsigmoid(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + hardswish(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + exp(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + log(src, dst_ptr, k_elements, item_ct1); + }); + }); } -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + neg(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + step(src, dst_ptr, k_elements, item_ct1); + }); + }); } -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sigmoid(src, dst_ptr, k_elements, item_ct1); + }); + }); +} - GGML_ASSERT(dst->src[0]->type == dst->type); +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), + sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sqrt(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sin(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + cos(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + leaky_relu(src, dst_ptr, k_elements, slope, item_ct1); + }); + }, negative_slope); } inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - #if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), + sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sqr(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, + [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, + int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, + queue_ptr stream) { + ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); + }); } inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst, + [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, + queue_ptr stream) { + ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream); + }); } inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined(GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - float min; - float max; - memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - switch (dst->type) { -#if defined(GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + float min_val; + float max_val; + memcpy(&min_val, dst->op_params, sizeof(float)); + memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float)); + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) { + const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), + sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); + }); + }, min_val, max_val); } inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -1499,156 +867,43 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused int offset = dst->op_params[3] / 4; // offset in bytes - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); + ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - geglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - reglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - swiglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } + void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqrt(ctx, dst); @@ -1788,5 +1043,3 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_swiglu(ctx, dst); } - - diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index f530c9c1e1bdd..86068b10129ec 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -3,24 +3,24 @@ #include "common.hpp" #include "ggml.h" -#include +#include // For std::numeric_limits template T neg_infinity() { return -std::numeric_limits::infinity(); } -template +template struct typed_data { - const T * src; - T * dst; + const T_Src * src; + T_Dst * dst; }; -template -typed_data cast_data(ggml_tensor * dst) { +template +typed_data cast_data(ggml_tensor * dst) { return { - /* .src = */ static_cast(dst->src[0]->data), - /* .dst = */ static_cast(dst->data) + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) }; } @@ -82,4 +82,3 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ELEMENTWISE_HPP - diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ee59f3a59957e..4a347bc5efcd6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -659,6 +659,11 @@ struct vk_op_push_constants { float param2; }; +struct vk_op_glu_push_constants { + uint32_t ne00; + uint32_t mode; // 0: default, 1: swapped, 2: split +}; + struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_UNARY #define CREATE_GLU(name) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX) { + if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { // Empty src1 is possible in soft_max, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } -static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } - const uint32_t swapped = (uint32_t)dst->op_params[1]; + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun); } static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: - ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun); + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); break; default: return false; @@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } } else if (tensor->op == GGML_OP_GLU) { - tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp index e58ac59d9a860..f4268ed24f44c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -1,43 +1,13 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx]); - const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); - data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx + offset]); - const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); - data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx])); - } - } +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; } + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp new file mode 100644 index 0000000000000..0d65baef38944 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -0,0 +1,15 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +layout (push_constant) uniform parameter +{ + uint ne00; + uint mode; +} p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp new file mode 100644 index 0000000000000..24814240365d2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -0,0 +1,31 @@ +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } + } else { + // Split + for (uint i = col; i < p.ne00; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp index 034481a1f17ef..0073d8f766610 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -1,36 +1,9 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx])); - } - } +float op(float a, float b) { + return max(a, 0.0f) * b; } + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp index e75c1d38aa1ea..a28e7c6cc8660 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -1,38 +1,9 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx]); - data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx + offset]); - data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx])); - } - } +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; } + +#include "glu_main.comp" diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2ae4e511b543b..8972af5d5b9bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2644,37 +2644,68 @@ struct ggml_tensor * ggml_exp_inplace( // ggml_glu -struct ggml_tensor * ggml_glu( +static struct ggml_tensor * ggml_glu_impl( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * b, enum ggml_glu_op op, bool swapped) { GGML_ASSERT(ggml_is_contiguous_1(a)); + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) op); ggml_set_op_params_i32(result, 1, (int32_t) swapped); result->op = GGML_OP_GLU; result->src[0] = a; + result->src[1] = b; return result; } +struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped) { + return ggml_glu_impl(ctx, a, NULL, op, swapped); +} + +struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op) { + return ggml_glu_impl(ctx, a, b, op, false); +} + // ggml_reglu struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false); } struct ggml_tensor * ggml_reglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true); +} + +struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false); } // ggml_geglu @@ -2682,13 +2713,20 @@ struct ggml_tensor * ggml_reglu_swapped( struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false); } struct ggml_tensor * ggml_geglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true); +} + +struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false); } // ggml_swiglu @@ -2696,13 +2734,20 @@ struct ggml_tensor * ggml_geglu_swapped( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false); } struct ggml_tensor * ggml_swiglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true); +} + +struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); } // ggml_norm diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 75420f277d92c..25d08296075a8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -604,12 +604,20 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case LLM_FFN_GELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); if (act_scales != NULL) { @@ -618,7 +626,11 @@ ggml_tensor * llm_graph_context::build_ffn( } } break; case LLM_FFN_RELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_reglu_split(ctx0, cur, tmp); + cb(cur, "ffn_reglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_relu(ctx0, cur); cb(cur, "ffn_relu", il); } break; @@ -774,12 +786,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate_exps) { + cur = ggml_swiglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - { + if (gate_exps) { + cur = ggml_geglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; @@ -787,11 +805,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - if (gate_exps) { - cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate_par", il); - } - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0278df1a98d66..757924ac01d70 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1151,6 +1151,60 @@ struct test_glu : public test_case { } }; +struct test_glu_split : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, v) + ",split"; + } + + test_glu_split(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + ggml_tensor * b; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + + b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0); + ggml_set_name(a, "view_of_b"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + + b = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(b, "b"); + } + + ggml_tensor * out = ggml_glu_split(ctx, a, b, op); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -3986,6 +4040,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped)); test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped)); } + + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); } } }